38 return (val_ & mask) == (f.
value() & mask);
75 std::pair<dist_range, dist_range>
divide(T r1, T r2)
const {
76 value_type at = begin_ + (end_ - begin_) * r1 / (r1 + r2);
84 if (at < begin_) at = begin_;
112 using version_t = int;
132 : max_depth_(max_depth),
133 node_win_(common::topology::
mpicomm(), max_depth_),
134 dominant_flag_win_(common::topology::
mpicomm(), max_depth_, 0),
135 versions_(max_depth_, common::topology::
my_rank() + 1) {}
138 int depth = parent.
depth + 1;
146 node& new_node = local_node(depth);
158 version_t value = (dominant ? 1 : -1) * local_node(nr.
depth).
version;
160 local_dominant_flag(nr.
depth).store(value, std::memory_order_relaxed);
163 std::size_t disp_dominant = nr.
depth *
sizeof(version_t);
173 if (nr.
depth < 0)
return std::nullopt;
177 for (
int d = 0; d <= nr.
depth; d++) {
181 node& n = local_node(d);
182 auto& dominant_flag = local_dominant_flag(d);
188 dominant_flag.load(std::memory_order_relaxed) != -n.
version) {
194 if (target_rank != owner_rank &&
195 dominant_flag.load(std::memory_order_relaxed) == n.
version) {
197 std::size_t disp_dominant = d *
sizeof(version_t);
199 target_rank, disp_dominant, dominant_flag_win_.
win());
201 if (dominant_val == -n.
version) {
202 dominant_flag.store(dominant_val, std::memory_order_relaxed);
206 std::size_t disp_dominant = d *
sizeof(version_t);
207 version_t dominant_val = common::mpi_atomic_get_value<version_t>(
208 target_rank, disp_dominant, dominant_flag_win_.
win());
211 dominant_flag.store(dominant_val, std::memory_order_relaxed);
216 if (dominant_flag.load(std::memory_order_relaxed) == n.
version) {
226 for (
int d = 0; d <= nr.
depth; d++) {
228 local_dominant_flag(d).store(0, std::memory_order_relaxed);
235 return local_node(nr.
depth);
239 node& local_node(
int depth) {
242 return node_win_.local_buf()[depth];
245 std::atomic<version_t>& local_dominant_flag(
int depth) {
248 return dominant_flag_win_.
local_buf()[depth];
252 common::mpi_win_manager<node> node_win_;
253 common::mpi_win_manager<std::atomic<version_t>> dominant_flag_win_;
254 std::vector<version_t> versions_;
265 template <
typename T>
271 template <
typename T>
278 template <
typename T>
308 stack_base_(reinterpret_cast<context_frame*>(stack_.bottom()) - 1),
313 dtree_(max_depth_) {}
315 template <
typename T,
typename SchedLoopCallback,
typename Fn,
typename... Args>
317 common::profiler::switch_phase<prof_phase_spmd, prof_phase_sched_fork>();
321 auto prev_sched_cf = sched_cf_;
323 suspend([&](context_frame* cf) {
325 root_on_stack([&, ts, fn = std::forward<Fn>(fn),
326 args_tuple = std::make_tuple(std::forward<Args>(
args)...)]()
mutable {
336 tls_->
dag_prof.increment_thread_count();
337 tls_->
dag_prof.increment_strand_count();
339 common::profiler::switch_phase<prof_phase_sched_fork, prof_phase_thread>();
341 T&& ret = invoke_fn<T>(std::forward<decltype(fn)>(fn), std::forward<decltype(args_tuple)>(args_tuple));
343 common::profiler::switch_phase<prof_phase_thread, prof_phase_sched_die>();
354 common::profiler::switch_phase<prof_phase_sched_loop, prof_phase_sched_join>();
360 if (dag_prof_enabled_) {
365 dag_prof_result_.merge_serial(retval.
dag_prof);
369 sched_cf_ = prev_sched_cf;
371 common::profiler::switch_phase<prof_phase_sched_join, prof_phase_spmd>();
395 common::verbose(
"Begin a cross-worker task group of distribution range [%f, %f) at depth %d",
401 tls_->
dag_prof.increment_strand_count();
404 template <
typename PreSuspendCallback,
typename PostSuspendCallback>
406 PostSuspendCallback&& post_suspend_cb) {
419 common::verbose(
"End a cross-worker task group of distribution range [%f, %f) at depth %d",
424 std::forward<PreSuspendCallback>(pre_suspend_cb),
425 std::forward<PostSuspendCallback>(post_suspend_cb));
444 std::destroy_at(tgdata);
447 tls_->
dag_prof.increment_strand_count();
450 template <
typename T,
typename OnDriftForkCallback,
typename OnDriftDieCallback,
451 typename WorkHint,
typename Fn,
typename... Args>
453 OnDriftForkCallback on_drift_fork_cb, OnDriftDieCallback on_drift_die_cb,
454 WorkHint w_new, WorkHint w_rest, Fn&& fn, Args&&...
args) {
455 common::profiler::switch_phase<prof_phase_thread, prof_phase_sched_fork>();
471 auto [dr_rest, dr_new] = tls_->
drange.
divide(w_rest, w_new);
473 common::verbose(
"Distribution range [%f, %f) is divided into [%f, %f) and [%f, %f)",
475 dr_rest.begin(), dr_rest.end(), dr_new.begin(), dr_new.end());
479 target_rank = dr_new.
owner();
483 new_drange = tls_->
drange;
492 suspend([&, ts, fn = std::forward<Fn>(fn),
493 args_tuple = std::make_tuple(std::forward<Args>(
args)...)](context_frame* cf)
mutable {
494 common::verbose<3>(
"push context frame [%p, %p) into task queue", cf, cf->parent_frame);
500 std::size_t cf_size =
reinterpret_cast<uintptr_t
>(cf->parent_frame) -
reinterpret_cast<uintptr_t
>(cf);
502 if (use_primary_wsq_) {
506 migration_wsq_.
push({
true,
nullptr, cf, cf_size, tls_->
tg_version},
511 tls_->
dag_prof.increment_thread_count();
512 tls_->
dag_prof.increment_strand_count();
514 common::verbose<3>(
"Starting new thread %p", ts);
515 common::profiler::switch_phase<prof_phase_sched_fork, prof_phase_thread>();
517 T&& ret = invoke_fn<T>(std::forward<decltype(fn)>(fn), std::forward<decltype(args_tuple)>(args_tuple));
519 common::profiler::switch_phase<prof_phase_thread, prof_phase_sched_die>();
520 common::verbose<3>(
"Thread %p is completed", ts);
523 on_die_workfirst(ts,
std::move(ret), on_drift_die_cb);
525 common::verbose<3>(
"Thread %p is serialized (fast path)", ts);
534 common::verbose<3>(
"Resume parent context frame [%p, %p) (fast path)", cf, cf->parent_frame);
536 common::profiler::switch_phase<prof_phase_sched_die, prof_phase_sched_resume_popped>();
541 common::profiler::switch_phase<prof_phase_sched_resume_popped, prof_phase_thread>();
551 auto new_task_fn = [&,
my_rank, ts, new_drange,
554 on_drift_fork_cb, on_drift_die_cb, fn = std::forward<Fn>(fn),
555 args_tuple = std::make_tuple(std::forward<Args>(
args)...)]()
mutable {
557 ts, new_drange.begin(), new_drange.end());
561 tg_version,
true, {}};
563 if (new_drange.is_cross_worker()) {
565 dtree_local_bottom_ref_ = dtree_node_ref;
569 tls_->
dag_prof.increment_thread_count();
570 tls_->
dag_prof.increment_strand_count();
578 common::profiler::switch_phase<prof_phase_sched_start_new, prof_phase_thread>();
581 T&& ret = invoke_fn<T>(std::forward<decltype(fn)>(fn), std::forward<decltype(args_tuple)>(args_tuple));
583 common::profiler::switch_phase<prof_phase_thread, prof_phase_sched_die>();
585 ts, new_drange.begin(), new_drange.end());
588 on_die_drifted(ts,
std::move(ret), on_drift_die_cb);
591 using callable_task_t =
callable_task<decltype(new_task_fn)>;
593 size_t task_size =
sizeof(callable_task_t);
594 void* task_ptr = suspended_thread_allocator_.allocate(task_size);
596 auto t =
new (task_ptr) callable_task_t(
std::move(new_task_fn));
598 if (new_drange.is_cross_worker()) {
600 ts, new_drange.begin(), new_drange.end(), target_rank);
602 migration_mailbox_.put({
nullptr, t, task_size}, target_rank);
604 common::verbose(
"Migrate non-cross-worker-task %p [%f, %f) to process %d",
605 ts, new_drange.begin(), new_drange.end(), target_rank);
607 migration_wsq_.
pass({
false,
nullptr, t, task_size, tls_->
tg_version},
611 common::profiler::switch_phase<prof_phase_sched_fork, prof_phase_thread>();
617 tls_->
dag_prof.increment_strand_count();
620 template <
typename T>
622 common::profiler::switch_phase<prof_phase_thread, prof_phase_sched_join>();
626 common::verbose<3>(
"Skip join for serialized thread (fast path)");
642 retval = get_retval_remote(ts);
646 bool migrated =
true;
647 suspend([&](context_frame* cf) {
654 common::verbose(
"Win the join race for thread %p (joining thread)", ts);
656 common::profiler::switch_phase<prof_phase_sched_join, prof_phase_sched_loop>();
660 common::verbose(
"Lose the join race for thread %p (joining thread)", ts);
669 common::profiler::switch_phase<prof_phase_sched_resume_join, prof_phase_sched_join>();
673 retval = get_retval_remote(ts);
687 common::profiler::switch_phase<prof_phase_sched_join, prof_phase_thread>();
691 template <
typename SchedLoopCallback>
695 while (!should_exit_sched_loop()) {
696 auto mte = migration_mailbox_.pop();
697 if (mte.has_value()) {
698 execute_migrated_task(*mte);
702 auto pwe = pop_from_primary_queues(primary_wsq_.
n_queues() - 1);
703 if (pwe.has_value()) {
704 common::profiler::switch_phase<prof_phase_sched_loop, prof_phase_sched_resume_popped>();
708 suspend([&](context_frame* cf) {
715 auto mwe = pop_from_migration_queues();
716 if (mwe.has_value()) {
717 use_primary_wsq_ =
false;
718 execute_migrated_task(*mwe);
719 use_primary_wsq_ =
true;
727 if constexpr (!std::is_null_pointer_v<std::remove_reference_t<SchedLoopCallback>>) {
732 dtree_local_bottom_ref_ = {};
737 template <
typename PreSuspendCallback,
typename PostSuspendCallback>
738 void poll(PreSuspendCallback&& pre_suspend_cb,
739 PostSuspendCallback&& post_suspend_cb) {
740 check_cross_worker_task_arrival<prof_phase_thread, prof_phase_thread>(
741 std::forward<PreSuspendCallback>(pre_suspend_cb),
742 std::forward<PostSuspendCallback>(post_suspend_cb));
745 template <
typename PreSuspendCallback,
typename PostSuspendCallback>
747 PreSuspendCallback&& pre_suspend_cb,
748 PostSuspendCallback&& post_suspend_cb) {
754 std::forward<PreSuspendCallback>(pre_suspend_cb));
756 suspend([&](context_frame* cf) {
759 common::verbose(
"Migrate continuation of cross-worker-task [%f, %f) to process %d",
762 migration_mailbox_.put(ss, target_rank);
765 common::profiler::switch_phase<prof_phase_sched_migrate, prof_phase_sched_loop>();
772 std::forward<PostSuspendCallback>(post_suspend_cb), cb_ret);
775 template <
typename Fn>
777 common::profiler::switch_phase<prof_phase_thread, prof_phase_spmd>();
784 size_t task_size =
sizeof(callable_task_t);
785 void* task_ptr = suspended_thread_allocator_.allocate(task_size);
787 auto t =
new (task_ptr) callable_task_t(fn);
790 execute_coll_task(t, ct);
792 suspended_thread_allocator_.deallocate(t, task_size);
795 tls_->
dag_prof.increment_strand_count();
797 common::profiler::switch_phase<prof_phase_spmd, prof_phase_thread>();
801 return cf_top_ && cf_top_ == stack_base_;
804 template <
typename T>
810 dag_prof_enabled_ =
true;
811 dag_prof_result_.clear();
815 tls_->
dag_prof.increment_thread_count();
820 dag_prof_enabled_ =
false;
835 dag_prof_result_.print();
842 std::size_t task_size;
846 struct primary_wsq_entry {
847 void* evacuation_ptr;
849 std::size_t frame_size;
853 struct migration_wsq_entry {
854 bool is_continuation;
855 void* evacuation_ptr;
857 std::size_t frame_size;
874 common::verbose(
"Distribution tree node (owner=%d, depth=%d) becomes dominant",
881 std::vector<std::pair<suspended_state, common::topology::rank_t>> tasks;
887 target_rank < tls_->drange.end_rank();
893 dtree_local_bottom_ref_ = dtree_node_ref;
895 common::profiler::switch_phase<prof_phase_sched_start_new, prof_phase_sched_loop>();
899 using callable_task_t = callable_task<decltype(new_task_fn)>;
901 size_t task_size =
sizeof(callable_task_t);
902 void* task_ptr = suspended_thread_allocator_.allocate(task_size);
904 auto t =
new (task_ptr) callable_task_t(new_task_fn);
905 tasks.push_back({{
nullptr, t, task_size}, target_rank});
909 for (
auto [t, target_rank] : tasks) {
910 migration_mailbox_.put(t, target_rank);
915 for (
auto [t, target_rank] : tasks) {
929 template <
typename T,
typename OnDriftDieCallback>
930 void on_die_workfirst(thread_state<T>* ts, T&& ret, OnDriftDieCallback on_drift_die_cb) {
931 if (use_primary_wsq_) {
933 if (qe.has_value()) {
934 if (!qe->evacuation_ptr) {
948 if (qe.has_value()) {
949 if (qe->is_continuation && !qe->evacuation_ptr) {
958 on_die_drifted(ts,
std::move(ret), on_drift_die_cb);
961 template <
typename T,
typename OnDriftDieCallback>
962 void on_die_drifted(thread_state<T>* ts, T&& ret, OnDriftDieCallback on_drift_die_cb) {
963 if constexpr (!std::is_null_pointer_v<std::remove_reference_t<OnDriftDieCallback>>) {
965 prof_phase_cb_drift_die,
966 prof_phase_sched_die>(on_drift_die_cb);
975 common::verbose(
"Win the join race for thread %p (joined thread)", ts);
981 common::profiler::switch_phase<prof_phase_sched_die, prof_phase_sched_loop>();
985 common::verbose(
"Lose the join race for thread %p (joined thread)", ts);
986 common::profiler::switch_phase<prof_phase_sched_die, prof_phase_sched_resume_join>();
987 suspended_state ss =
remote_get_value(thread_state_allocator_, &ts->suspended);
992 template <
typename T>
993 void on_root_die(thread_state<T>* ts, T&& ret) {
999 exit_request_mailbox_.
put(0);
1001 common::profiler::switch_phase<prof_phase_sched_die, prof_phase_sched_loop>();
1007 if (!ne.has_value()) {
1008 common::verbose<2>(
"Dominant dist_tree node not found");
1011 dist_range steal_range = ne->drange;
1012 flipper tg_version = ne->tg_version;
1013 int depth = ne->depth();
1015 common::verbose<2>(
"Dominant dist_tree node found: drange=[%f, %f), depth=%d",
1016 steal_range.begin(), steal_range.end(), depth);
1020 auto begin_rank = steal_range.begin_rank();
1021 auto end_rank = steal_range.end_rank();
1023 if (steal_range.is_at_end_boundary()) {
1027 if (begin_rank == end_rank) {
1033 common::verbose<2>(
"Start work stealing for dominant task group [%f, %f)",
1034 steal_range.begin(), steal_range.end());
1038 for (
int i = 0; i < max_reuse; i++) {
1041 common::verbose<2>(
"Target rank: %d", target_rank);
1043 if (target_rank != begin_rank) {
1044 bool success = steal_from_migration_queues(target_rank, depth, migration_wsq_.
n_queues(),
1045 [=](migration_wsq_entry& mwe) { return mwe.tg_version.match(tg_version, depth); });
1051 if (target_rank != end_rank || (target_rank == end_rank && steal_range.is_at_end_boundary())) {
1052 bool success = steal_from_primary_queues(target_rank, depth, primary_wsq_.
n_queues(),
1053 [=](primary_wsq_entry& pwe) { return pwe.tg_version.match(tg_version, depth); });
1060 auto mte = migration_mailbox_.pop();
1061 if (mte.has_value()) {
1062 execute_migrated_task(*mte);
1068 template <
typename StealCondFn>
1070 int min_depth,
int max_depth, StealCondFn&& steal_cond_fn) {
1071 bool steal_success =
false;
1074 auto ibd = common::profiler::interval_begin<prof_event_sched_steal>(target_rank);
1076 if (!primary_wsq_.
lock().
trylock(target_rank, d)) {
1077 common::profiler::interval_end<prof_event_sched_steal>(ibd,
false);
1082 if (!pwe.has_value()) {
1084 common::profiler::interval_end<prof_event_sched_steal>(ibd,
false);
1088 if (!steal_cond_fn(*pwe)) {
1091 common::profiler::interval_end<prof_event_sched_steal>(ibd,
false);
1096 if (pwe->evacuation_ptr) {
1098 common::verbose(
"Steal an evacuated context frame [%p, %p) from primary wsqueue (depth=%d) on rank %d",
1099 pwe->frame_base,
reinterpret_cast<std::byte*
>(pwe->frame_base) + pwe->frame_size,
1104 common::profiler::interval_end<prof_event_sched_steal>(ibd,
true);
1106 common::profiler::switch_phase<prof_phase_sched_loop, prof_phase_sched_resume_stolen>();
1108 suspend([&](context_frame* cf) {
1110 resume(suspended_state{pwe->evacuation_ptr, pwe->frame_base, pwe->frame_size});
1115 common::verbose(
"Steal context frame [%p, %p) from primary wsqueue (depth=%d) on rank %d",
1116 pwe->frame_base,
reinterpret_cast<std::byte*
>(pwe->frame_base) + pwe->frame_size,
1123 common::profiler::interval_end<prof_event_sched_steal>(ibd,
true);
1125 common::profiler::switch_phase<prof_phase_sched_loop, prof_phase_sched_resume_stolen>();
1127 context_frame* next_cf =
reinterpret_cast<context_frame*
>(pwe->frame_base);
1128 suspend([&](context_frame* cf) {
1130 context::clear_parent_frame(next_cf);
1135 steal_success =
true;
1139 if (!steal_success) {
1140 common::verbose<2>(
"Steal failed for primary queues on rank %d", target_rank);
1142 return steal_success;
1145 template <
typename StealCondFn>
1147 int min_depth,
int max_depth, StealCondFn&& steal_cond_fn) {
1148 bool steal_success =
false;
1151 auto ibd = common::profiler::interval_begin<prof_event_sched_steal>(target_rank);
1153 if (!migration_wsq_.
lock().
trylock(target_rank, d)) {
1154 common::profiler::interval_end<prof_event_sched_steal>(ibd,
false);
1158 auto mwe = migration_wsq_.
steal_nolock(target_rank, d);
1159 if (!mwe.has_value()) {
1160 migration_wsq_.
lock().
unlock(target_rank, d);
1161 common::profiler::interval_end<prof_event_sched_steal>(ibd,
false);
1165 if (!steal_cond_fn(*mwe)) {
1167 migration_wsq_.
lock().
unlock(target_rank, d);
1168 common::profiler::interval_end<prof_event_sched_steal>(ibd,
false);
1172 if (!mwe->is_continuation) {
1174 common::verbose(
"Steal a new task from migration wsqueue (depth=%d) on rank %d",
1177 migration_wsq_.
lock().
unlock(target_rank, d);
1179 common::profiler::interval_end<prof_event_sched_steal>(ibd,
true);
1181 common::profiler::switch_phase<prof_phase_sched_loop, prof_phase_sched_start_new>();
1183 suspend([&](context_frame* cf) {
1185 start_new_task(mwe->frame_base, mwe->frame_size);
1188 }
else if (mwe->evacuation_ptr) {
1190 common::verbose(
"Steal an evacuated context frame [%p, %p) from migration wsqueue (depth=%d) on rank %d",
1191 mwe->frame_base,
reinterpret_cast<std::byte*
>(mwe->frame_base) + mwe->frame_size,
1194 migration_wsq_.
lock().
unlock(target_rank, d);
1196 common::profiler::interval_end<prof_event_sched_steal>(ibd,
true);
1198 common::profiler::switch_phase<prof_phase_sched_loop, prof_phase_sched_resume_stolen>();
1200 suspend([&](context_frame* cf) {
1202 resume(suspended_state{mwe->evacuation_ptr, mwe->frame_base, mwe->frame_size});
1207 common::verbose(
"Steal a context frame [%p, %p) from migration wsqueue (depth=%d) on rank %d",
1208 mwe->frame_base,
reinterpret_cast<std::byte*
>(mwe->frame_base) + mwe->frame_size,
1213 migration_wsq_.
lock().
unlock(target_rank, d);
1215 common::profiler::interval_end<prof_event_sched_steal>(ibd,
true);
1217 common::profiler::switch_phase<prof_phase_sched_loop, prof_phase_sched_resume_stolen>();
1219 suspend([&](context_frame* cf) {
1221 context_frame* next_cf =
reinterpret_cast<context_frame*
>(mwe->frame_base);
1226 steal_success =
true;
1230 if (!steal_success) {
1231 common::verbose<2>(
"Steal failed for migration queues on rank %d", target_rank);
1233 return steal_success;
1236 template <
typename Fn>
1237 void suspend(Fn&& fn) {
1238 context_frame* prev_cf_top = cf_top_;
1239 thread_local_storage* prev_tls = tls_;
1241 context::save_context_with_call(prev_cf_top,
1242 [](context_frame* cf,
void* cf_top_p,
void* fn_p) {
1243 context_frame*& cf_top = *
reinterpret_cast<context_frame**
>(cf_top_p);
1244 Fn fn = std::forward<Fn>(*
reinterpret_cast<Fn*
>(fn_p));
1247 }, &cf_top_, &fn, prev_tls);
1249 cf_top_ = prev_cf_top;
1253 void resume(context_frame* cf) {
1254 common::verbose(
"Resume context frame [%p, %p) in the stack", cf, cf->parent_frame);
1255 context::resume(cf);
1258 void resume(suspended_state ss) {
1260 ss.frame_base,
reinterpret_cast<std::byte*
>(ss.frame_base) + ss.frame_size, ss.evacuation_ptr);
1264 context::jump_to_stack(ss.frame_base, [](
void* this_,
void* evacuation_ptr,
void* frame_base,
void* frame_size_) {
1265 scheduler_adws& this_sched = *reinterpret_cast<scheduler_adws*>(this_);
1266 std::size_t frame_size = reinterpret_cast<std::size_t>(frame_size_);
1268 common::remote_get(this_sched.suspended_thread_allocator_,
1269 reinterpret_cast<std::byte*>(frame_base),
1270 reinterpret_cast<std::byte*>(evacuation_ptr),
1272 this_sched.suspended_thread_allocator_.deallocate(evacuation_ptr, frame_size);
1274 context_frame* cf = reinterpret_cast<context_frame*>(frame_base);
1276 context::resume(cf);
1277 },
this, ss.evacuation_ptr, ss.frame_base,
reinterpret_cast<void*
>(ss.frame_size));
1280 void resume_sched() {
1282 context::resume(sched_cf_);
1285 void start_new_task(
void* task_ptr, std::size_t task_size) {
1286 root_on_stack([&]() {
1287 task_general* t =
reinterpret_cast<task_general*
>(alloca(task_size));
1290 reinterpret_cast<std::byte*
>(t),
1291 reinterpret_cast<std::byte*
>(task_ptr),
1293 suspended_thread_allocator_.deallocate(task_ptr, task_size);
1299 void execute_migrated_task(
const suspended_state& ss) {
1300 if (ss.evacuation_ptr ==
nullptr) {
1303 common::profiler::switch_phase<prof_phase_sched_loop, prof_phase_sched_start_new>();
1305 suspend([&](context_frame* cf) {
1307 start_new_task(ss.frame_base, ss.frame_size);
1313 common::profiler::switch_phase<prof_phase_sched_loop, prof_phase_sched_resume_migrate>();
1315 suspend([&](context_frame* cf) {
1322 void execute_migrated_task(
const migration_wsq_entry& mwe) {
1323 if (!mwe.is_continuation) {
1326 common::profiler::switch_phase<prof_phase_sched_loop, prof_phase_sched_start_new>();
1328 suspend([&](context_frame* cf) {
1330 start_new_task(mwe.frame_base, mwe.frame_size);
1333 }
else if (mwe.evacuation_ptr) {
1335 common::verbose(
"Popped an evacuated continuation from local migration queues");
1336 common::profiler::switch_phase<prof_phase_sched_loop, prof_phase_sched_resume_popped>();
1338 suspend([&](context_frame* cf) {
1340 resume(suspended_state{mwe.evacuation_ptr, mwe.frame_base, mwe.frame_size});
1345 common::die(
"On-stack threads cannot remain after switching to the scheduler. Something went wrong.");
1349 std::optional<primary_wsq_entry> pop_from_primary_queues(
int depth_from) {
1351 for (
int d = depth_from; d >= 0; d--) {
1352 auto pwe = primary_wsq_.
pop<
false>(d);
1353 if (pwe.has_value()) {
1357 return std::nullopt;
1360 std::optional<migration_wsq_entry> pop_from_migration_queues() {
1361 for (
int d = 0; d < migration_wsq_.
n_queues(); d++) {
1362 auto mwe = migration_wsq_.
pop<
false>(d);
1363 if (mwe.has_value()) {
1367 return std::nullopt;
1370 suspended_state evacuate(context_frame* cf) {
1371 std::size_t cf_size =
reinterpret_cast<uintptr_t
>(cf->parent_frame) -
reinterpret_cast<uintptr_t
>(cf);
1372 void* evacuation_ptr = suspended_thread_allocator_.allocate(cf_size);
1373 std::memcpy(evacuation_ptr, cf, cf_size);
1376 cf, cf->parent_frame, evacuation_ptr);
1378 return {evacuation_ptr, cf, cf_size};
1381 void evacuate_all() {
1382 if (use_primary_wsq_) {
1385 if (!pwe.evacuation_ptr) {
1386 context_frame* cf = reinterpret_cast<context_frame*>(pwe.frame_base);
1387 suspended_state ss = evacuate(cf);
1388 pwe = {ss.evacuation_ptr, ss.frame_base, ss.frame_size, pwe.tg_version};
1394 if (mwe.is_continuation && !mwe.evacuation_ptr) {
1395 context_frame* cf = reinterpret_cast<context_frame*>(mwe.frame_base);
1396 suspended_state ss = evacuate(cf);
1397 mwe = {true, ss.evacuation_ptr, ss.frame_base, ss.frame_size, mwe.tg_version};
1399 }, tls_->dtree_node_ref.depth);
1403 template <
typename PhaseFrom,
typename PhaseTo,
1404 typename PreSuspendCallback,
typename PostSuspendCallback>
1405 bool check_cross_worker_task_arrival(PreSuspendCallback&& pre_suspend_cb,
1406 PostSuspendCallback&& post_suspend_cb) {
1407 if (migration_mailbox_.arrived()) {
1408 tls_->dag_prof.stop();
1411 prof_phase_cb_pre_suspend,
1412 prof_phase_sched_evacuate>(
1413 std::forward<PreSuspendCallback>(pre_suspend_cb));
1419 suspend([&](context_frame* cf) {
1420 suspended_state ss = evacuate(cf);
1422 if (use_primary_wsq_) {
1423 primary_wsq_.push({ss.evacuation_ptr, ss.frame_base, ss.frame_size, tls_->tg_version},
1424 tls_->dtree_node_ref.depth);
1426 migration_wsq_.push({
true, ss.evacuation_ptr, ss.frame_base, ss.frame_size, tls_->tg_version},
1427 tls_->dtree_node_ref.depth);
1430 common::profiler::switch_phase<prof_phase_sched_evacuate, prof_phase_sched_loop>();
1436 prof_phase_cb_post_suspend,
1438 std::forward<PostSuspendCallback>(post_suspend_cb), cb_ret);
1441 prof_phase_cb_post_suspend,
1443 std::forward<PostSuspendCallback>(post_suspend_cb), cb_ret);
1446 tls_->dag_prof.start();
1450 }
else if constexpr (!std::is_same_v<PhaseTo, PhaseFrom>) {
1451 common::profiler::switch_phase<PhaseFrom, PhaseTo>();
1457 template <
typename Fn>
1458 void root_on_stack(Fn&& fn) {
1459 cf_top_ = stack_base_;
1460 std::size_t stack_size_bytes =
reinterpret_cast<std::byte*
>(stack_base_) -
1461 reinterpret_cast<std::byte*
>(stack_.top());
1462 context::call_on_stack(stack_.top(), stack_size_bytes,
1463 [](
void* fn_,
void*,
void*,
void*) {
1464 Fn fn = std::forward<Fn>(*reinterpret_cast<Fn*>(fn_));
1466 }, &fn,
nullptr,
nullptr,
nullptr);
1469 void execute_coll_task(task_general* t, coll_task ct) {
1471 coll_task ct_ {t, ct.task_size, ct.master_rank};
1478 if (my_rank_shifted % i == 0) {
1479 auto target_rank_shifted = my_rank_shifted + i / 2;
1480 if (target_rank_shifted <
n_ranks) {
1481 auto target_rank = (target_rank_shifted + ct.master_rank) %
n_ranks;
1482 coll_task_mailbox_.put(ct_, target_rank);
1487 auto prev_stack_base = stack_base_;
1488 if (
my_rank == ct.master_rank) {
1490 stack_base_ = cf_top_ - (cf_top_ -
reinterpret_cast<context_frame*
>(stack_.top())) / 2;
1501 stack_base_ = prev_stack_base;
1507 void execute_coll_task_if_arrived() {
1508 auto ct = coll_task_mailbox_.pop();
1509 if (ct.has_value()) {
1510 task_general* t =
reinterpret_cast<task_general*
>(
1511 suspended_thread_allocator_.allocate(ct->task_size));
1514 reinterpret_cast<std::byte*
>(t),
1515 reinterpret_cast<std::byte*
>(ct->task_ptr),
1518 common::profiler::switch_phase<prof_phase_sched_loop, prof_phase_spmd>();
1520 execute_coll_task(t, *ct);
1522 common::profiler::switch_phase<prof_phase_spmd, prof_phase_sched_loop>();
1524 suspended_thread_allocator_.deallocate(t, ct->task_size);
1528 bool should_exit_sched_loop() {
1529 if (sched_loop_make_mpi_progress_option::value()) {
1533 execute_coll_task_if_arrived();
1535 if (exit_request_mailbox_.pop()) {
1540 auto target_rank =
my_rank + i / 2;
1542 exit_request_mailbox_.put(target_rank);
1552 template <
typename T>
1553 thread_retval<T> get_retval_remote(thread_state<T>* ts) {
1554 if constexpr (std::is_trivially_copyable_v<T>) {
1558 thread_retval<T> retval;
1559 remote_get(thread_state_allocator_,
reinterpret_cast<std::byte*
>(&retval),
reinterpret_cast<std::byte*
>(&ts->retval),
sizeof(thread_retval<T>));
1564 template <
typename T>
1565 void put_retval_remote(thread_state<T>* ts, thread_retval<T>&& retval) {
1566 if constexpr (std::is_trivially_copyable_v<T>) {
1570 std::byte* retvalp =
reinterpret_cast<std::byte*
>(
new (alloca(
sizeof(thread_retval<T>))) thread_retval<T>{
std::move(retval)});
1571 remote_put(thread_state_allocator_, retvalp,
reinterpret_cast<std::byte*
>(&ts->retval),
sizeof(thread_retval<T>));
1577 context_frame* stack_base_;
1578 oneslot_mailbox<void> exit_request_mailbox_;
1579 oneslot_mailbox<coll_task> coll_task_mailbox_;
1580 oneslot_mailbox<suspended_state> migration_mailbox_;
1581 wsqueue<primary_wsq_entry, false> primary_wsq_;
1582 wsqueue<migration_wsq_entry, true> migration_wsq_;
1583 common::remotable_resource thread_state_allocator_;
1584 common::remotable_resource suspended_thread_allocator_;
1585 context_frame* cf_top_ =
nullptr;
1586 context_frame* sched_cf_ =
nullptr;
1587 thread_local_storage* tls_ =
nullptr;
1588 bool use_primary_wsq_ =
true;
1590 dist_tree::node_ref dtree_local_bottom_ref_;
1591 bool dag_prof_enabled_ =
false;
void unlock(topology::rank_t target_rank, int idx=0) const
Definition: global_lock.hpp:53
bool trylock(topology::rank_t target_rank, int idx=0) const
Definition: global_lock.hpp:21
span< T > local_buf() const
Definition: mpi_rma.hpp:412
MPI_Win win() const
Definition: mpi_rma.hpp:409
static value_type value()
Definition: options.hpp:62
bool is_remotely_freed(void *p, std::size_t alignment=alignof(max_align_t))
Definition: allocator.hpp:272
void direct_copy_from(void *addr, std::size_t size, common::topology::rank_t target_rank) const
Definition: callstack.hpp:25
common::topology::rank_t end_rank() const
Definition: adws.hpp:62
value_type end() const
Definition: adws.hpp:56
bool is_cross_worker() const
Definition: adws.hpp:94
double value_type
Definition: adws.hpp:47
dist_range()
Definition: adws.hpp:49
void move_to_end_boundary()
Definition: adws.hpp:70
dist_range(common::topology::rank_t n_ranks)
Definition: adws.hpp:50
void make_non_cross_worker()
Definition: adws.hpp:98
common::topology::rank_t begin_rank() const
Definition: adws.hpp:58
value_type begin() const
Definition: adws.hpp:55
bool is_sufficiently_small() const
Definition: adws.hpp:102
bool is_at_end_boundary() const
Definition: adws.hpp:66
std::pair< dist_range, dist_range > divide(T r1, T r2) const
Definition: adws.hpp:75
dist_range(value_type begin, value_type end)
Definition: adws.hpp:52
common::topology::rank_t owner() const
Definition: adws.hpp:90
dist_tree(int max_depth)
Definition: adws.hpp:131
void copy_parents(node_ref nr)
Definition: adws.hpp:225
node & get_local_node(node_ref nr)
Definition: adws.hpp:233
std::optional< node > get_topmost_dominant(node_ref nr)
Definition: adws.hpp:172
void set_dominant(node_ref nr, bool dominant)
Definition: adws.hpp:155
node_ref append(node_ref parent, dist_range drange, flipper tg_version)
Definition: adws.hpp:137
void flip(int at)
Definition: adws.hpp:26
bool match(flipper f, int until) const
Definition: adws.hpp:33
value_type value() const
Definition: adws.hpp:24
uint64_t value_type
Definition: adws.hpp:22
void put(common::topology::rank_t target_rank)
Definition: util.hpp:266
void dag_prof_end()
Definition: adws.hpp:819
void task_group_begin(task_group_data *tgdata)
Definition: adws.hpp:376
void migrate_to(common::topology::rank_t target_rank, PreSuspendCallback &&pre_suspend_cb, PostSuspendCallback &&post_suspend_cb)
Definition: adws.hpp:746
scheduler_adws()
Definition: adws.hpp:302
bool is_executing_root() const
Definition: adws.hpp:800
static bool is_serialized(const thread_handler< T > &th)
Definition: adws.hpp:805
void poll(PreSuspendCallback &&pre_suspend_cb, PostSuspendCallback &&post_suspend_cb)
Definition: adws.hpp:738
void dag_prof_begin()
Definition: adws.hpp:809
void fork(thread_handler< T > &th, OnDriftForkCallback on_drift_fork_cb, OnDriftDieCallback on_drift_die_cb, WorkHint w_new, WorkHint w_rest, Fn &&fn, Args &&... args)
Definition: adws.hpp:452
void sched_loop(SchedLoopCallback cb)
Definition: adws.hpp:692
T root_exec(SchedLoopCallback cb, Fn &&fn, Args &&... args)
Definition: adws.hpp:316
void task_group_end(PreSuspendCallback &&pre_suspend_cb, PostSuspendCallback &&post_suspend_cb)
Definition: adws.hpp:405
T join(thread_handler< T > &th)
Definition: adws.hpp:621
void dag_prof_print() const
Definition: adws.hpp:833
void coll_exec(const Fn &fn)
Definition: adws.hpp:776
std::optional< Entry > pop(int idx=0)
Definition: wsqueue.hpp:67
void abort_steal(common::topology::rank_t target_rank, int idx=0)
Definition: wsqueue.hpp:189
std::optional< Entry > steal_nolock(common::topology::rank_t target_rank, int idx=0)
Definition: wsqueue.hpp:158
void pass(const Entry &entry, common::topology::rank_t target_rank, int idx=0)
Definition: wsqueue.hpp:223
int n_queues() const
Definition: wsqueue.hpp:305
void push(const Entry &entry, int idx=0)
Definition: wsqueue.hpp:36
void for_each_entry(Fn fn, int idx=0)
Definition: wsqueue.hpp:231
const common::global_lock & lock() const
Definition: wsqueue.hpp:304
void for_each_nonempty_queue(common::topology::rank_t target_rank, int idx_begin, int idx_end, bool reverse, Fn fn)
Definition: wsqueue.hpp:281
#define ITYR_CHECK(cond)
Definition: util.hpp:48
bool enabled()
Definition: numa.hpp:86
rank_t n_ranks()
Definition: topology.hpp:208
int rank_t
Definition: topology.hpp:12
MPI_Comm mpicomm()
Definition: topology.hpp:206
rank_t my_rank()
Definition: topology.hpp:207
void remote_get(const remotable_resource &rmr, T *origin_p, const T *target_p, std::size_t size)
Definition: allocator.hpp:404
void remote_put_value(const remotable_resource &rmr, const T &val, T *target_p)
Definition: allocator.hpp:434
T mpi_atomic_put_value(const T &value, int target_rank, std::size_t target_disp, MPI_Win win)
Definition: mpi_rma.hpp:283
va_list args
Definition: util.hpp:76
T remote_faa_value(const remotable_resource &rmr, const T &val, T *target_p)
Definition: allocator.hpp:444
void mpi_make_progress()
Definition: mpi_util.hpp:260
T mpi_bcast_value(const T &value, int root_rank, MPI_Comm comm)
Definition: mpi_util.hpp:145
T mpi_allreduce_value(const T &value, MPI_Comm comm, MPI_Op op=MPI_SUM)
Definition: mpi_util.hpp:194
void mpi_get(T *origin, std::size_t count, int target_rank, std::size_t target_disp, MPI_Win win)
Definition: mpi_rma.hpp:69
void remote_put(const remotable_resource &rmr, const T *origin_p, T *target_p, std::size_t size)
Definition: allocator.hpp:424
T remote_get_value(const remotable_resource &rmr, const T *target_p)
Definition: allocator.hpp:414
T mpi_atomic_cas_value(const T &value, const T &compare, int target_rank, std::size_t target_disp, MPI_Win win)
Definition: mpi_rma.hpp:222
void mpi_barrier(MPI_Comm comm)
Definition: mpi_util.hpp:42
uint64_t next_pow2(uint64_t x)
Definition: util.hpp:102
void verbose(const char *fmt,...)
Definition: logger.hpp:11
Definition: aarch64.hpp:5
auto call_with_prof_events(Fn &&fn, Args &&... args)
Definition: util.hpp:182
ITYR_CONCAT(dag_profiler_, ITYR_ITO_DAG_PROF) dag_profiler
Definition: util.hpp:102
common::topology::rank_t get_random_rank(common::topology::rank_t a, common::topology::rank_t b)
Definition: util.hpp:128
monoid< T, max_functor<>, lowest< T > > max
Definition: reducer.hpp:104
rank_t my_rank()
Return the rank of the process running the current thread.
Definition: ityr.hpp:99
rank_t n_ranks()
Return the total number of processes.
Definition: ityr.hpp:107
ForwardIteratorD move(const ExecutionPolicy &policy, ForwardIterator1 first1, ForwardIterator1 last1, ForwardIteratorD first_d)
Move a range to another.
Definition: parallel_loop.hpp:934
#define ITYR_PROFILER_RECORD(event,...)
Definition: profiler.hpp:319
Definition: options.hpp:62
Definition: options.hpp:56
int depth
Definition: adws.hpp:117
common::topology::rank_t owner_rank
Definition: adws.hpp:116
node_ref parent
Definition: adws.hpp:125
node()
Definition: adws.hpp:121
flipper tg_version
Definition: adws.hpp:127
dist_range drange
Definition: adws.hpp:126
int depth() const
Definition: adws.hpp:123
version_t version
Definition: adws.hpp:128
Definition: prof_events.hpp:95
Definition: prof_events.hpp:190
Definition: prof_events.hpp:205
Definition: prof_events.hpp:200
Definition: prof_events.hpp:155
Definition: prof_events.hpp:180
Definition: prof_events.hpp:175
Definition: prof_events.hpp:185
Definition: prof_events.hpp:210
void * frame_base
Definition: adws.hpp:261
std::size_t frame_size
Definition: adws.hpp:262
void * evacuation_ptr
Definition: adws.hpp:260
dag_profiler dag_prof_acc
Definition: adws.hpp:290
dag_profiler dag_prof_before
Definition: adws.hpp:289
bool owns_dtree_node
Definition: adws.hpp:288
task_group_data * parent
Definition: adws.hpp:286
dist_range drange
Definition: adws.hpp:287
thread_retval< T > retval_ser
Definition: adws.hpp:282
bool serialized
Definition: adws.hpp:281
thread_state< T > * state
Definition: adws.hpp:280
dag_profiler dag_prof
Definition: adws.hpp:299
flipper tg_version
Definition: adws.hpp:297
dist_range drange
Definition: adws.hpp:295
dist_tree::node_ref dtree_node_ref
Definition: adws.hpp:296
bool undistributed
Definition: adws.hpp:298
task_group_data * tgdata
Definition: adws.hpp:294
T value
Definition: adws.hpp:267
dag_profiler dag_prof
Definition: adws.hpp:268
thread_retval< T > retval
Definition: adws.hpp:273
suspended_state suspended
Definition: adws.hpp:275
int resume_flag
Definition: adws.hpp:274
Definition: options.hpp:20
Definition: options.hpp:38
Definition: options.hpp:32