Itoyori  v0.0.1
parallel_filter.hpp
Go to the documentation of this file.
1 #pragma once
2 
3 #include "ityr/common/util.hpp"
6 
7 namespace ityr {
8 
9 namespace internal {
10 
11 template <typename W, typename BidirectionalIterator, typename Predicate>
12 inline BidirectionalIterator
13 stable_partition_aux(const execution::parallel_policy<W>& policy,
14  BidirectionalIterator first,
15  BidirectionalIterator last,
16  Predicate pred) {
17  std::size_t d = std::distance(first, last);
18 
19  if (d <= policy.cutoff_count) {
20  // TODO: consider policy.checkout_count
21  ITYR_CHECK(policy.cutoff_count == policy.checkout_count);
22 
23  auto&& [css, its] = checkout_global_iterators(d, first);
24  auto&& first_ = std::get<0>(its);
25  auto m = std::stable_partition(first_, std::next(first_, d), pred);
26  return std::next(first, std::distance(first_, m));
27  }
28 
29  auto mid = std::next(first, d / 2);
30 
31  auto [m1, m2] = parallel_invoke(
32  [=] { return stable_partition_aux(policy, first, mid , pred); },
33  [=] { return stable_partition_aux(policy, mid , last, pred); });
34 
35  return rotate(policy, m1, mid, m2);
36 }
37 
38 }
39 
75 template <typename ExecutionPolicy, typename BidirectionalIterator, typename Predicate>
76 inline BidirectionalIterator stable_partition(const ExecutionPolicy& policy,
77  BidirectionalIterator first,
78  BidirectionalIterator last,
79  Predicate pred) {
80  if constexpr (ori::is_global_ptr_v<BidirectionalIterator>) {
81  return stable_partition(
82  policy,
83  internal::convert_to_global_iterator(first, checkout_mode::read_write),
84  internal::convert_to_global_iterator(last , checkout_mode::read_write),
85  pred);
86 
87  } else {
88  return internal::stable_partition_aux(policy, first, last, pred);
89  }
90 }
91 
92 ITYR_TEST_CASE("[ityr::pattern::parallel_filter] stable_partition") {
93  ito::init();
94  ori::init();
95 
96  ITYR_SUBCASE("split half") {
97  long n = 100000;
98  ori::global_ptr<long> p = ori::malloc_coll<long>(n);
99 
100  ito::root_exec([=] {
101  transform(
102  execution::parallel_policy(100),
103  count_iterator<long>(0), count_iterator<long>(n), p,
104  [=](long i) { return i; });
105 
106  auto pp = stable_partition(
107  execution::parallel_policy(100),
108  p, p + n,
109  [](long x) { return x % 2 == 0; });
110 
111  ITYR_CHECK(pp == p + n / 2);
112 
113  for_each(
114  execution::parallel_policy(100),
117  count_iterator<long>(0),
118  [](long x, long i) { ITYR_CHECK(x == i * 2); });
119 
120  for_each(
121  execution::parallel_policy(100),
124  count_iterator<long>(0),
125  [](long x, long i) { ITYR_CHECK(x == i * 2 + 1); });
126  });
127 
128  ori::free_coll(p);
129  }
130 
131  ITYR_SUBCASE("split 1:2") {
132  long n = 90000;
133  ori::global_ptr<long> p = ori::malloc_coll<long>(n);
134 
135  ito::root_exec([=] {
136  transform(
137  execution::parallel_policy(100),
138  count_iterator<long>(0), count_iterator<long>(n), p,
139  [=](long i) { return i; });
140 
141  auto pp = stable_partition(
142  execution::parallel_policy(100),
143  p, p + n,
144  [](long x) { return x % 3 == 0; });
145 
146  ITYR_CHECK(pp == p + n / 3);
147 
148  for_each(
149  execution::parallel_policy(100),
152  count_iterator<long>(0),
153  [](long x, long i) { ITYR_CHECK(x == i * 3); });
154 
155  for_each(
156  execution::parallel_policy(100),
159  count_iterator<long>(0),
160  [](long x, long i) { ITYR_CHECK(x == (i / 2) * 3 + (i % 2) + 1); });
161  });
162 
163  ori::free_coll(p);
164  }
165 
166  ITYR_SUBCASE("corner cases") {
167  long n = 100000;
168  ori::global_ptr<long> p = ori::malloc_coll<long>(n);
169 
170  ito::root_exec([=] {
171  transform(
172  execution::parallel_policy(100),
173  count_iterator<long>(0), count_iterator<long>(n), p,
174  [=](long i) { return i; });
175 
176  auto pp1 = stable_partition(
177  execution::parallel_policy(100),
178  p, p + n,
179  [](long) { return true; });
180 
181  ITYR_CHECK(pp1 == p + n);
182 
183  for_each(
184  execution::parallel_policy(100),
187  count_iterator<long>(0),
188  [](long x, long i) { ITYR_CHECK(x == i); });
189 
190  auto pp2 = stable_partition(
191  execution::parallel_policy(100),
192  p, p + n,
193  [](long) { return false; });
194 
195  ITYR_CHECK(pp2 == p);
196 
197  for_each(
198  execution::parallel_policy(100),
201  count_iterator<long>(0),
202  [](long x, long i) { ITYR_CHECK(x == i); });
203  });
204 
205  ori::free_coll(p);
206  }
207 
208  ori::fini();
209  ito::fini();
210 }
211 
246 template <typename ExecutionPolicy, typename BidirectionalIterator, typename Predicate>
247 inline BidirectionalIterator partition(const ExecutionPolicy& policy,
248  BidirectionalIterator first,
249  BidirectionalIterator last,
250  Predicate pred) {
251  // TODO: implement faster unstable partition
252  return stable_partition(policy, first, last, pred);
253 }
254 
255 }
#define ITYR_SUBCASE(name)
Definition: util.hpp:41
#define ITYR_CHECK(cond)
Definition: util.hpp:48
constexpr read_write_t read_write
Read+Write checkout mode.
Definition: checkout_span.hpp:39
constexpr read_t read
Read-only checkout mode.
Definition: checkout_span.hpp:19
void fini()
Definition: ito.hpp:45
auto root_exec(Fn &&fn, Args &&... args)
Definition: ito.hpp:50
void init(MPI_Comm comm=MPI_COMM_WORLD)
Definition: ito.hpp:41
void fini()
Definition: ori.hpp:49
void init(MPI_Comm comm=MPI_COMM_WORLD)
Definition: ori.hpp:45
void free_coll(global_ptr< T > ptr)
Definition: ori.hpp:70
Definition: allocator.hpp:16
BidirectionalIterator rotate(const ExecutionPolicy &policy, BidirectionalIterator first, BidirectionalIterator middle, BidirectionalIterator last)
Rotate a range.
Definition: parallel_loop.hpp:1135
auto parallel_invoke(Args &&... args)
Fork parallel tasks and join them.
Definition: parallel_invoke.hpp:238
ForwardIteratorD transform(const ExecutionPolicy &policy, ForwardIterator1 first1, ForwardIterator1 last1, ForwardIteratorD first_d, UnaryOp unary_op)
Transform elements in a given range and store them in another range.
Definition: parallel_loop.hpp:583
BidirectionalIterator stable_partition(const ExecutionPolicy &policy, BidirectionalIterator first, BidirectionalIterator last, Predicate pred)
Partition elements into two disjoint parts in place.
Definition: parallel_filter.hpp:76
void for_each(const ExecutionPolicy &policy, ForwardIterator first, ForwardIterator last, Op op)
Apply an operator to each element in a range.
Definition: parallel_loop.hpp:136
BidirectionalIterator partition(const ExecutionPolicy &policy, BidirectionalIterator first, BidirectionalIterator last, Predicate pred)
Partition elements into two disjoint parts in place.
Definition: parallel_filter.hpp:247
global_iterator< T, Mode > make_global_iterator(ori::global_ptr< T > gptr, Mode)
Make a global iterator to enable/disable automatic checkout.
Definition: global_iterator.hpp:158