9template <
typename T,
typename P,
size_t E>
10__device__ T compute_chi2(
const T *residuals,
const P *pmat,
11 const size_t factor_id) {
15 for (
int i = 0; i < E; i++) {
17 for (
int j = 0; j < E; j++) {
18 r2[i] +=
static_cast<T
>(pmat[factor_id * E * E + i * E + j]) *
19 residuals[factor_id * E + j];
25 for (
int i = 0; i < E; i++) {
26 value += r2[i] * residuals[factor_id * E + i];
32template <
typename T,
typename S,
size_t E,
typename L>
34compute_chi2_kernel(T *chi2, S *chi2_derivative,
const T *residuals,
35 const size_t num_threads,
const S *pmat,
const L *loss) {
36 const size_t idx = get_thread_id();
38 if (idx >= num_threads) {
41 T raw_chi2 = compute_chi2<T, S, E>(residuals, pmat, idx);
42 chi2[idx] = loss[idx].loss(raw_chi2);
43 chi2_derivative[idx] = loss[idx].loss_derivative(raw_chi2);
46template <
typename T,
typename S,
typename F>
void compute_chi2_async(F *f) {
48 constexpr auto num_vertices = F::get_num_vertices();
49 constexpr auto vertex_sizes = F::get_vertex_sizes();
52 auto verts = f->get_vertices();
54 constexpr auto error_dim = F::error_dim;
55 const auto num_factors = f->active_count();
57 const auto num_threads = num_factors;
58 size_t threads_per_block = 256;
59 size_t num_blocks = (num_threads + threads_per_block - 1) / threads_per_block;
61 thrust::fill(thrust::cuda::par_nosync.on(0), f->chi2_vec.begin(),
62 f->chi2_vec.end(),
static_cast<T
>(0));
63 compute_chi2_kernel<T, S, F::error_dim>
64 <<<num_blocks, threads_per_block, 0, 0>>>(
65 f->chi2_vec.data().get(), f->chi2_derivative.data().get(),
66 f->residuals.data().get(), num_threads,
67 f->precision_matrices.data().get(), f->loss.data().get());