Itoyori  v0.0.1
topology.hpp
Go to the documentation of this file.
1 #pragma once
2 
3 #include <vector>
4 
5 #include "ityr/common/util.hpp"
8 #include "ityr/common/numa.hpp"
9 
11 
12 using rank_t = int;
13 
14 class topology {
15 public:
16  topology() : topology(MPI_COMM_WORLD) {}
17  topology(MPI_Comm comm)
18  : enable_shared_memory_(enable_shared_memory_option::value()),
19  cg_global_(comm, false),
20  cg_intra_(create_intra_comm(), enable_shared_memory_),
21  cg_inter_(create_inter_comm(), enable_shared_memory_),
22  process_map_(create_process_map()),
23  intra2global_rank_(create_intra2global_rank()),
24  inter2global_rank_(create_inter2global_rank()),
25  numa_enabled_(numa::enabled()),
26  numa_nodes_all_(create_intra_numa_nodes()),
27  numa_nodemask_all_(get_numa_bitmask(numa_nodes_all_)) {}
28 
29  topology(const topology&) = delete;
30  topology& operator=(const topology&) = delete;
31 
32  MPI_Comm mpicomm() const { return cg_global_.mpicomm; }
33  rank_t my_rank() const { return cg_global_.my_rank; }
34  rank_t n_ranks() const { return cg_global_.n_ranks; }
35 
36  MPI_Comm intra_mpicomm() const { return cg_intra_.mpicomm; }
37  rank_t intra_my_rank() const { return cg_intra_.my_rank; }
38  rank_t intra_n_ranks() const { return cg_intra_.n_ranks; }
39 
40  MPI_Comm inter_mpicomm() const { return cg_inter_.mpicomm; }
41  rank_t inter_my_rank() const { return cg_inter_.my_rank; }
42  rank_t inter_n_ranks() const { return cg_inter_.n_ranks; }
43 
44  rank_t intra_rank(rank_t global_rank) const {
45  ITYR_CHECK(0 <= global_rank);
46  ITYR_CHECK(global_rank < n_ranks());
47  return process_map_[global_rank].intra_rank;
48  }
49 
50  rank_t inter_rank(rank_t global_rank) const {
51  ITYR_CHECK(0 <= global_rank);
52  ITYR_CHECK(global_rank < n_ranks());
53  return process_map_[global_rank].inter_rank;
54  }
55 
57  ITYR_CHECK(0 <= intra_rank);
59  return intra2global_rank_[intra_rank];
60  }
61 
63  ITYR_CHECK(0 <= inter_rank);
65  return inter2global_rank_[inter_rank];
66  }
67 
68  bool is_locally_accessible(rank_t target_global_rank) const {
69  return inter_rank(target_global_rank) == inter_my_rank();
70  }
71 
72  bool numa_enabled() const { return numa_enabled_; }
73 
75  ITYR_CHECK(0 <= intra_rank);
77  return numa_nodes_all_[intra_rank];
78  }
79 
81  return numa_node(intra_my_rank());
82  }
83 
85  return get_unique_numa_nodes(numa_nodes_all_).size();
86  }
87 
89  return numa_nodemask_all_;
90  }
91 
92 private:
93  struct comm_group {
96  MPI_Comm mpicomm = MPI_COMM_NULL;
97  bool own_comm = false;
98 
99  comm_group(MPI_Comm comm, bool own)
100  : my_rank(mpi_comm_rank(comm)), n_ranks(mpi_comm_size(comm)),
101  mpicomm(comm), own_comm(own) {}
102 
103  ~comm_group() {
104  if (own_comm) {
105  MPI_Comm_free(&mpicomm);
106  }
107  }
108  };
109 
110  struct process_map_entry {
113  };
114 
115  MPI_Comm create_intra_comm() {
116  if (enable_shared_memory_) {
117  MPI_Comm h;
118  MPI_Comm_split_type(mpicomm(), MPI_COMM_TYPE_SHARED, my_rank(), MPI_INFO_NULL, &h);
119  return h;
120  } else {
121  return MPI_COMM_SELF;
122  }
123  }
124 
125  MPI_Comm create_inter_comm() {
126  if (enable_shared_memory_) {
127  MPI_Comm h;
128  MPI_Comm_split(mpicomm(), intra_my_rank(), my_rank(), &h);
129  return h;
130  } else {
131  return mpicomm();
132  }
133  }
134 
135  std::vector<process_map_entry> create_process_map() {
136  process_map_entry my_entry {intra_my_rank(), inter_my_rank()};
137  std::vector<process_map_entry> ret(n_ranks());
138  MPI_Allgather(&my_entry,
139  sizeof(process_map_entry),
140  MPI_BYTE,
141  ret.data(),
142  sizeof(process_map_entry),
143  MPI_BYTE,
144  mpicomm());
145  return ret;
146  }
147 
148  std::vector<rank_t> create_intra2global_rank() {
149  std::vector<rank_t> ret;
150  for (rank_t i = 0; i < n_ranks(); i++) {
151  if (process_map_[i].inter_rank == inter_my_rank()) {
152  ret.push_back(i);
153  }
154  }
155  ITYR_CHECK(ret.size() == std::size_t(intra_n_ranks()));
156  return ret;
157  }
158 
159  std::vector<rank_t> create_inter2global_rank() {
160  std::vector<rank_t> ret;
161  for (rank_t i = 0; i < n_ranks(); i++) {
162  if (process_map_[i].intra_rank == intra_my_rank()) {
163  ITYR_CHECK(process_map_[i].inter_rank == ret.size());
164  ret.push_back(i);
165  }
166  }
167  ITYR_CHECK(ret.size() == std::size_t(inter_n_ranks()));
168  return ret;
169  }
170 
171  std::vector<numa::node_t> create_intra_numa_nodes() const {
172  auto my_node = numa::get_current_node();
173  return mpi_allgather_value(my_node, intra_mpicomm());
174  }
175 
176  std::vector<numa::node_t> get_unique_numa_nodes(std::vector<numa::node_t> nodes) const {
177  std::sort(nodes.begin(), nodes.end());
178  nodes.erase(std::unique(nodes.begin(), nodes.end()), nodes.end());
179  return nodes;
180  }
181 
182  numa::node_bitmask get_numa_bitmask(std::vector<numa::node_t> nodes) const {
183  auto unique_nodes = get_unique_numa_nodes(nodes);
184  numa::node_bitmask nodemask;
185  for (const auto& node : unique_nodes) {
186  nodemask.setbit(node);
187  }
188  return nodemask;
189  }
190 
191  bool enable_shared_memory_;
192  comm_group cg_global_;
193  comm_group cg_intra_;
194  comm_group cg_inter_;
195  std::vector<process_map_entry> process_map_; // global_rank -> (intra, inter rank)
196  std::vector<rank_t> intra2global_rank_;
197  std::vector<rank_t> inter2global_rank_;
198 
199  bool numa_enabled_;
200  std::vector<numa::node_t> numa_nodes_all_;
201  numa::node_bitmask numa_nodemask_all_;
202 };
203 
205 
206 inline MPI_Comm mpicomm() { return instance::get().mpicomm(); }
207 inline rank_t my_rank() { return instance::get().my_rank(); }
208 inline rank_t n_ranks() { return instance::get().n_ranks(); }
209 
210 inline MPI_Comm intra_mpicomm() { return instance::get().intra_mpicomm(); }
211 inline rank_t intra_my_rank() { return instance::get().intra_my_rank(); }
212 inline rank_t intra_n_ranks() { return instance::get().intra_n_ranks(); }
213 
214 inline MPI_Comm inter_mpicomm() { return instance::get().inter_mpicomm(); }
215 inline rank_t inter_my_rank() { return instance::get().inter_my_rank(); }
216 inline rank_t inter_n_ranks() { return instance::get().inter_n_ranks(); }
217 
218 inline rank_t intra_rank(rank_t global_rank) { return instance::get().intra_rank(global_rank); };
219 inline rank_t inter_rank(rank_t global_rank) { return instance::get().inter_rank(global_rank); };
220 
221 inline rank_t intra2global_rank(rank_t intra_rank) { return instance::get().intra2global_rank(intra_rank); }
222 inline rank_t inter2global_rank(rank_t inter_rank) { return instance::get().inter2global_rank(inter_rank); }
223 
224 inline bool is_locally_accessible(rank_t target_global_rank) { return instance::get().is_locally_accessible(target_global_rank); };
225 
226 inline bool numa_enabled() { return instance::get().numa_enabled(); }
227 inline numa::node_t numa_my_node() { return instance::get().numa_my_node(); }
228 inline numa::node_t numa_n_nodes() { return instance::get().numa_n_nodes(); }
230 inline const numa::node_bitmask& numa_nodemask_all() { return instance::get().numa_nodemask_all(); }
231 
232 }
Definition: numa.hpp:78
Definition: util.hpp:176
static auto & get()
Definition: util.hpp:180
Definition: topology.hpp:14
rank_t inter_my_rank() const
Definition: topology.hpp:41
numa::node_t numa_my_node() const
Definition: topology.hpp:80
topology(const topology &)=delete
rank_t n_ranks() const
Definition: topology.hpp:34
topology & operator=(const topology &)=delete
numa::node_t numa_n_nodes() const
Definition: topology.hpp:84
rank_t intra_rank(rank_t global_rank) const
Definition: topology.hpp:44
topology()
Definition: topology.hpp:16
bool numa_enabled() const
Definition: topology.hpp:72
rank_t intra2global_rank(rank_t intra_rank) const
Definition: topology.hpp:56
topology(MPI_Comm comm)
Definition: topology.hpp:17
rank_t inter_rank(rank_t global_rank) const
Definition: topology.hpp:50
rank_t inter2global_rank(rank_t inter_rank) const
Definition: topology.hpp:62
rank_t intra_n_ranks() const
Definition: topology.hpp:38
MPI_Comm mpicomm() const
Definition: topology.hpp:32
rank_t inter_n_ranks() const
Definition: topology.hpp:42
MPI_Comm inter_mpicomm() const
Definition: topology.hpp:40
rank_t my_rank() const
Definition: topology.hpp:33
MPI_Comm intra_mpicomm() const
Definition: topology.hpp:36
numa::node_t numa_node(rank_t intra_rank) const
Definition: topology.hpp:74
const numa::node_bitmask & numa_nodemask_all() const
Definition: topology.hpp:88
bool is_locally_accessible(rank_t target_global_rank) const
Definition: topology.hpp:68
rank_t intra_my_rank() const
Definition: topology.hpp:37
#define ITYR_CHECK(cond)
Definition: util.hpp:48
int node_t
Definition: numa.hpp:76
bool enabled()
Definition: numa.hpp:86
node_t get_current_node()
Definition: numa.hpp:87
Definition: topology.hpp:10
numa::node_t numa_node(rank_t intra_rank)
Definition: topology.hpp:229
rank_t inter_my_rank()
Definition: topology.hpp:215
bool numa_enabled()
Definition: topology.hpp:226
rank_t n_ranks()
Definition: topology.hpp:208
int rank_t
Definition: topology.hpp:12
rank_t inter_rank(rank_t global_rank)
Definition: topology.hpp:219
MPI_Comm mpicomm()
Definition: topology.hpp:206
rank_t inter_n_ranks()
Definition: topology.hpp:216
rank_t intra_my_rank()
Definition: topology.hpp:211
const numa::node_bitmask & numa_nodemask_all()
Definition: topology.hpp:230
bool is_locally_accessible(rank_t target_global_rank)
Definition: topology.hpp:224
rank_t inter2global_rank(rank_t inter_rank)
Definition: topology.hpp:222
MPI_Comm inter_mpicomm()
Definition: topology.hpp:214
MPI_Comm intra_mpicomm()
Definition: topology.hpp:210
rank_t intra_n_ranks()
Definition: topology.hpp:212
numa::node_t numa_n_nodes()
Definition: topology.hpp:228
rank_t intra_rank(rank_t global_rank)
Definition: topology.hpp:218
rank_t my_rank()
Definition: topology.hpp:207
numa::node_t numa_my_node()
Definition: topology.hpp:227
rank_t intra2global_rank(rank_t intra_rank)
Definition: topology.hpp:221
std::vector< T > mpi_allgather_value(const T &value, MPI_Comm comm)
Definition: mpi_util.hpp:218
int mpi_comm_rank(MPI_Comm comm)
Definition: mpi_util.hpp:28
int mpi_comm_size(MPI_Comm comm)
Definition: mpi_util.hpp:35
void sort(const ExecutionPolicy &policy, RandomAccessIterator first, RandomAccessIterator last, Compare comp)
Sort a range.
Definition: parallel_sort.hpp:210