Itoyori  v0.0.1
mpi_util.hpp
Go to the documentation of this file.
1 #pragma once
2 
3 #include <mpi.h>
4 
5 #include "ityr/common/util.hpp"
6 
7 #ifndef ITYR_DEBUG_UCX
8 #define ITYR_DEBUG_UCX 0
9 #endif
10 
11 #if ITYR_DEBUG_UCX
12 #include <ucs/debug/log_def.h>
13 #include <sys/time.h>
14 #include <cstring>
15 #include <atomic>
16 #endif
17 
18 namespace ityr::common {
19 
20 template <typename T> inline MPI_Datatype mpi_type();
21 template <> inline MPI_Datatype mpi_type<int>() { return MPI_INT; }
22 template <> inline MPI_Datatype mpi_type<unsigned int>() { return MPI_UNSIGNED; }
23 template <> inline MPI_Datatype mpi_type<long>() { return MPI_LONG; }
24 template <> inline MPI_Datatype mpi_type<unsigned long>() { return MPI_UNSIGNED_LONG; }
25 template <> inline MPI_Datatype mpi_type<bool>() { return MPI_CXX_BOOL; }
26 template <> inline MPI_Datatype mpi_type<void*>() { return mpi_type<uintptr_t>(); }
27 
28 inline int mpi_comm_rank(MPI_Comm comm) {
29  int rank;
30  MPI_Comm_rank(comm, &rank);
31  ITYR_CHECK(rank >= 0);
32  return rank;
33 }
34 
35 inline int mpi_comm_size(MPI_Comm comm) {
36  int size;
37  MPI_Comm_size(comm, &size);
38  ITYR_CHECK(size >= 0);
39  return size;
40 }
41 
42 inline void mpi_barrier(MPI_Comm comm) {
43  MPI_Barrier(comm);
44 }
45 
46 inline MPI_Request mpi_ibarrier(MPI_Comm comm) {
47  MPI_Request req;
48  MPI_Ibarrier(comm, &req);
49  return req;
50 }
51 
52 template <typename T>
53 inline void mpi_send(const T* buf,
54  std::size_t count,
55  int target_rank,
56  int tag,
57  MPI_Comm comm) {
58  MPI_Send(buf,
59  sizeof(T) * count,
60  MPI_BYTE,
61  target_rank,
62  tag,
63  comm);
64 }
65 
66 template <typename T>
67 inline MPI_Request mpi_isend(const T* buf,
68  std::size_t count,
69  int target_rank,
70  int tag,
71  MPI_Comm comm) {
72  MPI_Request req;
73  MPI_Isend(buf,
74  sizeof(T) * count,
75  MPI_BYTE,
76  target_rank,
77  tag,
78  comm,
79  &req);
80  return req;
81 }
82 
83 template <typename T>
84 inline void mpi_send_value(const T& value,
85  int target_rank,
86  int tag,
87  MPI_Comm comm) {
88  mpi_send(&value, 1, target_rank, tag, comm);
89 }
90 
91 template <typename T>
92 inline void mpi_recv(T* buf,
93  std::size_t count,
94  int target_rank,
95  int tag,
96  MPI_Comm comm) {
97  MPI_Recv(buf,
98  sizeof(T) * count,
99  MPI_BYTE,
100  target_rank,
101  tag,
102  comm,
103  MPI_STATUS_IGNORE);
104 }
105 
106 template <typename T>
107 inline MPI_Request mpi_irecv(T* buf,
108  std::size_t count,
109  int target_rank,
110  int tag,
111  MPI_Comm comm) {
112  MPI_Request req;
113  MPI_Irecv(buf,
114  sizeof(T) * count,
115  MPI_BYTE,
116  target_rank,
117  tag,
118  comm,
119  &req);
120  return req;
121 }
122 
123 template <typename T>
124 inline T mpi_recv_value(int target_rank,
125  int tag,
126  MPI_Comm comm) {
127  T result {};
128  mpi_recv(&result, 1, target_rank, tag, comm);
129  return result;
130 }
131 
132 template <typename T>
133 inline void mpi_bcast(T* buf,
134  std::size_t count,
135  int root_rank,
136  MPI_Comm comm) {
137  MPI_Bcast(buf,
138  sizeof(T) * count,
139  MPI_BYTE,
140  root_rank,
141  comm);
142 }
143 
144 template <typename T>
145 inline T mpi_bcast_value(const T& value,
146  int root_rank,
147  MPI_Comm comm) {
148  T result = value;
149  mpi_bcast(&result, 1, root_rank, comm);
150  return result;
151 }
152 
153 template <typename T>
154 inline void mpi_reduce(const T* sendbuf,
155  T* recvbuf,
156  std::size_t count,
157  int root_rank,
158  MPI_Comm comm,
159  MPI_Op op = MPI_SUM) {
160  MPI_Reduce(sendbuf,
161  recvbuf,
162  count,
163  mpi_type<T>(),
164  op,
165  root_rank,
166  comm);
167 }
168 
169 template <typename T>
170 inline T mpi_reduce_value(const T& value,
171  int root_rank,
172  MPI_Comm comm,
173  MPI_Op op = MPI_SUM) {
174  T result;
175  mpi_reduce(&value, &result, 1, root_rank, comm, op);
176  return result;
177 }
178 
179 template <typename T>
180 inline void mpi_allreduce(const T* sendbuf,
181  T* recvbuf,
182  std::size_t count,
183  MPI_Comm comm,
184  MPI_Op op = MPI_SUM) {
185  MPI_Allreduce(sendbuf,
186  recvbuf,
187  count,
188  mpi_type<T>(),
189  op,
190  comm);
191 }
192 
193 template <typename T>
194 inline T mpi_allreduce_value(const T& value,
195  MPI_Comm comm,
196  MPI_Op op = MPI_SUM) {
197  T result;
198  mpi_allreduce(&value, &result, 1, comm, op);
199  return result;
200 }
201 
202 template <typename T>
203 inline void mpi_allgather(const T* sendbuf,
204  std::size_t sendcount,
205  T* recvbuf,
206  std::size_t recvcount,
207  MPI_Comm comm) {
208  MPI_Allgather(sendbuf,
209  sendcount,
210  mpi_type<T>(),
211  recvbuf,
212  recvcount,
213  mpi_type<T>(),
214  comm);
215 }
216 
217 template <typename T>
218 inline std::vector<T> mpi_allgather_value(const T& value,
219  MPI_Comm comm) {
220  std::vector<T> result(mpi_comm_size(comm));
221  mpi_allgather(&value, 1, result.data(), 1, comm);
222  return result;
223 }
224 
225 template <typename T>
226 inline void mpi_scatter(const T* sendbuf,
227  T* recvbuf,
228  std::size_t count,
229  int root_rank,
230  MPI_Comm comm) {
231  MPI_Scatter(sendbuf,
232  count,
233  mpi_type<T>(),
234  recvbuf,
235  count,
236  mpi_type<T>(),
237  root_rank,
238  comm);
239 }
240 
241 template <typename T>
242 inline T mpi_scatter_value(const T* sendbuf,
243  int root_rank,
244  MPI_Comm comm) {
245  T result {};
246  mpi_scatter(sendbuf, &result, 1, root_rank, comm);
247  return result;
248 }
249 
250 inline void mpi_wait(MPI_Request& req) {
251  MPI_Wait(&req, MPI_STATUS_IGNORE);
252 }
253 
254 inline bool mpi_test(MPI_Request& req) {
255  int flag;
256  MPI_Test(&req, &flag, MPI_STATUS_IGNORE);
257  return flag;
258 }
259 
260 inline void mpi_make_progress() {
261  int flag;
262  MPI_Iprobe(MPI_ANY_SOURCE, MPI_ANY_TAG, MPI_COMM_WORLD, &flag, MPI_STATUS_IGNORE);
263 }
264 
265 inline MPI_Comm& mpi_comm_root() {
266  static MPI_Comm comm = MPI_COMM_WORLD;
267  return comm;
268 }
269 
270 #if ITYR_DEBUG_UCX
271 
272 #define UCS_LOG_TIME_FMT "[%lu.%06lu]"
273 #define UCS_LOG_METADATA_FMT "%17s:%-4u %-4s %-5s %*s"
274 #define UCS_LOG_PROC_DATA_FMT "[%s:%-5d:%s]"
275 
276 #define UCS_LOG_FMT UCS_LOG_TIME_FMT " " UCS_LOG_PROC_DATA_FMT " " \
277  UCS_LOG_METADATA_FMT "%s\n"
278 
279 #define UCS_LOG_TIME_ARG(_tv) (_tv).tv_sec, (_tv).tv_usec
280 
281 #define UCS_LOG_METADATA_ARG(_short_file, _line, _level, _comp_conf) \
282  (_short_file), (_line), (_comp_conf)->name, \
283  ucs_log_level_names[_level], 0, ""
284 
285 #define UCS_LOG_PROC_DATA_ARG() \
286  ucs_get_host_name(), ucs_log_get_pid(), ucs_log_get_thread_name()
287 
288 #define UCS_LOG_COMPACT_ARG(_tv)\
289  UCS_LOG_TIME_ARG(_tv), UCS_LOG_PROC_DATA_ARG()
290 
291 #define UCS_LOG_ARG(_short_file, _line, _level, _comp_conf, _tv, _message) \
292  UCS_LOG_TIME_ARG(_tv), UCS_LOG_PROC_DATA_ARG(), \
293  UCS_LOG_METADATA_ARG(_short_file, _line, _level, _comp_conf), (_message)
294 
295 inline const char *ucs_log_level_names[] = {
296  [UCS_LOG_LEVEL_FATAL] = "FATAL",
297  [UCS_LOG_LEVEL_ERROR] = "ERROR",
298  [UCS_LOG_LEVEL_WARN] = "WARN",
299  [UCS_LOG_LEVEL_DIAG] = "DIAG",
300  [UCS_LOG_LEVEL_INFO] = "INFO",
301  [UCS_LOG_LEVEL_DEBUG] = "DEBUG",
302  [UCS_LOG_LEVEL_TRACE] = "TRACE",
303  [UCS_LOG_LEVEL_TRACE_REQ] = "REQ",
304  [UCS_LOG_LEVEL_TRACE_DATA] = "DATA",
305  [UCS_LOG_LEVEL_TRACE_ASYNC] = "ASYNC",
306  [UCS_LOG_LEVEL_TRACE_FUNC] = "FUNC",
307  [UCS_LOG_LEVEL_TRACE_POLL] = "POLL",
308  [UCS_LOG_LEVEL_LAST] = NULL,
309  [UCS_LOG_LEVEL_PRINT] = "PRINT"
310 };
311 
312 inline const char* ucs_get_host_name() {
313  static char hostname[256] = {0};
314  if (*hostname == 0) {
315  gethostname(hostname, sizeof(hostname));
316  strtok(hostname, ".");
317  }
318  return hostname;
319 }
320 
321 inline int ucs_log_get_pid() {
322  static int ucs_log_pid = 0;
323  if (ucs_log_pid == 0) {
324  return getpid();
325  }
326  return ucs_log_pid;
327 }
328 
329 inline const char* ucs_log_get_thread_name() {
330  static thread_local char ucs_log_thread_name[32] = {0};
331  static std::atomic<int> ucs_log_thread_count = 0;
332  char *name = ucs_log_thread_name;
333  uint32_t thread_num;
334 
335  if (name[0] == '\0') {
336  int thread_num = std::atomic_fetch_add(&ucs_log_thread_count, 1);
337  snprintf(ucs_log_thread_name, sizeof(ucs_log_thread_name), "%d", thread_num);
338  }
339 
340  return name;
341 }
342 
343 inline const char* ucs_basename(const char *path) {
344  const char *name = strrchr(path, '/');
345  return (name == NULL) ? path : name + 1;
346 }
347 
348 inline FILE* ityr_ucx_log_fileptr() {
349  static std::unique_ptr<char[]> outbuf;
350  static std::unique_ptr<FILE, void(*)(FILE*)> outfile(NULL, [](FILE*){});
351  std::size_t outbufsize = 1L * 1024 * 1024 * 1024;
352 
353  if (outfile == nullptr) {
354  outbuf = std::make_unique<char[]>(outbufsize);
355 
356  char buf[256];
357  snprintf(buf, sizeof(buf), "ityr_ucx.log.%d", mpi_comm_rank(MPI_COMM_WORLD));
358  outfile = std::unique_ptr<FILE, void(*)(FILE*)>(fopen(buf, "w"),
359  [](FILE* fp) { if (fp) ::fclose(fp); });
360  if (outfile == nullptr) {
361  perror("fopen");
362  die("could not open file %s", buf);
363  }
364 
365  int ret = setvbuf(outfile.get(), outbuf.get(), _IOFBF, outbufsize);
366  if (ret != 0) {
367  perror("setvbuf");
368  die("setvbuf failed");
369  }
370  }
371 
372  return outfile.get();
373 }
374 
375 inline bool ityr_ucx_log_enable(int mode = -1) {
376  static bool enabled = false;
377  if (mode == 0) {
378  enabled = false;
379  } else if (mode == 1) {
380  enabled = true;
381  }
382  return enabled;
383 }
384 
385 inline void ityr_ucx_log_flush() {
386  fflush(ityr_ucx_log_fileptr());
387 }
388 
389 inline ucs_log_func_rc_t
390 ityr_ucx_log_handler(const char *file, unsigned line, const char *function,
391  ucs_log_level_t level,
392  const ucs_log_component_config_t *comp_conf,
393  const char *format, va_list ap) {
394  if (!ityr_ucx_log_enable()) {
395  return UCS_LOG_FUNC_RC_CONTINUE;
396  }
397 
398  if (!ucs_log_component_is_enabled(level, comp_conf) &&
399  (level != UCS_LOG_LEVEL_PRINT)) {
400  return UCS_LOG_FUNC_RC_CONTINUE;
401  }
402 
403  size_t buffer_size = ucs_log_get_buffer_size();
404  char* buf = reinterpret_cast<char*>(alloca(buffer_size + 1));
405  buf[buffer_size] = 0;
406  vsnprintf(buf, buffer_size, format, ap);
407 
408  const char* short_file = ucs_basename(file);
409  struct timeval tv;
410  gettimeofday(&tv, NULL);
411 
412  char* saveptr = "";
413  char* log_line = strtok_r(buf, "\n", &saveptr);
414  while (log_line != NULL) {
415  fprintf(ityr_ucx_log_fileptr(), UCS_LOG_FMT,
416  UCS_LOG_ARG(short_file, line, level,
417  comp_conf, tv, log_line));
418  log_line = strtok_r(NULL, "\n", &saveptr);
419  }
420 
421  /* flush the log file if the log_level of this message is fatal or error */
422  if (level <= UCS_LOG_LEVEL_ERROR) {
423  ityr_ucx_log_flush();
424  }
425 
426  return UCS_LOG_FUNC_RC_CONTINUE;
427 }
428 #endif
429 
431 public:
432  mpi_initializer(MPI_Comm comm) {
433  mpi_comm_root() = comm;
434  MPI_Initialized(&initialized_outside_);
435  if (!initialized_outside_) {
436  MPI_Init(nullptr, nullptr);
437  }
438 #if ITYR_DEBUG_UCX
439  while (ucs_log_num_handlers() > 0) {
440  ucs_log_pop_handler();
441  }
442  ityr_ucx_log_fileptr();
443  ucs_log_push_handler(ityr_ucx_log_handler);
444 #endif
445  }
446 
448 #if ITYR_DEBUG_UCX
449  ucs_log_pop_handler();
450 #endif
451  if (!initialized_outside_) {
452  MPI_Finalize();
453  }
454  }
455 
458 
461 
462 private:
463  int initialized_outside_ = 1;
464 };
465 
466 template <typename T>
467 inline T getenv_coll(const std::string& env_var, T default_val) {
468  MPI_Comm comm = mpi_comm_root();
469 
470  int rank = mpi_comm_rank(comm);
471  T val = default_val;
472 
473  if (rank == 0) {
474  val = getenv_with_default(env_var.c_str(), default_val);
475  }
476 
477  return mpi_bcast_value(val, 0, comm);
478 }
479 
480 }
Definition: mpi_util.hpp:430
~mpi_initializer()
Definition: mpi_util.hpp:447
mpi_initializer(mpi_initializer &&)=delete
mpi_initializer & operator=(mpi_initializer &&)=delete
mpi_initializer & operator=(const mpi_initializer &)=delete
mpi_initializer(const mpi_initializer &)=delete
mpi_initializer(MPI_Comm comm)
Definition: mpi_util.hpp:432
#define ITYR_CHECK(cond)
Definition: util.hpp:48
bool enabled()
Definition: numa.hpp:86
ITYR_CONCAT(mode_, ITYR_PROFILER_MODE) mode
Definition: profiler.hpp:257
Definition: allocator.hpp:16
void mpi_recv(T *buf, std::size_t count, int target_rank, int tag, MPI_Comm comm)
Definition: mpi_util.hpp:92
MPI_Datatype mpi_type< bool >()
Definition: mpi_util.hpp:25
std::vector< T > mpi_allgather_value(const T &value, MPI_Comm comm)
Definition: mpi_util.hpp:218
void mpi_allgather(const T *sendbuf, std::size_t sendcount, T *recvbuf, std::size_t recvcount, MPI_Comm comm)
Definition: mpi_util.hpp:203
void mpi_reduce(const T *sendbuf, T *recvbuf, std::size_t count, int root_rank, MPI_Comm comm, MPI_Op op=MPI_SUM)
Definition: mpi_util.hpp:154
MPI_Datatype mpi_type< unsigned long >()
Definition: mpi_util.hpp:24
fprintf(stderr, "\x1b[31m%s\x1b[39m\n", msg)
void mpi_make_progress()
Definition: mpi_util.hpp:260
T mpi_bcast_value(const T &value, int root_rank, MPI_Comm comm)
Definition: mpi_util.hpp:145
vsnprintf(msg, slen, fmt, args)
MPI_Datatype mpi_type< void * >()
Definition: mpi_util.hpp:26
fflush(stderr)
T mpi_allreduce_value(const T &value, MPI_Comm comm, MPI_Op op=MPI_SUM)
Definition: mpi_util.hpp:194
MPI_Datatype mpi_type< long >()
Definition: mpi_util.hpp:23
T mpi_scatter_value(const T *sendbuf, int root_rank, MPI_Comm comm)
Definition: mpi_util.hpp:242
void mpi_wait(MPI_Request &req)
Definition: mpi_util.hpp:250
MPI_Datatype mpi_type()
T mpi_recv_value(int target_rank, int tag, MPI_Comm comm)
Definition: mpi_util.hpp:124
MPI_Comm & mpi_comm_root()
Definition: mpi_util.hpp:265
MPI_Request mpi_isend(const T *buf, std::size_t count, int target_rank, int tag, MPI_Comm comm)
Definition: mpi_util.hpp:67
T getenv_with_default(const char *env_var, T default_val)
Definition: util.hpp:88
int mpi_comm_rank(MPI_Comm comm)
Definition: mpi_util.hpp:28
MPI_Request mpi_ibarrier(MPI_Comm comm)
Definition: mpi_util.hpp:46
constexpr auto size(const span< T > &s) noexcept
Definition: span.hpp:61
void mpi_bcast(T *buf, std::size_t count, int root_rank, MPI_Comm comm)
Definition: mpi_util.hpp:133
int mpi_comm_size(MPI_Comm comm)
Definition: mpi_util.hpp:35
T getenv_coll(const std::string &env_var, T default_val)
Definition: mpi_util.hpp:467
T mpi_reduce_value(const T &value, int root_rank, MPI_Comm comm, MPI_Op op=MPI_SUM)
Definition: mpi_util.hpp:170
void mpi_barrier(MPI_Comm comm)
Definition: mpi_util.hpp:42
void mpi_send_value(const T &value, int target_rank, int tag, MPI_Comm comm)
Definition: mpi_util.hpp:84
void mpi_allreduce(const T *sendbuf, T *recvbuf, std::size_t count, MPI_Comm comm, MPI_Op op=MPI_SUM)
Definition: mpi_util.hpp:180
MPI_Datatype mpi_type< unsigned int >()
Definition: mpi_util.hpp:22
MPI_Datatype mpi_type< int >()
Definition: mpi_util.hpp:21
void mpi_send(const T *buf, std::size_t count, int target_rank, int tag, MPI_Comm comm)
Definition: mpi_util.hpp:53
MPI_Request mpi_irecv(T *buf, std::size_t count, int target_rank, int tag, MPI_Comm comm)
Definition: mpi_util.hpp:107
void mpi_scatter(const T *sendbuf, T *recvbuf, std::size_t count, int root_rank, MPI_Comm comm)
Definition: mpi_util.hpp:226
bool mpi_test(MPI_Request &req)
Definition: mpi_util.hpp:254