Itoyori  v0.0.1
worker.hpp
Go to the documentation of this file.
1 #pragma once
2 
3 #include "ityr/common/util.hpp"
5 #include "ityr/ito/scheduler.hpp"
6 
7 namespace ityr::ito::worker {
8 
9 class worker {
10 public:
12  : sched_() {}
13 
14  template <typename SchedLoopCallback, typename Fn, typename... Args>
15  auto root_exec(SchedLoopCallback cb, Fn&& fn, Args&&... args) {
16  ITYR_CHECK(is_spmd_);
17  is_spmd_ = false;
18 
19  using retval_t = std::invoke_result_t<Fn, Args...>;
20  if constexpr (std::is_void_v<retval_t>) {
21  if (common::topology::my_rank() == coll_master_) {
22  sched_.root_exec<no_retval_t>(cb, std::forward<Fn>(fn), std::forward<Args>(args)...);
23  } else {
24  common::profiler::switch_phase<prof_phase_spmd, prof_phase_sched_loop>();
25  sched_.sched_loop(cb);
26  common::profiler::switch_phase<prof_phase_sched_loop, prof_phase_spmd>();
27  }
28 
29  is_spmd_ = true;
30 
32 
33  } else {
34  retval_t retval {};
35  if (common::topology::my_rank() == coll_master_) {
36  retval = sched_.root_exec<retval_t>(cb, std::forward<Fn>(fn), std::forward<Args>(args)...);
37  } else {
38  common::profiler::switch_phase<prof_phase_spmd, prof_phase_sched_loop>();
39  sched_.sched_loop(cb);
40  common::profiler::switch_phase<prof_phase_sched_loop, prof_phase_spmd>();
41  }
42 
43  is_spmd_ = true;
44 
46 
47  return common::mpi_bcast_value(retval, coll_master_, common::topology::mpicomm());
48  }
49  }
50 
51  template <typename Fn, typename... Args>
52  auto coll_exec(const Fn& fn, const Args&... args) {
53  ITYR_CHECK(!is_spmd_);
54 
55  using retval_t = std::invoke_result_t<Fn, Args...>;
56 
57  auto next_master = common::topology::my_rank();
58  std::conditional_t<std::is_void_v<retval_t>, no_retval_t, retval_t> retv;
59 
60  auto coll_task_fn = [=, &retv]() {
61  is_spmd_ = true;
62  auto prev_coll_master = coll_master_;
63  coll_master_ = next_master;
64  if constexpr (std::is_void_v<retval_t>) {
65  fn(args...);
66  (void)retv;
67  } else {
68  auto&& ret = fn(args...);
69  if (common::topology::my_rank() == next_master) {
70  retv = std::forward<decltype(ret)>(ret);
71  }
72  }
73  coll_master_ = prev_coll_master;
74  is_spmd_ = false;
75  };
76 
77  sched_.coll_exec(coll_task_fn);
78 
79  if constexpr (!std::is_void_v<retval_t>) {
80  return retv;
81  }
82  }
83 
84  bool is_spmd() const { return is_spmd_; }
85 
86  scheduler& sched() { return sched_; }
87 
88 private:
89  scheduler sched_;
90  bool is_spmd_ = true;
91  common::topology::rank_t coll_master_ = 0;
92 };
93 
95 
96 }
Definition: util.hpp:176
Definition: worker.hpp:9
worker()
Definition: worker.hpp:11
bool is_spmd() const
Definition: worker.hpp:84
auto root_exec(SchedLoopCallback cb, Fn &&fn, Args &&... args)
Definition: worker.hpp:15
auto coll_exec(const Fn &fn, const Args &... args)
Definition: worker.hpp:52
scheduler & sched()
Definition: worker.hpp:86
#define ITYR_CHECK(cond)
Definition: util.hpp:48
int rank_t
Definition: topology.hpp:12
MPI_Comm mpicomm()
Definition: topology.hpp:206
rank_t my_rank()
Definition: topology.hpp:207
va_list args
Definition: util.hpp:76
T mpi_bcast_value(const T &value, int root_rank, MPI_Comm comm)
Definition: mpi_util.hpp:145
void mpi_barrier(MPI_Comm comm)
Definition: mpi_util.hpp:42
Definition: worker.hpp:7
ITYR_CONCAT(scheduler_, ITYR_ITO_SCHEDULER) scheduler
Definition: scheduler.hpp:11
Definition: util.hpp:126