Graphite
Loading...
Searching...
No Matches
stream.hpp
Go to the documentation of this file.
1
2#pragma once
3#include <cuda_runtime.h>
4
5namespace graphite {
6
7class StreamPool {
8public:
9 StreamPool(size_t num_streams)
10 : num_streams(num_streams), cleanup_streams(true) {
11 streams = new cudaStream_t[num_streams];
12 for (size_t i = 0; i < num_streams; ++i) {
13 cudaStreamCreateWithFlags(&streams[i], cudaStreamNonBlocking);
14 }
15 }
16
17 StreamPool(cudaStream_t *p_streams, size_t n)
18 : streams(p_streams), num_streams(n), cleanup_streams(false) {}
19
20 ~StreamPool() {
21 if (!cleanup_streams) {
22 return;
23 }
24 for (size_t i = 0; i < num_streams; ++i) {
25 cudaStreamDestroy(streams[i]);
26 }
27 delete[] streams;
28 }
29
30 cudaStream_t &select(size_t index) { return streams[index % num_streams]; }
31
32 void sync_all() {
33 for (size_t i = 0; i < num_streams; ++i) {
34 cudaStreamSynchronize(streams[i]);
35 }
36 }
37
38 void sync_n(size_t n) {
39 n = std::min(n, num_streams);
40 for (size_t i = 0; i < n; ++i) {
41 cudaStreamSynchronize(streams[i]);
42 }
43 }
44
45 cudaStream_t *streams;
46 size_t num_streams;
47 bool cleanup_streams;
48};
49
50StreamPool create_default_stream_pool() {
51 static cudaStream_t default_stream = cudaStreamPerThread;
52 static StreamPool default_pool(&default_stream, 1);
53 return default_pool;
54}
55
56} // namespace graphite
Definition stream.hpp:7