Itoyori  v0.0.1
randws.hpp
Go to the documentation of this file.
1 #pragma once
2 
3 #include "ityr/common/util.hpp"
7 #include "ityr/common/logger.hpp"
10 #include "ityr/ito/util.hpp"
11 #include "ityr/ito/options.hpp"
12 #include "ityr/ito/context.hpp"
13 #include "ityr/ito/callstack.hpp"
14 #include "ityr/ito/wsqueue.hpp"
15 #include "ityr/ito/prof_events.hpp"
16 #include "ityr/ito/sched/util.hpp"
17 
18 namespace ityr::ito {
19 
21 public:
22  struct suspended_state {
24  void* frame_base;
25  std::size_t frame_size;
26  };
27 
28  template <typename T>
29  struct thread_retval {
30  T value;
32  };
33 
34  template <typename T>
35  struct thread_state {
37  int resume_flag = 0;
39  };
40 
41  template <typename T>
42  struct thread_handler {
43  thread_state<T>* state = nullptr;
44  bool serialized = false;
45  thread_retval<T> retval_ser; // return the result by value if the thread is serialized
46  };
47 
48  struct task_group_data {
49  task_group_data* parent = nullptr;
52  };
53 
55  task_group_data* tgdata = nullptr;
57  };
58 
60  : stack_(stack_size_option::value()),
61  // Add a margin of sizeof(context_frame) to the bottom of the stack, because
62  // this region can be accessed by the clear_parent_frame() function later.
63  // This stack base is updated only in coll_exec().
64  stack_base_(reinterpret_cast<context_frame*>(stack_.bottom()) - 1),
65  wsq_(wsqueue_capacity_option::value()),
66  thread_state_allocator_(thread_state_allocator_size_option::value()),
67  suspended_thread_allocator_(suspended_thread_allocator_size_option::value()) {}
68 
69  template <typename T, typename SchedLoopCallback, typename Fn, typename... Args>
70  T root_exec(SchedLoopCallback cb, Fn&& fn, Args&&... args) {
71  common::profiler::switch_phase<prof_phase_spmd, prof_phase_sched_fork>();
72 
73  thread_state<T>* ts = new (thread_state_allocator_.allocate(sizeof(thread_state<T>))) thread_state<T>;
74 
75  auto prev_sched_cf = sched_cf_;
76 
77  suspend([&](context_frame* cf) {
78  sched_cf_ = cf;
79  root_on_stack([&, ts, fn = std::forward<Fn>(fn),
80  args_tuple = std::make_tuple(std::forward<Args>(args)...)]() mutable {
81  common::verbose("Starting root thread %p", ts);
82 
83  tls_ = new (alloca(sizeof(thread_local_storage))) thread_local_storage{};
84 
85  tls_->dag_prof.start();
86  tls_->dag_prof.increment_thread_count();
87  tls_->dag_prof.increment_strand_count();
88 
89  common::profiler::switch_phase<prof_phase_sched_fork, prof_phase_thread>();
90 
91  T&& ret = invoke_fn<T>(std::forward<decltype(fn)>(fn), std::forward<decltype(args_tuple)>(args_tuple));
92 
93  common::profiler::switch_phase<prof_phase_thread, prof_phase_sched_die>();
94  common::verbose("Root thread %p is completed", ts);
95 
96  tls_->dag_prof.stop();
97 
98  on_root_die(ts, std::move(ret));
99  });
100  });
101 
102  sched_loop(cb);
103 
104  common::profiler::switch_phase<prof_phase_sched_loop, prof_phase_sched_join>();
105 
106  thread_retval<T> retval = std::move(ts->retval);
107  std::destroy_at(ts);
108  thread_state_allocator_.deallocate(ts, sizeof(thread_state<T>));
109 
110  if (dag_prof_enabled_) {
111  if (tls_) {
112  // nested root/coll_exec()
113  tls_->dag_prof.merge_serial(retval.dag_prof);
114  } else {
115  dag_prof_result_.merge_serial(retval.dag_prof);
116  }
117  }
118 
119  sched_cf_ = prev_sched_cf;
120 
121  common::profiler::switch_phase<prof_phase_sched_join, prof_phase_spmd>();
122 
123  return std::move(retval.value);
124  }
125 
126  template <typename T, typename OnDriftForkCallback, typename OnDriftDieCallback,
127  typename WorkHint, typename Fn, typename... Args>
129  OnDriftForkCallback on_drift_fork_cb, OnDriftDieCallback on_drift_die_cb,
130  WorkHint, WorkHint, Fn&& fn, Args&&... args) {
131  common::profiler::switch_phase<prof_phase_thread, prof_phase_sched_fork>();
132 
133  thread_state<T>* ts = new (thread_state_allocator_.allocate(sizeof(thread_state<T>))) thread_state<T>;
134  th.state = ts;
135  th.serialized = false;
136 
137  suspend([&, ts, fn = std::forward<Fn>(fn),
138  args_tuple = std::make_tuple(std::forward<Args>(args)...)](context_frame* cf) mutable {
139  common::verbose<2>("push context frame [%p, %p) into task queue", cf, cf->parent_frame);
140 
141  tls_ = new (alloca(sizeof(thread_local_storage))) thread_local_storage{};
142 
143  std::size_t cf_size = reinterpret_cast<uintptr_t>(cf->parent_frame) - reinterpret_cast<uintptr_t>(cf);
144  wsq_.push(wsqueue_entry{cf, cf_size});
145 
146  tls_->dag_prof.start();
147  tls_->dag_prof.increment_thread_count();
148  tls_->dag_prof.increment_strand_count();
149 
150  common::verbose<2>("Starting new thread %p", ts);
151  common::profiler::switch_phase<prof_phase_sched_fork, prof_phase_thread>();
152 
153  T&& ret = invoke_fn<T>(std::forward<decltype(fn)>(fn), std::forward<decltype(args_tuple)>(args_tuple));
154 
155  common::profiler::switch_phase<prof_phase_thread, prof_phase_sched_die>();
156  common::verbose<2>("Thread %p is completed", ts);
157 
158  on_task_die();
159  on_die(ts, std::move(ret), on_drift_die_cb);
160 
161  common::verbose<2>("Thread %p is serialized (fast path)", ts);
162 
163  // The following is executed only when the thread is serialized
164  std::destroy_at(ts);
165  thread_state_allocator_.deallocate(ts, sizeof(thread_state<T>));
166  th.state = nullptr;
167  th.serialized = true;
168  th.retval_ser = {std::move(ret), tls_->dag_prof};
169 
170  common::verbose<2>("Resume parent context frame [%p, %p) (fast path)", cf, cf->parent_frame);
171 
172  common::profiler::switch_phase<prof_phase_sched_die, prof_phase_sched_resume_popped>();
173  });
174 
175  if (th.serialized) {
176  common::profiler::switch_phase<prof_phase_sched_resume_popped, prof_phase_thread>();
177  } else {
180  prof_phase_thread>(on_drift_fork_cb);
181  }
182 
183  // restart to count only the last task in the task group
184  tls_->dag_prof.clear();
185  tls_->dag_prof.start();
186  tls_->dag_prof.increment_strand_count();
187  }
188 
189  template <typename T>
191  common::profiler::switch_phase<prof_phase_thread, prof_phase_sched_join>();
192 
193  thread_retval<T> retval;
194  if (th.serialized) {
195  common::verbose<2>("Skip join for serialized thread (fast path)");
196  // We can skip deallocaton for its thread state because it has been already deallocated
197  // when the thread is serialized (i.e., at a fork)
198  retval = std::move(th.retval_ser);
199 
200  } else {
201  on_task_die();
202 
203  ITYR_CHECK(th.state != nullptr);
204  thread_state<T>* ts = th.state;
205 
206  if (remote_get_value(thread_state_allocator_, &ts->resume_flag) >= 1) {
207  common::verbose("Thread %p is already joined", ts);
208  if constexpr (!std::is_same_v<T, no_retval_t> || dag_profiler::enabled) {
209  retval = get_retval_remote(ts);
210  }
211 
212  } else {
213  bool migrated = true;
214  suspend([&](context_frame* cf) {
215  suspended_state ss = evacuate(cf);
216 
217  remote_put_value(thread_state_allocator_, ss, &ts->suspended);
218 
219  // race
220  if (remote_faa_value(thread_state_allocator_, 1, &ts->resume_flag) == 0) {
221  common::verbose("Win the join race for thread %p (joining thread)", ts);
222  common::profiler::switch_phase<prof_phase_sched_join, prof_phase_sched_loop>();
223  resume_sched();
224  } else {
225  common::verbose("Lose the join race for thread %p (joining thread)", ts);
226  suspended_thread_allocator_.deallocate(ss.evacuation_ptr, ss.frame_size);
227  migrated = false;
228  }
229  });
230 
231  common::verbose("Resume continuation of join for thread %p", ts);
232 
233  if (migrated) {
234  common::profiler::switch_phase<prof_phase_sched_resume_join, prof_phase_sched_join>();
235  }
236 
237  if constexpr (!std::is_same_v<T, no_retval_t> || dag_profiler::enabled) {
238  retval = get_retval_remote(ts);
239  }
240  }
241 
242  // TODO: correctly destroy T remotely if nontrivially destructible
243  /* std::destroy_at(ts); */
244 
245  thread_state_allocator_.deallocate(ts, sizeof(thread_state<T>));
246  th.state = nullptr;
247  }
248 
249  if (tls_->tgdata) {
250  tls_->tgdata->dag_prof_acc.merge_parallel(retval.dag_prof);
251  }
252 
253  common::profiler::switch_phase<prof_phase_sched_join, prof_phase_thread>();
254  return std::move(retval.value);
255  }
256 
257  template <typename SchedLoopCallback>
258  void sched_loop(SchedLoopCallback cb) {
259  common::verbose("Enter scheduling loop");
260 
261  while (!should_exit_sched_loop()) {
262  auto mte = migration_mailbox_.pop();
263  if (mte.has_value()) {
264  execute_migrated_task(*mte);
265  continue;
266  }
267 
268  steal();
269 
270  if constexpr (!std::is_null_pointer_v<std::remove_reference_t<SchedLoopCallback>>) {
271  cb();
272  }
273  }
274 
275  common::verbose("Exit scheduling loop");
276  }
277 
278  template <typename PreSuspendCallback, typename PostSuspendCallback>
279  void poll(PreSuspendCallback&&, PostSuspendCallback&&) {}
280 
281  template <typename PreSuspendCallback, typename PostSuspendCallback>
283  PreSuspendCallback&& pre_suspend_cb,
284  PostSuspendCallback&& post_suspend_cb) {
285  // Currently only for the root thread
287 
288  if (target_rank == common::topology::my_rank()) return;
289 
293  std::forward<PreSuspendCallback>(pre_suspend_cb));
294 
295  suspend([&](context_frame* cf) {
296  suspended_state ss = evacuate(cf);
297 
298  common::verbose("Migrate continuation of the root thread to process %d",
299  target_rank);
300 
301  migration_mailbox_.put(ss, target_rank);
302 
303  common::profiler::switch_phase<prof_phase_sched_migrate, prof_phase_sched_loop>();
304  resume_sched();
305  });
306 
310  std::forward<PostSuspendCallback>(post_suspend_cb), cb_ret);
311  }
312 
313  template <typename Fn>
314  void coll_exec(const Fn& fn) {
315  common::profiler::switch_phase<prof_phase_thread, prof_phase_spmd>();
316 
317  tls_->dag_prof.stop();
318  // TODO: consider dag prof for inside coll tasks
319 
320  using callable_task_t = callable_task<Fn>;
321 
322  size_t task_size = sizeof(callable_task_t);
323  void* task_ptr = suspended_thread_allocator_.allocate(task_size);
324 
325  auto t = new (task_ptr) callable_task_t(fn);
326 
327  coll_task ct {task_ptr, task_size, common::topology::my_rank()};
328  execute_coll_task(t, ct);
329 
330  suspended_thread_allocator_.deallocate(t, task_size);
331 
332  tls_->dag_prof.start();
333  tls_->dag_prof.increment_strand_count();
334 
335  common::profiler::switch_phase<prof_phase_spmd, prof_phase_thread>();
336  }
337 
338  bool is_executing_root() const {
339  return cf_top_ && cf_top_ == stack_base_;
340  }
341 
342  template <typename T>
343  static bool is_serialized(const thread_handler<T>& th) {
344  return th.serialized;
345  }
346 
348  tls_->dag_prof.stop();
349 
350  tgdata->parent = tls_->tgdata;
351  tgdata->dag_prof_before = tls_->dag_prof;
352 
353  tls_->tgdata = tgdata;
354 
355  tls_->dag_prof.clear();
356  tls_->dag_prof.start();
357  tls_->dag_prof.increment_strand_count();
358  }
359 
360  template <typename PreSuspendCallback, typename PostSuspendCallback>
361  void task_group_end(PreSuspendCallback&&, PostSuspendCallback&&) {
362  on_task_die();
363 
364  task_group_data* tgdata = tls_->tgdata;
365  ITYR_CHECK(tgdata);
366 
367  tls_->dag_prof = tgdata->dag_prof_before;
368  tls_->dag_prof.merge_serial(tgdata->dag_prof_acc);
369 
370  tls_->tgdata = tls_->tgdata->parent;
371 
372  tls_->dag_prof.start();
373  tls_->dag_prof.increment_strand_count();
374  }
375 
376  void dag_prof_begin() {
377  dag_prof_enabled_ = true;
378  dag_prof_result_.clear();
379  if (tls_) {
380  // nested root/coll_exec()
381  tls_->dag_prof.clear();
382  tls_->dag_prof.increment_thread_count();
383  }
384  }
385 
386  void dag_prof_end() {
387  dag_prof_enabled_ = false;
388  if constexpr (dag_profiler::enabled) {
389  common::topology::rank_t result_owner = 0;
390  if (tls_) {
391  // nested root/coll_exec()
392  dag_prof_result_ = tls_->dag_prof;
393  result_owner = common::topology::my_rank();
394  }
395  result_owner = common::mpi_allreduce_value(result_owner, common::topology::mpicomm(), MPI_MAX);
396  dag_prof_result_ = common::mpi_bcast_value(dag_prof_result_, result_owner, common::topology::mpicomm());
397  }
398  }
399 
400  void dag_prof_print() const {
401  if (common::topology::my_rank() == 0) {
402  dag_prof_result_.print();
403  }
404  }
405 
406 private:
407  struct coll_task {
408  void* task_ptr;
409  std::size_t task_size;
410  common::topology::rank_t master_rank;
411  };
412 
413  void on_task_die() {
414  if (!tls_->dag_prof.is_stopped()) {
415  tls_->dag_prof.stop();
416  if (tls_->tgdata) {
417  tls_->tgdata->dag_prof_acc.merge_parallel(tls_->dag_prof);
418  }
419  }
420  }
421 
422  template <typename T, typename OnDriftDieCallback>
423  void on_die(thread_state<T>* ts, T&& ret, OnDriftDieCallback on_drift_die_cb) {
424  auto qe = wsq_.pop();
425  bool serialized = qe.has_value();
426 
427  if (serialized) {
428  return;
429  }
430 
431  call_with_prof_events<prof_phase_sched_die,
432  prof_phase_cb_drift_die,
433  prof_phase_sched_die>(on_drift_die_cb);
434 
435  if constexpr (!std::is_same_v<T, no_retval_t> || dag_profiler::enabled) {
436  put_retval_remote(ts, {std::move(ret), tls_->dag_prof});
437  }
438 
439  // race
440  if (remote_faa_value(thread_state_allocator_, 1, &ts->resume_flag) == 0) {
441  common::verbose("Win the join race for thread %p (joined thread)", ts);
442  common::profiler::switch_phase<prof_phase_sched_die, prof_phase_sched_loop>();
443  resume_sched();
444  } else {
445  common::verbose("Lose the join race for thread %p (joined thread)", ts);
446  common::profiler::switch_phase<prof_phase_sched_die, prof_phase_sched_resume_join>();
447  suspended_state ss = remote_get_value(thread_state_allocator_, &ts->suspended);
448  resume(ss);
449  }
450  }
451 
452  template <typename T>
453  void on_root_die(thread_state<T>* ts, T&& ret) {
454  if constexpr (!std::is_same_v<T, no_retval_t> || dag_profiler::enabled) {
455  put_retval_remote(ts, {std::move(ret), tls_->dag_prof});
456  }
457  remote_put_value(thread_state_allocator_, 1, &ts->resume_flag);
458 
459  exit_request_mailbox_.put(0);
460 
461  common::profiler::switch_phase<prof_phase_sched_die, prof_phase_sched_loop>();
462  resume_sched();
463  }
464 
465  void steal() {
466  auto target_rank = get_random_rank(0, common::topology::n_ranks() - 1);
467 
468  auto ibd = common::profiler::interval_begin<prof_event_sched_steal>(target_rank);
469 
470  if (wsq_.empty(target_rank)) {
471  common::profiler::interval_end<prof_event_sched_steal>(ibd, false);
472  return;
473  }
474 
475  if (!wsq_.lock().trylock(target_rank)) {
476  common::profiler::interval_end<prof_event_sched_steal>(ibd, false);
477  return;
478  }
479 
480  auto we = wsq_.steal_nolock(target_rank);
481  if (!we.has_value()) {
482  wsq_.lock().unlock(target_rank);
483  common::profiler::interval_end<prof_event_sched_steal>(ibd, false);
484  return;
485  }
486 
487  common::verbose("Steal context frame [%p, %p) from rank %d",
488  we->frame_base, reinterpret_cast<std::byte*>(we->frame_base) + we->frame_size, target_rank);
489 
490  stack_.direct_copy_from(we->frame_base, we->frame_size, target_rank);
491 
492  wsq_.lock().unlock(target_rank);
493 
494  common::profiler::interval_end<prof_event_sched_steal>(ibd, true);
495 
496  common::profiler::switch_phase<prof_phase_sched_loop, prof_phase_sched_resume_stolen>();
497 
498  context_frame* next_cf = reinterpret_cast<context_frame*>(we->frame_base);
499  suspend([&](context_frame* cf) {
500  sched_cf_ = cf;
501  context::clear_parent_frame(next_cf);
502  resume(next_cf);
503  });
504  }
505 
506  template <typename Fn>
507  void suspend(Fn&& fn) {
508  context_frame* prev_cf_top = cf_top_;
509  thread_local_storage* prev_tls = tls_;
510 
511  context::save_context_with_call(prev_cf_top,
512  [](context_frame* cf, void* cf_top_p, void* fn_p) {
513  context_frame*& cf_top = *reinterpret_cast<context_frame**>(cf_top_p);
514  Fn fn = std::forward<Fn>(*reinterpret_cast<Fn*>(fn_p)); // copy closure to the new stack frame
515  cf_top = cf;
516  fn(cf);
517  }, &cf_top_, &fn, prev_tls);
518 
519  cf_top_ = prev_cf_top;
520  tls_ = prev_tls;
521  }
522 
523  void resume(context_frame* cf) {
524  common::verbose("Resume context frame [%p, %p) in the stack", cf, cf->parent_frame);
525  context::resume(cf);
526  }
527 
528  void resume(suspended_state ss) {
529  common::verbose("Resume context frame [%p, %p) evacuated at %p",
530  ss.frame_base, ss.frame_size, ss.evacuation_ptr);
531 
532  // We pass the suspended thread states *by value* because the current local variables can be overwritten by the
533  // new stack we will bring from remote nodes.
534  context::jump_to_stack(ss.frame_base, [](void* allocator_, void* evacuation_ptr, void* frame_base, void* frame_size_) {
535  common::remotable_resource& allocator = *reinterpret_cast<common::remotable_resource*>(allocator_);
536  std::size_t frame_size = reinterpret_cast<std::size_t>(frame_size_);
537  common::remote_get(allocator,
538  reinterpret_cast<std::byte*>(frame_base),
539  reinterpret_cast<std::byte*>(evacuation_ptr),
540  frame_size);
541  allocator.deallocate(evacuation_ptr, frame_size);
542 
543  context_frame* cf = reinterpret_cast<context_frame*>(frame_base);
544  context::clear_parent_frame(cf);
545  context::resume(cf);
546  }, &suspended_thread_allocator_, ss.evacuation_ptr, ss.frame_base, reinterpret_cast<void*>(ss.frame_size));
547  }
548 
549  void resume_sched() {
550  common::verbose("Resume scheduler context");
551  context::resume(sched_cf_);
552  }
553 
554  void execute_migrated_task(const suspended_state& ss) {
555  ITYR_CHECK(ss.evacuation_ptr);
556  common::verbose("Received a continuation of the root thread");
557  common::profiler::switch_phase<prof_phase_sched_loop, prof_phase_sched_resume_migrate>();
558 
559  suspend([&](context_frame* cf) {
560  sched_cf_ = cf;
561  resume(ss);
562  });
563  }
564 
565  suspended_state evacuate(context_frame* cf) {
566  std::size_t cf_size = reinterpret_cast<uintptr_t>(cf->parent_frame) - reinterpret_cast<uintptr_t>(cf);
567  void* evacuation_ptr = suspended_thread_allocator_.allocate(cf_size);
568  std::memcpy(evacuation_ptr, cf, cf_size);
569 
570  common::verbose("Evacuate suspended thread context [%p, %p) to %p",
571  cf, cf->parent_frame, evacuation_ptr);
572 
573  return {evacuation_ptr, cf, cf_size};
574  }
575 
576  template <typename Fn>
577  void root_on_stack(Fn&& fn) {
578  cf_top_ = stack_base_;
579  std::size_t stack_size_bytes = reinterpret_cast<std::byte*>(stack_base_) -
580  reinterpret_cast<std::byte*>(stack_.top());
581  context::call_on_stack(stack_.top(), stack_size_bytes,
582  [](void* fn_, void*, void*, void*) {
583  Fn fn = std::forward<Fn>(*reinterpret_cast<Fn*>(fn_)); // copy closure to the new stack frame
584  fn();
585  }, &fn, nullptr, nullptr, nullptr);
586  }
587 
588  void execute_coll_task(task_general* t, coll_task ct) {
589  // TODO: consider copy semantics for tasks
590  coll_task ct_ {t, ct.task_size, ct.master_rank};
591 
592  // pass coll task to other processes in a binary tree form
595  auto my_rank_shifted = (my_rank + n_ranks - ct.master_rank) % n_ranks;
596  for (common::topology::rank_t i = common::next_pow2(n_ranks); i > 1; i /= 2) {
597  if (my_rank_shifted % i == 0) {
598  auto target_rank_shifted = my_rank_shifted + i / 2;
599  if (target_rank_shifted < n_ranks) {
600  auto target_rank = (target_rank_shifted + ct.master_rank) % n_ranks;
601  coll_task_mailbox_.put(ct_, target_rank);
602  }
603  }
604  }
605 
606  auto prev_stack_base = stack_base_;
607  if (my_rank == ct.master_rank) {
608  // Allocate half the rest of the stack space for nested root/coll_exec()
609  stack_base_ = cf_top_ - (cf_top_ - reinterpret_cast<context_frame*>(stack_.top())) / 2;
610  }
611 
612  // In addition, collectively set the next stack base for nested root_exec() calls because
613  // the stack frame of the scheduler of the master worker is in the RDMA-capable stack region.
614  // TODO: check if the scheduler's stack frame and nested root_exec()'s stack frame do not overlap
615  stack_base_ = common::mpi_bcast_value(stack_base_, ct.master_rank, common::topology::mpicomm());
616 
617  // Ensure all processes have finished coll task execution before deallocation.
619 
620  t->execute();
621 
622  stack_base_ = prev_stack_base;
623 
624  // Ensure all processes have finished coll task execution before deallocation
626  }
627 
628  void execute_coll_task_if_arrived() {
629  auto ct = coll_task_mailbox_.pop();
630  if (ct.has_value()) {
631  task_general* t = reinterpret_cast<task_general*>(
632  suspended_thread_allocator_.allocate(ct->task_size));
633 
634  common::remote_get(suspended_thread_allocator_,
635  reinterpret_cast<std::byte*>(t),
636  reinterpret_cast<std::byte*>(ct->task_ptr),
637  ct->task_size);
638 
639  common::profiler::switch_phase<prof_phase_sched_loop, prof_phase_spmd>();
640 
641  execute_coll_task(t, *ct);
642 
643  common::profiler::switch_phase<prof_phase_spmd, prof_phase_sched_loop>();
644 
645  suspended_thread_allocator_.deallocate(t, ct->task_size);
646  }
647  }
648 
649  bool should_exit_sched_loop() {
652  }
653 
654  execute_coll_task_if_arrived();
655 
656  if (exit_request_mailbox_.pop()) {
659  for (common::topology::rank_t i = common::next_pow2(n_ranks); i > 1; i /= 2) {
660  if (my_rank % i == 0) {
661  auto target_rank = my_rank + i / 2;
662  if (target_rank < n_ranks) {
663  exit_request_mailbox_.put(target_rank);
664  }
665  }
666  }
667  return true;
668  }
669 
670  return false;
671  }
672 
673  template <typename T>
674  thread_retval<T> get_retval_remote(thread_state<T>* ts) {
675  if constexpr (std::is_trivially_copyable_v<T>) {
676  return remote_get_value(thread_state_allocator_, &ts->retval);
677  } else {
678  // TODO: Fix this ugly hack of avoiding object destruction by using checkout/checkin
679  thread_retval<T> retval;
680  remote_get(thread_state_allocator_, reinterpret_cast<std::byte*>(&retval), reinterpret_cast<std::byte*>(&ts->retval), sizeof(thread_retval<T>));
681  return retval;
682  }
683  }
684 
685  template <typename T>
686  void put_retval_remote(thread_state<T>* ts, thread_retval<T>&& retval) {
687  if constexpr (std::is_trivially_copyable_v<T>) {
688  remote_put_value(thread_state_allocator_, retval, &ts->retval);
689  } else {
690  // TODO: Fix this ugly hack of avoiding object destruction by using checkout/checkin
691  std::byte* retvalp = reinterpret_cast<std::byte*>(new (alloca(sizeof(thread_retval<T>))) thread_retval<T>{std::move(retval)});
692  remote_put(thread_state_allocator_, retvalp, reinterpret_cast<std::byte*>(&ts->retval), sizeof(thread_retval<T>));
693  }
694  }
695 
696  struct wsqueue_entry {
697  void* frame_base;
698  std::size_t frame_size;
699  };
700 
701  callstack stack_;
702  context_frame* stack_base_;
703  oneslot_mailbox<void> exit_request_mailbox_;
704  oneslot_mailbox<coll_task> coll_task_mailbox_;
705  oneslot_mailbox<suspended_state> migration_mailbox_;
706  wsqueue<wsqueue_entry> wsq_;
707  common::remotable_resource thread_state_allocator_;
708  common::remotable_resource suspended_thread_allocator_;
709  context_frame* cf_top_ = nullptr;
710  context_frame* sched_cf_ = nullptr;
711  thread_local_storage* tls_ = nullptr;
712  bool dag_prof_enabled_ = false;
713  dag_profiler dag_prof_result_;
714 };
715 
716 }
void unlock(topology::rank_t target_rank, int idx=0) const
Definition: global_lock.hpp:53
bool trylock(topology::rank_t target_rank, int idx=0) const
Definition: global_lock.hpp:21
static value_type value()
Definition: options.hpp:62
Definition: util.hpp:115
void direct_copy_from(void *addr, std::size_t size, common::topology::rank_t target_rank) const
Definition: callstack.hpp:25
void * top() const
Definition: callstack.hpp:21
void put(common::topology::rank_t target_rank)
Definition: util.hpp:266
bool pop()
Definition: util.hpp:273
std::optional< Entry > pop()
Definition: util.hpp:237
void put(const Entry &entry, common::topology::rank_t target_rank)
Definition: util.hpp:229
Definition: randws.hpp:20
scheduler_randws()
Definition: randws.hpp:59
T root_exec(SchedLoopCallback cb, Fn &&fn, Args &&... args)
Definition: randws.hpp:70
void sched_loop(SchedLoopCallback cb)
Definition: randws.hpp:258
T join(thread_handler< T > &th)
Definition: randws.hpp:190
void dag_prof_end()
Definition: randws.hpp:386
void fork(thread_handler< T > &th, OnDriftForkCallback on_drift_fork_cb, OnDriftDieCallback on_drift_die_cb, WorkHint, WorkHint, Fn &&fn, Args &&... args)
Definition: randws.hpp:128
void task_group_begin(task_group_data *tgdata)
Definition: randws.hpp:347
static bool is_serialized(const thread_handler< T > &th)
Definition: randws.hpp:343
void migrate_to(common::topology::rank_t target_rank, PreSuspendCallback &&pre_suspend_cb, PostSuspendCallback &&post_suspend_cb)
Definition: randws.hpp:282
void poll(PreSuspendCallback &&, PostSuspendCallback &&)
Definition: randws.hpp:279
bool is_executing_root() const
Definition: randws.hpp:338
void coll_exec(const Fn &fn)
Definition: randws.hpp:314
void dag_prof_print() const
Definition: randws.hpp:400
void dag_prof_begin()
Definition: randws.hpp:376
void task_group_end(PreSuspendCallback &&, PostSuspendCallback &&)
Definition: randws.hpp:361
std::optional< Entry > pop(int idx=0)
Definition: wsqueue.hpp:67
std::optional< Entry > steal_nolock(common::topology::rank_t target_rank, int idx=0)
Definition: wsqueue.hpp:158
void push(const Entry &entry, int idx=0)
Definition: wsqueue.hpp:36
const common::global_lock & lock() const
Definition: wsqueue.hpp:304
bool empty(common::topology::rank_t target_rank, int idx=0) const
Definition: wsqueue.hpp:271
#define ITYR_CHECK(cond)
Definition: util.hpp:48
bool enabled()
Definition: numa.hpp:86
rank_t n_ranks()
Definition: topology.hpp:208
int rank_t
Definition: topology.hpp:12
MPI_Comm mpicomm()
Definition: topology.hpp:206
rank_t my_rank()
Definition: topology.hpp:207
void remote_get(const remotable_resource &rmr, T *origin_p, const T *target_p, std::size_t size)
Definition: allocator.hpp:404
void remote_put_value(const remotable_resource &rmr, const T &val, T *target_p)
Definition: allocator.hpp:434
va_list args
Definition: util.hpp:76
T remote_faa_value(const remotable_resource &rmr, const T &val, T *target_p)
Definition: allocator.hpp:444
void mpi_make_progress()
Definition: mpi_util.hpp:260
T mpi_bcast_value(const T &value, int root_rank, MPI_Comm comm)
Definition: mpi_util.hpp:145
T mpi_allreduce_value(const T &value, MPI_Comm comm, MPI_Op op=MPI_SUM)
Definition: mpi_util.hpp:194
void remote_put(const remotable_resource &rmr, const T *origin_p, T *target_p, std::size_t size)
Definition: allocator.hpp:424
T remote_get_value(const remotable_resource &rmr, const T *target_p)
Definition: allocator.hpp:414
void mpi_barrier(MPI_Comm comm)
Definition: mpi_util.hpp:42
uint64_t next_pow2(uint64_t x)
Definition: util.hpp:102
void verbose(const char *fmt,...)
Definition: logger.hpp:11
Definition: aarch64.hpp:5
auto call_with_prof_events(Fn &&fn, Args &&... args)
Definition: util.hpp:182
ITYR_CONCAT(dag_profiler_, ITYR_ITO_DAG_PROF) dag_profiler
Definition: util.hpp:102
common::topology::rank_t get_random_rank(common::topology::rank_t a, common::topology::rank_t b)
Definition: util.hpp:128
rank_t my_rank()
Return the rank of the process running the current thread.
Definition: ityr.hpp:99
rank_t n_ranks()
Return the total number of processes.
Definition: ityr.hpp:107
ForwardIteratorD move(const ExecutionPolicy &policy, ForwardIterator1 first1, ForwardIterator1 last1, ForwardIteratorD first_d)
Move a range to another.
Definition: parallel_loop.hpp:934
Definition: prof_events.hpp:190
Definition: prof_events.hpp:205
Definition: prof_events.hpp:200
Definition: prof_events.hpp:155
Definition: prof_events.hpp:180
Definition: prof_events.hpp:175
Definition: prof_events.hpp:210
void * evacuation_ptr
Definition: randws.hpp:23
std::size_t frame_size
Definition: randws.hpp:25
void * frame_base
Definition: randws.hpp:24
dag_profiler dag_prof_acc
Definition: randws.hpp:51
dag_profiler dag_prof_before
Definition: randws.hpp:50
task_group_data * parent
Definition: randws.hpp:49
bool serialized
Definition: randws.hpp:44
thread_state< T > * state
Definition: randws.hpp:43
thread_retval< T > retval_ser
Definition: randws.hpp:45
task_group_data * tgdata
Definition: randws.hpp:55
dag_profiler dag_prof
Definition: randws.hpp:56
dag_profiler dag_prof
Definition: randws.hpp:31
int resume_flag
Definition: randws.hpp:37
thread_retval< T > retval
Definition: randws.hpp:36
suspended_state suspended
Definition: randws.hpp:38
Definition: options.hpp:20
Definition: options.hpp:26