38     return (val_ & mask) == (f.
value() & mask);
 
   75   std::pair<dist_range, dist_range> 
divide(T r1, T r2)
 const {
 
   76     value_type at = begin_ + (end_ - begin_) * r1 / (r1 + r2);
 
   84       if (at < begin_) at = begin_;
 
  112   using version_t = int;
 
  132     : max_depth_(max_depth),
 
  133       node_win_(common::topology::
mpicomm(), max_depth_),
 
  134       dominant_flag_win_(common::topology::
mpicomm(), max_depth_, 0),
 
  135       versions_(max_depth_, common::topology::
my_rank() + 1) {}
 
  138     int depth = parent.
depth + 1;
 
  146     node& new_node = local_node(depth);
 
  158     version_t value = (dominant ? 1 : -1) * local_node(nr.
depth).
version;
 
  160     local_dominant_flag(nr.
depth).store(value, std::memory_order_relaxed);
 
  163       std::size_t disp_dominant = nr.
depth * 
sizeof(version_t);
 
  173     if (nr.
depth < 0) 
return std::nullopt;
 
  177     for (
int d = 0; d <= nr.
depth; d++) {
 
  181       node& n = local_node(d);
 
  182       auto& dominant_flag = local_dominant_flag(d);
 
  188           dominant_flag.load(std::memory_order_relaxed) != -n.
version) {
 
  194         if (target_rank != owner_rank &&
 
  195             dominant_flag.load(std::memory_order_relaxed) == n.
version) {
 
  197           std::size_t disp_dominant = d * 
sizeof(version_t);
 
  199               target_rank, disp_dominant, dominant_flag_win_.
win());
 
  201           if (dominant_val == -n.
version) {
 
  202             dominant_flag.store(dominant_val, std::memory_order_relaxed);
 
  206           std::size_t disp_dominant = d * 
sizeof(version_t);
 
  207           version_t dominant_val = common::mpi_atomic_get_value<version_t>(
 
  208               target_rank, disp_dominant, dominant_flag_win_.
win());
 
  211             dominant_flag.store(dominant_val, std::memory_order_relaxed);
 
  216       if (dominant_flag.load(std::memory_order_relaxed) == n.
version) {
 
  226     for (
int d = 0; d <= nr.
depth; d++) {
 
  228       local_dominant_flag(d).store(0, std::memory_order_relaxed);
 
  235     return local_node(nr.
depth);
 
  239   node& local_node(
int depth) {
 
  242     return node_win_.local_buf()[depth];
 
  245   std::atomic<version_t>& local_dominant_flag(
int depth) {
 
  248     return dominant_flag_win_.
local_buf()[depth];
 
  252   common::mpi_win_manager<node>                   node_win_;
 
  253   common::mpi_win_manager<std::atomic<version_t>> dominant_flag_win_;
 
  254   std::vector<version_t>                          versions_;
 
  265   template <
typename T>
 
  271   template <
typename T>
 
  278   template <
typename T>
 
  308       stack_base_(reinterpret_cast<context_frame*>(stack_.bottom()) - 1),
 
  313       dtree_(max_depth_) {}
 
  315   template <
typename T, 
typename SchedLoopCallback, 
typename Fn, 
typename... Args>
 
  317     common::profiler::switch_phase<prof_phase_spmd, prof_phase_sched_fork>();
 
  321     auto prev_sched_cf = sched_cf_;
 
  323     suspend([&](context_frame* cf) {
 
  325       root_on_stack([&, ts, fn = std::forward<Fn>(fn),
 
  326                      args_tuple = std::make_tuple(std::forward<Args>(
args)...)]() 
mutable {
 
  336         tls_->
dag_prof.increment_thread_count();
 
  337         tls_->
dag_prof.increment_strand_count();
 
  339         common::profiler::switch_phase<prof_phase_sched_fork, prof_phase_thread>();
 
  341         T&& ret = invoke_fn<T>(std::forward<decltype(fn)>(fn), std::forward<decltype(args_tuple)>(args_tuple));
 
  343         common::profiler::switch_phase<prof_phase_thread, prof_phase_sched_die>();
 
  354     common::profiler::switch_phase<prof_phase_sched_loop, prof_phase_sched_join>();
 
  360     if (dag_prof_enabled_) {
 
  365         dag_prof_result_.merge_serial(retval.
dag_prof);
 
  369     sched_cf_ = prev_sched_cf;
 
  371     common::profiler::switch_phase<prof_phase_sched_join, prof_phase_spmd>();
 
  395       common::verbose(
"Begin a cross-worker task group of distribution range [%f, %f) at depth %d",
 
  401     tls_->
dag_prof.increment_strand_count();
 
  404   template <
typename PreSuspendCallback, 
typename PostSuspendCallback>
 
  406                       PostSuspendCallback&& post_suspend_cb) {
 
  419       common::verbose(
"End a cross-worker task group of distribution range [%f, %f) at depth %d",
 
  424                  std::forward<PreSuspendCallback>(pre_suspend_cb),
 
  425                  std::forward<PostSuspendCallback>(post_suspend_cb));
 
  444     std::destroy_at(tgdata);
 
  447     tls_->
dag_prof.increment_strand_count();
 
  450   template <
typename T, 
typename OnDriftForkCallback, 
typename OnDriftDieCallback,
 
  451             typename WorkHint, 
typename Fn, 
typename... Args>
 
  453             OnDriftForkCallback on_drift_fork_cb, OnDriftDieCallback on_drift_die_cb,
 
  454             WorkHint w_new, WorkHint w_rest, Fn&& fn, Args&&... 
args) {
 
  455     common::profiler::switch_phase<prof_phase_thread, prof_phase_sched_fork>();
 
  471       auto [dr_rest, dr_new] = tls_->
drange.
divide(w_rest, w_new);
 
  473       common::verbose(
"Distribution range [%f, %f) is divided into [%f, %f) and [%f, %f)",
 
  475                       dr_rest.begin(), dr_rest.end(), dr_new.begin(), dr_new.end());
 
  479       target_rank = dr_new.
owner();
 
  483       new_drange = tls_->
drange;
 
  492       suspend([&, ts, fn = std::forward<Fn>(fn),
 
  493                args_tuple = std::make_tuple(std::forward<Args>(
args)...)](context_frame* cf) 
mutable {
 
  494         common::verbose<3>(
"push context frame [%p, %p) into task queue", cf, cf->parent_frame);
 
  500         std::size_t cf_size = 
reinterpret_cast<uintptr_t
>(cf->parent_frame) - 
reinterpret_cast<uintptr_t
>(cf);
 
  502         if (use_primary_wsq_) {
 
  506           migration_wsq_.
push({
true, 
nullptr, cf, cf_size, tls_->
tg_version},
 
  511         tls_->
dag_prof.increment_thread_count();
 
  512         tls_->
dag_prof.increment_strand_count();
 
  514         common::verbose<3>(
"Starting new thread %p", ts);
 
  515         common::profiler::switch_phase<prof_phase_sched_fork, prof_phase_thread>();
 
  517         T&& ret = invoke_fn<T>(std::forward<decltype(fn)>(fn), std::forward<decltype(args_tuple)>(args_tuple));
 
  519         common::profiler::switch_phase<prof_phase_thread, prof_phase_sched_die>();
 
  520         common::verbose<3>(
"Thread %p is completed", ts);
 
  523         on_die_workfirst(ts, 
std::move(ret), on_drift_die_cb);
 
  525         common::verbose<3>(
"Thread %p is serialized (fast path)", ts);
 
  534         common::verbose<3>(
"Resume parent context frame [%p, %p) (fast path)", cf, cf->parent_frame);
 
  536         common::profiler::switch_phase<prof_phase_sched_die, prof_phase_sched_resume_popped>();
 
  541         common::profiler::switch_phase<prof_phase_sched_resume_popped, prof_phase_thread>();
 
  551       auto new_task_fn = [&, 
my_rank, ts, new_drange,
 
  554                           on_drift_fork_cb, on_drift_die_cb, fn = std::forward<Fn>(fn),
 
  555                           args_tuple = std::make_tuple(std::forward<Args>(
args)...)]() 
mutable {
 
  557                         ts, new_drange.begin(), new_drange.end());
 
  561                                     tg_version, 
true, {}};
 
  563         if (new_drange.is_cross_worker()) {
 
  565           dtree_local_bottom_ref_ = dtree_node_ref;
 
  569         tls_->
dag_prof.increment_thread_count();
 
  570         tls_->
dag_prof.increment_strand_count();
 
  578           common::profiler::switch_phase<prof_phase_sched_start_new, prof_phase_thread>();
 
  581         T&& ret = invoke_fn<T>(std::forward<decltype(fn)>(fn), std::forward<decltype(args_tuple)>(args_tuple));
 
  583         common::profiler::switch_phase<prof_phase_thread, prof_phase_sched_die>();
 
  585                         ts, new_drange.begin(), new_drange.end());
 
  588         on_die_drifted(ts, 
std::move(ret), on_drift_die_cb);
 
  591       using callable_task_t = 
callable_task<decltype(new_task_fn)>;
 
  593       size_t task_size = 
sizeof(callable_task_t);
 
  594       void* task_ptr = suspended_thread_allocator_.allocate(task_size);
 
  596       auto t = 
new (task_ptr) callable_task_t(
std::move(new_task_fn));
 
  598       if (new_drange.is_cross_worker()) {
 
  600                         ts, new_drange.begin(), new_drange.end(), target_rank);
 
  602         migration_mailbox_.put({
nullptr, t, task_size}, target_rank);
 
  604         common::verbose(
"Migrate non-cross-worker-task %p [%f, %f) to process %d",
 
  605                         ts, new_drange.begin(), new_drange.end(), target_rank);
 
  607         migration_wsq_.
pass({
false, 
nullptr, t, task_size, tls_->
tg_version},
 
  611       common::profiler::switch_phase<prof_phase_sched_fork, prof_phase_thread>();
 
  617     tls_->
dag_prof.increment_strand_count();
 
  620   template <
typename T>
 
  622     common::profiler::switch_phase<prof_phase_thread, prof_phase_sched_join>();
 
  626       common::verbose<3>(
"Skip join for serialized thread (fast path)");
 
  642           retval = get_retval_remote(ts);
 
  646         bool migrated = 
true;
 
  647         suspend([&](context_frame* cf) {
 
  654             common::verbose(
"Win the join race for thread %p (joining thread)", ts);
 
  656             common::profiler::switch_phase<prof_phase_sched_join, prof_phase_sched_loop>();
 
  660             common::verbose(
"Lose the join race for thread %p (joining thread)", ts);
 
  669           common::profiler::switch_phase<prof_phase_sched_resume_join, prof_phase_sched_join>();
 
  673           retval = get_retval_remote(ts);
 
  687     common::profiler::switch_phase<prof_phase_sched_join, prof_phase_thread>();
 
  691   template <
typename SchedLoopCallback>
 
  695     while (!should_exit_sched_loop()) {
 
  696       auto mte = migration_mailbox_.pop();
 
  697       if (mte.has_value()) {
 
  698         execute_migrated_task(*mte);
 
  702       auto pwe = pop_from_primary_queues(primary_wsq_.
n_queues() - 1);
 
  703       if (pwe.has_value()) {
 
  704         common::profiler::switch_phase<prof_phase_sched_loop, prof_phase_sched_resume_popped>();
 
  708         suspend([&](context_frame* cf) {
 
  715       auto mwe = pop_from_migration_queues();
 
  716       if (mwe.has_value()) {
 
  717         use_primary_wsq_ = 
false;
 
  718         execute_migrated_task(*mwe);
 
  719         use_primary_wsq_ = 
true;
 
  727       if constexpr (!std::is_null_pointer_v<std::remove_reference_t<SchedLoopCallback>>) {
 
  732     dtree_local_bottom_ref_ = {};
 
  737   template <
typename PreSuspendCallback, 
typename PostSuspendCallback>
 
  738   void poll(PreSuspendCallback&&  pre_suspend_cb,
 
  739             PostSuspendCallback&& post_suspend_cb) {
 
  740     check_cross_worker_task_arrival<prof_phase_thread, prof_phase_thread>(
 
  741         std::forward<PreSuspendCallback>(pre_suspend_cb),
 
  742         std::forward<PostSuspendCallback>(post_suspend_cb));
 
  745   template <
typename PreSuspendCallback, 
typename PostSuspendCallback>
 
  747                   PreSuspendCallback&&     pre_suspend_cb,
 
  748                   PostSuspendCallback&&    post_suspend_cb) {
 
  754         std::forward<PreSuspendCallback>(pre_suspend_cb));
 
  756     suspend([&](context_frame* cf) {
 
  759       common::verbose(
"Migrate continuation of cross-worker-task [%f, %f) to process %d",
 
  762       migration_mailbox_.put(ss, target_rank);
 
  765       common::profiler::switch_phase<prof_phase_sched_migrate, prof_phase_sched_loop>();
 
  772         std::forward<PostSuspendCallback>(post_suspend_cb), cb_ret);
 
  775   template <
typename Fn>
 
  777     common::profiler::switch_phase<prof_phase_thread, prof_phase_spmd>();
 
  784     size_t task_size = 
sizeof(callable_task_t);
 
  785     void* task_ptr = suspended_thread_allocator_.allocate(task_size);
 
  787     auto t = 
new (task_ptr) callable_task_t(fn);
 
  790     execute_coll_task(t, ct);
 
  792     suspended_thread_allocator_.deallocate(t, task_size);
 
  795     tls_->
dag_prof.increment_strand_count();
 
  797     common::profiler::switch_phase<prof_phase_spmd, prof_phase_thread>();
 
  801     return cf_top_ && cf_top_ == stack_base_;
 
  804   template <
typename T>
 
  810     dag_prof_enabled_ = 
true;
 
  811     dag_prof_result_.clear();
 
  815       tls_->
dag_prof.increment_thread_count();
 
  820     dag_prof_enabled_ = 
false;
 
  835       dag_prof_result_.print();
 
  842     std::size_t              task_size;
 
  846   struct primary_wsq_entry {
 
  847     void*       evacuation_ptr;
 
  849     std::size_t frame_size;
 
  853   struct migration_wsq_entry {
 
  854     bool        is_continuation;
 
  855     void*       evacuation_ptr;
 
  857     std::size_t frame_size;
 
  874       common::verbose(
"Distribution tree node (owner=%d, depth=%d) becomes dominant",
 
  881         std::vector<std::pair<suspended_state, common::topology::rank_t>> tasks;
 
  887              target_rank < tls_->drange.end_rank();
 
  893             dtree_local_bottom_ref_ = dtree_node_ref;
 
  895             common::profiler::switch_phase<prof_phase_sched_start_new, prof_phase_sched_loop>();
 
  899           using callable_task_t = callable_task<decltype(new_task_fn)>;
 
  901           size_t task_size = 
sizeof(callable_task_t);
 
  902           void* task_ptr = suspended_thread_allocator_.allocate(task_size);
 
  904           auto t = 
new (task_ptr) callable_task_t(new_task_fn);
 
  905           tasks.push_back({{
nullptr, t, task_size}, target_rank});
 
  909         for (
auto [t, target_rank] : tasks) {
 
  910           migration_mailbox_.put(t, target_rank);
 
  915         for (
auto [t, target_rank] : tasks) {
 
  929   template <
typename T, 
typename OnDriftDieCallback>
 
  930   void on_die_workfirst(thread_state<T>* ts, T&& ret, OnDriftDieCallback on_drift_die_cb) {
 
  931     if (use_primary_wsq_) {
 
  933       if (qe.has_value()) {
 
  934         if (!qe->evacuation_ptr) {
 
  948       if (qe.has_value()) {
 
  949         if (qe->is_continuation && !qe->evacuation_ptr) {
 
  958     on_die_drifted(ts, 
std::move(ret), on_drift_die_cb);
 
  961   template <
typename T, 
typename OnDriftDieCallback>
 
  962   void on_die_drifted(thread_state<T>* ts, T&& ret, OnDriftDieCallback on_drift_die_cb) {
 
  963     if constexpr (!std::is_null_pointer_v<std::remove_reference_t<OnDriftDieCallback>>) {
 
  965                             prof_phase_cb_drift_die,
 
  966                             prof_phase_sched_die>(on_drift_die_cb);
 
  975       common::verbose(
"Win the join race for thread %p (joined thread)", ts);
 
  981       common::profiler::switch_phase<prof_phase_sched_die, prof_phase_sched_loop>();
 
  985       common::verbose(
"Lose the join race for thread %p (joined thread)", ts);
 
  986       common::profiler::switch_phase<prof_phase_sched_die, prof_phase_sched_resume_join>();
 
  987       suspended_state ss = 
remote_get_value(thread_state_allocator_, &ts->suspended);
 
  992   template <
typename T>
 
  993   void on_root_die(thread_state<T>* ts, T&& ret) {
 
  999     exit_request_mailbox_.
put(0);
 
 1001     common::profiler::switch_phase<prof_phase_sched_die, prof_phase_sched_loop>();
 
 1007     if (!ne.has_value()) {
 
 1008       common::verbose<2>(
"Dominant dist_tree node not found");
 
 1011     dist_range steal_range = ne->drange;
 
 1012     flipper    tg_version  = ne->tg_version;
 
 1013     int        depth       = ne->depth();
 
 1015     common::verbose<2>(
"Dominant dist_tree node found: drange=[%f, %f), depth=%d",
 
 1016                        steal_range.begin(), steal_range.end(), depth);
 
 1020     auto begin_rank = steal_range.begin_rank();
 
 1021     auto end_rank   = steal_range.end_rank();
 
 1023     if (steal_range.is_at_end_boundary()) {
 
 1027     if (begin_rank == end_rank) {
 
 1033     common::verbose<2>(
"Start work stealing for dominant task group [%f, %f)",
 
 1034                        steal_range.begin(), steal_range.end());
 
 1038     for (
int i = 0; i < max_reuse; i++) {
 
 1041       common::verbose<2>(
"Target rank: %d", target_rank);
 
 1043       if (target_rank != begin_rank) {
 
 1044         bool success = steal_from_migration_queues(target_rank, depth, migration_wsq_.
n_queues(),
 
 1045             [=](migration_wsq_entry& mwe) { return mwe.tg_version.match(tg_version, depth); });
 
 1051       if (target_rank != end_rank || (target_rank == end_rank && steal_range.is_at_end_boundary())) {
 
 1052         bool success = steal_from_primary_queues(target_rank, depth, primary_wsq_.
n_queues(),
 
 1053             [=](primary_wsq_entry& pwe) { return pwe.tg_version.match(tg_version, depth); });
 
 1060       auto mte = migration_mailbox_.pop();
 
 1061       if (mte.has_value()) {
 
 1062         execute_migrated_task(*mte);
 
 1068   template <
typename StealCondFn>
 
 1070                                  int min_depth, 
int max_depth, StealCondFn&& steal_cond_fn) {
 
 1071     bool steal_success = 
false;
 
 1074       auto ibd = common::profiler::interval_begin<prof_event_sched_steal>(target_rank);
 
 1076       if (!primary_wsq_.
lock().
trylock(target_rank, d)) {
 
 1077         common::profiler::interval_end<prof_event_sched_steal>(ibd, 
false);
 
 1082       if (!pwe.has_value()) {
 
 1084         common::profiler::interval_end<prof_event_sched_steal>(ibd, 
false);
 
 1088       if (!steal_cond_fn(*pwe)) {
 
 1091         common::profiler::interval_end<prof_event_sched_steal>(ibd, 
false);
 
 1096       if (pwe->evacuation_ptr) {
 
 1098         common::verbose(
"Steal an evacuated context frame [%p, %p) from primary wsqueue (depth=%d) on rank %d",
 
 1099                         pwe->frame_base, 
reinterpret_cast<std::byte*
>(pwe->frame_base) + pwe->frame_size,
 
 1104         common::profiler::interval_end<prof_event_sched_steal>(ibd, 
true);
 
 1106         common::profiler::switch_phase<prof_phase_sched_loop, prof_phase_sched_resume_stolen>();
 
 1108         suspend([&](context_frame* cf) {
 
 1110           resume(suspended_state{pwe->evacuation_ptr, pwe->frame_base, pwe->frame_size});
 
 1115         common::verbose(
"Steal context frame [%p, %p) from primary wsqueue (depth=%d) on rank %d",
 
 1116                         pwe->frame_base, 
reinterpret_cast<std::byte*
>(pwe->frame_base) + pwe->frame_size,
 
 1123         common::profiler::interval_end<prof_event_sched_steal>(ibd, 
true);
 
 1125         common::profiler::switch_phase<prof_phase_sched_loop, prof_phase_sched_resume_stolen>();
 
 1127         context_frame* next_cf = 
reinterpret_cast<context_frame*
>(pwe->frame_base);
 
 1128         suspend([&](context_frame* cf) {
 
 1130           context::clear_parent_frame(next_cf);
 
 1135       steal_success = 
true;
 
 1139     if (!steal_success) {
 
 1140       common::verbose<2>(
"Steal failed for primary queues on rank %d", target_rank);
 
 1142     return steal_success;
 
 1145   template <
typename StealCondFn>
 
 1147                                    int min_depth, 
int max_depth, StealCondFn&& steal_cond_fn) {
 
 1148     bool steal_success = 
false;
 
 1151       auto ibd = common::profiler::interval_begin<prof_event_sched_steal>(target_rank);
 
 1153       if (!migration_wsq_.
lock().
trylock(target_rank, d)) {
 
 1154         common::profiler::interval_end<prof_event_sched_steal>(ibd, 
false);
 
 1158       auto mwe = migration_wsq_.
steal_nolock(target_rank, d);
 
 1159       if (!mwe.has_value()) {
 
 1160         migration_wsq_.
lock().
unlock(target_rank, d);
 
 1161         common::profiler::interval_end<prof_event_sched_steal>(ibd, 
false);
 
 1165       if (!steal_cond_fn(*mwe)) {
 
 1167         migration_wsq_.
lock().
unlock(target_rank, d);
 
 1168         common::profiler::interval_end<prof_event_sched_steal>(ibd, 
false);
 
 1172       if (!mwe->is_continuation) {
 
 1174         common::verbose(
"Steal a new task from migration wsqueue (depth=%d) on rank %d",
 
 1177         migration_wsq_.
lock().
unlock(target_rank, d);
 
 1179         common::profiler::interval_end<prof_event_sched_steal>(ibd, 
true);
 
 1181         common::profiler::switch_phase<prof_phase_sched_loop, prof_phase_sched_start_new>();
 
 1183         suspend([&](context_frame* cf) {
 
 1185           start_new_task(mwe->frame_base, mwe->frame_size);
 
 1188       } 
else if (mwe->evacuation_ptr) {
 
 1190         common::verbose(
"Steal an evacuated context frame [%p, %p) from migration wsqueue (depth=%d) on rank %d",
 
 1191                         mwe->frame_base, 
reinterpret_cast<std::byte*
>(mwe->frame_base) + mwe->frame_size,
 
 1194         migration_wsq_.
lock().
unlock(target_rank, d);
 
 1196         common::profiler::interval_end<prof_event_sched_steal>(ibd, 
true);
 
 1198         common::profiler::switch_phase<prof_phase_sched_loop, prof_phase_sched_resume_stolen>();
 
 1200         suspend([&](context_frame* cf) {
 
 1202           resume(suspended_state{mwe->evacuation_ptr, mwe->frame_base, mwe->frame_size});
 
 1207         common::verbose(
"Steal a context frame [%p, %p) from migration wsqueue (depth=%d) on rank %d",
 
 1208                         mwe->frame_base, 
reinterpret_cast<std::byte*
>(mwe->frame_base) + mwe->frame_size,
 
 1213         migration_wsq_.
lock().
unlock(target_rank, d);
 
 1215         common::profiler::interval_end<prof_event_sched_steal>(ibd, 
true);
 
 1217         common::profiler::switch_phase<prof_phase_sched_loop, prof_phase_sched_resume_stolen>();
 
 1219         suspend([&](context_frame* cf) {
 
 1221           context_frame* next_cf = 
reinterpret_cast<context_frame*
>(mwe->frame_base);
 
 1226       steal_success = 
true;
 
 1230     if (!steal_success) {
 
 1231       common::verbose<2>(
"Steal failed for migration queues on rank %d", target_rank);
 
 1233     return steal_success;
 
 1236   template <
typename Fn>
 
 1237   void suspend(Fn&& fn) {
 
 1238     context_frame*        prev_cf_top = cf_top_;
 
 1239     thread_local_storage* prev_tls    = tls_;
 
 1241     context::save_context_with_call(prev_cf_top,
 
 1242         [](context_frame* cf, 
void* cf_top_p, 
void* fn_p) {
 
 1243       context_frame*& cf_top = *
reinterpret_cast<context_frame**
>(cf_top_p);
 
 1244       Fn              fn     = std::forward<Fn>(*
reinterpret_cast<Fn*
>(fn_p)); 
 
 1247     }, &cf_top_, &fn, prev_tls);
 
 1249     cf_top_ = prev_cf_top;
 
 1253   void resume(context_frame* cf) {
 
 1254     common::verbose(
"Resume context frame [%p, %p) in the stack", cf, cf->parent_frame);
 
 1255     context::resume(cf);
 
 1258   void resume(suspended_state ss) {
 
 1260                     ss.frame_base, 
reinterpret_cast<std::byte*
>(ss.frame_base) + ss.frame_size, ss.evacuation_ptr);
 
 1264     context::jump_to_stack(ss.frame_base, [](
void* this_, 
void* evacuation_ptr, 
void* frame_base, 
void* frame_size_) {
 
 1265       scheduler_adws& this_sched = *reinterpret_cast<scheduler_adws*>(this_);
 
 1266       std::size_t     frame_size = reinterpret_cast<std::size_t>(frame_size_);
 
 1268       common::remote_get(this_sched.suspended_thread_allocator_,
 
 1269                          reinterpret_cast<std::byte*>(frame_base),
 
 1270                          reinterpret_cast<std::byte*>(evacuation_ptr),
 
 1272       this_sched.suspended_thread_allocator_.deallocate(evacuation_ptr, frame_size);
 
 1274       context_frame* cf = reinterpret_cast<context_frame*>(frame_base);
 
 1276       context::resume(cf);
 
 1277     }, 
this, ss.evacuation_ptr, ss.frame_base, 
reinterpret_cast<void*
>(ss.frame_size));
 
 1280   void resume_sched() {
 
 1282     context::resume(sched_cf_);
 
 1285   void start_new_task(
void* task_ptr, std::size_t task_size) {
 
 1286     root_on_stack([&]() {
 
 1287       task_general* t = 
reinterpret_cast<task_general*
>(alloca(task_size));
 
 1290                          reinterpret_cast<std::byte*
>(t),
 
 1291                          reinterpret_cast<std::byte*
>(task_ptr),
 
 1293       suspended_thread_allocator_.deallocate(task_ptr, task_size);
 
 1299   void execute_migrated_task(
const suspended_state& ss) {
 
 1300     if (ss.evacuation_ptr == 
nullptr) {
 
 1303       common::profiler::switch_phase<prof_phase_sched_loop, prof_phase_sched_start_new>();
 
 1305       suspend([&](context_frame* cf) {
 
 1307         start_new_task(ss.frame_base, ss.frame_size);
 
 1313       common::profiler::switch_phase<prof_phase_sched_loop, prof_phase_sched_resume_migrate>();
 
 1315       suspend([&](context_frame* cf) {
 
 1322   void execute_migrated_task(
const migration_wsq_entry& mwe) {
 
 1323     if (!mwe.is_continuation) {
 
 1326       common::profiler::switch_phase<prof_phase_sched_loop, prof_phase_sched_start_new>();
 
 1328       suspend([&](context_frame* cf) {
 
 1330         start_new_task(mwe.frame_base, mwe.frame_size);
 
 1333     } 
else if (mwe.evacuation_ptr) {
 
 1335       common::verbose(
"Popped an evacuated continuation from local migration queues");
 
 1336       common::profiler::switch_phase<prof_phase_sched_loop, prof_phase_sched_resume_popped>();
 
 1338       suspend([&](context_frame* cf) {
 
 1340         resume(suspended_state{mwe.evacuation_ptr, mwe.frame_base, mwe.frame_size});
 
 1345       common::die(
"On-stack threads cannot remain after switching to the scheduler. Something went wrong.");
 
 1349   std::optional<primary_wsq_entry> pop_from_primary_queues(
int depth_from) {
 
 1351     for (
int d = depth_from; d >= 0; d--) {
 
 1352       auto pwe = primary_wsq_.
pop<
false>(d);
 
 1353       if (pwe.has_value()) {
 
 1357     return std::nullopt;
 
 1360   std::optional<migration_wsq_entry> pop_from_migration_queues() {
 
 1361     for (
int d = 0; d < migration_wsq_.
n_queues(); d++) {
 
 1362       auto mwe = migration_wsq_.
pop<
false>(d);
 
 1363       if (mwe.has_value()) {
 
 1367     return std::nullopt;
 
 1370   suspended_state evacuate(context_frame* cf) {
 
 1371     std::size_t cf_size = 
reinterpret_cast<uintptr_t
>(cf->parent_frame) - 
reinterpret_cast<uintptr_t
>(cf);
 
 1372     void* evacuation_ptr = suspended_thread_allocator_.allocate(cf_size);
 
 1373     std::memcpy(evacuation_ptr, cf, cf_size);
 
 1376                     cf, cf->parent_frame, evacuation_ptr);
 
 1378     return {evacuation_ptr, cf, cf_size};
 
 1381   void evacuate_all() {
 
 1382     if (use_primary_wsq_) {
 
 1385           if (!pwe.evacuation_ptr) {
 
 1386             context_frame* cf = reinterpret_cast<context_frame*>(pwe.frame_base);
 
 1387             suspended_state ss = evacuate(cf);
 
 1388             pwe = {ss.evacuation_ptr, ss.frame_base, ss.frame_size, pwe.tg_version};
 
 1394         if (mwe.is_continuation && !mwe.evacuation_ptr) {
 
 1395           context_frame* cf = reinterpret_cast<context_frame*>(mwe.frame_base);
 
 1396           suspended_state ss = evacuate(cf);
 
 1397           mwe = {true, ss.evacuation_ptr, ss.frame_base, ss.frame_size, mwe.tg_version};
 
 1399       }, tls_->dtree_node_ref.depth);
 
 1403   template <
typename PhaseFrom, 
typename PhaseTo,
 
 1404             typename PreSuspendCallback, 
typename PostSuspendCallback>
 
 1405   bool check_cross_worker_task_arrival(PreSuspendCallback&&  pre_suspend_cb,
 
 1406                                        PostSuspendCallback&& post_suspend_cb) {
 
 1407     if (migration_mailbox_.arrived()) {
 
 1408       tls_->dag_prof.stop();
 
 1411                                           prof_phase_cb_pre_suspend,
 
 1412                                           prof_phase_sched_evacuate>(
 
 1413           std::forward<PreSuspendCallback>(pre_suspend_cb));
 
 1419       suspend([&](context_frame* cf) {
 
 1420         suspended_state ss = evacuate(cf);
 
 1422         if (use_primary_wsq_) {
 
 1423           primary_wsq_.push({ss.evacuation_ptr, ss.frame_base, ss.frame_size, tls_->tg_version},
 
 1424                             tls_->dtree_node_ref.depth);
 
 1426           migration_wsq_.push({
true, ss.evacuation_ptr, ss.frame_base, ss.frame_size, tls_->tg_version},
 
 1427                               tls_->dtree_node_ref.depth);
 
 1430         common::profiler::switch_phase<prof_phase_sched_evacuate, prof_phase_sched_loop>();
 
 1436                               prof_phase_cb_post_suspend,
 
 1438             std::forward<PostSuspendCallback>(post_suspend_cb), cb_ret);
 
 1441                               prof_phase_cb_post_suspend,
 
 1443             std::forward<PostSuspendCallback>(post_suspend_cb), cb_ret);
 
 1446       tls_->dag_prof.start();
 
 1450     } 
else if constexpr (!std::is_same_v<PhaseTo, PhaseFrom>) {
 
 1451       common::profiler::switch_phase<PhaseFrom, PhaseTo>();
 
 1457   template <
typename Fn>
 
 1458   void root_on_stack(Fn&& fn) {
 
 1459     cf_top_ = stack_base_;
 
 1460     std::size_t stack_size_bytes = 
reinterpret_cast<std::byte*
>(stack_base_) -
 
 1461                                    reinterpret_cast<std::byte*
>(stack_.top());
 
 1462     context::call_on_stack(stack_.top(), stack_size_bytes,
 
 1463                            [](
void* fn_, 
void*, 
void*, 
void*) {
 
 1464       Fn fn = std::forward<Fn>(*reinterpret_cast<Fn*>(fn_)); 
 
 1466     }, &fn, 
nullptr, 
nullptr, 
nullptr);
 
 1469   void execute_coll_task(task_general* t, coll_task ct) {
 
 1471     coll_task ct_ {t, ct.task_size, ct.master_rank};
 
 1478       if (my_rank_shifted % i == 0) {
 
 1479         auto target_rank_shifted = my_rank_shifted + i / 2;
 
 1480         if (target_rank_shifted < 
n_ranks) {
 
 1481           auto target_rank = (target_rank_shifted + ct.master_rank) % 
n_ranks;
 
 1482           coll_task_mailbox_.put(ct_, target_rank);
 
 1487     auto prev_stack_base = stack_base_;
 
 1488     if (
my_rank == ct.master_rank) {
 
 1490       stack_base_ = cf_top_ - (cf_top_ - 
reinterpret_cast<context_frame*
>(stack_.top())) / 2;
 
 1501     stack_base_ = prev_stack_base;
 
 1507   void execute_coll_task_if_arrived() {
 
 1508     auto ct = coll_task_mailbox_.pop();
 
 1509     if (ct.has_value()) {
 
 1510       task_general* t = 
reinterpret_cast<task_general*
>(
 
 1511           suspended_thread_allocator_.allocate(ct->task_size));
 
 1514                          reinterpret_cast<std::byte*
>(t),
 
 1515                          reinterpret_cast<std::byte*
>(ct->task_ptr),
 
 1518       common::profiler::switch_phase<prof_phase_sched_loop, prof_phase_spmd>();
 
 1520       execute_coll_task(t, *ct);
 
 1522       common::profiler::switch_phase<prof_phase_spmd, prof_phase_sched_loop>();
 
 1524       suspended_thread_allocator_.deallocate(t, ct->task_size);
 
 1528   bool should_exit_sched_loop() {
 
 1529     if (sched_loop_make_mpi_progress_option::value()) {
 
 1533     execute_coll_task_if_arrived();
 
 1535     if (exit_request_mailbox_.pop()) {
 
 1540           auto target_rank = 
my_rank + i / 2;
 
 1542             exit_request_mailbox_.put(target_rank);
 
 1552   template <
typename T>
 
 1553   thread_retval<T> get_retval_remote(thread_state<T>* ts) {
 
 1554     if constexpr (std::is_trivially_copyable_v<T>) {
 
 1558       thread_retval<T> retval;
 
 1559       remote_get(thread_state_allocator_, 
reinterpret_cast<std::byte*
>(&retval), 
reinterpret_cast<std::byte*
>(&ts->retval), 
sizeof(thread_retval<T>));
 
 1564   template <
typename T>
 
 1565   void put_retval_remote(thread_state<T>* ts, thread_retval<T>&& retval) {
 
 1566     if constexpr (std::is_trivially_copyable_v<T>) {
 
 1570       std::byte* retvalp = 
reinterpret_cast<std::byte*
>(
new (alloca(
sizeof(thread_retval<T>))) thread_retval<T>{
std::move(retval)});
 
 1571       remote_put(thread_state_allocator_, retvalp, 
reinterpret_cast<std::byte*
>(&ts->retval), 
sizeof(thread_retval<T>));
 
 1577   context_frame*                     stack_base_;
 
 1578   oneslot_mailbox<void>              exit_request_mailbox_;
 
 1579   oneslot_mailbox<coll_task>         coll_task_mailbox_;
 
 1580   oneslot_mailbox<suspended_state>   migration_mailbox_;
 
 1581   wsqueue<primary_wsq_entry, false>  primary_wsq_;
 
 1582   wsqueue<migration_wsq_entry, true> migration_wsq_;
 
 1583   common::remotable_resource         thread_state_allocator_;
 
 1584   common::remotable_resource         suspended_thread_allocator_;
 
 1585   context_frame*                     cf_top_          = 
nullptr;
 
 1586   context_frame*                     sched_cf_        = 
nullptr;
 
 1587   thread_local_storage*              tls_             = 
nullptr;
 
 1588   bool                               use_primary_wsq_ = 
true;
 
 1590   dist_tree::node_ref                dtree_local_bottom_ref_;
 
 1591   bool                               dag_prof_enabled_ = 
false;
 
void unlock(topology::rank_t target_rank, int idx=0) const
Definition: global_lock.hpp:53
 
bool trylock(topology::rank_t target_rank, int idx=0) const
Definition: global_lock.hpp:21
 
span< T > local_buf() const
Definition: mpi_rma.hpp:412
 
MPI_Win win() const
Definition: mpi_rma.hpp:409
 
static value_type value()
Definition: options.hpp:62
 
bool is_remotely_freed(void *p, std::size_t alignment=alignof(max_align_t))
Definition: allocator.hpp:272
 
void direct_copy_from(void *addr, std::size_t size, common::topology::rank_t target_rank) const
Definition: callstack.hpp:25
 
common::topology::rank_t end_rank() const
Definition: adws.hpp:62
 
value_type end() const
Definition: adws.hpp:56
 
bool is_cross_worker() const
Definition: adws.hpp:94
 
double value_type
Definition: adws.hpp:47
 
dist_range()
Definition: adws.hpp:49
 
void move_to_end_boundary()
Definition: adws.hpp:70
 
dist_range(common::topology::rank_t n_ranks)
Definition: adws.hpp:50
 
void make_non_cross_worker()
Definition: adws.hpp:98
 
common::topology::rank_t begin_rank() const
Definition: adws.hpp:58
 
value_type begin() const
Definition: adws.hpp:55
 
bool is_sufficiently_small() const
Definition: adws.hpp:102
 
bool is_at_end_boundary() const
Definition: adws.hpp:66
 
std::pair< dist_range, dist_range > divide(T r1, T r2) const
Definition: adws.hpp:75
 
dist_range(value_type begin, value_type end)
Definition: adws.hpp:52
 
common::topology::rank_t owner() const
Definition: adws.hpp:90
 
dist_tree(int max_depth)
Definition: adws.hpp:131
 
void copy_parents(node_ref nr)
Definition: adws.hpp:225
 
node & get_local_node(node_ref nr)
Definition: adws.hpp:233
 
std::optional< node > get_topmost_dominant(node_ref nr)
Definition: adws.hpp:172
 
void set_dominant(node_ref nr, bool dominant)
Definition: adws.hpp:155
 
node_ref append(node_ref parent, dist_range drange, flipper tg_version)
Definition: adws.hpp:137
 
void flip(int at)
Definition: adws.hpp:26
 
bool match(flipper f, int until) const
Definition: adws.hpp:33
 
value_type value() const
Definition: adws.hpp:24
 
uint64_t value_type
Definition: adws.hpp:22
 
void put(common::topology::rank_t target_rank)
Definition: util.hpp:266
 
void dag_prof_end()
Definition: adws.hpp:819
 
void task_group_begin(task_group_data *tgdata)
Definition: adws.hpp:376
 
void migrate_to(common::topology::rank_t target_rank, PreSuspendCallback &&pre_suspend_cb, PostSuspendCallback &&post_suspend_cb)
Definition: adws.hpp:746
 
scheduler_adws()
Definition: adws.hpp:302
 
bool is_executing_root() const
Definition: adws.hpp:800
 
static bool is_serialized(const thread_handler< T > &th)
Definition: adws.hpp:805
 
void poll(PreSuspendCallback &&pre_suspend_cb, PostSuspendCallback &&post_suspend_cb)
Definition: adws.hpp:738
 
void dag_prof_begin()
Definition: adws.hpp:809
 
void fork(thread_handler< T > &th, OnDriftForkCallback on_drift_fork_cb, OnDriftDieCallback on_drift_die_cb, WorkHint w_new, WorkHint w_rest, Fn &&fn, Args &&... args)
Definition: adws.hpp:452
 
void sched_loop(SchedLoopCallback cb)
Definition: adws.hpp:692
 
T root_exec(SchedLoopCallback cb, Fn &&fn, Args &&... args)
Definition: adws.hpp:316
 
void task_group_end(PreSuspendCallback &&pre_suspend_cb, PostSuspendCallback &&post_suspend_cb)
Definition: adws.hpp:405
 
T join(thread_handler< T > &th)
Definition: adws.hpp:621
 
void dag_prof_print() const
Definition: adws.hpp:833
 
void coll_exec(const Fn &fn)
Definition: adws.hpp:776
 
std::optional< Entry > pop(int idx=0)
Definition: wsqueue.hpp:67
 
void abort_steal(common::topology::rank_t target_rank, int idx=0)
Definition: wsqueue.hpp:189
 
std::optional< Entry > steal_nolock(common::topology::rank_t target_rank, int idx=0)
Definition: wsqueue.hpp:158
 
void pass(const Entry &entry, common::topology::rank_t target_rank, int idx=0)
Definition: wsqueue.hpp:223
 
int n_queues() const
Definition: wsqueue.hpp:305
 
void push(const Entry &entry, int idx=0)
Definition: wsqueue.hpp:36
 
void for_each_entry(Fn fn, int idx=0)
Definition: wsqueue.hpp:231
 
const common::global_lock & lock() const
Definition: wsqueue.hpp:304
 
void for_each_nonempty_queue(common::topology::rank_t target_rank, int idx_begin, int idx_end, bool reverse, Fn fn)
Definition: wsqueue.hpp:281
 
#define ITYR_CHECK(cond)
Definition: util.hpp:48
 
bool enabled()
Definition: numa.hpp:86
 
rank_t n_ranks()
Definition: topology.hpp:208
 
int rank_t
Definition: topology.hpp:12
 
MPI_Comm mpicomm()
Definition: topology.hpp:206
 
rank_t my_rank()
Definition: topology.hpp:207
 
void remote_get(const remotable_resource &rmr, T *origin_p, const T *target_p, std::size_t size)
Definition: allocator.hpp:404
 
void remote_put_value(const remotable_resource &rmr, const T &val, T *target_p)
Definition: allocator.hpp:434
 
T mpi_atomic_put_value(const T &value, int target_rank, std::size_t target_disp, MPI_Win win)
Definition: mpi_rma.hpp:283
 
va_list args
Definition: util.hpp:76
 
T remote_faa_value(const remotable_resource &rmr, const T &val, T *target_p)
Definition: allocator.hpp:444
 
void mpi_make_progress()
Definition: mpi_util.hpp:260
 
T mpi_bcast_value(const T &value, int root_rank, MPI_Comm comm)
Definition: mpi_util.hpp:145
 
T mpi_allreduce_value(const T &value, MPI_Comm comm, MPI_Op op=MPI_SUM)
Definition: mpi_util.hpp:194
 
void mpi_get(T *origin, std::size_t count, int target_rank, std::size_t target_disp, MPI_Win win)
Definition: mpi_rma.hpp:69
 
void remote_put(const remotable_resource &rmr, const T *origin_p, T *target_p, std::size_t size)
Definition: allocator.hpp:424
 
T remote_get_value(const remotable_resource &rmr, const T *target_p)
Definition: allocator.hpp:414
 
T mpi_atomic_cas_value(const T &value, const T &compare, int target_rank, std::size_t target_disp, MPI_Win win)
Definition: mpi_rma.hpp:222
 
void mpi_barrier(MPI_Comm comm)
Definition: mpi_util.hpp:42
 
uint64_t next_pow2(uint64_t x)
Definition: util.hpp:102
 
void verbose(const char *fmt,...)
Definition: logger.hpp:11
 
Definition: aarch64.hpp:5
 
auto call_with_prof_events(Fn &&fn, Args &&... args)
Definition: util.hpp:182
 
ITYR_CONCAT(dag_profiler_, ITYR_ITO_DAG_PROF) dag_profiler
Definition: util.hpp:102
 
common::topology::rank_t get_random_rank(common::topology::rank_t a, common::topology::rank_t b)
Definition: util.hpp:128
 
monoid< T, max_functor<>, lowest< T > > max
Definition: reducer.hpp:104
 
rank_t my_rank()
Return the rank of the process running the current thread.
Definition: ityr.hpp:99
 
rank_t n_ranks()
Return the total number of processes.
Definition: ityr.hpp:107
 
ForwardIteratorD move(const ExecutionPolicy &policy, ForwardIterator1 first1, ForwardIterator1 last1, ForwardIteratorD first_d)
Move a range to another.
Definition: parallel_loop.hpp:934
 
#define ITYR_PROFILER_RECORD(event,...)
Definition: profiler.hpp:319
 
Definition: options.hpp:62
 
Definition: options.hpp:56
 
int depth
Definition: adws.hpp:117
 
common::topology::rank_t owner_rank
Definition: adws.hpp:116
 
node_ref parent
Definition: adws.hpp:125
 
node()
Definition: adws.hpp:121
 
flipper tg_version
Definition: adws.hpp:127
 
dist_range drange
Definition: adws.hpp:126
 
int depth() const
Definition: adws.hpp:123
 
version_t version
Definition: adws.hpp:128
 
Definition: prof_events.hpp:95
 
Definition: prof_events.hpp:190
 
Definition: prof_events.hpp:205
 
Definition: prof_events.hpp:200
 
Definition: prof_events.hpp:155
 
Definition: prof_events.hpp:180
 
Definition: prof_events.hpp:175
 
Definition: prof_events.hpp:185
 
Definition: prof_events.hpp:210
 
void * frame_base
Definition: adws.hpp:261
 
std::size_t frame_size
Definition: adws.hpp:262
 
void * evacuation_ptr
Definition: adws.hpp:260
 
dag_profiler dag_prof_acc
Definition: adws.hpp:290
 
dag_profiler dag_prof_before
Definition: adws.hpp:289
 
bool owns_dtree_node
Definition: adws.hpp:288
 
task_group_data * parent
Definition: adws.hpp:286
 
dist_range drange
Definition: adws.hpp:287
 
thread_retval< T > retval_ser
Definition: adws.hpp:282
 
bool serialized
Definition: adws.hpp:281
 
thread_state< T > * state
Definition: adws.hpp:280
 
dag_profiler dag_prof
Definition: adws.hpp:299
 
flipper tg_version
Definition: adws.hpp:297
 
dist_range drange
Definition: adws.hpp:295
 
dist_tree::node_ref dtree_node_ref
Definition: adws.hpp:296
 
bool undistributed
Definition: adws.hpp:298
 
task_group_data * tgdata
Definition: adws.hpp:294
 
T value
Definition: adws.hpp:267
 
dag_profiler dag_prof
Definition: adws.hpp:268
 
thread_retval< T > retval
Definition: adws.hpp:273
 
suspended_state suspended
Definition: adws.hpp:275
 
int resume_flag
Definition: adws.hpp:274
 
Definition: options.hpp:20
 
Definition: options.hpp:38
 
Definition: options.hpp:32