Itoyori  v0.0.1
parallel_merge.hpp
Go to the documentation of this file.
1 #pragma once
2 
3 #include "ityr/common/util.hpp"
7 
8 namespace ityr {
9 
10 namespace internal {
11 
12 template <typename RandomAccessIterator1, typename RandomAccessIterator2, typename Compare>
13 inline std::pair<RandomAccessIterator1, RandomAccessIterator2>
14 find_split_points_for_merge(RandomAccessIterator1 first1,
15  RandomAccessIterator1 last1,
16  RandomAccessIterator2 first2,
17  RandomAccessIterator2 last2,
18  Compare comp) {
19  auto n1 = std::distance(first1, last1);
20  auto n2 = std::distance(first2, last2);
21 
22  ITYR_CHECK(n1 > 0);
23 
24  if (n1 > n2) {
25  // so that the size of [first1, last1) is always smaller than that of [first2, last2)
26  auto [p2, p1] = find_split_points_for_merge(first2, last2, first1, last1, comp);
27  return std::make_pair(p1, p2);
28  }
29 
30  auto m = (n1 + n2) / 2;
31 
32  if (n1 == 1) {
33  RandomAccessIterator2 it2 = std::next(first2, m);
34  ITYR_CHECK(first2 <= std::prev(it2));
35  auto&& [css, its] = checkout_global_iterators(1, first1, it2);
36  auto [it1, it2r] = its;
37 
38  if (comp(*it1, *it2r)) {
39  return std::make_pair(last1, it2);
40  } else {
41  return std::make_pair(first1, it2);
42  }
43  }
44 
45  // Binary search over the larger region
46  RandomAccessIterator2 low = first2;
47  RandomAccessIterator2 high = last2;
48 
49  while (true) {
50  ITYR_CHECK(low <= high);
51 
52  RandomAccessIterator2 it2 = std::next(low, std::distance(low, high) / 2);
53 
54  auto c2 = std::distance(first2, it2);
55  if (m <= c2) {
56  // it2 is close to the right end (last2)
57  auto&& [css, its] = checkout_global_iterators(1, first1, std::prev(it2));
58  auto [it1r, it2l] = its;
59 
60  if (comp(*it1r, *it2l)) {
61  ITYR_CHECK(high != it2);
62  high = it2;
63 
64  } else {
65  return std::make_pair(first1, it2);
66  }
67 
68  } else if (m - c2 >= n1) {
69  // it2 is close to the left end (first2)
70  auto&& [css, its] = checkout_global_iterators(1, std::prev(last1), it2);
71  auto [it1l, it2r] = its;
72 
73  if (comp(*it2r, *it1l)) {
74  ITYR_CHECK(low != std::next(it2));
75  low = std::next(it2);
76 
77  } else {
78  return std::make_pair(last1, it2);
79  }
80 
81  } else {
82  // Both regions are split in the middle
83  RandomAccessIterator1 it1 = std::next(first1, m - c2);
84 
85  ITYR_CHECK(it1 != first1);
86  ITYR_CHECK(it1 != last1);
87  ITYR_CHECK(it2 != first2);
88  ITYR_CHECK(it2 != last2);
89 
90  auto&& [css, its] = checkout_global_iterators(2, std::prev(it1), std::prev(it2));
91  auto [it1_, it2_] = its;
92 
93  auto it1l = it1_;
94  auto it1r = std::next(it1_);
95  auto it2l = it2_;
96  auto it2r = std::next(it2_);
97 
98  if (comp(*it2r, *it1l)) {
99  ITYR_CHECK(low != std::next(it2));
100  low = std::next(it2);
101 
102  } else if (comp(*it1r, *it2l)) {
103  ITYR_CHECK(high != it2);
104  high = it2;
105 
106  } else {
107  return std::make_pair(it1, it2);
108  }
109  }
110  }
111 }
112 
113 ITYR_TEST_CASE("[ityr::pattern::parallel_merge] find_split_points_for_merge") {
114  ito::init();
115  ori::init();
116 
117  auto check_fn = [](std::vector<int> v1, std::vector<int> v2) {
118  auto [it1, it2] = find_split_points_for_merge(v1.begin(), v1.end(), v2.begin(), v2.end(), std::less<>{});
119  if (it1 != v1.begin() && it2 != v2.end()) {
120  ITYR_CHECK(*std::prev(it1) <= *it2);
121  }
122  if (it2 != v2.begin() && it1 != v1.end()) {
123  ITYR_CHECK(*std::prev(it2) <= *it1);
124  }
125  ITYR_CHECK(!(it1 == v1.begin() && it2 == v2.begin()));
126  ITYR_CHECK(!(it1 == v1.end() && it2 == v2.end()));
127  };
128 
129  check_fn({0}, {1, 2, 3, 4, 5});
130  check_fn({2}, {1, 2, 3, 4, 5});
131  check_fn({3}, {1, 2, 3, 4, 5});
132  check_fn({6}, {1, 2, 3, 4, 5});
133  check_fn({1, 4}, {1, 2, 3, 4, 5});
134  check_fn({2, 3}, {1, 2, 3, 4, 5});
135  check_fn({0, 6}, {1, 2, 3, 4, 5});
136  check_fn({0, 1}, {2, 2, 2, 4, 5});
137  check_fn({4, 5}, {2, 2, 2, 2, 3});
138  check_fn({3, 3}, {3, 3, 3, 3, 3, 3, 3});
139  check_fn({3, 4}, {2, 2, 3, 3, 3, 3, 4});
140  check_fn({1, 2, 3, 4, 5}, {1, 2, 3, 4, 5});
141  check_fn({1, 2, 3, 4, 5}, {4, 5, 6, 7, 8});
142  check_fn({1, 2, 3, 4, 5}, {6, 7, 8, 9, 10});
143  check_fn({6, 7, 8, 9, 10}, {1, 2, 3, 4, 5});
144  check_fn({0, 0, 0, 0, 0}, {0, 0, 0, 0, 0});
145 
146  ito::fini();
147  ori::fini();
148 }
149 
150 template <bool Stable, typename W, typename RandomAccessIterator, typename Compare>
151 inline void inplace_merge_aux(const execution::parallel_policy<W>& policy,
152  RandomAccessIterator first,
153  RandomAccessIterator middle,
154  RandomAccessIterator last,
155  Compare comp) {
156  // TODO: implement a version with BidirectionalIterator
157  std::size_t d = std::distance(first, last);
158 
159  if (d <= 1 || first == middle || middle == last) return;
160 
161  if (d <= policy.cutoff_count) {
162  // TODO: consider policy.checkout_count
163  ITYR_CHECK(policy.cutoff_count == policy.checkout_count);
164 
165  auto&& [css, its] = checkout_global_iterators(d, first);
166  auto&& first_ = std::get<0>(its);
167  std::inplace_merge(first_,
168  std::next(first_, std::distance(first, middle)),
169  std::next(first_, d),
170  comp);
171  return;
172  }
173 
174  auto comp_mids = [&]{
175  auto&& [css, its] = checkout_global_iterators(2, std::prev(middle));
176  auto mids = std::get<0>(its);
177  return !comp(*std::next(mids), *mids);
178  };
179  if (comp_mids()) {
180  // middle
181  // ... a || b ... where !(a > b) <=> a <= b
182  return;
183  }
184 
185  auto comp_ends = [&]{
186  auto&& [css, its] = checkout_global_iterators(1, std::prev(last), first);
187  auto [l, f] = its;
188  return comp(*l, *f);
189  };
190  if (comp_ends()) {
191  // middle
192  // a ... || ... b where b < a
193  // (If b == a, we shall not rotate them for stability)
194  rotate(policy, first, middle, last);
195  return;
196  }
197 
198  auto [s1, s2] = find_split_points_for_merge(first, middle, middle, last, comp);
199 
200  if constexpr (Stable) {
201  if (s1 != middle && s2 != middle) {
202  // When equal values are swapped (rotated) across the middle point,
203  // the stability will be lost.
204  // In particular, we want to avoid the following situation:
205  // s1 middle s2
206  // ... a | x ... || ... x | b ...
207  auto&& [css, its] = checkout_global_iterators(1, s1, std::prev(s2));
208  auto [it1r, it2l] = its;
209  if (!comp(*it1r, *it2l) && !comp(*it2l, *it1r)) { // equal
210  // TODO: more efficient impl for cases where the number of equal values is small
211  using value_type = typename std::iterator_traits<RandomAccessIterator>::value_type;
212  if (s1 == first) {
213  // s1 -------> s1 middle
214  // ... x x | x x x x x | a ... || ...
215  s1 = std::partition_point(s1, middle,
216  [&, it = it1r](const value_type& r) { return !comp(*it, r); });
217  } else {
218  // Move s2 so that equal elements are never swapped
219  // middle s2 <------- s2
220  // ... || ... b | x x x x x | x x ...
221  s2 = std::partition_point(middle, s2,
222  [&, it = it2l](const value_type& r) { return comp(r, *it); });
223  }
224  }
225  }
226  }
227 
228  auto m = rotate(policy, s1, middle, s2);
229 
230  ITYR_CHECK(first < m);
231  ITYR_CHECK(m < last);
232 
234  [=, s1 = s1] { inplace_merge_aux<Stable>(policy, first, s1, m, comp); },
235  [=, s2 = s2] { inplace_merge_aux<Stable>(policy, m, s2, last, comp); });
236 }
237 
238 }
239 
272 template <typename ExecutionPolicy, typename RandomAccessIterator, typename Compare>
273 inline void inplace_merge(const ExecutionPolicy& policy,
274  RandomAccessIterator first,
275  RandomAccessIterator middle,
276  RandomAccessIterator last,
277  Compare comp) {
278  if constexpr (ori::is_global_ptr_v<RandomAccessIterator>) {
280  policy,
281  internal::convert_to_global_iterator(first , checkout_mode::read_write),
282  internal::convert_to_global_iterator(middle, checkout_mode::read_write),
283  internal::convert_to_global_iterator(last , checkout_mode::read_write),
284  comp);
285 
286  } else {
287  internal::inplace_merge_aux<true>(policy, first, middle, last, comp);
288  }
289 }
290 
315 template <typename ExecutionPolicy, typename RandomAccessIterator>
316 inline void inplace_merge(const ExecutionPolicy& policy,
317  RandomAccessIterator first,
318  RandomAccessIterator middle,
319  RandomAccessIterator last) {
320  inplace_merge(policy, first, middle, last, std::less<>{});
321 }
322 
323 ITYR_TEST_CASE("[ityr::pattern::parallel_merge] inplace_merge") {
324  ito::init();
325  ori::init();
326 
327  ITYR_SUBCASE("integer") {
328  long n = 100000;
329  ori::global_ptr<long> p = ori::malloc_coll<long>(n);
330 
331  ito::root_exec([=] {
332  long m = n / 3;
333 
334  transform(
335  execution::parallel_policy(100),
336  count_iterator<long>(0), count_iterator<long>(m), p,
337  [=](long i) { return i; });
338 
339  transform(
340  execution::parallel_policy(100),
341  count_iterator<long>(0), count_iterator<long>(n - m), p + m,
342  [=](long i) { return i; });
343 
344  ITYR_CHECK(is_sorted(execution::parallel_policy(100), p , p + m) == true);
345  ITYR_CHECK(is_sorted(execution::parallel_policy(100), p + m, p + n) == true);
346  ITYR_CHECK(is_sorted(execution::parallel_policy(100), p , p + n) == false);
347 
348  inplace_merge(execution::parallel_policy(100), p, p + m, p + n);
349 
350  ITYR_CHECK(is_sorted(execution::parallel_policy(100), p, p + n) == true);
351  });
352 
353  ori::free_coll(p);
354  }
355 
356  ITYR_SUBCASE("stability test") {
357  // std::pair is not trivially copyable
358  struct item {
359  long key;
360  long val;
361  };
362 
363  long n = 100000;
364  long nb = 1738;
365  ori::global_ptr<item> p = ori::malloc_coll<item>(n);
366 
367  ito::root_exec([=] {
368  long m = n / 2;
369 
370  transform(
371  execution::parallel_policy(100),
372  count_iterator<long>(0), count_iterator<long>(m), p,
373  [=](long i) { return item{i / nb, i}; });
374 
375  transform(
376  execution::parallel_policy(100),
377  count_iterator<long>(0), count_iterator<long>(n - m), p + m,
378  [=](long i) { return item{i / nb, i + m}; });
379 
380  auto comp_key = [](const auto& a, const auto& b) { return a.key < b.key ; };
381  auto comp_val = [](const auto& a, const auto& b) { return a.val < b.val; };
382 
383  ITYR_CHECK(is_sorted(execution::parallel_policy(100),
384  p, p + n, comp_val) == true);
385 
386  inplace_merge(execution::parallel_policy(100),
387  p, p + m, p + n, comp_key);
388 
389  ITYR_CHECK(is_sorted(execution::parallel_policy(100),
390  p, p + n, comp_key) == true);
391 
392  for (long key = 0; key < m / nb; key++) {
393  bool sorted = is_sorted(execution::parallel_policy(100),
394  p + key * nb * 2,
395  p + std::min((key + 1) * nb * 2, n),
396  [=](const auto& a, const auto& b) {
397  ITYR_CHECK(a.key == key);
398  ITYR_CHECK(b.key == key);
399  return a.val < b.val;
400  });
401  ITYR_CHECK(sorted);
402  }
403  });
404 
405  ori::free_coll(p);
406  }
407 
408  ITYR_SUBCASE("corner cases") {
409  long n = 1802;
410  ori::global_ptr<long> p = ori::malloc_coll<long>(n);
411 
412  ito::root_exec([=] {
413  auto p_ = ori::checkout(p, n, ori::mode::write);
414  for (int i = 0; i < 4; i++) {
415  p_[i] = 21;
416  }
417  for (int i = 4; i < 187; i++) {
418  p_[i] = 22;
419  }
420  for (int i = 187; i < 1635; i++) {
421  p_[i] = 23;
422  }
423  for (int i = 1635; i < 1802; i++) {
424  p_[i] = 22;
425  }
427 
428  inplace_merge(execution::parallel_policy(100), p, p + 1635, p + n);
429 
430  ITYR_CHECK(is_sorted(execution::parallel_policy(100), p, p + n) == true);
431  });
432 
433  ori::free_coll(p);
434  }
435 
436  ori::fini();
437  ito::fini();
438 }
439 
440 }
#define ITYR_SUBCASE(name)
Definition: util.hpp:41
#define ITYR_CHECK(cond)
Definition: util.hpp:48
constexpr read_write_t read_write
Read+Write checkout mode.
Definition: checkout_span.hpp:39
void fini()
Definition: ito.hpp:45
auto root_exec(Fn &&fn, Args &&... args)
Definition: ito.hpp:50
void init(MPI_Comm comm=MPI_COMM_WORLD)
Definition: ito.hpp:41
constexpr write_t write
Definition: util.hpp:13
void fini()
Definition: ori.hpp:49
void init(MPI_Comm comm=MPI_COMM_WORLD)
Definition: ori.hpp:45
void checkin(T *raw_ptr, std::size_t count, mode::read_t)
Definition: ori.hpp:168
void free_coll(global_ptr< T > ptr)
Definition: ori.hpp:70
auto checkout(global_ptr< T > ptr, std::size_t count, Mode mode)
Definition: ori.hpp:148
monoid< T, min_functor<>, highest< T > > min
Definition: reducer.hpp:101
Definition: allocator.hpp:16
BidirectionalIterator rotate(const ExecutionPolicy &policy, BidirectionalIterator first, BidirectionalIterator middle, BidirectionalIterator last)
Rotate a range.
Definition: parallel_loop.hpp:1135
auto parallel_invoke(Args &&... args)
Fork parallel tasks and join them.
Definition: parallel_invoke.hpp:238
ForwardIteratorD transform(const ExecutionPolicy &policy, ForwardIterator1 first1, ForwardIterator1 last1, ForwardIteratorD first_d, UnaryOp unary_op)
Transform elements in a given range and store them in another range.
Definition: parallel_loop.hpp:583
void inplace_merge(const ExecutionPolicy &policy, RandomAccessIterator first, RandomAccessIterator middle, RandomAccessIterator last, Compare comp)
Merge two sorted ranges into one sorted range in place.
Definition: parallel_merge.hpp:273
bool is_sorted(const ExecutionPolicy &policy, ForwardIterator first, ForwardIterator last, Compare comp)
Check if a range is sorted.
Definition: parallel_reduce.hpp:1054
void inplace_merge(const ExecutionPolicy &policy, RandomAccessIterator first, RandomAccessIterator middle, RandomAccessIterator last)
Merge two sorted ranges into one sorted range in place.
Definition: parallel_merge.hpp:316