Graphite  0.5.0
GPU-accelerated graph optimization framework
Loading...
Searching...
No Matches
vector.hpp
Go to the documentation of this file.
1
2#pragma once
3
4namespace graphite {
5namespace ops {
6template <typename T>
7__global__ void axpy_kernel(size_t n, T *z, const T a, const T *x, T *y) {
8 const size_t idx =
9 static_cast<size_t>(blockIdx.x) * static_cast<size_t>(blockDim.x) +
10 static_cast<size_t>(threadIdx.x);
11 if (idx < n) {
12 z[idx] = a * x[idx] + y[idx];
13 }
14}
15
16template <typename T>
17void axpy_async(cudaStream_t stream, size_t n, T *z, const T a, const T *x,
18 T *y) {
19 size_t threads_per_block = 256;
20 size_t num_blocks = (n + threads_per_block - 1) / threads_per_block;
21 axpy_kernel<T><<<num_blocks, threads_per_block, 0, stream>>>(n, z, a, x, y);
22}
23
24template <typename T>
25__global__ void damping_kernel(size_t n, T *z, const T damping_factor,
26 const bool use_identity, const T *diag,
27 const T *x) {
28 const size_t idx =
29 static_cast<size_t>(blockIdx.x) * static_cast<size_t>(blockDim.x) +
30 static_cast<size_t>(threadIdx.x);
31 if (idx < n) {
32 if (use_identity) {
33 z[idx] += damping_factor * x[idx];
34 } else {
35 // diag should be already clamped
36 z[idx] += damping_factor * diag[idx] * x[idx];
37 }
38 }
39}
40
41template <typename T>
42void damp_by_factor_async(cudaStream_t stream, size_t n, T *z,
43 const T damping_factor, const bool use_identity,
44 const T *diag, const T *x) {
45 size_t threads_per_block = 256;
46 size_t num_blocks = (n + threads_per_block - 1) / threads_per_block;
47 damping_kernel<T><<<num_blocks, threads_per_block, 0, stream>>>(
48 n, z, damping_factor, use_identity, diag, x);
49}
50
51template <typename T>
52__global__ void clamp_kernel(size_t n, T min_val, T max_val, T *x) {
53 const size_t idx =
54 static_cast<size_t>(blockIdx.x) * static_cast<size_t>(blockDim.x) +
55 static_cast<size_t>(threadIdx.x);
56 if (idx < n) {
57 x[idx] = std::clamp(x[idx], min_val, max_val);
58 }
59}
60
61template <typename T>
62void clamp_async(cudaStream_t stream, size_t n, T min_val, T max_val, T *x) {
63 size_t threads_per_block = 256;
64 size_t num_blocks = (n + threads_per_block - 1) / threads_per_block;
65 clamp_kernel<T>
66 <<<num_blocks, threads_per_block, 0, stream>>>(n, min_val, max_val, x);
67}
68
69template <typename T>
70__global__ void rescale_vec_kernel(size_t n, T *out, const T scale,
71 const T *x) {
72 const size_t idx =
73 static_cast<size_t>(blockIdx.x) * static_cast<size_t>(blockDim.x) +
74 static_cast<size_t>(threadIdx.x);
75 if (idx < n) {
76 out[idx] = scale * x[idx];
77 }
78}
79
80template <typename T>
81void rescale_vec_async(cudaStream_t stream, size_t n, T *out, const T scale,
82 const T *x) {
83 size_t threads_per_block = 256;
84 size_t num_blocks = (n + threads_per_block - 1) / threads_per_block;
85 rescale_vec_kernel<T>
86 <<<num_blocks, threads_per_block, 0, stream>>>(n, out, scale, x);
87}
88
89template <typename T>
90__global__ void compute_adam_step(const size_t n, T *gradient, T *step, T *m,
91 T *v, const T lr, const T beta1,
92 const T beta2, const T epsilon,
93 const size_t t) {
94 const size_t i =
95 static_cast<size_t>(blockIdx.x) * static_cast<size_t>(blockDim.x) +
96 static_cast<size_t>(threadIdx.x);
97
98 if (i < n) {
99 const auto g = -gradient[i];
100 m[i] = beta1 * m[i] + (1 - beta1) * g;
101 v[i] = beta2 * v[i] + (1 - beta2) * g * g;
102
103 const auto b1t = cuda::std::pow(beta1, static_cast<T>(t));
104 const auto m_hat = m[i] / (1 - b1t);
105
106 const auto b2t = cuda::std::pow(beta2, static_cast<T>(t));
107 const auto v_hat = v[i] / (1 - b2t);
108
109 step[i] = -lr * m_hat / (cuda::std::sqrt(v_hat) + epsilon);
110 }
111}
112
113template <typename T>
114void compute_adam_step_async(cudaStream_t stream, const size_t n, T *gradient,
115 T *step, T *m, T *v, const T lr, const T beta1,
116 const T beta2, const T epsilon, const size_t t) {
117 size_t threads_per_block = 256;
118 size_t num_blocks = (n + threads_per_block - 1) / threads_per_block;
119 compute_adam_step<T><<<num_blocks, threads_per_block, 0, stream>>>(
120 n, gradient, step, m, v, lr, beta1, beta2, epsilon, t + 1);
121}
122
123} // namespace ops
124
125} // namespace graphite
The top-level namespace for Graphite.
Definition eigen_solver.cpp:4