Itoyori  v0.0.1
parallel_shuffle.hpp
Go to the documentation of this file.
1 #pragma once
2 
3 #include "ityr/common/util.hpp"
9 
10 namespace ityr {
11 
12 namespace internal {
13 
14 template <typename W, typename RandomAccessIterator,
15  typename SplittableUniformRandomBitGenerator>
16 inline RandomAccessIterator
17 random_partition(const execution::parallel_policy<W>& policy,
18  RandomAccessIterator first,
19  RandomAccessIterator last,
20  SplittableUniformRandomBitGenerator&& urbg) {
21  std::size_t d = std::distance(first, last);
22 
23  if (d <= policy.cutoff_count) {
24  // TODO: consider policy.checkout_count
25  ITYR_CHECK(policy.cutoff_count == policy.checkout_count);
26 
27  auto&& [css, its] = checkout_global_iterators(d, first);
28  auto&& first_ = std::get<0>(its);
29  auto m = std::stable_partition(first_, std::next(first_, d),
30  [&](auto&&) { return urbg() % 2 == 0; });
31  return std::next(first, std::distance(first_, m));
32  }
33 
34  auto mid = std::next(first, d / 2);
35 
36  auto child_urbg1 = urbg.split();
37  auto child_urbg2 = urbg.split();
38 
39  auto [m1, m2] = parallel_invoke(
40  [=]() mutable { return random_partition(policy, first, mid , child_urbg1); },
41  [=]() mutable { return random_partition(policy, mid , last, child_urbg2); });
42 
43  // TODO: use swap_ranges; stability is not needed
44  return rotate(policy, m1, mid, m2);
45 }
46 
47 template <typename W, typename RandomAccessIterator,
48  typename SplittableUniformRandomBitGenerator>
49 inline void shuffle(const execution::parallel_policy<W>& policy,
50  RandomAccessIterator first,
51  RandomAccessIterator last,
52  SplittableUniformRandomBitGenerator&& urbg) {
53  std::size_t d = std::distance(first, last);
54 
55  if (d <= 1) return;
56 
57  if (d <= policy.cutoff_count) {
58  auto [css, its] = checkout_global_iterators(d, first);
59  auto first_ = std::get<0>(its);
60  std::shuffle(first_, std::next(first_, d), urbg);
61 
62  } else {
63  auto mid = random_partition(policy, first, last, urbg);
64 
65  auto child_urbg1 = urbg.split();
66  auto child_urbg2 = urbg.split();
67 
69  [=]() mutable { shuffle(policy, first, mid , child_urbg1); },
70  [=]() mutable { shuffle(policy, mid , last, child_urbg2); });
71  }
72 }
73 
74 }
75 
111 template <typename ExecutionPolicy, typename RandomAccessIterator,
112  typename UniformRandomBitGenerator>
113 inline void shuffle(const ExecutionPolicy& policy,
114  RandomAccessIterator first,
115  RandomAccessIterator last,
116  UniformRandomBitGenerator&& urbg) {
117  if constexpr (ori::is_global_ptr_v<RandomAccessIterator>) {
118  shuffle(
119  policy,
120  internal::convert_to_global_iterator(first, checkout_mode::read_write),
121  internal::convert_to_global_iterator(last , checkout_mode::read_write),
122  std::forward<UniformRandomBitGenerator>(urbg));
123 
124  } else {
125  internal::shuffle(policy, first, last, std::forward<UniformRandomBitGenerator>(urbg));
126  }
127 }
128 
129 ITYR_TEST_CASE("[ityr::pattern::parallel_shuffle] shuffle") {
130  ito::init();
131  ori::init();
132 
133  long n = 100000;
134  ori::global_ptr<long> p1 = ori::malloc_coll<long>(n);
135  ori::global_ptr<long> p2 = ori::malloc_coll<long>(n);
136 
137  ito::root_exec([=] {
138  transform(
139  execution::parallel_policy(100),
140  count_iterator<long>(0), count_iterator<long>(n), p1,
141  [=](long i) { return i; });
142 
143  transform(
144  execution::parallel_policy(100),
145  count_iterator<long>(0), count_iterator<long>(n), p2,
146  [=](long i) { return i; });
147 
148  ITYR_CHECK(equal(execution::parallel_policy(100),
149  p1, p1 + n, p2) == true);
150  });
151 
152  ITYR_SUBCASE("should not lose values") {
153  ito::root_exec([=] {
154  shuffle(execution::parallel_policy(100), p1, p1 + n,
156 
157  ITYR_CHECK(equal(execution::parallel_policy(100),
158  p1, p1 + n, p2) == false);
159 
160  ITYR_CHECK(reduce(execution::parallel_policy(100),
161  p1, p1 + n) == n * (n - 1) / 2);
162 
163  sort(execution::parallel_policy(100), p1, p1 + n);
164 
165  ITYR_CHECK(equal(execution::parallel_policy(100),
166  p1, p1 + n, p2) == true);
167  });
168  }
169 
170  ITYR_SUBCASE("same RNG, same result") {
171  ito::root_exec([=] {
172  uint64_t seed = 42;
173  default_random_engine rng(seed);
174 
175  auto rng_copy = rng;
176 
177  shuffle(execution::parallel_policy(100), p1, p1 + n, rng);
178  shuffle(execution::parallel_policy(100), p2, p2 + n, rng_copy);
179 
180  ITYR_CHECK(equal(execution::parallel_policy(100),
181  p1, p1 + n, p2) == true);
182  });
183  }
184 
185  ITYR_SUBCASE("differente RNG, different result") {
186  ito::root_exec([=] {
187  uint64_t seed1 = 42;
188  default_random_engine rng1(seed1);
189 
190  uint64_t seed2 = 417;
191  default_random_engine rng2(seed2);
192 
193  shuffle(execution::parallel_policy(100), p1, p1 + n, rng1);
194  shuffle(execution::parallel_policy(100), p2, p2 + n, rng2);
195 
196  ITYR_CHECK(equal(execution::parallel_policy(100),
197  p1, p1 + n, p2) == false);
198  });
199  }
200 
201  ori::free_coll(p1);
202  ori::free_coll(p2);
203 
204  ori::fini();
205  ito::fini();
206 }
207 
208 }
#define ITYR_SUBCASE(name)
Definition: util.hpp:41
#define ITYR_CHECK(cond)
Definition: util.hpp:48
constexpr read_write_t read_write
Read+Write checkout mode.
Definition: checkout_span.hpp:39
void fini()
Definition: ito.hpp:45
auto root_exec(Fn &&fn, Args &&... args)
Definition: ito.hpp:50
void init(MPI_Comm comm=MPI_COMM_WORLD)
Definition: ito.hpp:41
void fini()
Definition: ori.hpp:49
void init(MPI_Comm comm=MPI_COMM_WORLD)
Definition: ori.hpp:45
void free_coll(global_ptr< T > ptr)
Definition: ori.hpp:70
Definition: allocator.hpp:16
void shuffle(const ExecutionPolicy &policy, RandomAccessIterator first, RandomAccessIterator last, UniformRandomBitGenerator &&urbg)
Randomly shuffle elements in a range.
Definition: parallel_shuffle.hpp:113
BidirectionalIterator rotate(const ExecutionPolicy &policy, BidirectionalIterator first, BidirectionalIterator middle, BidirectionalIterator last)
Rotate a range.
Definition: parallel_loop.hpp:1135
internal::random_engine_dummy default_random_engine
Definition: random.hpp:41
auto parallel_invoke(Args &&... args)
Fork parallel tasks and join them.
Definition: parallel_invoke.hpp:238
ForwardIteratorD transform(const ExecutionPolicy &policy, ForwardIterator1 first1, ForwardIterator1 last1, ForwardIteratorD first_d, UnaryOp unary_op)
Transform elements in a given range and store them in another range.
Definition: parallel_loop.hpp:583
BidirectionalIterator stable_partition(const ExecutionPolicy &policy, BidirectionalIterator first, BidirectionalIterator last, Predicate pred)
Partition elements into two disjoint parts in place.
Definition: parallel_filter.hpp:76
void sort(const ExecutionPolicy &policy, RandomAccessIterator first, RandomAccessIterator last, Compare comp)
Sort a range.
Definition: parallel_sort.hpp:210
Reducer::accumulator_type reduce(const ExecutionPolicy &policy, ForwardIterator first, ForwardIterator last, Reducer reducer)
Calculate reduction.
Definition: parallel_reduce.hpp:340
bool equal(const ExecutionPolicy &policy, ForwardIterator1 first1, ForwardIterator1 last1, ForwardIterator2 first2, BinaryPredicate pred)
Check if two ranges have equal values.
Definition: parallel_reduce.hpp:887