13 static_assert(std::is_arithmetic_v<W>);
23 : n_leaves_(n_leaves),
24 bin_tree_(mem_alloc(
size())) {
34 : n_leaves_(r.n_leaves_), bin_tree_(r.bin_tree_) {
35 r.bin_tree_ =
nullptr;
39 n_leaves_ = r.n_leaves_;
40 bin_tree_ = r.bin_tree_;
41 r.bin_tree_ =
nullptr;
46 return n_leaves_ * 2 - 1;
60 ori::global_ptr<bin_tree_node> mem_alloc(std::size_t
size) {
62 return ori::malloc_coll<bin_tree_node>(
size);
66 common::die(
"workhint_range must be created on the root thread or SPMD region.");
70 void mem_free(ori::global_ptr<bin_tree_node> p) {
72 ori::free_coll<bin_tree_node>(p);
76 common::die(
"workhint_range must be destroyed on the root thread or SPMD region.");
80 std::size_t n_leaves_;
81 ori::global_ptr<bin_tree_node> bin_tree_;
89 template <
typename W,
typename Op,
typename ReleaseHandler,
90 typename ForwardIterator,
typename... ForwardIterators>
92 std::size_t checkout_count,
95 ForwardIterator first,
97 ForwardIterators... firsts) {
104 if (target_wh.
empty()) {
107 execution::sequenced_policy(checkout_count),
108 [&](
auto&&... refs) {
109 w += op(std::forward<decltype(refs)>(refs)...);
111 first, last, firsts...);
115 std::size_t d = std::distance(first, last);
116 auto mid = std::next(first, d / 2);
121 workhint_range_view<W> c1, c2;
125 c2 = children.second;
132 return create_workhint_range_aux(c1, checkout_count, op, rh,
133 first, mid, firsts...);
136 W w2 = create_workhint_range_aux(c2, checkout_count, op, rh,
137 mid, last, std::next(firsts, d / 2)...);
139 if (!th.serialized()) {
147 if (!th.serialized()) {
158 template <
typename ExecutionPolicy,
typename ForwardIterator,
typename Op>
160 ForwardIterator first,
161 ForwardIterator last,
163 std::size_t n_leaves) {
165 "The number of leaves for workhint_range must be a power of two.");
167 if constexpr (ori::is_global_ptr_v<ForwardIterator>) {
176 using value_type =
typename std::iterator_traits<ForwardIterator>::value_type;
177 using workhint_t = std::invoke_result_t<Op, value_type>;
182 auto rh = ori::release_lazy();
183 internal::create_workhint_range_aux(wh, policy.checkout_count, op, rh, first, last);
187 internal::create_workhint_range_aux(
workhint.view(), policy.checkout_count, op, rh, first, last);
189 common::die(
"workhint_range must be created on the root thread or SPMD region.");
196 template <
typename ExecutionPolicy,
typename ForwardIterator,
typename Op>
198 ForwardIterator first,
199 ForwardIterator last,
205 ITYR_TEST_CASE(
"[ityr::workhint] workhint range test") {
210 ori::global_ptr<long> p = ori::malloc_coll<long>(n);
214 execution::parallel_policy(100),
215 count_iterator<long>(0),
216 count_iterator<long>(n),
218 [](
long i,
long& x) { x = i; });
221 execution::parallel_policy(100),
223 [](
long x) {
return x; });
225 auto check_workhint = [&](
auto& wh) {
226 auto [w1, w2] = wh.view().get_workhint();
228 ITYR_CHECK(w2 == n / 2 * (n / 2 * 3 - 1) / 2);
230 auto [c1, c2] = wh.view().get_children();
232 auto [w11, w12] = c1.get_workhint();
234 ITYR_CHECK(w12 == n / 4 * (n / 4 * 3 - 1) / 2);
236 auto [w21, w22] = c2.get_workhint();
237 ITYR_CHECK(w21 == n / 4 * (n / 4 * 5 - 1) / 2);
238 ITYR_CHECK(w22 == n / 4 * (n / 4 * 7 - 1) / 2);
241 check_workhint(workhint);
244 execution::parallel_policy(100),
246 [](
long x) {
return x; },
249 check_workhint(workhint2);
252 execution::parallel_policy(100, workhint),
256 [](
long x) {
return x * 2; });
259 execution::parallel_policy(100, workhint2),
260 count_iterator<long>(0),
261 count_iterator<long>(n),
263 [](
long i,
long x) {
ITYR_CHECK(i * 2 == x); });
Definition: workhint_view.hpp:12
bool has_children() const
Definition: workhint_view.hpp:44
bool empty() const
Definition: workhint_view.hpp:42
void set_workhint(const value_type &v1, const value_type &v2)
Definition: workhint_view.hpp:30
std::pair< workhint_range_view, workhint_range_view > get_children() const
Definition: workhint_view.hpp:34
Definition: workhint.hpp:12
workhint_range(workhint_range &&r)
Definition: workhint.hpp:33
workhint_range & operator=(const workhint_range &)=delete
~workhint_range()
Definition: workhint.hpp:28
workhint_range & operator=(workhint_range &&r)
Definition: workhint.hpp:37
std::size_t size() const
Definition: workhint.hpp:45
workhint_range(std::size_t n_leaves)
Definition: workhint.hpp:22
typename workhint_range_view< W >::bin_tree_node bin_tree_node
Definition: workhint.hpp:18
workhint_range()
Definition: workhint.hpp:20
workhint_range_view< W > view() const
Definition: workhint.hpp:49
workhint_range(const workhint_range &)=delete
W value_type
Definition: workhint.hpp:16
#define ITYR_CHECK(cond)
Definition: util.hpp:48
#define ITYR_REQUIRE_MESSAGE(cond, msg,...)
Definition: util.hpp:43
constexpr read_t read
Read-only checkout mode.
Definition: checkout_span.hpp:19
constexpr write_t write
Write-only checkout mode.
Definition: checkout_span.hpp:29
bool is_pow2(T x)
Definition: util.hpp:125
uint64_t next_pow2(uint64_t x)
Definition: util.hpp:102
void fini()
Definition: ito.hpp:45
void task_group_begin(task_group_data *tgdata)
Definition: ito.hpp:105
auto coll_exec(const Fn &fn, const Args &... args)
Definition: ito.hpp:72
void init(MPI_Comm comm=MPI_COMM_WORLD)
Definition: ito.hpp:41
bool is_root()
Definition: ito.hpp:66
void poll(PreSuspendCallback &&pre_suspend_cb, PostSuspendCallback &&post_suspend_cb)
Definition: ito.hpp:96
bool is_spmd()
Definition: ito.hpp:61
constexpr with_callback_t with_callback
Definition: thread.hpp:11
void task_group_end(PreSuspendCallback &&pre_suspend_cb, PostSuspendCallback &&post_suspend_cb)
Definition: ito.hpp:112
scheduler::task_group_data task_group_data
Definition: ito.hpp:103
void fini()
Definition: ori.hpp:49
void init(MPI_Comm comm=MPI_COMM_WORLD)
Definition: ori.hpp:45
auto release_lazy()
Definition: ori.hpp:200
void free_coll(global_ptr< T > ptr)
Definition: ori.hpp:70
core::instance::instance_type::release_handler release_handler
Definition: ori.hpp:204
void poll()
Definition: ori.hpp:224
void release()
Definition: ori.hpp:196
void acquire()
Definition: ori.hpp:206
Definition: allocator.hpp:16
auto create_workhint_range(const ExecutionPolicy &policy, ForwardIterator first, ForwardIterator last, Op op, std::size_t n_leaves)
Definition: workhint.hpp:159
ForwardIteratorD transform(const ExecutionPolicy &policy, ForwardIterator1 first1, ForwardIterator1 last1, ForwardIteratorD first_d, UnaryOp unary_op)
Transform elements in a given range and store them in another range.
Definition: parallel_loop.hpp:583
void for_each(const ExecutionPolicy &policy, ForwardIterator first, ForwardIterator last, Op op)
Apply an operator to each element in a range.
Definition: parallel_loop.hpp:136
auto root_exec(Fn &&fn, Args &&... args)
Spawn the root thread (collective).
Definition: root_exec.hpp:47
global_iterator< T, Mode > make_global_iterator(ori::global_ptr< T > gptr, Mode)
Make a global iterator to enable/disable automatic checkout.
Definition: global_iterator.hpp:158
Definition: workhint_view.hpp:16
Definition: parallel_invoke.hpp:14