9template <
typename T,
typename S,
typename Descriptor,
typename V>
11apply_update_kernel(V **vertices,
const T *delta_x,
const T *jacobian_scales,
12 const size_t *hessian_ids,
const uint8_t *active_state,
13 const size_t num_threads) {
14 const size_t vertex_id = get_thread_id();
16 if (vertex_id >= num_threads || !is_vertex_active(active_state, vertex_id)) {
20 const T *delta = delta_x + hessian_ids[vertex_id];
21 const T *scales = jacobian_scales + hessian_ids[vertex_id];
23 std::array<T, Descriptor::dim> scaled_delta;
25 for (
size_t i = 0; i < Descriptor::dim; i++) {
26 scaled_delta[i] = delta[i] * scales[i];
30 Descriptor::Traits::update(*vertices[vertex_id], scaled_delta.data());
33template <
typename T,
typename S,
typename V>
34void apply_update(V *v,
const T *delta_x, T *jacobian_scales,
35 cudaStream_t stream) {
36 const size_t num_threads = v->count();
37 const auto threads_per_block = 256;
38 const auto num_blocks =
39 (num_threads + threads_per_block - 1) / threads_per_block;
41 apply_update_kernel<T, S, V, typename V::VertexType>
42 <<<num_blocks, threads_per_block, 0, stream>>>(
43 v->vertices(), delta_x, jacobian_scales, v->get_hessian_ids(),
44 v->get_active_state(), num_threads);