Itoyori  v0.0.1
parallel_sort.hpp
Go to the documentation of this file.
1 #pragma once
2 
3 #include "ityr/common/util.hpp"
7 
8 namespace ityr {
9 
10 namespace internal {
11 
12 template <bool Stable, typename W, typename RandomAccessIterator, typename Compare>
13 inline void merge_sort(const execution::parallel_policy<W>& policy,
14  RandomAccessIterator first,
15  RandomAccessIterator last,
16  Compare comp) {
17  std::size_t d = std::distance(first, last);
18 
19  if (d <= 1) return;
20 
21  if (d <= policy.cutoff_count) {
22  auto [css, its] = checkout_global_iterators(d, first);
23  auto first_ = std::get<0>(its);
24  std::stable_sort(first_, std::next(first_, d), comp);
25 
26  } else {
27  auto middle = std::next(first, d / 2);
28 
30  [=] { merge_sort<Stable>(policy, first, middle, comp); },
31  [=] { merge_sort<Stable>(policy, middle, last, comp); });
32 
33  internal::inplace_merge_aux<Stable>(policy, first, middle, last, comp);
34  }
35 }
36 
37 }
38 
70 template <typename ExecutionPolicy, typename RandomAccessIterator, typename Compare>
71 inline void stable_sort(const ExecutionPolicy& policy,
72  RandomAccessIterator first,
73  RandomAccessIterator last,
74  Compare comp) {
75  if constexpr (ori::is_global_ptr_v<RandomAccessIterator>) {
77  policy,
78  internal::convert_to_global_iterator(first, checkout_mode::read_write),
79  internal::convert_to_global_iterator(last , checkout_mode::read_write),
80  comp);
81 
82  } else {
83  internal::merge_sort<true>(policy, first, last, comp);
84  }
85 }
86 
103 template <typename ExecutionPolicy, typename RandomAccessIterator>
104 inline void stable_sort(const ExecutionPolicy& policy,
105  RandomAccessIterator first,
106  RandomAccessIterator last) {
107  stable_sort(policy, first, last, std::less<>{});
108 }
109 
110 ITYR_TEST_CASE("[ityr::pattern::parallel_sort] stable_sort") {
111  ito::init();
112  ori::init();
113 
114  ITYR_SUBCASE("stability test") {
115  // std::pair is not trivially copyable
116  struct item {
117  long key;
118  long val;
119  };
120 
121  long n = 100000;
122  long n_keys = 100;
123  ori::global_ptr<item> p = ori::malloc_coll<item>(n);
124 
125  ito::root_exec([=] {
126  transform(
127  execution::parallel_policy(100),
128  count_iterator<long>(0), count_iterator<long>(n), p,
129  [=](long i) { return item{i % n_keys, (3 * i + 5) % 13}; });
130 
131  stable_sort(execution::parallel_policy(100),
132  p, p + n, [](const auto& a, const auto& b) { return a.val < b.val; });
133 
134  stable_sort(execution::parallel_policy(100),
135  p, p + n, [](const auto& a, const auto& b) { return a.key < b.key; });
136 
137  long n_values_per_key = n / n_keys;
138  for (long key = 0; key < n_keys; key++) {
139  bool sorted = is_sorted(execution::parallel_policy(100),
140  p + key * n_values_per_key,
141  p + (key + 1) * n_values_per_key,
142  [=](const auto& a, const auto& b) {
143  ITYR_CHECK(a.key == key);
144  ITYR_CHECK(b.key == key);
145  return a.val < b.val;
146  });
147  ITYR_CHECK(sorted);
148  }
149  });
150 
151  ori::free_coll(p);
152  }
153 
154  ITYR_SUBCASE("corner cases") {
155  long n = 100000;
156  ori::global_ptr<long> p = ori::malloc_coll<long>(n);
157 
158  ito::root_exec([=] {
159  transform(
160  execution::parallel_policy(100),
161  count_iterator<long>(0), count_iterator<long>(n), p,
162  [=](long i) { return i; });
163 
164  stable_sort(execution::parallel_policy(100),
165  p, p + n, [](const auto&, const auto&) { return false; /* all equal */ });
166 
167  for_each(
168  execution::parallel_policy(100),
169  count_iterator<long>(0), count_iterator<long>(n),
171  [=](long i, long v) { ITYR_CHECK(i == v); });
172  });
173 
174  ori::free_coll(p);
175  }
176 
177  ori::fini();
178  ito::fini();
179 }
180 
209 template <typename ExecutionPolicy, typename RandomAccessIterator, typename Compare>
210 inline void sort(const ExecutionPolicy& policy,
211  RandomAccessIterator first,
212  RandomAccessIterator last,
213  Compare comp) {
214  if constexpr (ori::is_global_ptr_v<RandomAccessIterator>) {
215  sort(
216  policy,
217  internal::convert_to_global_iterator(first, checkout_mode::read_write),
218  internal::convert_to_global_iterator(last , checkout_mode::read_write),
219  comp);
220 
221  } else {
222  internal::merge_sort<false>(policy, first, last, comp);
223  }
224 }
225 
248 template <typename ExecutionPolicy, typename RandomAccessIterator>
249 inline void sort(const ExecutionPolicy& policy,
250  RandomAccessIterator first,
251  RandomAccessIterator last) {
252  sort(policy, first, last, std::less<>{});
253 }
254 
255 ITYR_TEST_CASE("[ityr::pattern::parallel_sort] sort") {
256  ito::init();
257  ori::init();
258 
259  long n = 100000;
260  ori::global_ptr<long> p = ori::malloc_coll<long>(n);
261 
262  ito::root_exec([=] {
263  transform(
264  execution::parallel_policy(100),
265  count_iterator<long>(0), count_iterator<long>(n), p,
266  [=](long i) { return (3 * i + 5) % 13; });
267 
268  ITYR_CHECK(is_sorted(execution::parallel_policy(100),
269  p, p + n) == false);
270 
271  sort(execution::parallel_policy(100), p, p + n);
272 
273  ITYR_CHECK(is_sorted(execution::parallel_policy(100),
274  p, p + n) == true);
275  });
276 
277  ori::free_coll(p);
278 
279  ori::fini();
280  ito::fini();
281 }
282 
283 }
#define ITYR_SUBCASE(name)
Definition: util.hpp:41
#define ITYR_CHECK(cond)
Definition: util.hpp:48
constexpr read_write_t read_write
Read+Write checkout mode.
Definition: checkout_span.hpp:39
constexpr read_t read
Read-only checkout mode.
Definition: checkout_span.hpp:19
void fini()
Definition: ito.hpp:45
auto root_exec(Fn &&fn, Args &&... args)
Definition: ito.hpp:50
void init(MPI_Comm comm=MPI_COMM_WORLD)
Definition: ito.hpp:41
void fini()
Definition: ori.hpp:49
void init(MPI_Comm comm=MPI_COMM_WORLD)
Definition: ori.hpp:45
void free_coll(global_ptr< T > ptr)
Definition: ori.hpp:70
Definition: allocator.hpp:16
void stable_sort(const ExecutionPolicy &policy, RandomAccessIterator first, RandomAccessIterator last, Compare comp)
Stable sort for a range.
Definition: parallel_sort.hpp:71
auto parallel_invoke(Args &&... args)
Fork parallel tasks and join them.
Definition: parallel_invoke.hpp:238
ForwardIteratorD transform(const ExecutionPolicy &policy, ForwardIterator1 first1, ForwardIterator1 last1, ForwardIteratorD first_d, UnaryOp unary_op)
Transform elements in a given range and store them in another range.
Definition: parallel_loop.hpp:583
void for_each(const ExecutionPolicy &policy, ForwardIterator first, ForwardIterator last, Op op)
Apply an operator to each element in a range.
Definition: parallel_loop.hpp:136
void sort(const ExecutionPolicy &policy, RandomAccessIterator first, RandomAccessIterator last, Compare comp)
Sort a range.
Definition: parallel_sort.hpp:210
void stable_sort(const ExecutionPolicy &policy, RandomAccessIterator first, RandomAccessIterator last)
Stable sort for a range.
Definition: parallel_sort.hpp:104
global_iterator< T, Mode > make_global_iterator(ori::global_ptr< T > gptr, Mode)
Make a global iterator to enable/disable automatic checkout.
Definition: global_iterator.hpp:158
bool is_sorted(const ExecutionPolicy &policy, ForwardIterator first, ForwardIterator last, Compare comp)
Check if a range is sorted.
Definition: parallel_reduce.hpp:1054