Graphite
Loading...
Searching...
No Matches
chi2.hpp
Go to the documentation of this file.
1
2#pragma once
4
5namespace graphite {
6
7namespace ops {
8
9template <typename T, typename P, size_t E>
10__device__ T compute_chi2(const T *residuals, const P *pmat,
11 const size_t factor_id) {
12 T r2[E] = {0};
13
14#pragma unroll
15 for (int i = 0; i < E; i++) {
16#pragma unroll
17 for (int j = 0; j < E; j++) {
18 r2[i] += static_cast<T>(pmat[factor_id * E * E + i * E + j]) *
19 residuals[factor_id * E + j];
20 }
21 }
22
23 T value = 0;
24#pragma unroll
25 for (int i = 0; i < E; i++) {
26 value += r2[i] * residuals[factor_id * E + i];
27 }
28
29 return value;
30}
31
32template <typename T, typename S, size_t E, typename L>
33__global__ void
34compute_chi2_kernel(T *chi2, S *chi2_derivative, const T *residuals,
35 const size_t num_threads, const S *pmat, const L *loss) {
36 const size_t idx = get_thread_id();
37
38 if (idx >= num_threads) {
39 return;
40 }
41 T raw_chi2 = compute_chi2<T, S, E>(residuals, pmat, idx);
42 chi2[idx] = loss[idx].loss(raw_chi2);
43 chi2_derivative[idx] = loss[idx].loss_derivative(raw_chi2);
44}
45
46template <typename T, typename S, typename F> void compute_chi2_async(F *f) {
47 // Then for each vertex, we need to compute the error
48 constexpr auto num_vertices = F::get_num_vertices();
49 constexpr auto vertex_sizes = F::get_vertex_sizes();
50
51 // At this point all necessary data should be on the GPU
52 auto verts = f->get_vertices();
53
54 constexpr auto error_dim = F::error_dim;
55 const auto num_factors = f->active_count();
56
57 const auto num_threads = num_factors;
58 size_t threads_per_block = 256;
59 size_t num_blocks = (num_threads + threads_per_block - 1) / threads_per_block;
60
61 thrust::fill(thrust::cuda::par_nosync.on(0), f->chi2_vec.begin(),
62 f->chi2_vec.end(), static_cast<T>(0));
63 compute_chi2_kernel<T, S, F::error_dim>
64 <<<num_blocks, threads_per_block, 0, 0>>>(
65 f->chi2_vec.data().get(), f->chi2_derivative.data().get(),
66 f->residuals.data().get(), num_threads,
67 f->precision_matrices.data().get(), f->loss.data().get());
68
69 // cudaStreamSynchronize(0);
70}
71
72} // namespace ops
73
74} // namespace graphite