64 stack_base_(reinterpret_cast<context_frame*>(stack_.bottom()) - 1),
69 template <
typename T,
typename SchedLoopCallback,
typename Fn,
typename... Args>
71 common::profiler::switch_phase<prof_phase_spmd, prof_phase_sched_fork>();
75 auto prev_sched_cf = sched_cf_;
77 suspend([&](context_frame* cf) {
79 root_on_stack([&, ts, fn = std::forward<Fn>(fn),
80 args_tuple = std::make_tuple(std::forward<Args>(
args)...)]()
mutable {
86 tls_->
dag_prof.increment_thread_count();
87 tls_->
dag_prof.increment_strand_count();
89 common::profiler::switch_phase<prof_phase_sched_fork, prof_phase_thread>();
91 T&& ret = invoke_fn<T>(std::forward<decltype(fn)>(fn), std::forward<decltype(args_tuple)>(args_tuple));
93 common::profiler::switch_phase<prof_phase_thread, prof_phase_sched_die>();
104 common::profiler::switch_phase<prof_phase_sched_loop, prof_phase_sched_join>();
110 if (dag_prof_enabled_) {
115 dag_prof_result_.merge_serial(retval.
dag_prof);
119 sched_cf_ = prev_sched_cf;
121 common::profiler::switch_phase<prof_phase_sched_join, prof_phase_spmd>();
126 template <
typename T,
typename OnDriftForkCallback,
typename OnDriftDieCallback,
127 typename WorkHint,
typename Fn,
typename... Args>
129 OnDriftForkCallback on_drift_fork_cb, OnDriftDieCallback on_drift_die_cb,
130 WorkHint, WorkHint, Fn&& fn, Args&&...
args) {
131 common::profiler::switch_phase<prof_phase_thread, prof_phase_sched_fork>();
137 suspend([&, ts, fn = std::forward<Fn>(fn),
138 args_tuple = std::make_tuple(std::forward<Args>(
args)...)](context_frame* cf)
mutable {
139 common::verbose<2>(
"push context frame [%p, %p) into task queue", cf, cf->parent_frame);
143 std::size_t cf_size =
reinterpret_cast<uintptr_t
>(cf->parent_frame) -
reinterpret_cast<uintptr_t
>(cf);
144 wsq_.
push(wsqueue_entry{cf, cf_size});
147 tls_->
dag_prof.increment_thread_count();
148 tls_->
dag_prof.increment_strand_count();
150 common::verbose<2>(
"Starting new thread %p", ts);
151 common::profiler::switch_phase<prof_phase_sched_fork, prof_phase_thread>();
153 T&& ret = invoke_fn<T>(std::forward<decltype(fn)>(fn), std::forward<decltype(args_tuple)>(args_tuple));
155 common::profiler::switch_phase<prof_phase_thread, prof_phase_sched_die>();
156 common::verbose<2>(
"Thread %p is completed", ts);
159 on_die(ts,
std::move(ret), on_drift_die_cb);
161 common::verbose<2>(
"Thread %p is serialized (fast path)", ts);
170 common::verbose<2>(
"Resume parent context frame [%p, %p) (fast path)", cf, cf->parent_frame);
172 common::profiler::switch_phase<prof_phase_sched_die, prof_phase_sched_resume_popped>();
176 common::profiler::switch_phase<prof_phase_sched_resume_popped, prof_phase_thread>();
186 tls_->
dag_prof.increment_strand_count();
189 template <
typename T>
191 common::profiler::switch_phase<prof_phase_thread, prof_phase_sched_join>();
195 common::verbose<2>(
"Skip join for serialized thread (fast path)");
209 retval = get_retval_remote(ts);
213 bool migrated =
true;
214 suspend([&](context_frame* cf) {
221 common::verbose(
"Win the join race for thread %p (joining thread)", ts);
222 common::profiler::switch_phase<prof_phase_sched_join, prof_phase_sched_loop>();
225 common::verbose(
"Lose the join race for thread %p (joining thread)", ts);
234 common::profiler::switch_phase<prof_phase_sched_resume_join, prof_phase_sched_join>();
238 retval = get_retval_remote(ts);
253 common::profiler::switch_phase<prof_phase_sched_join, prof_phase_thread>();
257 template <
typename SchedLoopCallback>
261 while (!should_exit_sched_loop()) {
262 auto mte = migration_mailbox_.pop();
263 if (mte.has_value()) {
264 execute_migrated_task(*mte);
270 if constexpr (!std::is_null_pointer_v<std::remove_reference_t<SchedLoopCallback>>) {
278 template <
typename PreSuspendCallback,
typename PostSuspendCallback>
279 void poll(PreSuspendCallback&&, PostSuspendCallback&&) {}
281 template <
typename PreSuspendCallback,
typename PostSuspendCallback>
283 PreSuspendCallback&& pre_suspend_cb,
284 PostSuspendCallback&& post_suspend_cb) {
293 std::forward<PreSuspendCallback>(pre_suspend_cb));
295 suspend([&](context_frame* cf) {
298 common::verbose(
"Migrate continuation of the root thread to process %d",
301 migration_mailbox_.put(ss, target_rank);
303 common::profiler::switch_phase<prof_phase_sched_migrate, prof_phase_sched_loop>();
310 std::forward<PostSuspendCallback>(post_suspend_cb), cb_ret);
313 template <
typename Fn>
315 common::profiler::switch_phase<prof_phase_thread, prof_phase_spmd>();
322 size_t task_size =
sizeof(callable_task_t);
323 void* task_ptr = suspended_thread_allocator_.allocate(task_size);
325 auto t =
new (task_ptr) callable_task_t(fn);
328 execute_coll_task(t, ct);
330 suspended_thread_allocator_.deallocate(t, task_size);
333 tls_->
dag_prof.increment_strand_count();
335 common::profiler::switch_phase<prof_phase_spmd, prof_phase_thread>();
339 return cf_top_ && cf_top_ == stack_base_;
342 template <
typename T>
357 tls_->
dag_prof.increment_strand_count();
360 template <
typename PreSuspendCallback,
typename PostSuspendCallback>
373 tls_->
dag_prof.increment_strand_count();
377 dag_prof_enabled_ =
true;
378 dag_prof_result_.clear();
382 tls_->
dag_prof.increment_thread_count();
387 dag_prof_enabled_ =
false;
402 dag_prof_result_.print();
409 std::size_t task_size;
422 template <
typename T,
typename OnDriftDieCallback>
423 void on_die(thread_state<T>* ts, T&& ret, OnDriftDieCallback on_drift_die_cb) {
424 auto qe = wsq_.
pop();
425 bool serialized = qe.has_value();
432 prof_phase_cb_drift_die,
433 prof_phase_sched_die>(on_drift_die_cb);
441 common::verbose(
"Win the join race for thread %p (joined thread)", ts);
442 common::profiler::switch_phase<prof_phase_sched_die, prof_phase_sched_loop>();
445 common::verbose(
"Lose the join race for thread %p (joined thread)", ts);
446 common::profiler::switch_phase<prof_phase_sched_die, prof_phase_sched_resume_join>();
447 suspended_state ss =
remote_get_value(thread_state_allocator_, &ts->suspended);
452 template <
typename T>
453 void on_root_die(thread_state<T>* ts, T&& ret) {
459 exit_request_mailbox_.
put(0);
461 common::profiler::switch_phase<prof_phase_sched_die, prof_phase_sched_loop>();
468 auto ibd = common::profiler::interval_begin<prof_event_sched_steal>(target_rank);
470 if (wsq_.
empty(target_rank)) {
471 common::profiler::interval_end<prof_event_sched_steal>(ibd,
false);
476 common::profiler::interval_end<prof_event_sched_steal>(ibd,
false);
481 if (!we.has_value()) {
483 common::profiler::interval_end<prof_event_sched_steal>(ibd,
false);
488 we->frame_base,
reinterpret_cast<std::byte*
>(we->frame_base) + we->frame_size, target_rank);
494 common::profiler::interval_end<prof_event_sched_steal>(ibd,
true);
496 common::profiler::switch_phase<prof_phase_sched_loop, prof_phase_sched_resume_stolen>();
498 context_frame* next_cf =
reinterpret_cast<context_frame*
>(we->frame_base);
499 suspend([&](context_frame* cf) {
501 context::clear_parent_frame(next_cf);
506 template <
typename Fn>
507 void suspend(Fn&& fn) {
508 context_frame* prev_cf_top = cf_top_;
509 thread_local_storage* prev_tls = tls_;
511 context::save_context_with_call(prev_cf_top,
512 [](context_frame* cf,
void* cf_top_p,
void* fn_p) {
513 context_frame*& cf_top = *
reinterpret_cast<context_frame**
>(cf_top_p);
514 Fn fn = std::forward<Fn>(*
reinterpret_cast<Fn*
>(fn_p));
517 }, &cf_top_, &fn, prev_tls);
519 cf_top_ = prev_cf_top;
523 void resume(context_frame* cf) {
524 common::verbose(
"Resume context frame [%p, %p) in the stack", cf, cf->parent_frame);
528 void resume(suspended_state ss) {
530 ss.frame_base, ss.frame_size, ss.evacuation_ptr);
534 context::jump_to_stack(ss.frame_base, [](
void* allocator_,
void* evacuation_ptr,
void* frame_base,
void* frame_size_) {
535 common::remotable_resource& allocator = *reinterpret_cast<common::remotable_resource*>(allocator_);
536 std::size_t frame_size = reinterpret_cast<std::size_t>(frame_size_);
537 common::remote_get(allocator,
538 reinterpret_cast<std::byte*>(frame_base),
539 reinterpret_cast<std::byte*>(evacuation_ptr),
541 allocator.deallocate(evacuation_ptr, frame_size);
543 context_frame* cf = reinterpret_cast<context_frame*>(frame_base);
544 context::clear_parent_frame(cf);
546 }, &suspended_thread_allocator_, ss.evacuation_ptr, ss.frame_base,
reinterpret_cast<void*
>(ss.frame_size));
549 void resume_sched() {
551 context::resume(sched_cf_);
554 void execute_migrated_task(
const suspended_state& ss) {
557 common::profiler::switch_phase<prof_phase_sched_loop, prof_phase_sched_resume_migrate>();
559 suspend([&](context_frame* cf) {
565 suspended_state evacuate(context_frame* cf) {
566 std::size_t cf_size =
reinterpret_cast<uintptr_t
>(cf->parent_frame) -
reinterpret_cast<uintptr_t
>(cf);
567 void* evacuation_ptr = suspended_thread_allocator_.allocate(cf_size);
568 std::memcpy(evacuation_ptr, cf, cf_size);
571 cf, cf->parent_frame, evacuation_ptr);
573 return {evacuation_ptr, cf, cf_size};
576 template <
typename Fn>
577 void root_on_stack(Fn&& fn) {
578 cf_top_ = stack_base_;
579 std::size_t stack_size_bytes =
reinterpret_cast<std::byte*
>(stack_base_) -
580 reinterpret_cast<std::byte*
>(stack_.
top());
581 context::call_on_stack(stack_.
top(), stack_size_bytes,
582 [](
void* fn_,
void*,
void*,
void*) {
583 Fn fn = std::forward<Fn>(*reinterpret_cast<Fn*>(fn_));
585 }, &fn,
nullptr,
nullptr,
nullptr);
588 void execute_coll_task(task_general* t, coll_task ct) {
590 coll_task ct_ {t, ct.task_size, ct.master_rank};
597 if (my_rank_shifted % i == 0) {
598 auto target_rank_shifted = my_rank_shifted + i / 2;
599 if (target_rank_shifted <
n_ranks) {
600 auto target_rank = (target_rank_shifted + ct.master_rank) %
n_ranks;
601 coll_task_mailbox_.
put(ct_, target_rank);
606 auto prev_stack_base = stack_base_;
607 if (
my_rank == ct.master_rank) {
609 stack_base_ = cf_top_ - (cf_top_ -
reinterpret_cast<context_frame*
>(stack_.
top())) / 2;
622 stack_base_ = prev_stack_base;
628 void execute_coll_task_if_arrived() {
629 auto ct = coll_task_mailbox_.
pop();
630 if (ct.has_value()) {
631 task_general* t =
reinterpret_cast<task_general*
>(
632 suspended_thread_allocator_.allocate(ct->task_size));
635 reinterpret_cast<std::byte*
>(t),
636 reinterpret_cast<std::byte*
>(ct->task_ptr),
639 common::profiler::switch_phase<prof_phase_sched_loop, prof_phase_spmd>();
641 execute_coll_task(t, *ct);
643 common::profiler::switch_phase<prof_phase_spmd, prof_phase_sched_loop>();
645 suspended_thread_allocator_.deallocate(t, ct->task_size);
649 bool should_exit_sched_loop() {
654 execute_coll_task_if_arrived();
656 if (exit_request_mailbox_.
pop()) {
661 auto target_rank =
my_rank + i / 2;
663 exit_request_mailbox_.
put(target_rank);
673 template <
typename T>
674 thread_retval<T> get_retval_remote(thread_state<T>* ts) {
675 if constexpr (std::is_trivially_copyable_v<T>) {
679 thread_retval<T> retval;
680 remote_get(thread_state_allocator_,
reinterpret_cast<std::byte*
>(&retval),
reinterpret_cast<std::byte*
>(&ts->retval),
sizeof(thread_retval<T>));
685 template <
typename T>
686 void put_retval_remote(thread_state<T>* ts, thread_retval<T>&& retval) {
687 if constexpr (std::is_trivially_copyable_v<T>) {
691 std::byte* retvalp =
reinterpret_cast<std::byte*
>(
new (alloca(
sizeof(thread_retval<T>))) thread_retval<T>{
std::move(retval)});
692 remote_put(thread_state_allocator_, retvalp,
reinterpret_cast<std::byte*
>(&ts->retval),
sizeof(thread_retval<T>));
696 struct wsqueue_entry {
698 std::size_t frame_size;
702 context_frame* stack_base_;
703 oneslot_mailbox<void> exit_request_mailbox_;
704 oneslot_mailbox<coll_task> coll_task_mailbox_;
705 oneslot_mailbox<suspended_state> migration_mailbox_;
706 wsqueue<wsqueue_entry> wsq_;
707 common::remotable_resource thread_state_allocator_;
708 common::remotable_resource suspended_thread_allocator_;
709 context_frame* cf_top_ =
nullptr;
710 context_frame* sched_cf_ =
nullptr;
711 thread_local_storage* tls_ =
nullptr;
712 bool dag_prof_enabled_ =
false;
void unlock(topology::rank_t target_rank, int idx=0) const
Definition: global_lock.hpp:53
bool trylock(topology::rank_t target_rank, int idx=0) const
Definition: global_lock.hpp:21
static value_type value()
Definition: options.hpp:62
void direct_copy_from(void *addr, std::size_t size, common::topology::rank_t target_rank) const
Definition: callstack.hpp:25
void * top() const
Definition: callstack.hpp:21
void put(common::topology::rank_t target_rank)
Definition: util.hpp:266
bool pop()
Definition: util.hpp:273
std::optional< Entry > pop()
Definition: util.hpp:237
void put(const Entry &entry, common::topology::rank_t target_rank)
Definition: util.hpp:229
Definition: randws.hpp:20
scheduler_randws()
Definition: randws.hpp:59
T root_exec(SchedLoopCallback cb, Fn &&fn, Args &&... args)
Definition: randws.hpp:70
void sched_loop(SchedLoopCallback cb)
Definition: randws.hpp:258
T join(thread_handler< T > &th)
Definition: randws.hpp:190
void dag_prof_end()
Definition: randws.hpp:386
void fork(thread_handler< T > &th, OnDriftForkCallback on_drift_fork_cb, OnDriftDieCallback on_drift_die_cb, WorkHint, WorkHint, Fn &&fn, Args &&... args)
Definition: randws.hpp:128
void task_group_begin(task_group_data *tgdata)
Definition: randws.hpp:347
static bool is_serialized(const thread_handler< T > &th)
Definition: randws.hpp:343
void migrate_to(common::topology::rank_t target_rank, PreSuspendCallback &&pre_suspend_cb, PostSuspendCallback &&post_suspend_cb)
Definition: randws.hpp:282
void poll(PreSuspendCallback &&, PostSuspendCallback &&)
Definition: randws.hpp:279
bool is_executing_root() const
Definition: randws.hpp:338
void coll_exec(const Fn &fn)
Definition: randws.hpp:314
void dag_prof_print() const
Definition: randws.hpp:400
void dag_prof_begin()
Definition: randws.hpp:376
void task_group_end(PreSuspendCallback &&, PostSuspendCallback &&)
Definition: randws.hpp:361
std::optional< Entry > pop(int idx=0)
Definition: wsqueue.hpp:67
std::optional< Entry > steal_nolock(common::topology::rank_t target_rank, int idx=0)
Definition: wsqueue.hpp:158
void push(const Entry &entry, int idx=0)
Definition: wsqueue.hpp:36
const common::global_lock & lock() const
Definition: wsqueue.hpp:304
bool empty(common::topology::rank_t target_rank, int idx=0) const
Definition: wsqueue.hpp:271
#define ITYR_CHECK(cond)
Definition: util.hpp:48
bool enabled()
Definition: numa.hpp:86
rank_t n_ranks()
Definition: topology.hpp:208
int rank_t
Definition: topology.hpp:12
MPI_Comm mpicomm()
Definition: topology.hpp:206
rank_t my_rank()
Definition: topology.hpp:207
void remote_get(const remotable_resource &rmr, T *origin_p, const T *target_p, std::size_t size)
Definition: allocator.hpp:404
void remote_put_value(const remotable_resource &rmr, const T &val, T *target_p)
Definition: allocator.hpp:434
va_list args
Definition: util.hpp:76
T remote_faa_value(const remotable_resource &rmr, const T &val, T *target_p)
Definition: allocator.hpp:444
void mpi_make_progress()
Definition: mpi_util.hpp:260
T mpi_bcast_value(const T &value, int root_rank, MPI_Comm comm)
Definition: mpi_util.hpp:145
T mpi_allreduce_value(const T &value, MPI_Comm comm, MPI_Op op=MPI_SUM)
Definition: mpi_util.hpp:194
void remote_put(const remotable_resource &rmr, const T *origin_p, T *target_p, std::size_t size)
Definition: allocator.hpp:424
T remote_get_value(const remotable_resource &rmr, const T *target_p)
Definition: allocator.hpp:414
void mpi_barrier(MPI_Comm comm)
Definition: mpi_util.hpp:42
uint64_t next_pow2(uint64_t x)
Definition: util.hpp:102
void verbose(const char *fmt,...)
Definition: logger.hpp:11
Definition: aarch64.hpp:5
auto call_with_prof_events(Fn &&fn, Args &&... args)
Definition: util.hpp:182
ITYR_CONCAT(dag_profiler_, ITYR_ITO_DAG_PROF) dag_profiler
Definition: util.hpp:102
common::topology::rank_t get_random_rank(common::topology::rank_t a, common::topology::rank_t b)
Definition: util.hpp:128
rank_t my_rank()
Return the rank of the process running the current thread.
Definition: ityr.hpp:99
rank_t n_ranks()
Return the total number of processes.
Definition: ityr.hpp:107
ForwardIteratorD move(const ExecutionPolicy &policy, ForwardIterator1 first1, ForwardIterator1 last1, ForwardIteratorD first_d)
Move a range to another.
Definition: parallel_loop.hpp:934
Definition: prof_events.hpp:190
Definition: prof_events.hpp:205
Definition: prof_events.hpp:200
Definition: prof_events.hpp:155
Definition: prof_events.hpp:180
Definition: prof_events.hpp:175
Definition: prof_events.hpp:210
Definition: randws.hpp:22
void * evacuation_ptr
Definition: randws.hpp:23
std::size_t frame_size
Definition: randws.hpp:25
void * frame_base
Definition: randws.hpp:24
Definition: randws.hpp:48
dag_profiler dag_prof_acc
Definition: randws.hpp:51
dag_profiler dag_prof_before
Definition: randws.hpp:50
task_group_data * parent
Definition: randws.hpp:49
Definition: randws.hpp:42
bool serialized
Definition: randws.hpp:44
thread_state< T > * state
Definition: randws.hpp:43
thread_retval< T > retval_ser
Definition: randws.hpp:45
Definition: randws.hpp:54
task_group_data * tgdata
Definition: randws.hpp:55
dag_profiler dag_prof
Definition: randws.hpp:56
Definition: randws.hpp:29
dag_profiler dag_prof
Definition: randws.hpp:31
T value
Definition: randws.hpp:30
Definition: randws.hpp:35
int resume_flag
Definition: randws.hpp:37
thread_retval< T > retval
Definition: randws.hpp:36
suspended_state suspended
Definition: randws.hpp:38
Definition: options.hpp:20
Definition: options.hpp:38
Definition: options.hpp:32
Definition: options.hpp:26