64       stack_base_(reinterpret_cast<context_frame*>(stack_.bottom()) - 1),
 
   69   template <
typename T, 
typename SchedLoopCallback, 
typename Fn, 
typename... Args>
 
   71     common::profiler::switch_phase<prof_phase_spmd, prof_phase_sched_fork>();
 
   75     auto prev_sched_cf = sched_cf_;
 
   77     suspend([&](context_frame* cf) {
 
   79       root_on_stack([&, ts, fn = std::forward<Fn>(fn),
 
   80                      args_tuple = std::make_tuple(std::forward<Args>(
args)...)]() 
mutable {
 
   86         tls_->
dag_prof.increment_thread_count();
 
   87         tls_->
dag_prof.increment_strand_count();
 
   89         common::profiler::switch_phase<prof_phase_sched_fork, prof_phase_thread>();
 
   91         T&& ret = invoke_fn<T>(std::forward<decltype(fn)>(fn), std::forward<decltype(args_tuple)>(args_tuple));
 
   93         common::profiler::switch_phase<prof_phase_thread, prof_phase_sched_die>();
 
  104     common::profiler::switch_phase<prof_phase_sched_loop, prof_phase_sched_join>();
 
  110     if (dag_prof_enabled_) {
 
  115         dag_prof_result_.merge_serial(retval.
dag_prof);
 
  119     sched_cf_ = prev_sched_cf;
 
  121     common::profiler::switch_phase<prof_phase_sched_join, prof_phase_spmd>();
 
  126   template <
typename T, 
typename OnDriftForkCallback, 
typename OnDriftDieCallback,
 
  127             typename WorkHint, 
typename Fn, 
typename... Args>
 
  129             OnDriftForkCallback on_drift_fork_cb, OnDriftDieCallback on_drift_die_cb,
 
  130             WorkHint, WorkHint, Fn&& fn, Args&&... 
args) {
 
  131     common::profiler::switch_phase<prof_phase_thread, prof_phase_sched_fork>();
 
  137     suspend([&, ts, fn = std::forward<Fn>(fn),
 
  138              args_tuple = std::make_tuple(std::forward<Args>(
args)...)](context_frame* cf) 
mutable {
 
  139       common::verbose<2>(
"push context frame [%p, %p) into task queue", cf, cf->parent_frame);
 
  143       std::size_t cf_size = 
reinterpret_cast<uintptr_t
>(cf->parent_frame) - 
reinterpret_cast<uintptr_t
>(cf);
 
  144       wsq_.
push(wsqueue_entry{cf, cf_size});
 
  147       tls_->
dag_prof.increment_thread_count();
 
  148       tls_->
dag_prof.increment_strand_count();
 
  150       common::verbose<2>(
"Starting new thread %p", ts);
 
  151       common::profiler::switch_phase<prof_phase_sched_fork, prof_phase_thread>();
 
  153       T&& ret = invoke_fn<T>(std::forward<decltype(fn)>(fn), std::forward<decltype(args_tuple)>(args_tuple));
 
  155       common::profiler::switch_phase<prof_phase_thread, prof_phase_sched_die>();
 
  156       common::verbose<2>(
"Thread %p is completed", ts);
 
  159       on_die(ts, 
std::move(ret), on_drift_die_cb);
 
  161       common::verbose<2>(
"Thread %p is serialized (fast path)", ts);
 
  170       common::verbose<2>(
"Resume parent context frame [%p, %p) (fast path)", cf, cf->parent_frame);
 
  172       common::profiler::switch_phase<prof_phase_sched_die, prof_phase_sched_resume_popped>();
 
  176       common::profiler::switch_phase<prof_phase_sched_resume_popped, prof_phase_thread>();
 
  186     tls_->
dag_prof.increment_strand_count();
 
  189   template <
typename T>
 
  191     common::profiler::switch_phase<prof_phase_thread, prof_phase_sched_join>();
 
  195       common::verbose<2>(
"Skip join for serialized thread (fast path)");
 
  209           retval = get_retval_remote(ts);
 
  213         bool migrated = 
true;
 
  214         suspend([&](context_frame* cf) {
 
  221             common::verbose(
"Win the join race for thread %p (joining thread)", ts);
 
  222             common::profiler::switch_phase<prof_phase_sched_join, prof_phase_sched_loop>();
 
  225             common::verbose(
"Lose the join race for thread %p (joining thread)", ts);
 
  234           common::profiler::switch_phase<prof_phase_sched_resume_join, prof_phase_sched_join>();
 
  238           retval = get_retval_remote(ts);
 
  253     common::profiler::switch_phase<prof_phase_sched_join, prof_phase_thread>();
 
  257   template <
typename SchedLoopCallback>
 
  261     while (!should_exit_sched_loop()) {
 
  262       auto mte = migration_mailbox_.pop();
 
  263       if (mte.has_value()) {
 
  264         execute_migrated_task(*mte);
 
  270       if constexpr (!std::is_null_pointer_v<std::remove_reference_t<SchedLoopCallback>>) {
 
  278   template <
typename PreSuspendCallback, 
typename PostSuspendCallback>
 
  279   void poll(PreSuspendCallback&&, PostSuspendCallback&&) {}
 
  281   template <
typename PreSuspendCallback, 
typename PostSuspendCallback>
 
  283                   PreSuspendCallback&&     pre_suspend_cb,
 
  284                   PostSuspendCallback&&    post_suspend_cb) {
 
  293         std::forward<PreSuspendCallback>(pre_suspend_cb));
 
  295     suspend([&](context_frame* cf) {
 
  298       common::verbose(
"Migrate continuation of the root thread to process %d",
 
  301       migration_mailbox_.put(ss, target_rank);
 
  303       common::profiler::switch_phase<prof_phase_sched_migrate, prof_phase_sched_loop>();
 
  310         std::forward<PostSuspendCallback>(post_suspend_cb), cb_ret);
 
  313   template <
typename Fn>
 
  315     common::profiler::switch_phase<prof_phase_thread, prof_phase_spmd>();
 
  322     size_t task_size = 
sizeof(callable_task_t);
 
  323     void* task_ptr = suspended_thread_allocator_.allocate(task_size);
 
  325     auto t = 
new (task_ptr) callable_task_t(fn);
 
  328     execute_coll_task(t, ct);
 
  330     suspended_thread_allocator_.deallocate(t, task_size);
 
  333     tls_->
dag_prof.increment_strand_count();
 
  335     common::profiler::switch_phase<prof_phase_spmd, prof_phase_thread>();
 
  339     return cf_top_ && cf_top_ == stack_base_;
 
  342   template <
typename T>
 
  357     tls_->
dag_prof.increment_strand_count();
 
  360   template <
typename PreSuspendCallback, 
typename PostSuspendCallback>
 
  373     tls_->
dag_prof.increment_strand_count();
 
  377     dag_prof_enabled_ = 
true;
 
  378     dag_prof_result_.clear();
 
  382       tls_->
dag_prof.increment_thread_count();
 
  387     dag_prof_enabled_ = 
false;
 
  402       dag_prof_result_.print();
 
  409     std::size_t              task_size;
 
  422   template <
typename T, 
typename OnDriftDieCallback>
 
  423   void on_die(thread_state<T>* ts, T&& ret, OnDriftDieCallback on_drift_die_cb) {
 
  424     auto qe = wsq_.
pop();
 
  425     bool serialized = qe.has_value();
 
  432                           prof_phase_cb_drift_die,
 
  433                           prof_phase_sched_die>(on_drift_die_cb);
 
  441       common::verbose(
"Win the join race for thread %p (joined thread)", ts);
 
  442       common::profiler::switch_phase<prof_phase_sched_die, prof_phase_sched_loop>();
 
  445       common::verbose(
"Lose the join race for thread %p (joined thread)", ts);
 
  446       common::profiler::switch_phase<prof_phase_sched_die, prof_phase_sched_resume_join>();
 
  447       suspended_state ss = 
remote_get_value(thread_state_allocator_, &ts->suspended);
 
  452   template <
typename T>
 
  453   void on_root_die(thread_state<T>* ts, T&& ret) {
 
  459     exit_request_mailbox_.
put(0);
 
  461     common::profiler::switch_phase<prof_phase_sched_die, prof_phase_sched_loop>();
 
  468     auto ibd = common::profiler::interval_begin<prof_event_sched_steal>(target_rank);
 
  470     if (wsq_.
empty(target_rank)) {
 
  471       common::profiler::interval_end<prof_event_sched_steal>(ibd, 
false);
 
  476       common::profiler::interval_end<prof_event_sched_steal>(ibd, 
false);
 
  481     if (!we.has_value()) {
 
  483       common::profiler::interval_end<prof_event_sched_steal>(ibd, 
false);
 
  488                     we->frame_base, 
reinterpret_cast<std::byte*
>(we->frame_base) + we->frame_size, target_rank);
 
  494     common::profiler::interval_end<prof_event_sched_steal>(ibd, 
true);
 
  496     common::profiler::switch_phase<prof_phase_sched_loop, prof_phase_sched_resume_stolen>();
 
  498     context_frame* next_cf = 
reinterpret_cast<context_frame*
>(we->frame_base);
 
  499     suspend([&](context_frame* cf) {
 
  501       context::clear_parent_frame(next_cf);
 
  506   template <
typename Fn>
 
  507   void suspend(Fn&& fn) {
 
  508     context_frame*        prev_cf_top = cf_top_;
 
  509     thread_local_storage* prev_tls    = tls_;
 
  511     context::save_context_with_call(prev_cf_top,
 
  512         [](context_frame* cf, 
void* cf_top_p, 
void* fn_p) {
 
  513       context_frame*& cf_top = *
reinterpret_cast<context_frame**
>(cf_top_p);
 
  514       Fn              fn     = std::forward<Fn>(*
reinterpret_cast<Fn*
>(fn_p)); 
 
  517     }, &cf_top_, &fn, prev_tls);
 
  519     cf_top_ = prev_cf_top;
 
  523   void resume(context_frame* cf) {
 
  524     common::verbose(
"Resume context frame [%p, %p) in the stack", cf, cf->parent_frame);
 
  528   void resume(suspended_state ss) {
 
  530                     ss.frame_base, ss.frame_size, ss.evacuation_ptr);
 
  534     context::jump_to_stack(ss.frame_base, [](
void* allocator_, 
void* evacuation_ptr, 
void* frame_base, 
void* frame_size_) {
 
  535       common::remotable_resource& allocator  = *reinterpret_cast<common::remotable_resource*>(allocator_);
 
  536       std::size_t                 frame_size = reinterpret_cast<std::size_t>(frame_size_);
 
  537       common::remote_get(allocator,
 
  538                          reinterpret_cast<std::byte*>(frame_base),
 
  539                          reinterpret_cast<std::byte*>(evacuation_ptr),
 
  541       allocator.deallocate(evacuation_ptr, frame_size);
 
  543       context_frame* cf = reinterpret_cast<context_frame*>(frame_base);
 
  544       context::clear_parent_frame(cf);
 
  546     }, &suspended_thread_allocator_, ss.evacuation_ptr, ss.frame_base, 
reinterpret_cast<void*
>(ss.frame_size));
 
  549   void resume_sched() {
 
  551     context::resume(sched_cf_);
 
  554   void execute_migrated_task(
const suspended_state& ss) {
 
  557     common::profiler::switch_phase<prof_phase_sched_loop, prof_phase_sched_resume_migrate>();
 
  559     suspend([&](context_frame* cf) {
 
  565   suspended_state evacuate(context_frame* cf) {
 
  566     std::size_t cf_size = 
reinterpret_cast<uintptr_t
>(cf->parent_frame) - 
reinterpret_cast<uintptr_t
>(cf);
 
  567     void* evacuation_ptr = suspended_thread_allocator_.allocate(cf_size);
 
  568     std::memcpy(evacuation_ptr, cf, cf_size);
 
  571                     cf, cf->parent_frame, evacuation_ptr);
 
  573     return {evacuation_ptr, cf, cf_size};
 
  576   template <
typename Fn>
 
  577   void root_on_stack(Fn&& fn) {
 
  578     cf_top_ = stack_base_;
 
  579     std::size_t stack_size_bytes = 
reinterpret_cast<std::byte*
>(stack_base_) -
 
  580                                    reinterpret_cast<std::byte*
>(stack_.
top());
 
  581     context::call_on_stack(stack_.
top(), stack_size_bytes,
 
  582                            [](
void* fn_, 
void*, 
void*, 
void*) {
 
  583       Fn fn = std::forward<Fn>(*reinterpret_cast<Fn*>(fn_)); 
 
  585     }, &fn, 
nullptr, 
nullptr, 
nullptr);
 
  588   void execute_coll_task(task_general* t, coll_task ct) {
 
  590     coll_task ct_ {t, ct.task_size, ct.master_rank};
 
  597       if (my_rank_shifted % i == 0) {
 
  598         auto target_rank_shifted = my_rank_shifted + i / 2;
 
  599         if (target_rank_shifted < 
n_ranks) {
 
  600           auto target_rank = (target_rank_shifted + ct.master_rank) % 
n_ranks;
 
  601           coll_task_mailbox_.
put(ct_, target_rank);
 
  606     auto prev_stack_base = stack_base_;
 
  607     if (
my_rank == ct.master_rank) {
 
  609       stack_base_ = cf_top_ - (cf_top_ - 
reinterpret_cast<context_frame*
>(stack_.
top())) / 2;
 
  622     stack_base_ = prev_stack_base;
 
  628   void execute_coll_task_if_arrived() {
 
  629     auto ct = coll_task_mailbox_.
pop();
 
  630     if (ct.has_value()) {
 
  631       task_general* t = 
reinterpret_cast<task_general*
>(
 
  632           suspended_thread_allocator_.allocate(ct->task_size));
 
  635                          reinterpret_cast<std::byte*
>(t),
 
  636                          reinterpret_cast<std::byte*
>(ct->task_ptr),
 
  639       common::profiler::switch_phase<prof_phase_sched_loop, prof_phase_spmd>();
 
  641       execute_coll_task(t, *ct);
 
  643       common::profiler::switch_phase<prof_phase_spmd, prof_phase_sched_loop>();
 
  645       suspended_thread_allocator_.deallocate(t, ct->task_size);
 
  649   bool should_exit_sched_loop() {
 
  654     execute_coll_task_if_arrived();
 
  656     if (exit_request_mailbox_.
pop()) {
 
  661           auto target_rank = 
my_rank + i / 2;
 
  663             exit_request_mailbox_.
put(target_rank);
 
  673   template <
typename T>
 
  674   thread_retval<T> get_retval_remote(thread_state<T>* ts) {
 
  675     if constexpr (std::is_trivially_copyable_v<T>) {
 
  679       thread_retval<T> retval;
 
  680       remote_get(thread_state_allocator_, 
reinterpret_cast<std::byte*
>(&retval), 
reinterpret_cast<std::byte*
>(&ts->retval), 
sizeof(thread_retval<T>));
 
  685   template <
typename T>
 
  686   void put_retval_remote(thread_state<T>* ts, thread_retval<T>&& retval) {
 
  687     if constexpr (std::is_trivially_copyable_v<T>) {
 
  691       std::byte* retvalp = 
reinterpret_cast<std::byte*
>(
new (alloca(
sizeof(thread_retval<T>))) thread_retval<T>{
std::move(retval)});
 
  692       remote_put(thread_state_allocator_, retvalp, 
reinterpret_cast<std::byte*
>(&ts->retval), 
sizeof(thread_retval<T>));
 
  696   struct wsqueue_entry {
 
  698     std::size_t frame_size;
 
  702   context_frame*                   stack_base_;
 
  703   oneslot_mailbox<void>            exit_request_mailbox_;
 
  704   oneslot_mailbox<coll_task>       coll_task_mailbox_;
 
  705   oneslot_mailbox<suspended_state> migration_mailbox_;
 
  706   wsqueue<wsqueue_entry>           wsq_;
 
  707   common::remotable_resource       thread_state_allocator_;
 
  708   common::remotable_resource       suspended_thread_allocator_;
 
  709   context_frame*                   cf_top_           = 
nullptr;
 
  710   context_frame*                   sched_cf_         = 
nullptr;
 
  711   thread_local_storage*            tls_              = 
nullptr;
 
  712   bool                             dag_prof_enabled_ = 
false;
 
void unlock(topology::rank_t target_rank, int idx=0) const
Definition: global_lock.hpp:53
 
bool trylock(topology::rank_t target_rank, int idx=0) const
Definition: global_lock.hpp:21
 
static value_type value()
Definition: options.hpp:62
 
void direct_copy_from(void *addr, std::size_t size, common::topology::rank_t target_rank) const
Definition: callstack.hpp:25
 
void * top() const
Definition: callstack.hpp:21
 
void put(common::topology::rank_t target_rank)
Definition: util.hpp:266
 
bool pop()
Definition: util.hpp:273
 
std::optional< Entry > pop()
Definition: util.hpp:237
 
void put(const Entry &entry, common::topology::rank_t target_rank)
Definition: util.hpp:229
 
Definition: randws.hpp:20
 
scheduler_randws()
Definition: randws.hpp:59
 
T root_exec(SchedLoopCallback cb, Fn &&fn, Args &&... args)
Definition: randws.hpp:70
 
void sched_loop(SchedLoopCallback cb)
Definition: randws.hpp:258
 
T join(thread_handler< T > &th)
Definition: randws.hpp:190
 
void dag_prof_end()
Definition: randws.hpp:386
 
void fork(thread_handler< T > &th, OnDriftForkCallback on_drift_fork_cb, OnDriftDieCallback on_drift_die_cb, WorkHint, WorkHint, Fn &&fn, Args &&... args)
Definition: randws.hpp:128
 
void task_group_begin(task_group_data *tgdata)
Definition: randws.hpp:347
 
static bool is_serialized(const thread_handler< T > &th)
Definition: randws.hpp:343
 
void migrate_to(common::topology::rank_t target_rank, PreSuspendCallback &&pre_suspend_cb, PostSuspendCallback &&post_suspend_cb)
Definition: randws.hpp:282
 
void poll(PreSuspendCallback &&, PostSuspendCallback &&)
Definition: randws.hpp:279
 
bool is_executing_root() const
Definition: randws.hpp:338
 
void coll_exec(const Fn &fn)
Definition: randws.hpp:314
 
void dag_prof_print() const
Definition: randws.hpp:400
 
void dag_prof_begin()
Definition: randws.hpp:376
 
void task_group_end(PreSuspendCallback &&, PostSuspendCallback &&)
Definition: randws.hpp:361
 
std::optional< Entry > pop(int idx=0)
Definition: wsqueue.hpp:67
 
std::optional< Entry > steal_nolock(common::topology::rank_t target_rank, int idx=0)
Definition: wsqueue.hpp:158
 
void push(const Entry &entry, int idx=0)
Definition: wsqueue.hpp:36
 
const common::global_lock & lock() const
Definition: wsqueue.hpp:304
 
bool empty(common::topology::rank_t target_rank, int idx=0) const
Definition: wsqueue.hpp:271
 
#define ITYR_CHECK(cond)
Definition: util.hpp:48
 
bool enabled()
Definition: numa.hpp:86
 
rank_t n_ranks()
Definition: topology.hpp:208
 
int rank_t
Definition: topology.hpp:12
 
MPI_Comm mpicomm()
Definition: topology.hpp:206
 
rank_t my_rank()
Definition: topology.hpp:207
 
void remote_get(const remotable_resource &rmr, T *origin_p, const T *target_p, std::size_t size)
Definition: allocator.hpp:404
 
void remote_put_value(const remotable_resource &rmr, const T &val, T *target_p)
Definition: allocator.hpp:434
 
va_list args
Definition: util.hpp:76
 
T remote_faa_value(const remotable_resource &rmr, const T &val, T *target_p)
Definition: allocator.hpp:444
 
void mpi_make_progress()
Definition: mpi_util.hpp:260
 
T mpi_bcast_value(const T &value, int root_rank, MPI_Comm comm)
Definition: mpi_util.hpp:145
 
T mpi_allreduce_value(const T &value, MPI_Comm comm, MPI_Op op=MPI_SUM)
Definition: mpi_util.hpp:194
 
void remote_put(const remotable_resource &rmr, const T *origin_p, T *target_p, std::size_t size)
Definition: allocator.hpp:424
 
T remote_get_value(const remotable_resource &rmr, const T *target_p)
Definition: allocator.hpp:414
 
void mpi_barrier(MPI_Comm comm)
Definition: mpi_util.hpp:42
 
uint64_t next_pow2(uint64_t x)
Definition: util.hpp:102
 
void verbose(const char *fmt,...)
Definition: logger.hpp:11
 
Definition: aarch64.hpp:5
 
auto call_with_prof_events(Fn &&fn, Args &&... args)
Definition: util.hpp:182
 
ITYR_CONCAT(dag_profiler_, ITYR_ITO_DAG_PROF) dag_profiler
Definition: util.hpp:102
 
common::topology::rank_t get_random_rank(common::topology::rank_t a, common::topology::rank_t b)
Definition: util.hpp:128
 
rank_t my_rank()
Return the rank of the process running the current thread.
Definition: ityr.hpp:99
 
rank_t n_ranks()
Return the total number of processes.
Definition: ityr.hpp:107
 
ForwardIteratorD move(const ExecutionPolicy &policy, ForwardIterator1 first1, ForwardIterator1 last1, ForwardIteratorD first_d)
Move a range to another.
Definition: parallel_loop.hpp:934
 
Definition: prof_events.hpp:190
 
Definition: prof_events.hpp:205
 
Definition: prof_events.hpp:200
 
Definition: prof_events.hpp:155
 
Definition: prof_events.hpp:180
 
Definition: prof_events.hpp:175
 
Definition: prof_events.hpp:210
 
Definition: randws.hpp:22
 
void * evacuation_ptr
Definition: randws.hpp:23
 
std::size_t frame_size
Definition: randws.hpp:25
 
void * frame_base
Definition: randws.hpp:24
 
Definition: randws.hpp:48
 
dag_profiler dag_prof_acc
Definition: randws.hpp:51
 
dag_profiler dag_prof_before
Definition: randws.hpp:50
 
task_group_data * parent
Definition: randws.hpp:49
 
Definition: randws.hpp:42
 
bool serialized
Definition: randws.hpp:44
 
thread_state< T > * state
Definition: randws.hpp:43
 
thread_retval< T > retval_ser
Definition: randws.hpp:45
 
Definition: randws.hpp:54
 
task_group_data * tgdata
Definition: randws.hpp:55
 
dag_profiler dag_prof
Definition: randws.hpp:56
 
Definition: randws.hpp:29
 
dag_profiler dag_prof
Definition: randws.hpp:31
 
T value
Definition: randws.hpp:30
 
Definition: randws.hpp:35
 
int resume_flag
Definition: randws.hpp:37
 
thread_retval< T > retval
Definition: randws.hpp:36
 
suspended_state suspended
Definition: randws.hpp:38
 
Definition: options.hpp:20
 
Definition: options.hpp:38
 
Definition: options.hpp:32
 
Definition: options.hpp:26