21 const char*
what() const noexcept
override {
return "Work stealing queue is full."; }
24 template <
typename Entry,
bool EnablePass = true>
28 : n_entries_(n_entries),
30 initial_pos_(EnablePass ? n_entries / 2 : 0),
31 queue_state_win_(common::topology::
mpicomm(), n_queues_ * 2, initial_pos_),
32 entries_win_(common::topology::
mpicomm(), n_entries_ * n_queues_),
33 queue_lock_(n_queues_),
34 local_empty_(n_queues_, false) {}
36 void push(
const Entry& entry,
int idx = 0) {
41 queue_state& qs = local_queue_state(idx);
42 auto entries = local_entries(idx);
44 int t = qs.top.load(std::memory_order_relaxed);
46 if (t == n_entries_) {
49 int b = qs.base.load(std::memory_order_relaxed);
50 int offset = -(b + 1) / 2;
51 move_entries(offset, idx);
59 qs.top.store(t + 1, std::memory_order_release);
61 if constexpr (!EnablePass) {
62 local_empty_[idx] =
false;
66 template <
bool EnsureEmpty = true>
67 std::optional<Entry>
pop(
int idx = 0) {
72 queue_state& qs = local_queue_state(idx);
73 if constexpr (EnablePass) {
77 int b = qs.base.load(std::memory_order_relaxed);
78 if (b < n_entries_ / 10) {
79 int t = qs.top.load(std::memory_order_relaxed);
80 if (n_entries_ - t > n_entries_ / 10) {
83 int t = qs.top.load(std::memory_order_relaxed);
84 int offset = (n_entries_ - t + 1) / 2;
85 move_entries(offset, idx);
92 if constexpr (!EnablePass) {
93 if (local_empty_[idx]) {
102 if constexpr (!EnsureEmpty) {
108 std::optional<Entry> ret;
109 auto entries = local_entries(idx);
111 int t = qs.top.load(std::memory_order_relaxed) - 1;
112 qs.top.store(t, std::memory_order_relaxed);
114 std::atomic_thread_fence(std::memory_order_seq_cst);
116 int b = qs.base.load(std::memory_order_relaxed);
121 qs.top.store(t + 1, std::memory_order_relaxed);
125 qs.top.store(t, std::memory_order_relaxed);
126 int b = qs.base.load(std::memory_order_relaxed);
133 qs.top.store(initial_pos_, std::memory_order_relaxed);
134 qs.base.store(initial_pos_, std::memory_order_relaxed);
138 if constexpr (!EnablePass) {
139 local_empty_[idx] =
true;
144 qs.top.store(initial_pos_, std::memory_order_relaxed);
145 qs.base.store(initial_pos_, std::memory_order_relaxed);
147 if constexpr (!EnablePass) {
148 local_empty_[idx] =
true;
165 std::optional<Entry> ret;
167 int b = common::mpi_atomic_faa_value<int>(1, target_rank, queue_state_base_disp(idx), queue_state_win_.
win());
168 int t = common::mpi_get_value<int>(target_rank, queue_state_top_disp(idx), queue_state_win_.
win());
171 ret = common::mpi_get_value<Entry>(target_rank, entries_disp(b, idx), entries_win_.
win());
173 common::mpi_atomic_faa_value<int>(-1, target_rank, queue_state_base_disp(idx), queue_state_win_.
win());
183 queue_lock_.
lock(target_rank, idx);
185 queue_lock_.
unlock(target_rank, idx);
195 common::mpi_atomic_faa_value<int>(-1, target_rank, queue_state_base_disp(idx), queue_state_win_.
win());
199 if constexpr (!EnablePass) {
200 common::die(
"Pass operation is not allowed");
205 queue_lock_.
lock(target_rank, idx);
207 int b = common::mpi_get_value<int>(target_rank, queue_state_base_disp(idx), queue_state_win_.
win());
210 queue_lock_.
unlock(target_rank, idx);
214 common::mpi_put_value<Entry>(entry, target_rank, entries_disp(b - 1, idx), entries_win_.
win());
216 common::mpi_put_value<int>(b - 1, target_rank, queue_state_base_disp(idx), queue_state_win_.
win());
218 queue_lock_.
unlock(target_rank, idx);
227 while (!
trypass(entry, target_rank, idx));
230 template <
typename Fn,
bool EnsureEmpty = true>
234 if constexpr (!EnablePass) {
235 if (local_empty_[idx]) {
240 queue_state& qs = local_queue_state(idx);
241 if constexpr (!EnsureEmpty) {
247 auto entries = local_entries(idx);
251 int t = qs.top.load(std::memory_order_relaxed);
252 int b = qs.base.load(std::memory_order_relaxed);
253 for (
int i = b; i < t; i++) {
257 if constexpr (!EnablePass) {
259 local_empty_[idx] =
true;
268 return local_queue_state(idx).size();
276 auto remote_qs = common::mpi_get_value<queue_state>(target_rank, queue_state_disp(idx), queue_state_win_.
win());
277 return remote_qs.empty();
280 template <
typename Fn>
282 int idx_begin,
int idx_end,
bool reverse, Fn fn) {
285 common::mpi_get(&local_queue_state_buf(idx_begin), idx_end - idx_begin,
286 target_rank, queue_state_disp(idx_begin), queue_state_win_.
win());
289 for (
int idx = idx_end - 1; idx >= idx_begin; idx--) {
290 if (!local_queue_state_buf(idx).
empty()) {
296 for (
int idx = idx_begin; idx < idx_end; idx++) {
297 if (!local_queue_state_buf(idx).
empty()) {
309 std::atomic<int> top;
310 std::atomic<int> base;
312 static_assert(
sizeof(std::atomic<int>) ==
sizeof(
int));
314 queue_state(
int initial_pos = 0) : top(initial_pos), base(initial_pos) {}
317 queue_state(
const queue_state& qs)
318 : top(qs.top.load(std::memory_order_relaxed)),
319 base(qs.base.load(std::memory_order_relaxed)) {}
320 queue_state& operator=(
const queue_state& qs) {
321 top.store(qs.top.load(std::memory_order_relaxed), std::memory_order_relaxed);
322 base.store(qs.base.load(std::memory_order_relaxed), std::memory_order_relaxed);
326 return std::max(0, top.load(std::memory_order_relaxed) -
327 base.load(std::memory_order_relaxed));
331 return top.load(std::memory_order_relaxed) <=
332 base.load(std::memory_order_relaxed);
336 static_assert(std::is_standard_layout_v<queue_state>);
342 static constexpr std::size_t queue_state_align =
sizeof(queue_state);
344 struct alignas(queue_state_align) queue_state_wrapper {
345 template <
typename... Args>
346 queue_state_wrapper(Args&&...
args) : value(std::forward<Args>(
args)...) {}
350 std::size_t queue_state_disp(
int idx)
const {
351 return idx *
sizeof(queue_state_wrapper) + offsetof(queue_state_wrapper, value);
354 std::size_t queue_state_top_disp(
int idx)
const {
355 return idx *
sizeof(queue_state_wrapper) + offsetof(queue_state_wrapper, value) + offsetof(queue_state, top);
358 std::size_t queue_state_base_disp(
int idx)
const {
359 return idx *
sizeof(queue_state_wrapper) + offsetof(queue_state_wrapper, value) + offsetof(queue_state, base);
362 std::size_t entries_disp(
int entry_num,
int idx)
const {
363 return (entry_num + idx * n_entries_) *
sizeof(Entry);
366 queue_state& local_queue_state(
int idx)
const {
367 return queue_state_win_.local_buf()[idx].value;
370 queue_state& local_queue_state_buf(
int idx)
const {
372 return queue_state_win_.local_buf()[n_queues_ + idx].value;
375 auto local_entries(
int idx)
const {
376 return entries_win_.local_buf().subspan(idx * n_entries_, n_entries_);
379 void move_entries(
int offset,
int idx) {
382 queue_state& qs = local_queue_state(idx);
383 auto entries = local_entries(idx);
385 int t = qs.top.load(std::memory_order_relaxed);
386 int b = qs.base.load(std::memory_order_relaxed);
390 int new_b = b + offset;
391 int new_t = t + offset;
393 if (offset == 0 || new_b < 0 || n_entries_ < new_t) {
394 throw wsqueue_full_exception{};
397 std::move(&entries[b], &entries[t], &entries[new_b]);
399 qs.top.store(new_t, std::memory_order_relaxed);
400 qs.base.store(new_b, std::memory_order_relaxed);
406 common::mpi_win_manager<queue_state_wrapper> queue_state_win_;
407 common::mpi_win_manager<Entry> entries_win_;
408 common::global_lock queue_lock_;
409 std::vector<bool> local_empty_;
412 ITYR_TEST_CASE(
"[ityr::ito::wsqueue] single queue") {
413 int n_entries = 1000;
416 common::runtime_options common_opts;
417 common::singleton_initializer<common::topology::instance> topo;
418 wsqueue<entry_t> wsq(n_entries);
425 for (
int t = 0; t < n_trial; t++) {
426 for (
int i = 0; i < n_entries; i++) {
429 for (
int i = 0; i < n_entries; i++) {
430 auto result = wsq.pop();
438 for (
int i = 0; i < n_entries; i++) {
452 entry_t sum_expected = 0;
454 for (
int i = 0; i < n_entries; i++) {
462 entry_t local_sum = 0;
466 for (
int i = 0; i < n_entries; i++) {
467 auto result = wsq.steal(target_rank);
470 local_sum += *result;
477 while (!wsq.empty(target_rank)) {
478 auto result = wsq.steal(target_rank);
479 if (result.has_value()) {
480 local_sum += *result;
486 ITYR_SUBCASE(
"local pop and remote steal concurrently") {
489 auto result = wsq.pop();
490 if (result.has_value()) {
491 local_sum += *result;
495 while (!wsq.empty(target_rank)) {
496 auto result = wsq.steal(target_rank);
497 if (result.has_value()) {
498 local_sum += *result;
526 entry_t sum_expected = 0;
527 entry_t local_sum = 0;
530 for (
int r = 0; r < n_repeats; r++) {
531 for (
int i = 0; i < n_entries; i++) {
536 auto result = wsq.pop();
537 if (result.has_value()) {
538 local_sum += *result;
551 entry_t local_sum = 0;
555 auto result = wsq.steal(target_rank);
556 if (result.has_value()) {
557 local_sum += *result;
579 for (
int i = 0; i < n_entries; i++) {
589 for (
int i = 0; i < n_entries / 2; i++) {
590 auto result = wsq.steal(target_rank);
600 for (
int i = 0; i < n_entries / 2; i++) {
604 for (
int i = 0; i < n_entries; i++) {
605 auto result = wsq.pop();
627 auto result = wsq.steal(target_rank);
628 if (result.has_value()) {
636 for (
int i = 0; i < n_entries /
n_ranks; i++) {
637 wsq.pass(i, target_rank);
649 ITYR_TEST_CASE(
"[ityr::ito::wsqueue] multiple queues") {
650 int n_entries = 1000;
654 common::runtime_options common_opts;
655 common::singleton_initializer<common::topology::instance> topo;
656 wsqueue<entry_t> wsq(n_entries,
n_queues);
664 for (
int q = 0; q <
n_queues; q++) {
671 entry_t sum_expected = 0;
672 entry_t local_sum = 0;
675 for (
int r = 0; r < n_repeats; r++) {
676 for (
int i = 0; i < n_entries; i++) {
677 for (
int q = 0; q <
n_queues; q++) {
682 for (
int q = 0; q <
n_queues; q++) {
683 while (!wsq.empty(
my_rank, q)) {
684 auto result = wsq.pop(q);
685 if (result.has_value()) {
686 local_sum += *result;
700 entry_t local_sum = 0;
704 for (
int q = 0; q <
n_queues; q++) {
705 auto result = wsq.steal(target_rank, q);
706 if (result.has_value()) {
707 local_sum += *result;
712 for (
int q = 0; q <
n_queues; q++) {
Definition: global_lock.hpp:15
void priolock(topology::rank_t target_rank, int idx=0) const
Definition: global_lock.hpp:38
void lock(topology::rank_t target_rank, int idx=0) const
Definition: global_lock.hpp:33
void unlock(topology::rank_t target_rank, int idx=0) const
Definition: global_lock.hpp:53
bool is_locked(topology::rank_t target_rank, int idx=0) const
Definition: global_lock.hpp:61
MPI_Win win() const
Definition: mpi_rma.hpp:409
Definition: wsqueue.hpp:19
const char * what() const noexcept override
Definition: wsqueue.hpp:21
Definition: wsqueue.hpp:25
std::optional< Entry > pop(int idx=0)
Definition: wsqueue.hpp:67
wsqueue(int n_entries, int n_queues=1)
Definition: wsqueue.hpp:27
void abort_steal(common::topology::rank_t target_rank, int idx=0)
Definition: wsqueue.hpp:189
std::optional< Entry > steal_nolock(common::topology::rank_t target_rank, int idx=0)
Definition: wsqueue.hpp:158
void pass(const Entry &entry, common::topology::rank_t target_rank, int idx=0)
Definition: wsqueue.hpp:223
int n_queues() const
Definition: wsqueue.hpp:305
bool trypass(const Entry &entry, common::topology::rank_t target_rank, int idx=0)
Definition: wsqueue.hpp:198
void push(const Entry &entry, int idx=0)
Definition: wsqueue.hpp:36
std::optional< Entry > steal(common::topology::rank_t target_rank, int idx=0)
Definition: wsqueue.hpp:180
void for_each_entry(Fn fn, int idx=0)
Definition: wsqueue.hpp:231
const common::global_lock & lock() const
Definition: wsqueue.hpp:304
bool empty(common::topology::rank_t target_rank, int idx=0) const
Definition: wsqueue.hpp:271
int size(int idx=0) const
Definition: wsqueue.hpp:266
void for_each_nonempty_queue(common::topology::rank_t target_rank, int idx_begin, int idx_end, bool reverse, Fn fn)
Definition: wsqueue.hpp:281
#define ITYR_CHECK_THROWS_AS(exp, exception)
Definition: util.hpp:51
#define ITYR_SUBCASE(name)
Definition: util.hpp:41
#define ITYR_CHECK(cond)
Definition: util.hpp:48
rank_t n_ranks()
Definition: topology.hpp:208
int rank_t
Definition: topology.hpp:12
MPI_Comm mpicomm()
Definition: topology.hpp:206
rank_t my_rank()
Definition: topology.hpp:207
va_list args
Definition: util.hpp:76
void mpi_get(T *origin, std::size_t count, int target_rank, std::size_t target_disp, MPI_Win win)
Definition: mpi_rma.hpp:69
void mpi_wait(MPI_Request &req)
Definition: mpi_util.hpp:250
MPI_Request mpi_ibarrier(MPI_Comm comm)
Definition: mpi_util.hpp:46
T mpi_reduce_value(const T &value, int root_rank, MPI_Comm comm, MPI_Op op=MPI_SUM)
Definition: mpi_util.hpp:170
void mpi_barrier(MPI_Comm comm)
Definition: mpi_util.hpp:42
bool mpi_test(MPI_Request &req)
Definition: mpi_util.hpp:254
Definition: aarch64.hpp:5
monoid< T, max_functor<>, lowest< T > > max
Definition: reducer.hpp:104
rank_t my_rank()
Return the rank of the process running the current thread.
Definition: ityr.hpp:99
rank_t n_ranks()
Return the total number of processes.
Definition: ityr.hpp:107
void reverse(const ExecutionPolicy &policy, BidirectionalIterator first, BidirectionalIterator last)
Reverse a range.
Definition: parallel_loop.hpp:1014
ForwardIteratorD move(const ExecutionPolicy &policy, ForwardIterator1 first1, ForwardIterator1 last1, ForwardIteratorD first_d)
Move a range to another.
Definition: parallel_loop.hpp:934
#define ITYR_PROFILER_RECORD(event,...)
Definition: profiler.hpp:319
Definition: prof_events.hpp:130
Definition: prof_events.hpp:125
Definition: prof_events.hpp:120
Definition: prof_events.hpp:105
Definition: prof_events.hpp:100
Definition: prof_events.hpp:115
Definition: prof_events.hpp:110