Graphite
Loading...
Searching...
No Matches
vector.hpp
Go to the documentation of this file.
1
2#pragma once
3
4namespace graphite {
5namespace ops {
6template <typename T>
7__global__ void axpy_kernel(size_t n, T *z, const T a, const T *x, T *y) {
8 const size_t idx =
9 static_cast<size_t>(blockIdx.x) * static_cast<size_t>(blockDim.x) +
10 static_cast<size_t>(threadIdx.x);
11 if (idx < n) {
12 z[idx] = a * x[idx] + y[idx];
13 }
14}
15
16template <typename T>
17void axpy_async(cudaStream_t stream, size_t n, T *z, const T a, const T *x,
18 T *y) {
19 size_t threads_per_block = 256;
20 size_t num_blocks = (n + threads_per_block - 1) / threads_per_block;
21 axpy_kernel<T><<<num_blocks, threads_per_block, 0, stream>>>(n, z, a, x, y);
22}
23
24template <typename T>
25__global__ void damping_kernel(size_t n, T *z, const T a, const T *diag,
26 const T *x) {
27 const size_t idx =
28 static_cast<size_t>(blockIdx.x) * static_cast<size_t>(blockDim.x) +
29 static_cast<size_t>(threadIdx.x);
30 if (idx < n) {
31 z[idx] += a * diag[idx] * x[idx];
32 }
33}
34
35template <typename T>
36void damp_by_factor_async(cudaStream_t stream, size_t n, T *z, const T a,
37 const T *diag, const T *x) {
38 size_t threads_per_block = 256;
39 size_t num_blocks = (n + threads_per_block - 1) / threads_per_block;
40 damping_kernel<T>
41 <<<num_blocks, threads_per_block, 0, stream>>>(n, z, a, diag, x);
42}
43
44template <typename T>
45__global__ void clamp_kernel(size_t n, T min_val, T max_val, T *x) {
46 const size_t idx =
47 static_cast<size_t>(blockIdx.x) * static_cast<size_t>(blockDim.x) +
48 static_cast<size_t>(threadIdx.x);
49 if (idx < n) {
50 x[idx] = std::clamp(x[idx], min_val, max_val);
51 }
52}
53
54template <typename T>
55void clamp_async(cudaStream_t stream, size_t n, T min_val, T max_val, T *x) {
56 size_t threads_per_block = 256;
57 size_t num_blocks = (n + threads_per_block - 1) / threads_per_block;
58 clamp_kernel<T>
59 <<<num_blocks, threads_per_block, 0, stream>>>(n, min_val, max_val, x);
60}
61
62template <typename T>
63__global__ void rescale_vec_kernel(size_t n, T *out, const T scale,
64 const T *x) {
65 const size_t idx =
66 static_cast<size_t>(blockIdx.x) * static_cast<size_t>(blockDim.x) +
67 static_cast<size_t>(threadIdx.x);
68 if (idx < n) {
69 out[idx] = scale * x[idx];
70 }
71}
72
73template <typename T>
74void rescale_vec_async(cudaStream_t stream, size_t n, T *out, const T scale,
75 const T *x) {
76 size_t threads_per_block = 256;
77 size_t num_blocks = (n + threads_per_block - 1) / threads_per_block;
78 rescale_vec_kernel<T>
79 <<<num_blocks, threads_per_block, 0, stream>>>(n, out, scale, x);
80}
81
82} // namespace ops
83
84} // namespace graphite