Itoyori  v0.0.1
parallel_invoke.hpp
Go to the documentation of this file.
1 #pragma once
2 
3 #include <variant>
4 
5 #include "ityr/common/util.hpp"
6 #include "ityr/ito/ito.hpp"
7 #include "ityr/ori/ori.hpp"
10 
11 namespace ityr {
12 
13 template <typename W>
14 struct workhint {
15  constexpr explicit workhint(W w) : value(w) {}
16  W value;
17 };
18 
19 namespace internal {
20 
21 template <typename... Args>
22 struct count_num_tasks;
23 
24 template <typename Fn>
25 struct count_num_tasks<Fn> {
26  static constexpr int value = 1;
27 };
28 
29 template <typename Fn1, typename Fn2, typename... Rest>
30 struct count_num_tasks<Fn1, Fn2, Rest...> {
31  static constexpr int value = 1 + count_num_tasks<Fn2, Rest...>::value;
32 };
33 
34 template <typename Fn, typename... Args>
35 struct count_num_tasks<Fn, std::tuple<Args...>> {
36  static constexpr int value = 1;
37 };
38 
39 template <typename Fn, typename... Args, typename... Rest>
40 struct count_num_tasks<Fn, std::tuple<Args...>, Rest...> {
41  static constexpr int value = 1 + count_num_tasks<Rest...>::value;
42 };
43 
44 static_assert(count_num_tasks<void (*)()>::value == 1);
45 static_assert(count_num_tasks<void (*)(int), std::tuple<int>>::value == 1);
46 static_assert(count_num_tasks<void (*)(), int (*)()>::value == 2);
47 static_assert(count_num_tasks<void (*)(int, int), std::tuple<int, int>, void (*)()>::value == 2);
48 static_assert(count_num_tasks<void (*)(int), std::tuple<int>, int (*)(), void (*)()>::value == 3);
49 
50 template <typename Fn, typename W>
51 constexpr inline W get_total_work(const Fn&, workhint<W> wh) {
52  return wh.value;
53 }
54 
55 template <typename Fn, typename... Args, typename W>
56 constexpr inline W get_total_work(const Fn&, const std::tuple<Args...>&, workhint<W> wh) {
57  return wh.value;
58 }
59 
60 template <typename Fn, typename W, typename... Rest>
61 constexpr inline W get_total_work(const Fn&, workhint<W> wh, const Rest&... rest) {
62  return wh.value + get_total_work(rest...);
63 }
64 
65 template <typename Fn, typename... Args, typename W, typename... Rest>
66 constexpr inline W get_total_work(const Fn&, const std::tuple<Args...>&, workhint<W> wh, const Rest&... rest) {
67  return wh.value + get_total_work(rest...);
68 }
69 
70 template <typename ReleaseHandler>
71 struct parallel_invoke_state {
72 public:
73  parallel_invoke_state(ReleaseHandler rh) : rh_(rh) {}
74 
75  bool all_serialized() const { return all_serialized_; }
76 
77  inline auto parallel_invoke_aux() {
78  return std::make_tuple();
79  }
80 
81  template <typename Fn>
82  inline auto parallel_invoke_aux(Fn&& fn) {
83  // insert an empty arg tuple
84  return parallel_invoke_with_args(std::forward<Fn>(fn), std::make_tuple());
85  }
86 
87  template <typename Fn1, typename Fn2, typename... Rest>
88  inline auto parallel_invoke_aux(Fn1&& fn1, Fn2&& fn2, Rest&&... rest) {
89  // insert an empty arg tuple
90  return parallel_invoke_with_args(std::forward<Fn1>(fn1), std::make_tuple(),
91  std::forward<Fn2>(fn2), std::forward<Rest>(rest)...);
92  }
93 
94  template <typename Fn, typename... Args, typename... Rest>
95  inline auto parallel_invoke_aux(Fn&& fn, const std::tuple<Args...>& args, Rest&&... rest) {
96  // universal reference is not available for concrete types (std::tuple in this case)
97  return parallel_invoke_with_args(std::forward<Fn>(fn), args, std::forward<Rest>(rest)...);
98  }
99 
100  template <typename Fn, typename... Args, typename... Rest>
101  inline auto parallel_invoke_aux(Fn&& fn, std::tuple<Args...>&& args, Rest&&... rest) {
102  return parallel_invoke_with_args(std::forward<Fn>(fn), std::move(args), std::forward<Rest>(rest)...);
103  }
104 
105  /* with work hints */
106 
107  template <typename Fn, typename W, typename... Rest>
108  inline auto parallel_invoke_aux(Fn&& fn, workhint<W> wh, Rest&&... rest) {
109  return parallel_invoke_with_args(std::forward<Fn>(fn), std::make_tuple(), wh, std::forward<Rest>(rest)...);
110  }
111 
112  template <typename Fn, typename... Args, typename W, typename... Rest>
113  inline auto parallel_invoke_aux(Fn&& fn, const std::tuple<Args...>& args, workhint<W> wh, Rest&&... rest) {
114  return parallel_invoke_with_args(std::forward<Fn>(fn), args, wh, std::forward<Rest>(rest)...);
115  }
116 
117  template <typename Fn, typename... Args, typename W, typename... Rest>
118  inline auto parallel_invoke_aux(Fn&& fn, std::tuple<Args...>&& args, workhint<W> wh, Rest&&... rest) {
119  return parallel_invoke_with_args(std::forward<Fn>(fn), std::move(args), wh, std::forward<Rest>(rest)...);
120  }
121 
122 private:
123  template <typename Fn, typename ArgsTuple>
124  inline auto parallel_invoke_with_args(Fn&& fn, ArgsTuple&& args_tuple) {
125  return do_parallel_invoke(std::forward<Fn>(fn), std::forward<ArgsTuple>(args_tuple));
126  }
127 
128  template <typename Fn, typename ArgsTuple, typename... Rest>
129  inline auto parallel_invoke_with_args(Fn&& fn, ArgsTuple&& args_tuple, Rest&&... rest) {
130  constexpr int n_rest_tasks = count_num_tasks<Rest...>::value;
131  static_assert(n_rest_tasks > 0);
132  return do_parallel_invoke(std::forward<Fn>(fn), std::forward<ArgsTuple>(args_tuple),
133  ito::workhint(1, n_rest_tasks), std::forward<Rest>(rest)...);
134  }
135 
136  template <typename Fn, typename ArgsTuple, typename W>
137  inline auto parallel_invoke_with_args(Fn&& fn, ArgsTuple&& args_tuple, workhint<W>) {
138  return do_parallel_invoke(std::forward<Fn>(fn), std::forward<ArgsTuple>(args_tuple));
139  }
140 
141  template <typename Fn, typename ArgsTuple, typename W, typename... Rest>
142  inline auto parallel_invoke_with_args(Fn&& fn, ArgsTuple&& args_tuple, workhint<W> wh, Rest&&... rest) {
143  W wh_rest = get_total_work(rest...);
144  return do_parallel_invoke(std::forward<Fn>(fn), std::forward<ArgsTuple>(args_tuple),
145  ito::workhint(wh.value, wh_rest), std::forward<Rest>(rest)...);
146  }
147 
148  template <typename Fn, typename ArgsTuple>
149  inline auto do_parallel_invoke(Fn&& fn, ArgsTuple&& args_tuple) {
150  using retval_t = std::invoke_result_t<decltype(std::apply<Fn, ArgsTuple>), Fn, ArgsTuple>;
151 
152  if constexpr (std::is_void_v<retval_t>) {
153  std::apply(std::forward<Fn>(fn), std::forward<ArgsTuple>(args_tuple));
154  return std::make_tuple(std::monostate{});
155 
156  } else {
157  auto&& ret = std::apply(std::forward<Fn>(fn), std::forward<ArgsTuple>(args_tuple));
158  return std::make_tuple(std::forward<decltype(ret)>(ret));
159  }
160  }
161 
162  template <typename Fn, typename ArgsTuple, typename W, typename... Rest>
163  inline auto do_parallel_invoke(Fn&& fn, ArgsTuple&& args_tuple, ito::workhint<W> iwh, Rest&&... rest) {
164  using retval_t = std::invoke_result_t<decltype(std::apply<Fn, ArgsTuple>), Fn, ArgsTuple>;
165 
166  ori::poll();
167 
168  // for immediately executing cross-worker tasks in ADWS
169  // TODO: remove one of these two acquire calls?
170  ito::poll([&]() { all_serialized_ = false; return ori::release_lazy(); },
171  [&](ori::release_handler rh) { ori::acquire(rh); ori::acquire(rh_); });
172 
173  ito::thread<retval_t> th(
174  ito::with_callback, [rh = rh_] { ori::acquire(rh); }, [] { ori::release(); }, iwh,
175  [fn = std::forward<Fn>(fn),
176  args_tuple = std::forward<ArgsTuple>(args_tuple)]() mutable {
177  return std::apply(std::forward<decltype(fn)>(fn),
178  std::forward<decltype(args_tuple)>(args_tuple));
179  });
180  all_serialized_ &= th.serialized();
181 
182  auto&& ret_rest = parallel_invoke_aux(std::forward<Rest>(rest)...);
183 
184  if constexpr (std::is_void_v<retval_t>) {
185  if (!th.serialized()) {
186  ori::release();
187  }
188 
189  th.join();
190  return std::tuple_cat(std::make_tuple(std::monostate{}),
191  std::move(ret_rest));
192 
193  } else {
194  if (!th.serialized()) {
195  ori::release();
196  }
197 
198  auto&& ret = th.join();
199  return std::tuple_cat(std::make_tuple(std::forward<decltype(ret)>(ret)),
200  std::move(ret_rest));
201  }
202  }
203 
204  ReleaseHandler rh_;
205  bool all_serialized_ = true;
206 };
207 
208 }
209 
237 template <typename... Args>
238 inline auto parallel_invoke(Args&&... args) {
239  auto rh = ori::release_lazy();
240 
241  ito::task_group_data tgdata;
242  ito::task_group_begin(&tgdata);
243 
244  internal::parallel_invoke_state s(rh);
245  auto&& ret = s.parallel_invoke_aux(std::forward<Args>(args)...);
246 
247  // No lazy release here because the suspended thread (cross-worker tasks in ADWS) is
248  // always resumed by another process.
249  ito::task_group_end([] { ori::release(); }, [] { ori::acquire(); });
250 
251  // TODO: avoid duplicated acquire calls
252  if (!s.all_serialized()) {
253  ori::acquire();
254  }
255  return std::move(ret);
256 }
257 
258 ITYR_TEST_CASE("[ityr::pattern::parallel_invoke] parallel invoke") {
259  ito::init();
260  ori::init();
261 
262  ITYR_SUBCASE("with functions") {
263  ito::root_exec([=] {
264  auto [x, y] = parallel_invoke(
265  []() { return 1; },
266  []() { return 2; }
267  );
268  ITYR_CHECK(x == 1);
269  ITYR_CHECK(y == 2);
270  });
271 
272  ito::root_exec([=] {
273  auto [x, y, z] = parallel_invoke(
274  []() { return 1; },
275  []() { return 2; },
276  []() { return 4.8; }
277  );
278  ITYR_CHECK(x == 1);
279  ITYR_CHECK(y == 2);
280  ITYR_CHECK(z == 4.8);
281  });
282  }
283 
284  ITYR_SUBCASE("with args") {
285  ito::root_exec([=] {
286  auto [x, y] = parallel_invoke(
287  [](int i) { return i ; }, std::make_tuple(1),
288  [](int i) { return i * 2; }, std::make_tuple(2)
289  );
290  ITYR_CHECK(x == 1);
291  ITYR_CHECK(y == 4);
292  });
293 
294  ito::root_exec([=] {
295  auto [x, y] = parallel_invoke(
296  [](int i, int j ) { return i + j ; }, std::make_tuple(1, 2),
297  [](int i, int j, int k) { return i + j + k; }, std::make_tuple(3, 4, 5)
298  );
299  ITYR_CHECK(x == 3);
300  ITYR_CHECK(y == 12);
301  });
302  }
303 
304  ITYR_SUBCASE("with work hints") {
305  ito::root_exec([=] {
306  auto [x, y] = parallel_invoke(
307  []() { return 1; }, workhint(1),
308  []() { return 2; }, workhint(1)
309  );
310  ITYR_CHECK(x == 1);
311  ITYR_CHECK(y == 2);
312  });
313 
314  ito::root_exec([=] {
315  auto [x, y] = parallel_invoke(
316  [](int i) { return i ; }, std::make_tuple(1), workhint(1.0),
317  [](int i) { return i * 2; }, std::make_tuple(2), workhint(2.0)
318  );
319  ITYR_CHECK(x == 1);
320  ITYR_CHECK(y == 4);
321  });
322 
323  ito::root_exec([=] {
324  auto [x, y, z] = parallel_invoke(
325  []( ) { return 1; }, workhint(1),
326  [](int i, int j) { return i + j + 1; }, std::make_tuple(2, 4), workhint(2),
327  [](int i ) { return i * 2; }, std::make_tuple(2) , workhint(3)
328  );
329  ITYR_CHECK(x == 1);
330  ITYR_CHECK(y == 7);
331  ITYR_CHECK(z == 4);
332  });
333  }
334 
335  ITYR_SUBCASE("corner cases") {
336  ito::root_exec([=] {
337  ITYR_CHECK(parallel_invoke() == std::make_tuple());
338 
339  // The following is not allowed before C++20: Lambda expression in an unevaluated operand
340  // ITYR_CHECK(std::make_tuple(std::monostate{}) == parallel_invoke([]{}));
341 
342  {
343  auto ret = parallel_invoke([]{});
344  ITYR_CHECK(ret == std::make_tuple(std::monostate{}));
345  }
346 
347  {
348  auto ret = parallel_invoke([]{}, []{});
349  ITYR_CHECK(ret == std::make_tuple(std::monostate{}, std::monostate{}));
350  }
351 
352  {
353  auto ret = parallel_invoke([]{}, []{}, []{ return 1; });
354  ITYR_CHECK(ret == std::make_tuple(std::monostate{}, std::monostate{}, 1));
355  }
356 
357  {
358  auto ret = parallel_invoke([]{ return 1; }, []{}, []{});
359  ITYR_CHECK(ret == std::make_tuple(1, std::monostate{}, std::monostate{}));
360  }
361 
362  {
363  auto ret = parallel_invoke([](int){}, std::make_tuple(1), []{ return 1; }, []{});
364  ITYR_CHECK(ret == std::make_tuple(std::monostate{}, 1, std::monostate{}));
365  }
366  });
367  }
368 
369  ori::fini();
370  ito::fini();
371 }
372 
373 }
#define ITYR_SUBCASE(name)
Definition: util.hpp:41
#define ITYR_CHECK(cond)
Definition: util.hpp:48
va_list args
Definition: util.hpp:76
void fini()
Definition: ito.hpp:45
auto root_exec(Fn &&fn, Args &&... args)
Definition: ito.hpp:50
void task_group_begin(task_group_data *tgdata)
Definition: ito.hpp:105
void init(MPI_Comm comm=MPI_COMM_WORLD)
Definition: ito.hpp:41
void poll(PreSuspendCallback &&pre_suspend_cb, PostSuspendCallback &&post_suspend_cb)
Definition: ito.hpp:96
constexpr with_callback_t with_callback
Definition: thread.hpp:11
void task_group_end(PreSuspendCallback &&pre_suspend_cb, PostSuspendCallback &&post_suspend_cb)
Definition: ito.hpp:112
scheduler::task_group_data task_group_data
Definition: ito.hpp:103
void fini()
Definition: ori.hpp:49
void init(MPI_Comm comm=MPI_COMM_WORLD)
Definition: ori.hpp:45
auto release_lazy()
Definition: ori.hpp:200
core::instance::instance_type::release_handler release_handler
Definition: ori.hpp:204
void poll()
Definition: ori.hpp:224
void release()
Definition: ori.hpp:196
void acquire()
Definition: ori.hpp:206
Definition: allocator.hpp:16
auto parallel_invoke(Args &&... args)
Fork parallel tasks and join them.
Definition: parallel_invoke.hpp:238
ForwardIteratorD move(const ExecutionPolicy &policy, ForwardIterator1 first1, ForwardIterator1 last1, ForwardIteratorD first_d)
Move a range to another.
Definition: parallel_loop.hpp:934
Definition: parallel_invoke.hpp:14
W value
Definition: parallel_invoke.hpp:16
constexpr workhint(W w)
Definition: parallel_invoke.hpp:15