Itoyori  v0.0.1
workhint.hpp
Go to the documentation of this file.
1 #pragma once
2 
3 #include "ityr/common/util.hpp"
4 #include "ityr/ito/ito.hpp"
5 #include "ityr/ori/ori.hpp"
8 
9 namespace ityr {
10 
11 template <typename W>
13  static_assert(std::is_arithmetic_v<W>);
14 
15 public:
16  using value_type = W;
17 
19 
21 
22  explicit workhint_range(std::size_t n_leaves)
23  : n_leaves_(n_leaves),
24  bin_tree_(mem_alloc(size())) {
25  ITYR_CHECK(common::is_pow2(n_leaves));
26  }
27 
28  ~workhint_range() { destroy(); }
29 
30  workhint_range(const workhint_range&) = delete;
32 
34  : n_leaves_(r.n_leaves_), bin_tree_(r.bin_tree_) {
35  r.bin_tree_ = nullptr;
36  }
38  destroy();
39  n_leaves_ = r.n_leaves_;
40  bin_tree_ = r.bin_tree_;
41  r.bin_tree_ = nullptr;
42  return *this;
43  }
44 
45  std::size_t size() const {
46  return n_leaves_ * 2 - 1;
47  }
48 
50  return workhint_range_view<W>({bin_tree_, size()});
51  }
52 
53 private:
54  void destroy() {
55  if (bin_tree_) {
56  mem_free(bin_tree_);
57  }
58  }
59 
60  ori::global_ptr<bin_tree_node> mem_alloc(std::size_t size) {
61  if (ito::is_spmd()) {
62  return ori::malloc_coll<bin_tree_node>(size);
63  } else if (ito::is_root()) {
64  return ito::coll_exec([=] { return ori::malloc_coll<bin_tree_node>(size); });
65  } else {
66  common::die("workhint_range must be created on the root thread or SPMD region.");
67  }
68  }
69 
70  void mem_free(ori::global_ptr<bin_tree_node> p) {
71  if (ito::is_spmd()) {
72  ori::free_coll<bin_tree_node>(p);
73  } else if (ito::is_root()) {
74  ito::coll_exec([=] { ori::free_coll(p); });
75  } else {
76  common::die("workhint_range must be destroyed on the root thread or SPMD region.");
77  }
78  }
79 
80  std::size_t n_leaves_;
81  ori::global_ptr<bin_tree_node> bin_tree_;
82 };
83 
84 template <>
85 class workhint_range<void> {};
86 
87 namespace internal {
88 
89 template <typename W, typename Op, typename ReleaseHandler,
90  typename ForwardIterator, typename... ForwardIterators>
91 inline W create_workhint_range_aux(workhint_range_view<W> target_wh,
92  std::size_t checkout_count,
93  Op op,
94  ReleaseHandler rh,
95  ForwardIterator first,
96  ForwardIterator last,
97  ForwardIterators... firsts) {
98  ori::poll();
99 
100  // for immediately executing cross-worker tasks in ADWS
101  ito::poll([] { return ori::release_lazy(); },
102  [&](ori::release_handler rh_) { ori::acquire(rh); ori::acquire(rh_); });
103 
104  if (target_wh.empty()) {
105  W w {};
106  for_each_aux(
107  execution::sequenced_policy(checkout_count),
108  [&](auto&&... refs) {
109  w += op(std::forward<decltype(refs)>(refs)...);
110  },
111  first, last, firsts...);
112  return w;
113  }
114 
115  std::size_t d = std::distance(first, last);
116  auto mid = std::next(first, d / 2);
117 
118  ito::task_group_data tgdata;
119  ito::task_group_begin(&tgdata);
120 
121  workhint_range_view<W> c1, c2;
122  if (target_wh.has_children()) {
123  auto children = target_wh.get_children();
124  c1 = children.first;
125  c2 = children.second;
126  }
127 
128  ito::thread<W> th(
129  ito::with_callback, [=] { ori::acquire(rh); }, [] { ori::release(); },
130  ito::workhint(1, 1),
131  [=] {
132  return create_workhint_range_aux(c1, checkout_count, op, rh,
133  first, mid, firsts...);
134  });
135 
136  W w2 = create_workhint_range_aux(c2, checkout_count, op, rh,
137  mid, last, std::next(firsts, d / 2)...);
138 
139  if (!th.serialized()) {
140  ori::release();
141  }
142 
143  W w1 = th.join();
144 
145  ito::task_group_end([] { ori::release(); }, [] { ori::acquire(); });
146 
147  if (!th.serialized()) {
148  ori::acquire();
149  }
150 
151  target_wh.set_workhint(w1, w2);
152 
153  return w1 + w2;
154 }
155 
156 }
157 
158 template <typename ExecutionPolicy, typename ForwardIterator, typename Op>
159 inline auto create_workhint_range(const ExecutionPolicy& policy,
160  ForwardIterator first,
161  ForwardIterator last,
162  Op op,
163  std::size_t n_leaves) {
165  "The number of leaves for workhint_range must be a power of two.");
166 
167  if constexpr (ori::is_global_ptr_v<ForwardIterator>) {
168  return create_workhint_range(
169  policy,
170  internal::convert_to_global_iterator(first, checkout_mode::read),
171  internal::convert_to_global_iterator(last , checkout_mode::read),
172  op,
173  n_leaves);
174 
175  } else {
176  using value_type = typename std::iterator_traits<ForwardIterator>::value_type;
177  using workhint_t = std::invoke_result_t<Op, value_type>;
179 
180  if (ito::is_spmd()) {
181  root_exec([=, wh = workhint.view()] {
182  auto rh = ori::release_lazy();
183  internal::create_workhint_range_aux(wh, policy.checkout_count, op, rh, first, last);
184  });
185  } else if (ito::is_root()) {
186  auto rh = ori::release_lazy();
187  internal::create_workhint_range_aux(workhint.view(), policy.checkout_count, op, rh, first, last);
188  } else {
189  common::die("workhint_range must be created on the root thread or SPMD region.");
190  }
191 
192  return workhint;
193  }
194 }
195 
196 template <typename ExecutionPolicy, typename ForwardIterator, typename Op>
197 inline auto create_workhint_range(const ExecutionPolicy& policy,
198  ForwardIterator first,
199  ForwardIterator last,
200  Op op) {
201  return create_workhint_range(policy, first, last, op,
202  common::next_pow2(std::distance(first, last)));
203 }
204 
205 ITYR_TEST_CASE("[ityr::workhint] workhint range test") {
206  ito::init();
207  ori::init();
208 
209  long n = 100000;
210  ori::global_ptr<long> p = ori::malloc_coll<long>(n);
211 
212  root_exec([=] {
213  for_each(
214  execution::parallel_policy(100),
215  count_iterator<long>(0),
216  count_iterator<long>(n),
218  [](long i, long& x) { x = i; });
219 
220  auto workhint = create_workhint_range(
221  execution::parallel_policy(100),
222  p, p + n,
223  [](long x) { return x; });
224 
225  auto check_workhint = [&](auto& wh) {
226  auto [w1, w2] = wh.view().get_workhint();
227  ITYR_CHECK(w1 == n / 2 * (n / 2 - 1) / 2);
228  ITYR_CHECK(w2 == n / 2 * (n / 2 * 3 - 1) / 2);
229 
230  auto [c1, c2] = wh.view().get_children();
231 
232  auto [w11, w12] = c1.get_workhint();
233  ITYR_CHECK(w11 == n / 4 * (n / 4 - 1) / 2);
234  ITYR_CHECK(w12 == n / 4 * (n / 4 * 3 - 1) / 2);
235 
236  auto [w21, w22] = c2.get_workhint();
237  ITYR_CHECK(w21 == n / 4 * (n / 4 * 5 - 1) / 2);
238  ITYR_CHECK(w22 == n / 4 * (n / 4 * 7 - 1) / 2);
239  };
240 
241  check_workhint(workhint);
242 
243  auto workhint2 = create_workhint_range(
244  execution::parallel_policy(100),
245  p, p + n,
246  [](long x) { return x; },
247  common::next_pow2(n / 100));
248 
249  check_workhint(workhint2);
250 
251  transform(
252  execution::parallel_policy(100, workhint),
256  [](long x) { return x * 2; });
257 
258  for_each(
259  execution::parallel_policy(100, workhint2),
260  count_iterator<long>(0),
261  count_iterator<long>(n),
263  [](long i, long x) { ITYR_CHECK(i * 2 == x); });
264  });
265 
266  ori::free_coll(p);
267 
268  ori::fini();
269  ito::fini();
270 }
271 
272 }
Definition: workhint_view.hpp:12
bool has_children() const
Definition: workhint_view.hpp:44
bool empty() const
Definition: workhint_view.hpp:42
void set_workhint(const value_type &v1, const value_type &v2)
Definition: workhint_view.hpp:30
std::pair< workhint_range_view, workhint_range_view > get_children() const
Definition: workhint_view.hpp:34
Definition: workhint.hpp:12
workhint_range(workhint_range &&r)
Definition: workhint.hpp:33
workhint_range & operator=(const workhint_range &)=delete
~workhint_range()
Definition: workhint.hpp:28
workhint_range & operator=(workhint_range &&r)
Definition: workhint.hpp:37
std::size_t size() const
Definition: workhint.hpp:45
workhint_range(std::size_t n_leaves)
Definition: workhint.hpp:22
typename workhint_range_view< W >::bin_tree_node bin_tree_node
Definition: workhint.hpp:18
workhint_range()
Definition: workhint.hpp:20
workhint_range_view< W > view() const
Definition: workhint.hpp:49
workhint_range(const workhint_range &)=delete
W value_type
Definition: workhint.hpp:16
#define ITYR_CHECK(cond)
Definition: util.hpp:48
#define ITYR_REQUIRE_MESSAGE(cond, msg,...)
Definition: util.hpp:43
constexpr read_t read
Read-only checkout mode.
Definition: checkout_span.hpp:19
constexpr write_t write
Write-only checkout mode.
Definition: checkout_span.hpp:29
bool is_pow2(T x)
Definition: util.hpp:125
uint64_t next_pow2(uint64_t x)
Definition: util.hpp:102
void fini()
Definition: ito.hpp:45
void task_group_begin(task_group_data *tgdata)
Definition: ito.hpp:105
auto coll_exec(const Fn &fn, const Args &... args)
Definition: ito.hpp:72
void init(MPI_Comm comm=MPI_COMM_WORLD)
Definition: ito.hpp:41
bool is_root()
Definition: ito.hpp:66
void poll(PreSuspendCallback &&pre_suspend_cb, PostSuspendCallback &&post_suspend_cb)
Definition: ito.hpp:96
bool is_spmd()
Definition: ito.hpp:61
constexpr with_callback_t with_callback
Definition: thread.hpp:11
void task_group_end(PreSuspendCallback &&pre_suspend_cb, PostSuspendCallback &&post_suspend_cb)
Definition: ito.hpp:112
scheduler::task_group_data task_group_data
Definition: ito.hpp:103
void fini()
Definition: ori.hpp:49
void init(MPI_Comm comm=MPI_COMM_WORLD)
Definition: ori.hpp:45
auto release_lazy()
Definition: ori.hpp:200
void free_coll(global_ptr< T > ptr)
Definition: ori.hpp:70
core::instance::instance_type::release_handler release_handler
Definition: ori.hpp:204
void poll()
Definition: ori.hpp:224
void release()
Definition: ori.hpp:196
void acquire()
Definition: ori.hpp:206
Definition: allocator.hpp:16
auto create_workhint_range(const ExecutionPolicy &policy, ForwardIterator first, ForwardIterator last, Op op, std::size_t n_leaves)
Definition: workhint.hpp:159
ForwardIteratorD transform(const ExecutionPolicy &policy, ForwardIterator1 first1, ForwardIterator1 last1, ForwardIteratorD first_d, UnaryOp unary_op)
Transform elements in a given range and store them in another range.
Definition: parallel_loop.hpp:583
void for_each(const ExecutionPolicy &policy, ForwardIterator first, ForwardIterator last, Op op)
Apply an operator to each element in a range.
Definition: parallel_loop.hpp:136
auto root_exec(Fn &&fn, Args &&... args)
Spawn the root thread (collective).
Definition: root_exec.hpp:47
global_iterator< T, Mode > make_global_iterator(ori::global_ptr< T > gptr, Mode)
Make a global iterator to enable/disable automatic checkout.
Definition: global_iterator.hpp:158
Definition: workhint_view.hpp:16
Definition: parallel_invoke.hpp:14