Itoyori  v0.0.1
adws.hpp
Go to the documentation of this file.
1 #pragma once
2 
3 #include "ityr/common/util.hpp"
7 #include "ityr/common/logger.hpp"
10 #include "ityr/ito/util.hpp"
11 #include "ityr/ito/options.hpp"
12 #include "ityr/ito/context.hpp"
13 #include "ityr/ito/callstack.hpp"
14 #include "ityr/ito/wsqueue.hpp"
15 #include "ityr/ito/prof_events.hpp"
16 #include "ityr/ito/sched/util.hpp"
17 
18 namespace ityr::ito {
19 
20 class flipper {
21 public:
22  using value_type = uint64_t;
23 
24  value_type value() const { return val_; }
25 
26  void flip(int at) {
27  ITYR_CHECK(0 <= at);
28  ITYR_CHECK(at < sizeof(value_type) * 8);
29 
30  val_ ^= (value_type(1) << at);
31  }
32 
33  bool match(flipper f, int until) const {
34  ITYR_CHECK(0 <= until);
35  ITYR_CHECK(until < sizeof(value_type) * 8);
36 
37  value_type mask = (value_type(1) << (until + 1)) - 1;
38  return (val_ & mask) == (f.value() & mask);
39  }
40 
41 private:
42  value_type val_ = 0;
43 };
44 
45 class dist_range {
46 public:
47  using value_type = double;
48 
51  : begin_(0), end_(static_cast<value_type>(n_ranks)) {}
53  : begin_(begin), end_(end) {}
54 
55  value_type begin() const { return begin_; }
56  value_type end() const { return end_; }
57 
59  return static_cast<common::topology::rank_t>(begin_);
60  }
61 
63  return static_cast<common::topology::rank_t>(end_);
64  }
65 
66  bool is_at_end_boundary() const {
67  return static_cast<value_type>(static_cast<common::topology::rank_t>(end_)) == end_;
68  }
69 
71  end_ = static_cast<value_type>(static_cast<common::topology::rank_t>(end_));
72  }
73 
74  template <typename T>
75  std::pair<dist_range, dist_range> divide(T r1, T r2) const {
76  value_type at = begin_ + (end_ - begin_) * r1 / (r1 + r2);
77 
78  // Boundary condition for tasks at the very bottom of the task hierarchy.
79  // A task with range [P, P) such that P = #workers would be assigned to worker P,
80  // but worker P does not exist; thus we need to assign the task to worker P-1.
81  if (at == end_) {
82  constexpr value_type eps = 0.00001;
83  at -= eps;
84  if (at < begin_) at = begin_;
85  }
86 
87  return std::make_pair(dist_range{begin_, at}, dist_range{at, end_});
88  }
89 
91  return static_cast<common::topology::rank_t>(begin_);
92  }
93 
94  bool is_cross_worker() const {
95  return static_cast<common::topology::rank_t>(begin_) != static_cast<common::topology::rank_t>(end_);
96  }
97 
99  end_ = begin_;
100  }
101 
102  bool is_sufficiently_small() const {
103  return (end_ - begin_) < adws_min_drange_size_option::value();
104  }
105 
106 private:
107  value_type begin_;
108  value_type end_;
109 };
110 
111 class dist_tree {
112  using version_t = int;
113 
114 public:
115  struct node_ref {
117  int depth = -1;
118  };
119 
120  struct node {
121  node() {}
122 
123  int depth() const { return parent.depth + 1; }
124 
128  version_t version = 0;
129  };
130 
131  dist_tree(int max_depth)
132  : max_depth_(max_depth),
133  node_win_(common::topology::mpicomm(), max_depth_),
134  dominant_flag_win_(common::topology::mpicomm(), max_depth_, 0),
135  versions_(max_depth_, common::topology::my_rank() + 1) {}
136 
137  node_ref append(node_ref parent, dist_range drange, flipper tg_version) {
138  int depth = parent.depth + 1;
139 
140  // handle version overflow
142  if (versions_[depth] >= std::numeric_limits<version_t>::max() - n_ranks) {
143  versions_[depth] = common::topology::my_rank() + 1;
144  }
145 
146  node& new_node = local_node(depth);
147  new_node.parent = parent;
148  new_node.drange = drange;
149  new_node.tg_version = tg_version;
150  new_node.version = (versions_[depth] += n_ranks);
151 
152  return {common::topology::my_rank(), depth};
153  }
154 
155  void set_dominant(node_ref nr, bool dominant) {
156  // Store the version as the flag if dominant
157  // To disable steals from this dist range, set -version as the special dominant flag value
158  version_t value = (dominant ? 1 : -1) * local_node(nr.depth).version;
159 
160  local_dominant_flag(nr.depth).store(value, std::memory_order_relaxed);
161 
163  std::size_t disp_dominant = nr.depth * sizeof(version_t);
164  common::mpi_atomic_put_value(value, nr.owner_rank, disp_dominant, dominant_flag_win_.win());
165  }
166  }
167 
168  // The meaning of a dominant flag value:
169  // 0 : undetermined
170  // version : the node with this "version" is dominant
171  // -version : the node with this "version" is removed and non-dominant
172  std::optional<node> get_topmost_dominant(node_ref nr) {
173  if (nr.depth < 0) return std::nullopt;
174 
176 
177  for (int d = 0; d <= nr.depth; d++) {
178  auto owner_rank = (d == nr.depth) ? nr.owner_rank
179  : local_node(d + 1).parent.owner_rank;
180 
181  node& n = local_node(d);
182  auto& dominant_flag = local_dominant_flag(d);
183 
184  ITYR_CHECK(n.parent.depth == d - 1);
185  ITYR_CHECK(n.version != 0);
186 
187  if (owner_rank != common::topology::my_rank() &&
188  dominant_flag.load(std::memory_order_relaxed) != -n.version) {
189  // To avoid network contention on the owner rank, we randomly choose a worker within the
190  // distribution range to query the dominant flag (decentralized dominant node propagation)
191  ITYR_CHECK(owner_rank == n.drange.begin_rank());
192  auto target_rank = get_random_rank(owner_rank, n.drange.end_rank() - 1);
193 
194  if (target_rank != owner_rank &&
195  dominant_flag.load(std::memory_order_relaxed) == n.version) {
196  // If the remote value is 0, propagate the dominant flag to remote
197  std::size_t disp_dominant = d * sizeof(version_t);
198  version_t dominant_val = common::mpi_atomic_cas_value(n.version, 0,
199  target_rank, disp_dominant, dominant_flag_win_.win());
200 
201  if (dominant_val == -n.version) {
202  dominant_flag.store(dominant_val, std::memory_order_relaxed);
203  }
204  } else {
205  // Read the remote dominant flag
206  std::size_t disp_dominant = d * sizeof(version_t);
207  version_t dominant_val = common::mpi_atomic_get_value<version_t>(
208  target_rank, disp_dominant, dominant_flag_win_.win());
209 
210  if (dominant_val == n.version || dominant_val == -n.version) {
211  dominant_flag.store(dominant_val, std::memory_order_relaxed);
212  }
213  }
214  }
215 
216  if (dominant_flag.load(std::memory_order_relaxed) == n.version) {
217  // return the topmost dominant node
218  return n;
219  }
220  }
221 
222  return std::nullopt;
223  }
224 
226  for (int d = 0; d <= nr.depth; d++) {
227  // non-owners write 0 as a non-dominant flag
228  local_dominant_flag(d).store(0, std::memory_order_relaxed);
229  }
230  common::mpi_get(&local_node(0), nr.depth + 1, nr.owner_rank, 0, node_win_.win());
231  }
232 
235  return local_node(nr.depth);
236  }
237 
238 private:
239  node& local_node(int depth) {
240  ITYR_CHECK(0 <= depth);
241  ITYR_CHECK(depth < max_depth_);
242  return node_win_.local_buf()[depth];
243  }
244 
245  std::atomic<version_t>& local_dominant_flag(int depth) {
246  ITYR_CHECK(0 <= depth);
247  ITYR_CHECK(depth < max_depth_);
248  return dominant_flag_win_.local_buf()[depth];
249  }
250 
251  int max_depth_;
252  common::mpi_win_manager<node> node_win_;
253  common::mpi_win_manager<std::atomic<version_t>> dominant_flag_win_;
254  std::vector<version_t> versions_;
255 };
256 
258 public:
261  void* frame_base;
262  std::size_t frame_size;
263  };
264 
265  template <typename T>
266  struct thread_retval {
267  T value;
269  };
270 
271  template <typename T>
272  struct thread_state {
274  int resume_flag = 0;
276  };
277 
278  template <typename T>
279  struct thread_handler {
280  thread_state<T>* state = nullptr;
281  bool serialized = false;
282  thread_retval<T> retval_ser; // return the result by value if the thread is serialized
283  };
284 
291  };
292 
295  dist_range drange; // distribution range of this thread
296  dist_tree::node_ref dtree_node_ref; // distribution tree node of the cross-worker task group that this thread belongs to
300  };
301 
303  : max_depth_(adws_max_depth_option::value()),
304  stack_(stack_size_option::value()),
305  // Add a margin of sizeof(context_frame) to the bottom of the stack, because
306  // this region can be accessed by the clear_parent_frame() function later.
307  // This stack base is updated only in coll_exec().
308  stack_base_(reinterpret_cast<context_frame*>(stack_.bottom()) - 1),
309  primary_wsq_(adws_wsqueue_capacity_option::value(), max_depth_),
310  migration_wsq_(adws_wsqueue_capacity_option::value(), max_depth_),
311  thread_state_allocator_(thread_state_allocator_size_option::value()),
312  suspended_thread_allocator_(suspended_thread_allocator_size_option::value()),
313  dtree_(max_depth_) {}
314 
315  template <typename T, typename SchedLoopCallback, typename Fn, typename... Args>
316  T root_exec(SchedLoopCallback cb, Fn&& fn, Args&&... args) {
317  common::profiler::switch_phase<prof_phase_spmd, prof_phase_sched_fork>();
318 
319  thread_state<T>* ts = new (thread_state_allocator_.allocate(sizeof(thread_state<T>))) thread_state<T>;
320 
321  auto prev_sched_cf = sched_cf_;
322 
323  suspend([&](context_frame* cf) {
324  sched_cf_ = cf;
325  root_on_stack([&, ts, fn = std::forward<Fn>(fn),
326  args_tuple = std::make_tuple(std::forward<Args>(args)...)]() mutable {
327  migrate_to(0, nullptr, nullptr);
328 
329  common::verbose("Starting root thread %p", ts);
330 
331  dist_range root_drange {common::topology::n_ranks()};
332  tls_ = new (alloca(sizeof(thread_local_storage)))
333  thread_local_storage{nullptr, root_drange, {}, {}, true, {}};
334 
335  tls_->dag_prof.start();
336  tls_->dag_prof.increment_thread_count();
337  tls_->dag_prof.increment_strand_count();
338 
339  common::profiler::switch_phase<prof_phase_sched_fork, prof_phase_thread>();
340 
341  T&& ret = invoke_fn<T>(std::forward<decltype(fn)>(fn), std::forward<decltype(args_tuple)>(args_tuple));
342 
343  common::profiler::switch_phase<prof_phase_thread, prof_phase_sched_die>();
344  common::verbose("Root thread %p is completed", ts);
345 
346  tls_->dag_prof.stop();
347 
348  on_root_die(ts, std::move(ret));
349  });
350  });
351 
352  sched_loop(cb);
353 
354  common::profiler::switch_phase<prof_phase_sched_loop, prof_phase_sched_join>();
355 
356  thread_retval<T> retval = std::move(ts->retval);
357  std::destroy_at(ts);
358  thread_state_allocator_.deallocate(ts, sizeof(thread_state<T>));
359 
360  if (dag_prof_enabled_) {
361  if (tls_) {
362  // nested root/coll_exec()
363  tls_->dag_prof.merge_serial(retval.dag_prof);
364  } else {
365  dag_prof_result_.merge_serial(retval.dag_prof);
366  }
367  }
368 
369  sched_cf_ = prev_sched_cf;
370 
371  common::profiler::switch_phase<prof_phase_sched_join, prof_phase_spmd>();
372 
373  return retval.value;
374  }
375 
377  tls_->dag_prof.stop();
378 
379  tgdata->parent = tls_->tgdata;
380  tgdata->drange = tls_->drange;
381  tgdata->owns_dtree_node = false;
382  tgdata->dag_prof_before = tls_->dag_prof;
383 
384  tls_->tgdata = tgdata;
385 
386  if (tls_->drange.is_cross_worker()) {
387  if (tls_->dtree_node_ref.depth + 1 < max_depth_) {
388  tls_->dtree_node_ref = dtree_.append(tls_->dtree_node_ref, tls_->drange, tls_->tg_version);
389  dtree_local_bottom_ref_ = tls_->dtree_node_ref;
390  tgdata->owns_dtree_node = true;
391  }
392 
393  tls_->undistributed = true;
394 
395  common::verbose("Begin a cross-worker task group of distribution range [%f, %f) at depth %d",
396  tls_->drange.begin(), tls_->drange.end(), tls_->dtree_node_ref.depth);
397  }
398 
399  tls_->dag_prof.clear();
400  tls_->dag_prof.start();
401  tls_->dag_prof.increment_strand_count();
402  }
403 
404  template <typename PreSuspendCallback, typename PostSuspendCallback>
405  void task_group_end(PreSuspendCallback&& pre_suspend_cb,
406  PostSuspendCallback&& post_suspend_cb) {
407  on_task_die();
408 
409  task_group_data* tgdata = tls_->tgdata;
410  ITYR_CHECK(tgdata);
411 
412  tls_->dag_prof = tgdata->dag_prof_before;
413  tls_->dag_prof.merge_serial(tgdata->dag_prof_acc);
414 
415  // restore the original distribution range of this thread at the beginning of the task group
416  tls_->drange = tgdata->drange;
417 
418  if (tls_->drange.is_cross_worker()) {
419  common::verbose("End a cross-worker task group of distribution range [%f, %f) at depth %d",
420  tls_->drange.begin(), tls_->drange.end(), tls_->dtree_node_ref.depth);
421 
422  // migrate the cross-worker-task to the owner
423  migrate_to(tls_->drange.owner(),
424  std::forward<PreSuspendCallback>(pre_suspend_cb),
425  std::forward<PostSuspendCallback>(post_suspend_cb));
426 
427  if (tgdata->owns_dtree_node) {
428  // Set the completed current task group as non-dominant to reduce steal requests
429  dtree_.set_dominant(tls_->dtree_node_ref, false);
430 
431  // Set the parent dist_tree node to the current thread
432  auto& dtree_node = dtree_.get_local_node(tls_->dtree_node_ref);
433  tls_->dtree_node_ref = dtree_node.parent;
434  dtree_local_bottom_ref_ = tls_->dtree_node_ref;
435 
436  // Flip the next version of the task group at this depth
437  tls_->tg_version.flip(dtree_node.depth());
438  }
439 
440  tls_->undistributed = false;
441  }
442 
443  tls_->tgdata = tls_->tgdata->parent;
444  std::destroy_at(tgdata);
445 
446  tls_->dag_prof.start();
447  tls_->dag_prof.increment_strand_count();
448  }
449 
450  template <typename T, typename OnDriftForkCallback, typename OnDriftDieCallback,
451  typename WorkHint, typename Fn, typename... Args>
453  OnDriftForkCallback on_drift_fork_cb, OnDriftDieCallback on_drift_die_cb,
454  WorkHint w_new, WorkHint w_rest, Fn&& fn, Args&&... args) {
455  common::profiler::switch_phase<prof_phase_thread, prof_phase_sched_fork>();
456 
458 
459  thread_state<T>* ts = new (thread_state_allocator_.allocate(sizeof(thread_state<T>))) thread_state<T>;
460  th.state = ts;
461  th.serialized = false;
462 
463  dist_range new_drange;
464  common::topology::rank_t target_rank;
465  if (tls_->drange.is_cross_worker()) {
466  // Avoid too fine-grained task migration
467  if (tls_->drange.is_sufficiently_small()) {
469  }
470 
471  auto [dr_rest, dr_new] = tls_->drange.divide(w_rest, w_new);
472 
473  common::verbose("Distribution range [%f, %f) is divided into [%f, %f) and [%f, %f)",
474  tls_->drange.begin(), tls_->drange.end(),
475  dr_rest.begin(), dr_rest.end(), dr_new.begin(), dr_new.end());
476 
477  tls_->drange = dr_rest;
478  new_drange = dr_new;
479  target_rank = dr_new.owner();
480 
481  } else {
482  // quick path for non-cross-worker tasks (without dividing the distribution range)
483  new_drange = tls_->drange;
484  // Since this task may have been stolen by workers outside of this task group,
485  // the target rank should be itself.
486  target_rank = my_rank;
487  }
488 
489  if (target_rank == my_rank) {
490  /* Put the continuation into the local queue and execute the new task (work-first) */
491 
492  suspend([&, ts, fn = std::forward<Fn>(fn),
493  args_tuple = std::make_tuple(std::forward<Args>(args)...)](context_frame* cf) mutable {
494  common::verbose<3>("push context frame [%p, %p) into task queue", cf, cf->parent_frame);
495 
496  tls_ = new (alloca(sizeof(thread_local_storage)))
497  thread_local_storage{nullptr, new_drange, tls_->dtree_node_ref,
498  tls_->tg_version, true, {}};
499 
500  std::size_t cf_size = reinterpret_cast<uintptr_t>(cf->parent_frame) - reinterpret_cast<uintptr_t>(cf);
501 
502  if (use_primary_wsq_) {
503  primary_wsq_.push({nullptr, cf, cf_size, tls_->tg_version},
504  tls_->dtree_node_ref.depth);
505  } else {
506  migration_wsq_.push({true, nullptr, cf, cf_size, tls_->tg_version},
507  tls_->dtree_node_ref.depth);
508  }
509 
510  tls_->dag_prof.start();
511  tls_->dag_prof.increment_thread_count();
512  tls_->dag_prof.increment_strand_count();
513 
514  common::verbose<3>("Starting new thread %p", ts);
515  common::profiler::switch_phase<prof_phase_sched_fork, prof_phase_thread>();
516 
517  T&& ret = invoke_fn<T>(std::forward<decltype(fn)>(fn), std::forward<decltype(args_tuple)>(args_tuple));
518 
519  common::profiler::switch_phase<prof_phase_thread, prof_phase_sched_die>();
520  common::verbose<3>("Thread %p is completed", ts);
521 
522  on_task_die();
523  on_die_workfirst(ts, std::move(ret), on_drift_die_cb);
524 
525  common::verbose<3>("Thread %p is serialized (fast path)", ts);
526 
527  // The following is executed only when the thread is serialized
528  std::destroy_at(ts);
529  thread_state_allocator_.deallocate(ts, sizeof(thread_state<T>));
530  th.state = nullptr;
531  th.serialized = true;
532  th.retval_ser = {std::move(ret), tls_->dag_prof};
533 
534  common::verbose<3>("Resume parent context frame [%p, %p) (fast path)", cf, cf->parent_frame);
535 
536  common::profiler::switch_phase<prof_phase_sched_die, prof_phase_sched_resume_popped>();
537  });
538 
539  // reload my_rank because this thread might have been migrated
540  if (target_rank == common::topology::my_rank()) {
541  common::profiler::switch_phase<prof_phase_sched_resume_popped, prof_phase_thread>();
542  } else {
545  prof_phase_thread>(on_drift_fork_cb);
546  }
547 
548  } else {
549  /* Pass the new task to another worker and execute the continuation */
550 
551  auto new_task_fn = [&, my_rank, ts, new_drange,
552  dtree_node_ref = tls_->dtree_node_ref,
553  tg_version = tls_->tg_version,
554  on_drift_fork_cb, on_drift_die_cb, fn = std::forward<Fn>(fn),
555  args_tuple = std::make_tuple(std::forward<Args>(args)...)]() mutable {
556  common::verbose("Starting a migrated thread %p [%f, %f)",
557  ts, new_drange.begin(), new_drange.end());
558 
559  tls_ = new (alloca(sizeof(thread_local_storage)))
560  thread_local_storage{nullptr, new_drange, dtree_node_ref,
561  tg_version, true, {}};
562 
563  if (new_drange.is_cross_worker()) {
564  dtree_.copy_parents(dtree_node_ref);
565  dtree_local_bottom_ref_ = dtree_node_ref;
566  }
567 
568  tls_->dag_prof.start();
569  tls_->dag_prof.increment_thread_count();
570  tls_->dag_prof.increment_strand_count();
571 
572  // If the new task is executed on another process
576  prof_phase_thread>(on_drift_fork_cb);
577  } else {
578  common::profiler::switch_phase<prof_phase_sched_start_new, prof_phase_thread>();
579  }
580 
581  T&& ret = invoke_fn<T>(std::forward<decltype(fn)>(fn), std::forward<decltype(args_tuple)>(args_tuple));
582 
583  common::profiler::switch_phase<prof_phase_thread, prof_phase_sched_die>();
584  common::verbose("A migrated thread %p [%f, %f) is completed",
585  ts, new_drange.begin(), new_drange.end());
586 
587  on_task_die();
588  on_die_drifted(ts, std::move(ret), on_drift_die_cb);
589  };
590 
591  using callable_task_t = callable_task<decltype(new_task_fn)>;
592 
593  size_t task_size = sizeof(callable_task_t);
594  void* task_ptr = suspended_thread_allocator_.allocate(task_size);
595 
596  auto t = new (task_ptr) callable_task_t(std::move(new_task_fn));
597 
598  if (new_drange.is_cross_worker()) {
599  common::verbose("Migrate cross-worker-task %p [%f, %f) to process %d",
600  ts, new_drange.begin(), new_drange.end(), target_rank);
601 
602  migration_mailbox_.put({nullptr, t, task_size}, target_rank);
603  } else {
604  common::verbose("Migrate non-cross-worker-task %p [%f, %f) to process %d",
605  ts, new_drange.begin(), new_drange.end(), target_rank);
606 
607  migration_wsq_.pass({false, nullptr, t, task_size, tls_->tg_version},
608  target_rank, tls_->dtree_node_ref.depth);
609  }
610 
611  common::profiler::switch_phase<prof_phase_sched_fork, prof_phase_thread>();
612  }
613 
614  // restart to count only the last task in the task group
615  tls_->dag_prof.clear();
616  tls_->dag_prof.start();
617  tls_->dag_prof.increment_strand_count();
618  }
619 
620  template <typename T>
622  common::profiler::switch_phase<prof_phase_thread, prof_phase_sched_join>();
623 
624  thread_retval<T> retval;
625  if (th.serialized) {
626  common::verbose<3>("Skip join for serialized thread (fast path)");
627  // We can skip deallocaton for its thread state because it has been already deallocated
628  // when the thread is serialized (i.e., at a fork)
629  retval = std::move(th.retval_ser);
630 
631  } else {
632  // Note that this point is also considered the end of the last task of a task group
633  // (the last task of a task group may not be spawned as a thread)
634  on_task_die();
635 
636  ITYR_CHECK(th.state != nullptr);
637  thread_state<T>* ts = th.state;
638 
639  if (remote_get_value(thread_state_allocator_, &ts->resume_flag) >= 1) {
640  common::verbose("Thread %p is already joined", ts);
641  if constexpr (!std::is_same_v<T, no_retval_t> || dag_profiler::enabled) {
642  retval = get_retval_remote(ts);
643  }
644 
645  } else {
646  bool migrated = true;
647  suspend([&](context_frame* cf) {
648  suspended_state ss = evacuate(cf);
649 
650  remote_put_value(thread_state_allocator_, ss, &ts->suspended);
651 
652  // race
653  if (remote_faa_value(thread_state_allocator_, 1, &ts->resume_flag) == 0) {
654  common::verbose("Win the join race for thread %p (joining thread)", ts);
655  evacuate_all();
656  common::profiler::switch_phase<prof_phase_sched_join, prof_phase_sched_loop>();
657  resume_sched();
658 
659  } else {
660  common::verbose("Lose the join race for thread %p (joining thread)", ts);
661  suspended_thread_allocator_.deallocate(ss.evacuation_ptr, ss.frame_size);
662  migrated = false;
663  }
664  });
665 
666  common::verbose("Resume continuation of join for thread %p", ts);
667 
668  if (migrated) {
669  common::profiler::switch_phase<prof_phase_sched_resume_join, prof_phase_sched_join>();
670  }
671 
672  if constexpr (!std::is_same_v<T, no_retval_t> || dag_profiler::enabled) {
673  retval = get_retval_remote(ts);
674  }
675  }
676 
677  // TODO: correctly destroy T remotely if nontrivially destructible
678  /* std::destroy_at(ts); */
679 
680  thread_state_allocator_.deallocate(ts, sizeof(thread_state<T>));
681  th.state = nullptr;
682  }
683 
684  ITYR_CHECK(tls_->tgdata);
685  tls_->tgdata->dag_prof_acc.merge_parallel(retval.dag_prof);
686 
687  common::profiler::switch_phase<prof_phase_sched_join, prof_phase_thread>();
688  return std::move(retval.value);
689  }
690 
691  template <typename SchedLoopCallback>
692  void sched_loop(SchedLoopCallback cb) {
693  common::verbose("Enter scheduling loop");
694 
695  while (!should_exit_sched_loop()) {
696  auto mte = migration_mailbox_.pop();
697  if (mte.has_value()) {
698  execute_migrated_task(*mte);
699  continue;
700  }
701 
702  auto pwe = pop_from_primary_queues(primary_wsq_.n_queues() - 1);
703  if (pwe.has_value()) {
704  common::profiler::switch_phase<prof_phase_sched_loop, prof_phase_sched_resume_popped>();
705 
706  // No on-stack thread can exist while the scheduler thread is running
707  ITYR_CHECK(pwe->evacuation_ptr);
708  suspend([&](context_frame* cf) {
709  sched_cf_ = cf;
710  resume(suspended_state{pwe->evacuation_ptr, pwe->frame_base, pwe->frame_size});
711  });
712  continue;
713  }
714 
715  auto mwe = pop_from_migration_queues();
716  if (mwe.has_value()) {
717  use_primary_wsq_ = false;
718  execute_migrated_task(*mwe);
719  use_primary_wsq_ = true;
720  continue;
721  }
722 
724  steal();
725  }
726 
727  if constexpr (!std::is_null_pointer_v<std::remove_reference_t<SchedLoopCallback>>) {
728  cb();
729  }
730  }
731 
732  dtree_local_bottom_ref_ = {};
733 
734  common::verbose("Exit scheduling loop");
735  }
736 
737  template <typename PreSuspendCallback, typename PostSuspendCallback>
738  void poll(PreSuspendCallback&& pre_suspend_cb,
739  PostSuspendCallback&& post_suspend_cb) {
740  check_cross_worker_task_arrival<prof_phase_thread, prof_phase_thread>(
741  std::forward<PreSuspendCallback>(pre_suspend_cb),
742  std::forward<PostSuspendCallback>(post_suspend_cb));
743  }
744 
745  template <typename PreSuspendCallback, typename PostSuspendCallback>
747  PreSuspendCallback&& pre_suspend_cb,
748  PostSuspendCallback&& post_suspend_cb) {
749  if (target_rank == common::topology::my_rank()) return;
750 
754  std::forward<PreSuspendCallback>(pre_suspend_cb));
755 
756  suspend([&](context_frame* cf) {
757  suspended_state ss = evacuate(cf);
758 
759  common::verbose("Migrate continuation of cross-worker-task [%f, %f) to process %d",
760  tls_->drange.begin(), tls_->drange.end(), target_rank);
761 
762  migration_mailbox_.put(ss, target_rank);
763 
764  evacuate_all();
765  common::profiler::switch_phase<prof_phase_sched_migrate, prof_phase_sched_loop>();
766  resume_sched();
767  });
768 
772  std::forward<PostSuspendCallback>(post_suspend_cb), cb_ret);
773  }
774 
775  template <typename Fn>
776  void coll_exec(const Fn& fn) {
777  common::profiler::switch_phase<prof_phase_thread, prof_phase_spmd>();
778 
779  tls_->dag_prof.stop();
780  // TODO: consider dag prof for inside coll tasks
781 
782  using callable_task_t = callable_task<Fn>;
783 
784  size_t task_size = sizeof(callable_task_t);
785  void* task_ptr = suspended_thread_allocator_.allocate(task_size);
786 
787  auto t = new (task_ptr) callable_task_t(fn);
788 
789  coll_task ct {task_ptr, task_size, common::topology::my_rank()};
790  execute_coll_task(t, ct);
791 
792  suspended_thread_allocator_.deallocate(t, task_size);
793 
794  tls_->dag_prof.start();
795  tls_->dag_prof.increment_strand_count();
796 
797  common::profiler::switch_phase<prof_phase_spmd, prof_phase_thread>();
798  }
799 
800  bool is_executing_root() const {
801  return cf_top_ && cf_top_ == stack_base_;
802  }
803 
804  template <typename T>
805  static bool is_serialized(const thread_handler<T>& th) {
806  return th.serialized;
807  }
808 
809  void dag_prof_begin() {
810  dag_prof_enabled_ = true;
811  dag_prof_result_.clear();
812  if (tls_) {
813  // nested root/coll_exec()
814  tls_->dag_prof.clear();
815  tls_->dag_prof.increment_thread_count();
816  }
817  }
818 
819  void dag_prof_end() {
820  dag_prof_enabled_ = false;
821  if constexpr (dag_profiler::enabled) {
822  common::topology::rank_t result_owner = 0;
823  if (tls_) {
824  // nested root/coll_exec()
825  dag_prof_result_ = tls_->dag_prof;
826  result_owner = common::topology::my_rank();
827  }
828  result_owner = common::mpi_allreduce_value(result_owner, common::topology::mpicomm(), MPI_MAX);
829  dag_prof_result_ = common::mpi_bcast_value(dag_prof_result_, result_owner, common::topology::mpicomm());
830  }
831  }
832 
833  void dag_prof_print() const {
834  if (common::topology::my_rank() == 0) {
835  dag_prof_result_.print();
836  }
837  }
838 
839 private:
840  struct coll_task {
841  void* task_ptr;
842  std::size_t task_size;
843  common::topology::rank_t master_rank;
844  };
845 
846  struct primary_wsq_entry {
847  void* evacuation_ptr;
848  void* frame_base;
849  std::size_t frame_size;
850  flipper tg_version;
851  };
852 
853  struct migration_wsq_entry {
854  bool is_continuation;
855  void* evacuation_ptr;
856  void* frame_base;
857  std::size_t frame_size;
858  flipper tg_version;
859  };
860 
861  void on_task_die() {
862  if (!tls_->dag_prof.is_stopped()) {
863  tls_->dag_prof.stop();
864  if (tls_->tgdata) {
865  tls_->tgdata->dag_prof_acc.merge_parallel(tls_->dag_prof);
866  }
867  }
868 
869  // TODO: handle corner cases where cross-worker tasks finish without distributing
870  // child cross-worker tasks to their owners
871  if (tls_->drange.is_cross_worker()) {
872  // Set the parent cross-worker task group as "dominant" task group, which allows for
873  // work stealing within the range of workers within the task group.
874  common::verbose("Distribution tree node (owner=%d, depth=%d) becomes dominant",
876 
877  dtree_.set_dominant(tls_->dtree_node_ref, true);
878 
879  if (tls_->undistributed &&
880  tls_->drange.begin_rank() + 1 < tls_->drange.end_rank()) {
881  std::vector<std::pair<suspended_state, common::topology::rank_t>> tasks;
882 
883  // If a cross-worker task with range [i.xxx, j.xxx) is completed without distributing
884  // child cross-worker tasks to workers i+1, i+2, ..., j-1, it should pass the dist node
885  // tree reference to them, so that they can perform work stealing.
886  for (common::topology::rank_t target_rank = tls_->drange.begin_rank() + 1;
887  target_rank < tls_->drange.end_rank();
888  target_rank++) {
889  // Create a dummy task to set the parent dtree nodes
890  // TODO: we can reduce communication as only dtree_node_ref needs to be passed
891  auto new_task_fn = [&, dtree_node_ref = tls_->dtree_node_ref]() {
892  dtree_.copy_parents(dtree_node_ref);
893  dtree_local_bottom_ref_ = dtree_node_ref;
894 
895  common::profiler::switch_phase<prof_phase_sched_start_new, prof_phase_sched_loop>();
896  resume_sched();
897  };
898 
899  using callable_task_t = callable_task<decltype(new_task_fn)>;
900 
901  size_t task_size = sizeof(callable_task_t);
902  void* task_ptr = suspended_thread_allocator_.allocate(task_size);
903 
904  auto t = new (task_ptr) callable_task_t(new_task_fn);
905  tasks.push_back({{nullptr, t, task_size}, target_rank});
906  }
907 
908  // allocate memory then put
909  for (auto [t, target_rank] : tasks) {
910  migration_mailbox_.put(t, target_rank);
911  }
912 
913  // Wait until all tasks are completed on remote workers
914  // TODO: barrier is a better solution to avoid network contention when many workers are involved
915  for (auto [t, target_rank] : tasks) {
916  while (!suspended_thread_allocator_.is_remotely_freed(t.frame_base));
917  }
918  }
919 
920  // Temporarily make this thread a non-cross-worker task, so that the thread does not enter
921  // this scope multiple times. When a task group has multiple child tasks, the entering thread
922  // makes multiple join calls, which causes this function to be called multiple times.
923  // Even if we discard the current dist range, the task group's dist range is anyway restored
924  // when the task group is completed after those join calls.
926  }
927  }
928 
929  template <typename T, typename OnDriftDieCallback>
930  void on_die_workfirst(thread_state<T>* ts, T&& ret, OnDriftDieCallback on_drift_die_cb) {
931  if (use_primary_wsq_) {
932  auto qe = primary_wsq_.pop(tls_->dtree_node_ref.depth);
933  if (qe.has_value()) {
934  if (!qe->evacuation_ptr) {
935  // parent is popped
936  ITYR_CHECK(qe->frame_base == cf_top_);
937  return;
938  } else {
939  // If it might not be its parent, return it to the queue.
940  // This is a conservative approach because the popped task can be its evacuated parent
941  // (if qe->frame_base == cf_top_), but it is not guaranteed because multiple threads
942  // can have the same base frame address due to the uni-address scheme.
943  primary_wsq_.push(*qe, tls_->dtree_node_ref.depth);
944  }
945  }
946  } else {
947  auto qe = migration_wsq_.pop(tls_->dtree_node_ref.depth);
948  if (qe.has_value()) {
949  if (qe->is_continuation && !qe->evacuation_ptr) {
950  ITYR_CHECK(qe->frame_base == cf_top_);
951  return;
952  } else {
953  migration_wsq_.push(*qe, tls_->dtree_node_ref.depth);
954  }
955  }
956  }
957 
958  on_die_drifted(ts, std::move(ret), on_drift_die_cb);
959  }
960 
961  template <typename T, typename OnDriftDieCallback>
962  void on_die_drifted(thread_state<T>* ts, T&& ret, OnDriftDieCallback on_drift_die_cb) {
963  if constexpr (!std::is_null_pointer_v<std::remove_reference_t<OnDriftDieCallback>>) {
964  call_with_prof_events<prof_phase_sched_die,
965  prof_phase_cb_drift_die,
966  prof_phase_sched_die>(on_drift_die_cb);
967  }
968 
969  if constexpr (!std::is_same_v<T, no_retval_t> || dag_profiler::enabled) {
970  put_retval_remote(ts, {std::move(ret), tls_->dag_prof});
971  }
972 
973  // race
974  if (remote_faa_value(thread_state_allocator_, 1, &ts->resume_flag) == 0) {
975  common::verbose("Win the join race for thread %p (joined thread)", ts);
976  // Ancestor threads can remain on the stack here because ADWS no longer follows the work-first policy.
977  // Threads that are in the middle of the call stack can be stolen because of the task depth management.
978  // Therefore, we conservatively evacuate them before switching to the scheduler here.
979  // Note that a fast path exists when the immediate parent thread is popped from the queue.
980  evacuate_all();
981  common::profiler::switch_phase<prof_phase_sched_die, prof_phase_sched_loop>();
982  resume_sched();
983 
984  } else {
985  common::verbose("Lose the join race for thread %p (joined thread)", ts);
986  common::profiler::switch_phase<prof_phase_sched_die, prof_phase_sched_resume_join>();
987  suspended_state ss = remote_get_value(thread_state_allocator_, &ts->suspended);
988  resume(ss);
989  }
990  }
991 
992  template <typename T>
993  void on_root_die(thread_state<T>* ts, T&& ret) {
994  if constexpr (!std::is_same_v<T, no_retval_t> || dag_profiler::enabled) {
995  put_retval_remote(ts, {std::move(ret), tls_->dag_prof});
996  }
997  remote_put_value(thread_state_allocator_, 1, &ts->resume_flag);
998 
999  exit_request_mailbox_.put(0);
1000 
1001  common::profiler::switch_phase<prof_phase_sched_die, prof_phase_sched_loop>();
1002  resume_sched();
1003  }
1004 
1005  void steal() {
1006  auto ne = dtree_.get_topmost_dominant(dtree_local_bottom_ref_);
1007  if (!ne.has_value()) {
1008  common::verbose<2>("Dominant dist_tree node not found");
1009  return;
1010  }
1011  dist_range steal_range = ne->drange;
1012  flipper tg_version = ne->tg_version;
1013  int depth = ne->depth();
1014 
1015  common::verbose<2>("Dominant dist_tree node found: drange=[%f, %f), depth=%d",
1016  steal_range.begin(), steal_range.end(), depth);
1017 
1019 
1020  auto begin_rank = steal_range.begin_rank();
1021  auto end_rank = steal_range.end_rank();
1022 
1023  if (steal_range.is_at_end_boundary()) {
1024  end_rank--;
1025  }
1026 
1027  if (begin_rank == end_rank) {
1028  return;
1029  }
1030 
1031  ITYR_CHECK((begin_rank <= my_rank || my_rank <= end_rank));
1032 
1033  common::verbose<2>("Start work stealing for dominant task group [%f, %f)",
1034  steal_range.begin(), steal_range.end());
1035 
1036  // reuse the dist tree information multiple times
1037  int max_reuse = std::max(1, adws_max_dtree_reuse_option::value());
1038  for (int i = 0; i < max_reuse; i++) {
1039  auto target_rank = get_random_rank(begin_rank, end_rank);
1040 
1041  common::verbose<2>("Target rank: %d", target_rank);
1042 
1043  if (target_rank != begin_rank) {
1044  bool success = steal_from_migration_queues(target_rank, depth, migration_wsq_.n_queues(),
1045  [=](migration_wsq_entry& mwe) { return mwe.tg_version.match(tg_version, depth); });
1046  if (success) {
1047  return;
1048  }
1049  }
1050 
1051  if (target_rank != end_rank || (target_rank == end_rank && steal_range.is_at_end_boundary())) {
1052  bool success = steal_from_primary_queues(target_rank, depth, primary_wsq_.n_queues(),
1053  [=](primary_wsq_entry& pwe) { return pwe.tg_version.match(tg_version, depth); });
1054  if (success) {
1055  return;
1056  }
1057  }
1058 
1059  // Periodic check for cross-worker task arrival
1060  auto mte = migration_mailbox_.pop();
1061  if (mte.has_value()) {
1062  execute_migrated_task(*mte);
1063  return;
1064  }
1065  }
1066  }
1067 
1068  template <typename StealCondFn>
1069  bool steal_from_primary_queues(common::topology::rank_t target_rank,
1070  int min_depth, int max_depth, StealCondFn&& steal_cond_fn) {
1071  bool steal_success = false;
1072 
1073  primary_wsq_.for_each_nonempty_queue(target_rank, min_depth, max_depth, false, [&](int d) {
1074  auto ibd = common::profiler::interval_begin<prof_event_sched_steal>(target_rank);
1075 
1076  if (!primary_wsq_.lock().trylock(target_rank, d)) {
1077  common::profiler::interval_end<prof_event_sched_steal>(ibd, false);
1078  return false;
1079  }
1080 
1081  auto pwe = primary_wsq_.steal_nolock(target_rank, d);
1082  if (!pwe.has_value()) {
1083  primary_wsq_.lock().unlock(target_rank, d);
1084  common::profiler::interval_end<prof_event_sched_steal>(ibd, false);
1085  return false;
1086  }
1087 
1088  if (!steal_cond_fn(*pwe)) {
1089  primary_wsq_.abort_steal(target_rank, d);
1090  primary_wsq_.lock().unlock(target_rank, d);
1091  common::profiler::interval_end<prof_event_sched_steal>(ibd, false);
1092  return false;
1093  }
1094 
1095  // TODO: commonize implementation for primary and migration queues
1096  if (pwe->evacuation_ptr) {
1097  // This task is an evacuated continuation
1098  common::verbose("Steal an evacuated context frame [%p, %p) from primary wsqueue (depth=%d) on rank %d",
1099  pwe->frame_base, reinterpret_cast<std::byte*>(pwe->frame_base) + pwe->frame_size,
1100  d, target_rank);
1101 
1102  primary_wsq_.lock().unlock(target_rank, d);
1103 
1104  common::profiler::interval_end<prof_event_sched_steal>(ibd, true);
1105 
1106  common::profiler::switch_phase<prof_phase_sched_loop, prof_phase_sched_resume_stolen>();
1107 
1108  suspend([&](context_frame* cf) {
1109  sched_cf_ = cf;
1110  resume(suspended_state{pwe->evacuation_ptr, pwe->frame_base, pwe->frame_size});
1111  });
1112 
1113  } else {
1114  // This task is a context frame on the stack
1115  common::verbose("Steal context frame [%p, %p) from primary wsqueue (depth=%d) on rank %d",
1116  pwe->frame_base, reinterpret_cast<std::byte*>(pwe->frame_base) + pwe->frame_size,
1117  d, target_rank);
1118 
1119  stack_.direct_copy_from(pwe->frame_base, pwe->frame_size, target_rank);
1120 
1121  primary_wsq_.lock().unlock(target_rank, d);
1122 
1123  common::profiler::interval_end<prof_event_sched_steal>(ibd, true);
1124 
1125  common::profiler::switch_phase<prof_phase_sched_loop, prof_phase_sched_resume_stolen>();
1126 
1127  context_frame* next_cf = reinterpret_cast<context_frame*>(pwe->frame_base);
1128  suspend([&](context_frame* cf) {
1129  sched_cf_ = cf;
1130  context::clear_parent_frame(next_cf);
1131  resume(next_cf);
1132  });
1133  }
1134 
1135  steal_success = true;
1136  return true;
1137  });
1138 
1139  if (!steal_success) {
1140  common::verbose<2>("Steal failed for primary queues on rank %d", target_rank);
1141  }
1142  return steal_success;
1143  }
1144 
1145  template <typename StealCondFn>
1146  bool steal_from_migration_queues(common::topology::rank_t target_rank,
1147  int min_depth, int max_depth, StealCondFn&& steal_cond_fn) {
1148  bool steal_success = false;
1149 
1150  migration_wsq_.for_each_nonempty_queue(target_rank, min_depth, max_depth, true, [&](int d) {
1151  auto ibd = common::profiler::interval_begin<prof_event_sched_steal>(target_rank);
1152 
1153  if (!migration_wsq_.lock().trylock(target_rank, d)) {
1154  common::profiler::interval_end<prof_event_sched_steal>(ibd, false);
1155  return false;
1156  }
1157 
1158  auto mwe = migration_wsq_.steal_nolock(target_rank, d);
1159  if (!mwe.has_value()) {
1160  migration_wsq_.lock().unlock(target_rank, d);
1161  common::profiler::interval_end<prof_event_sched_steal>(ibd, false);
1162  return false;
1163  }
1164 
1165  if (!steal_cond_fn(*mwe)) {
1166  migration_wsq_.abort_steal(target_rank, d);
1167  migration_wsq_.lock().unlock(target_rank, d);
1168  common::profiler::interval_end<prof_event_sched_steal>(ibd, false);
1169  return false;
1170  }
1171 
1172  if (!mwe->is_continuation) {
1173  // This task is a new task
1174  common::verbose("Steal a new task from migration wsqueue (depth=%d) on rank %d",
1175  d, target_rank);
1176 
1177  migration_wsq_.lock().unlock(target_rank, d);
1178 
1179  common::profiler::interval_end<prof_event_sched_steal>(ibd, true);
1180 
1181  common::profiler::switch_phase<prof_phase_sched_loop, prof_phase_sched_start_new>();
1182 
1183  suspend([&](context_frame* cf) {
1184  sched_cf_ = cf;
1185  start_new_task(mwe->frame_base, mwe->frame_size);
1186  });
1187 
1188  } else if (mwe->evacuation_ptr) {
1189  // This task is an evacuated continuation
1190  common::verbose("Steal an evacuated context frame [%p, %p) from migration wsqueue (depth=%d) on rank %d",
1191  mwe->frame_base, reinterpret_cast<std::byte*>(mwe->frame_base) + mwe->frame_size,
1192  d, target_rank);
1193 
1194  migration_wsq_.lock().unlock(target_rank, d);
1195 
1196  common::profiler::interval_end<prof_event_sched_steal>(ibd, true);
1197 
1198  common::profiler::switch_phase<prof_phase_sched_loop, prof_phase_sched_resume_stolen>();
1199 
1200  suspend([&](context_frame* cf) {
1201  sched_cf_ = cf;
1202  resume(suspended_state{mwe->evacuation_ptr, mwe->frame_base, mwe->frame_size});
1203  });
1204 
1205  } else {
1206  // This task is a continuation on the stack
1207  common::verbose("Steal a context frame [%p, %p) from migration wsqueue (depth=%d) on rank %d",
1208  mwe->frame_base, reinterpret_cast<std::byte*>(mwe->frame_base) + mwe->frame_size,
1209  d, target_rank);
1210 
1211  stack_.direct_copy_from(mwe->frame_base, mwe->frame_size, target_rank);
1212 
1213  migration_wsq_.lock().unlock(target_rank, d);
1214 
1215  common::profiler::interval_end<prof_event_sched_steal>(ibd, true);
1216 
1217  common::profiler::switch_phase<prof_phase_sched_loop, prof_phase_sched_resume_stolen>();
1218 
1219  suspend([&](context_frame* cf) {
1220  sched_cf_ = cf;
1221  context_frame* next_cf = reinterpret_cast<context_frame*>(mwe->frame_base);
1222  resume(next_cf);
1223  });
1224  }
1225 
1226  steal_success = true;
1227  return true;
1228  });
1229 
1230  if (!steal_success) {
1231  common::verbose<2>("Steal failed for migration queues on rank %d", target_rank);
1232  }
1233  return steal_success;
1234  }
1235 
1236  template <typename Fn>
1237  void suspend(Fn&& fn) {
1238  context_frame* prev_cf_top = cf_top_;
1239  thread_local_storage* prev_tls = tls_;
1240 
1241  context::save_context_with_call(prev_cf_top,
1242  [](context_frame* cf, void* cf_top_p, void* fn_p) {
1243  context_frame*& cf_top = *reinterpret_cast<context_frame**>(cf_top_p);
1244  Fn fn = std::forward<Fn>(*reinterpret_cast<Fn*>(fn_p)); // copy closure to the new stack frame
1245  cf_top = cf;
1246  fn(cf);
1247  }, &cf_top_, &fn, prev_tls);
1248 
1249  cf_top_ = prev_cf_top;
1250  tls_ = prev_tls;
1251  }
1252 
1253  void resume(context_frame* cf) {
1254  common::verbose("Resume context frame [%p, %p) in the stack", cf, cf->parent_frame);
1255  context::resume(cf);
1256  }
1257 
1258  void resume(suspended_state ss) {
1259  common::verbose("Resume context frame [%p, %p) evacuated at %p",
1260  ss.frame_base, reinterpret_cast<std::byte*>(ss.frame_base) + ss.frame_size, ss.evacuation_ptr);
1261 
1262  // We pass the suspended thread states *by value* because the current local variables can be overwritten by the
1263  // new stack we will bring from remote nodes.
1264  context::jump_to_stack(ss.frame_base, [](void* this_, void* evacuation_ptr, void* frame_base, void* frame_size_) {
1265  scheduler_adws& this_sched = *reinterpret_cast<scheduler_adws*>(this_);
1266  std::size_t frame_size = reinterpret_cast<std::size_t>(frame_size_);
1267 
1268  common::remote_get(this_sched.suspended_thread_allocator_,
1269  reinterpret_cast<std::byte*>(frame_base),
1270  reinterpret_cast<std::byte*>(evacuation_ptr),
1271  frame_size);
1272  this_sched.suspended_thread_allocator_.deallocate(evacuation_ptr, frame_size);
1273 
1274  context_frame* cf = reinterpret_cast<context_frame*>(frame_base);
1275  /* context::clear_parent_frame(cf); */
1276  context::resume(cf);
1277  }, this, ss.evacuation_ptr, ss.frame_base, reinterpret_cast<void*>(ss.frame_size));
1278  }
1279 
1280  void resume_sched() {
1281  common::verbose("Resume scheduler context");
1282  context::resume(sched_cf_);
1283  }
1284 
1285  void start_new_task(void* task_ptr, std::size_t task_size) {
1286  root_on_stack([&]() {
1287  task_general* t = reinterpret_cast<task_general*>(alloca(task_size));
1288 
1289  common::remote_get(suspended_thread_allocator_,
1290  reinterpret_cast<std::byte*>(t),
1291  reinterpret_cast<std::byte*>(task_ptr),
1292  task_size);
1293  suspended_thread_allocator_.deallocate(task_ptr, task_size);
1294 
1295  t->execute();
1296  });
1297  }
1298 
1299  void execute_migrated_task(const suspended_state& ss) {
1300  if (ss.evacuation_ptr == nullptr) {
1301  // This task is a new task
1302  common::verbose("Received a new cross-worker task");
1303  common::profiler::switch_phase<prof_phase_sched_loop, prof_phase_sched_start_new>();
1304 
1305  suspend([&](context_frame* cf) {
1306  sched_cf_ = cf;
1307  start_new_task(ss.frame_base, ss.frame_size);
1308  });
1309 
1310  } else {
1311  // This task is an evacuated continuation
1312  common::verbose("Received a continuation of a cross-worker task");
1313  common::profiler::switch_phase<prof_phase_sched_loop, prof_phase_sched_resume_migrate>();
1314 
1315  suspend([&](context_frame* cf) {
1316  sched_cf_ = cf;
1317  resume(ss);
1318  });
1319  }
1320  }
1321 
1322  void execute_migrated_task(const migration_wsq_entry& mwe) {
1323  if (!mwe.is_continuation) {
1324  // This task is a new task
1325  common::verbose("Popped a new task from local migration queues");
1326  common::profiler::switch_phase<prof_phase_sched_loop, prof_phase_sched_start_new>();
1327 
1328  suspend([&](context_frame* cf) {
1329  sched_cf_ = cf;
1330  start_new_task(mwe.frame_base, mwe.frame_size);
1331  });
1332 
1333  } else if (mwe.evacuation_ptr) {
1334  // This task is an evacuated continuation
1335  common::verbose("Popped an evacuated continuation from local migration queues");
1336  common::profiler::switch_phase<prof_phase_sched_loop, prof_phase_sched_resume_popped>();
1337 
1338  suspend([&](context_frame* cf) {
1339  sched_cf_ = cf;
1340  resume(suspended_state{mwe.evacuation_ptr, mwe.frame_base, mwe.frame_size});
1341  });
1342 
1343  } else {
1344  // This task is a continuation on the stack
1345  common::die("On-stack threads cannot remain after switching to the scheduler. Something went wrong.");
1346  }
1347  }
1348 
1349  std::optional<primary_wsq_entry> pop_from_primary_queues(int depth_from) {
1350  // TODO: upper bound for depth can be tracked
1351  for (int d = depth_from; d >= 0; d--) {
1352  auto pwe = primary_wsq_.pop<false>(d);
1353  if (pwe.has_value()) {
1354  return pwe;
1355  }
1356  }
1357  return std::nullopt;
1358  }
1359 
1360  std::optional<migration_wsq_entry> pop_from_migration_queues() {
1361  for (int d = 0; d < migration_wsq_.n_queues(); d++) {
1362  auto mwe = migration_wsq_.pop<false>(d);
1363  if (mwe.has_value()) {
1364  return mwe;
1365  }
1366  }
1367  return std::nullopt;
1368  }
1369 
1370  suspended_state evacuate(context_frame* cf) {
1371  std::size_t cf_size = reinterpret_cast<uintptr_t>(cf->parent_frame) - reinterpret_cast<uintptr_t>(cf);
1372  void* evacuation_ptr = suspended_thread_allocator_.allocate(cf_size);
1373  std::memcpy(evacuation_ptr, cf, cf_size);
1374 
1375  common::verbose("Evacuate suspended thread context [%p, %p) to %p",
1376  cf, cf->parent_frame, evacuation_ptr);
1377 
1378  return {evacuation_ptr, cf, cf_size};
1379  }
1380 
1381  void evacuate_all() {
1382  if (use_primary_wsq_) {
1383  for (int d = tls_->dtree_node_ref.depth; d >= 0; d--) {
1384  primary_wsq_.for_each_entry([&](primary_wsq_entry& pwe) {
1385  if (!pwe.evacuation_ptr) {
1386  context_frame* cf = reinterpret_cast<context_frame*>(pwe.frame_base);
1387  suspended_state ss = evacuate(cf);
1388  pwe = {ss.evacuation_ptr, ss.frame_base, ss.frame_size, pwe.tg_version};
1389  }
1390  }, d);
1391  }
1392  } else {
1393  migration_wsq_.for_each_entry([&](migration_wsq_entry& mwe) {
1394  if (mwe.is_continuation && !mwe.evacuation_ptr) {
1395  context_frame* cf = reinterpret_cast<context_frame*>(mwe.frame_base);
1396  suspended_state ss = evacuate(cf);
1397  mwe = {true, ss.evacuation_ptr, ss.frame_base, ss.frame_size, mwe.tg_version};
1398  }
1399  }, tls_->dtree_node_ref.depth);
1400  }
1401  }
1402 
1403  template <typename PhaseFrom, typename PhaseTo,
1404  typename PreSuspendCallback, typename PostSuspendCallback>
1405  bool check_cross_worker_task_arrival(PreSuspendCallback&& pre_suspend_cb,
1406  PostSuspendCallback&& post_suspend_cb) {
1407  if (migration_mailbox_.arrived()) {
1408  tls_->dag_prof.stop();
1409 
1410  auto cb_ret = call_with_prof_events<PhaseFrom,
1411  prof_phase_cb_pre_suspend,
1412  prof_phase_sched_evacuate>(
1413  std::forward<PreSuspendCallback>(pre_suspend_cb));
1414 
1416 
1417  evacuate_all();
1418 
1419  suspend([&](context_frame* cf) {
1420  suspended_state ss = evacuate(cf);
1421 
1422  if (use_primary_wsq_) {
1423  primary_wsq_.push({ss.evacuation_ptr, ss.frame_base, ss.frame_size, tls_->tg_version},
1424  tls_->dtree_node_ref.depth);
1425  } else {
1426  migration_wsq_.push({true, ss.evacuation_ptr, ss.frame_base, ss.frame_size, tls_->tg_version},
1427  tls_->dtree_node_ref.depth);
1428  }
1429 
1430  common::profiler::switch_phase<prof_phase_sched_evacuate, prof_phase_sched_loop>();
1431  resume_sched();
1432  });
1433 
1435  call_with_prof_events<prof_phase_sched_resume_popped,
1436  prof_phase_cb_post_suspend,
1437  PhaseTo>(
1438  std::forward<PostSuspendCallback>(post_suspend_cb), cb_ret);
1439  } else {
1440  call_with_prof_events<prof_phase_sched_resume_stolen,
1441  prof_phase_cb_post_suspend,
1442  PhaseTo>(
1443  std::forward<PostSuspendCallback>(post_suspend_cb), cb_ret);
1444  }
1445 
1446  tls_->dag_prof.start();
1447 
1448  return true;
1449 
1450  } else if constexpr (!std::is_same_v<PhaseTo, PhaseFrom>) {
1451  common::profiler::switch_phase<PhaseFrom, PhaseTo>();
1452  }
1453 
1454  return false;
1455  }
1456 
1457  template <typename Fn>
1458  void root_on_stack(Fn&& fn) {
1459  cf_top_ = stack_base_;
1460  std::size_t stack_size_bytes = reinterpret_cast<std::byte*>(stack_base_) -
1461  reinterpret_cast<std::byte*>(stack_.top());
1462  context::call_on_stack(stack_.top(), stack_size_bytes,
1463  [](void* fn_, void*, void*, void*) {
1464  Fn fn = std::forward<Fn>(*reinterpret_cast<Fn*>(fn_)); // copy closure to the new stack frame
1465  fn();
1466  }, &fn, nullptr, nullptr, nullptr);
1467  }
1468 
1469  void execute_coll_task(task_general* t, coll_task ct) {
1470  // TODO: consider copy semantics for tasks
1471  coll_task ct_ {t, ct.task_size, ct.master_rank};
1472 
1473  // pass coll task to other processes in a binary tree form
1476  auto my_rank_shifted = (my_rank + n_ranks - ct.master_rank) % n_ranks;
1477  for (common::topology::rank_t i = common::next_pow2(n_ranks); i > 1; i /= 2) {
1478  if (my_rank_shifted % i == 0) {
1479  auto target_rank_shifted = my_rank_shifted + i / 2;
1480  if (target_rank_shifted < n_ranks) {
1481  auto target_rank = (target_rank_shifted + ct.master_rank) % n_ranks;
1482  coll_task_mailbox_.put(ct_, target_rank);
1483  }
1484  }
1485  }
1486 
1487  auto prev_stack_base = stack_base_;
1488  if (my_rank == ct.master_rank) {
1489  // Allocate half the rest of the stack space for nested root/coll_exec()
1490  stack_base_ = cf_top_ - (cf_top_ - reinterpret_cast<context_frame*>(stack_.top())) / 2;
1491  }
1492 
1493  // Ensure all processes have finished coll task execution before deallocation.
1494  // In addition, collectively set the next stack base for nested root_exec() calls because
1495  // the stack frame of the scheduler of the master worker is in the RDMA-capable stack region.
1496  // TODO: check if the scheduler's stack frame and nested root_exec()'s stack frame do not overlap
1497  stack_base_ = common::mpi_bcast_value(stack_base_, ct.master_rank, common::topology::mpicomm());
1498 
1499  t->execute();
1500 
1501  stack_base_ = prev_stack_base;
1502 
1503  // Ensure all processes have finished coll task execution before deallocation
1505  }
1506 
1507  void execute_coll_task_if_arrived() {
1508  auto ct = coll_task_mailbox_.pop();
1509  if (ct.has_value()) {
1510  task_general* t = reinterpret_cast<task_general*>(
1511  suspended_thread_allocator_.allocate(ct->task_size));
1512 
1513  common::remote_get(suspended_thread_allocator_,
1514  reinterpret_cast<std::byte*>(t),
1515  reinterpret_cast<std::byte*>(ct->task_ptr),
1516  ct->task_size);
1517 
1518  common::profiler::switch_phase<prof_phase_sched_loop, prof_phase_spmd>();
1519 
1520  execute_coll_task(t, *ct);
1521 
1522  common::profiler::switch_phase<prof_phase_spmd, prof_phase_sched_loop>();
1523 
1524  suspended_thread_allocator_.deallocate(t, ct->task_size);
1525  }
1526  }
1527 
1528  bool should_exit_sched_loop() {
1529  if (sched_loop_make_mpi_progress_option::value()) {
1531  }
1532 
1533  execute_coll_task_if_arrived();
1534 
1535  if (exit_request_mailbox_.pop()) {
1538  for (common::topology::rank_t i = common::next_pow2(n_ranks); i > 1; i /= 2) {
1539  if (my_rank % i == 0) {
1540  auto target_rank = my_rank + i / 2;
1541  if (target_rank < n_ranks) {
1542  exit_request_mailbox_.put(target_rank);
1543  }
1544  }
1545  }
1546  return true;
1547  }
1548 
1549  return false;
1550  }
1551 
1552  template <typename T>
1553  thread_retval<T> get_retval_remote(thread_state<T>* ts) {
1554  if constexpr (std::is_trivially_copyable_v<T>) {
1555  return remote_get_value(thread_state_allocator_, &ts->retval);
1556  } else {
1557  // TODO: Fix this ugly hack of avoiding object destruction by using checkout/checkin
1558  thread_retval<T> retval;
1559  remote_get(thread_state_allocator_, reinterpret_cast<std::byte*>(&retval), reinterpret_cast<std::byte*>(&ts->retval), sizeof(thread_retval<T>));
1560  return retval;
1561  }
1562  }
1563 
1564  template <typename T>
1565  void put_retval_remote(thread_state<T>* ts, thread_retval<T>&& retval) {
1566  if constexpr (std::is_trivially_copyable_v<T>) {
1567  remote_put_value(thread_state_allocator_, retval, &ts->retval);
1568  } else {
1569  // TODO: Fix this ugly hack of avoiding object destruction by using checkout/checkin
1570  std::byte* retvalp = reinterpret_cast<std::byte*>(new (alloca(sizeof(thread_retval<T>))) thread_retval<T>{std::move(retval)});
1571  remote_put(thread_state_allocator_, retvalp, reinterpret_cast<std::byte*>(&ts->retval), sizeof(thread_retval<T>));
1572  }
1573  }
1574 
1575  int max_depth_;
1576  callstack stack_;
1577  context_frame* stack_base_;
1578  oneslot_mailbox<void> exit_request_mailbox_;
1579  oneslot_mailbox<coll_task> coll_task_mailbox_;
1580  oneslot_mailbox<suspended_state> migration_mailbox_;
1581  wsqueue<primary_wsq_entry, false> primary_wsq_;
1582  wsqueue<migration_wsq_entry, true> migration_wsq_;
1583  common::remotable_resource thread_state_allocator_;
1584  common::remotable_resource suspended_thread_allocator_;
1585  context_frame* cf_top_ = nullptr;
1586  context_frame* sched_cf_ = nullptr;
1587  thread_local_storage* tls_ = nullptr;
1588  bool use_primary_wsq_ = true;
1589  dist_tree dtree_;
1590  dist_tree::node_ref dtree_local_bottom_ref_;
1591  bool dag_prof_enabled_ = false;
1592  dag_profiler dag_prof_result_;
1593 };
1594 
1595 }
void unlock(topology::rank_t target_rank, int idx=0) const
Definition: global_lock.hpp:53
bool trylock(topology::rank_t target_rank, int idx=0) const
Definition: global_lock.hpp:21
span< T > local_buf() const
Definition: mpi_rma.hpp:412
MPI_Win win() const
Definition: mpi_rma.hpp:409
static value_type value()
Definition: options.hpp:62
bool is_remotely_freed(void *p, std::size_t alignment=alignof(max_align_t))
Definition: allocator.hpp:272
Definition: util.hpp:115
void direct_copy_from(void *addr, std::size_t size, common::topology::rank_t target_rank) const
Definition: callstack.hpp:25
Definition: adws.hpp:45
common::topology::rank_t end_rank() const
Definition: adws.hpp:62
value_type end() const
Definition: adws.hpp:56
bool is_cross_worker() const
Definition: adws.hpp:94
double value_type
Definition: adws.hpp:47
dist_range()
Definition: adws.hpp:49
void move_to_end_boundary()
Definition: adws.hpp:70
dist_range(common::topology::rank_t n_ranks)
Definition: adws.hpp:50
void make_non_cross_worker()
Definition: adws.hpp:98
common::topology::rank_t begin_rank() const
Definition: adws.hpp:58
value_type begin() const
Definition: adws.hpp:55
bool is_sufficiently_small() const
Definition: adws.hpp:102
bool is_at_end_boundary() const
Definition: adws.hpp:66
std::pair< dist_range, dist_range > divide(T r1, T r2) const
Definition: adws.hpp:75
dist_range(value_type begin, value_type end)
Definition: adws.hpp:52
common::topology::rank_t owner() const
Definition: adws.hpp:90
Definition: adws.hpp:111
dist_tree(int max_depth)
Definition: adws.hpp:131
void copy_parents(node_ref nr)
Definition: adws.hpp:225
node & get_local_node(node_ref nr)
Definition: adws.hpp:233
std::optional< node > get_topmost_dominant(node_ref nr)
Definition: adws.hpp:172
void set_dominant(node_ref nr, bool dominant)
Definition: adws.hpp:155
node_ref append(node_ref parent, dist_range drange, flipper tg_version)
Definition: adws.hpp:137
Definition: adws.hpp:20
void flip(int at)
Definition: adws.hpp:26
bool match(flipper f, int until) const
Definition: adws.hpp:33
value_type value() const
Definition: adws.hpp:24
uint64_t value_type
Definition: adws.hpp:22
void put(common::topology::rank_t target_rank)
Definition: util.hpp:266
Definition: adws.hpp:257
void dag_prof_end()
Definition: adws.hpp:819
void task_group_begin(task_group_data *tgdata)
Definition: adws.hpp:376
void migrate_to(common::topology::rank_t target_rank, PreSuspendCallback &&pre_suspend_cb, PostSuspendCallback &&post_suspend_cb)
Definition: adws.hpp:746
scheduler_adws()
Definition: adws.hpp:302
bool is_executing_root() const
Definition: adws.hpp:800
static bool is_serialized(const thread_handler< T > &th)
Definition: adws.hpp:805
void poll(PreSuspendCallback &&pre_suspend_cb, PostSuspendCallback &&post_suspend_cb)
Definition: adws.hpp:738
void dag_prof_begin()
Definition: adws.hpp:809
void fork(thread_handler< T > &th, OnDriftForkCallback on_drift_fork_cb, OnDriftDieCallback on_drift_die_cb, WorkHint w_new, WorkHint w_rest, Fn &&fn, Args &&... args)
Definition: adws.hpp:452
void sched_loop(SchedLoopCallback cb)
Definition: adws.hpp:692
T root_exec(SchedLoopCallback cb, Fn &&fn, Args &&... args)
Definition: adws.hpp:316
void task_group_end(PreSuspendCallback &&pre_suspend_cb, PostSuspendCallback &&post_suspend_cb)
Definition: adws.hpp:405
T join(thread_handler< T > &th)
Definition: adws.hpp:621
void dag_prof_print() const
Definition: adws.hpp:833
void coll_exec(const Fn &fn)
Definition: adws.hpp:776
std::optional< Entry > pop(int idx=0)
Definition: wsqueue.hpp:67
void abort_steal(common::topology::rank_t target_rank, int idx=0)
Definition: wsqueue.hpp:189
std::optional< Entry > steal_nolock(common::topology::rank_t target_rank, int idx=0)
Definition: wsqueue.hpp:158
void pass(const Entry &entry, common::topology::rank_t target_rank, int idx=0)
Definition: wsqueue.hpp:223
int n_queues() const
Definition: wsqueue.hpp:305
void push(const Entry &entry, int idx=0)
Definition: wsqueue.hpp:36
void for_each_entry(Fn fn, int idx=0)
Definition: wsqueue.hpp:231
const common::global_lock & lock() const
Definition: wsqueue.hpp:304
void for_each_nonempty_queue(common::topology::rank_t target_rank, int idx_begin, int idx_end, bool reverse, Fn fn)
Definition: wsqueue.hpp:281
#define ITYR_CHECK(cond)
Definition: util.hpp:48
bool enabled()
Definition: numa.hpp:86
rank_t n_ranks()
Definition: topology.hpp:208
int rank_t
Definition: topology.hpp:12
MPI_Comm mpicomm()
Definition: topology.hpp:206
rank_t my_rank()
Definition: topology.hpp:207
void remote_get(const remotable_resource &rmr, T *origin_p, const T *target_p, std::size_t size)
Definition: allocator.hpp:404
void remote_put_value(const remotable_resource &rmr, const T &val, T *target_p)
Definition: allocator.hpp:434
T mpi_atomic_put_value(const T &value, int target_rank, std::size_t target_disp, MPI_Win win)
Definition: mpi_rma.hpp:283
va_list args
Definition: util.hpp:76
T remote_faa_value(const remotable_resource &rmr, const T &val, T *target_p)
Definition: allocator.hpp:444
void mpi_make_progress()
Definition: mpi_util.hpp:260
T mpi_bcast_value(const T &value, int root_rank, MPI_Comm comm)
Definition: mpi_util.hpp:145
T mpi_allreduce_value(const T &value, MPI_Comm comm, MPI_Op op=MPI_SUM)
Definition: mpi_util.hpp:194
void mpi_get(T *origin, std::size_t count, int target_rank, std::size_t target_disp, MPI_Win win)
Definition: mpi_rma.hpp:69
void remote_put(const remotable_resource &rmr, const T *origin_p, T *target_p, std::size_t size)
Definition: allocator.hpp:424
T remote_get_value(const remotable_resource &rmr, const T *target_p)
Definition: allocator.hpp:414
T mpi_atomic_cas_value(const T &value, const T &compare, int target_rank, std::size_t target_disp, MPI_Win win)
Definition: mpi_rma.hpp:222
void mpi_barrier(MPI_Comm comm)
Definition: mpi_util.hpp:42
uint64_t next_pow2(uint64_t x)
Definition: util.hpp:102
void verbose(const char *fmt,...)
Definition: logger.hpp:11
Definition: aarch64.hpp:5
auto call_with_prof_events(Fn &&fn, Args &&... args)
Definition: util.hpp:182
ITYR_CONCAT(dag_profiler_, ITYR_ITO_DAG_PROF) dag_profiler
Definition: util.hpp:102
common::topology::rank_t get_random_rank(common::topology::rank_t a, common::topology::rank_t b)
Definition: util.hpp:128
monoid< T, max_functor<>, lowest< T > > max
Definition: reducer.hpp:104
rank_t my_rank()
Return the rank of the process running the current thread.
Definition: ityr.hpp:99
rank_t n_ranks()
Return the total number of processes.
Definition: ityr.hpp:107
ForwardIteratorD move(const ExecutionPolicy &policy, ForwardIterator1 first1, ForwardIterator1 last1, ForwardIteratorD first_d)
Move a range to another.
Definition: parallel_loop.hpp:934
#define ITYR_PROFILER_RECORD(event,...)
Definition: profiler.hpp:319
Definition: options.hpp:62
Definition: options.hpp:56
Definition: adws.hpp:115
int depth
Definition: adws.hpp:117
common::topology::rank_t owner_rank
Definition: adws.hpp:116
Definition: adws.hpp:120
node_ref parent
Definition: adws.hpp:125
node()
Definition: adws.hpp:121
flipper tg_version
Definition: adws.hpp:127
dist_range drange
Definition: adws.hpp:126
int depth() const
Definition: adws.hpp:123
version_t version
Definition: adws.hpp:128
Definition: prof_events.hpp:95
Definition: prof_events.hpp:190
Definition: prof_events.hpp:205
Definition: prof_events.hpp:200
Definition: prof_events.hpp:155
Definition: prof_events.hpp:180
Definition: prof_events.hpp:175
Definition: prof_events.hpp:185
Definition: prof_events.hpp:210
void * frame_base
Definition: adws.hpp:261
std::size_t frame_size
Definition: adws.hpp:262
void * evacuation_ptr
Definition: adws.hpp:260
dag_profiler dag_prof_acc
Definition: adws.hpp:290
dag_profiler dag_prof_before
Definition: adws.hpp:289
bool owns_dtree_node
Definition: adws.hpp:288
task_group_data * parent
Definition: adws.hpp:286
dist_range drange
Definition: adws.hpp:287
thread_retval< T > retval_ser
Definition: adws.hpp:282
bool serialized
Definition: adws.hpp:281
thread_state< T > * state
Definition: adws.hpp:280
dag_profiler dag_prof
Definition: adws.hpp:299
flipper tg_version
Definition: adws.hpp:297
dist_range drange
Definition: adws.hpp:295
dist_tree::node_ref dtree_node_ref
Definition: adws.hpp:296
bool undistributed
Definition: adws.hpp:298
task_group_data * tgdata
Definition: adws.hpp:294
dag_profiler dag_prof
Definition: adws.hpp:268
thread_retval< T > retval
Definition: adws.hpp:273
suspended_state suspended
Definition: adws.hpp:275
int resume_flag
Definition: adws.hpp:274
Definition: options.hpp:20