7__global__
void axpy_kernel(
size_t n, T *z,
const T a,
const T *x, T *y) {
9 static_cast<size_t>(blockIdx.x) *
static_cast<size_t>(blockDim.x) +
10 static_cast<size_t>(threadIdx.x);
12 z[idx] = a * x[idx] + y[idx];
17void axpy_async(cudaStream_t stream,
size_t n, T *z,
const T a,
const T *x,
19 size_t threads_per_block = 256;
20 size_t num_blocks = (n + threads_per_block - 1) / threads_per_block;
21 axpy_kernel<T><<<num_blocks, threads_per_block, 0, stream>>>(n, z, a, x, y);
25__global__
void damping_kernel(
size_t n, T *z,
const T damping_factor,
26 const bool use_identity,
const T *diag,
29 static_cast<size_t>(blockIdx.x) *
static_cast<size_t>(blockDim.x) +
30 static_cast<size_t>(threadIdx.x);
33 z[idx] += damping_factor * x[idx];
36 z[idx] += damping_factor * diag[idx] * x[idx];
42void damp_by_factor_async(cudaStream_t stream,
size_t n, T *z,
43 const T damping_factor,
const bool use_identity,
44 const T *diag,
const T *x) {
45 size_t threads_per_block = 256;
46 size_t num_blocks = (n + threads_per_block - 1) / threads_per_block;
47 damping_kernel<T><<<num_blocks, threads_per_block, 0, stream>>>(
48 n, z, damping_factor, use_identity, diag, x);
52__global__
void clamp_kernel(
size_t n, T min_val, T max_val, T *x) {
54 static_cast<size_t>(blockIdx.x) *
static_cast<size_t>(blockDim.x) +
55 static_cast<size_t>(threadIdx.x);
57 x[idx] = std::clamp(x[idx], min_val, max_val);
62void clamp_async(cudaStream_t stream,
size_t n, T min_val, T max_val, T *x) {
63 size_t threads_per_block = 256;
64 size_t num_blocks = (n + threads_per_block - 1) / threads_per_block;
66 <<<num_blocks, threads_per_block, 0, stream>>>(n, min_val, max_val, x);
70__global__
void rescale_vec_kernel(
size_t n, T *out,
const T scale,
73 static_cast<size_t>(blockIdx.x) *
static_cast<size_t>(blockDim.x) +
74 static_cast<size_t>(threadIdx.x);
76 out[idx] = scale * x[idx];
81void rescale_vec_async(cudaStream_t stream,
size_t n, T *out,
const T scale,
83 size_t threads_per_block = 256;
84 size_t num_blocks = (n + threads_per_block - 1) / threads_per_block;
86 <<<num_blocks, threads_per_block, 0, stream>>>(n, out, scale, x);
90__global__
void compute_adam_step(
const size_t n, T *gradient, T *step, T *m,
91 T *v,
const T lr,
const T beta1,
92 const T beta2,
const T epsilon,
95 static_cast<size_t>(blockIdx.x) *
static_cast<size_t>(blockDim.x) +
96 static_cast<size_t>(threadIdx.x);
99 const auto g = -gradient[i];
100 m[i] = beta1 * m[i] + (1 - beta1) * g;
101 v[i] = beta2 * v[i] + (1 - beta2) * g * g;
103 const auto b1t = cuda::std::pow(beta1,
static_cast<T
>(t));
104 const auto m_hat = m[i] / (1 - b1t);
106 const auto b2t = cuda::std::pow(beta2,
static_cast<T
>(t));
107 const auto v_hat = v[i] / (1 - b2t);
109 step[i] = -lr * m_hat / (cuda::std::sqrt(v_hat) + epsilon);
114void compute_adam_step_async(cudaStream_t stream,
const size_t n, T *gradient,
115 T *step, T *m, T *v,
const T lr,
const T beta1,
116 const T beta2,
const T epsilon,
const size_t t) {
117 size_t threads_per_block = 256;
118 size_t num_blocks = (n + threads_per_block - 1) / threads_per_block;
119 compute_adam_step<T><<<num_blocks, threads_per_block, 0, stream>>>(
120 n, gradient, step, m, v, lr, beta1, beta2, epsilon, t + 1);
The top-level namespace for Graphite.
Definition eigen_solver.cpp:4