12template <
typename,
typename =
void>
19template <
typename T,
typename Fallback,
typename =
void>
struct get_State_or {
20 using type = Fallback;
24template <
typename T,
typename Fallback>
26 using type =
typename T::State;
30template <
typename T,
typename Fallback>
31using get_State_or_t =
typename get_State_or<T, Fallback>::type;
33template <
typename VertexType,
typename State,
typename Traits,
typename T>
34__global__
void backup_state_kernel(VertexType **vertices, State *dst,
35 const uint8_t *active_state,
36 const size_t num_vertices) {
38 const size_t vertex_id = get_thread_id();
40 if (vertex_id >= num_vertices || !is_vertex_active(active_state, vertex_id))
43 dst[vertex_id] = Traits::get_state(*vertices[vertex_id]);
45 dst[vertex_id] = *vertices[vertex_id];
49template <
typename VertexType,
typename State,
typename Traits,
typename T>
50__global__
void set_state_kernel(VertexType **vertices,
const State *src,
51 const uint8_t *active_state,
52 const size_t num_vertices) {
54 const size_t vertex_id = get_thread_id();
56 if (vertex_id >= num_vertices || !is_vertex_active(active_state, vertex_id))
59 if constexpr (has_type_alias_State<Traits>::value) {
60 Traits::set_state(*vertices[vertex_id], src[vertex_id]);
62 *vertices[vertex_id] = src[vertex_id];
68 using InvP = std::conditional_t<is_low_precision<S>::value, T, S>;
73 virtual void apply_update_async(
const T *delta_x, T *jacobian_scales,
74 cudaStream_t stream) = 0;
75 virtual void augment_block_diagonal_async(InvP *block_diagonal,
76 InvP *scalar_diagonal, T mu,
77 cudaStream_t stream) = 0;
78 virtual void apply_block_jacobi(T *z,
const T *r, InvP *block_diagonal,
79 cudaStream_t stream) = 0;
81 virtual size_t dimension()
const = 0;
82 virtual size_t count()
const = 0;
84 virtual void backup_parameters_async() = 0;
85 virtual void restore_parameters_async() = 0;
86 virtual void to_device() = 0;
88 virtual const std::unordered_map<size_t, size_t> &get_global_map()
const = 0;
89 virtual const size_t *get_hessian_ids()
const = 0;
90 virtual void set_hessian_column(
const size_t global_id,
91 const size_t hessian_column,
92 const size_t block_index) = 0;
93 virtual bool is_fixed(
const size_t id)
const = 0;
94 virtual bool is_active(
const size_t id)
const = 0;
95 virtual uint8_t *get_active_state()
const = 0;
96 virtual const size_t *get_block_ids()
const = 0;
103template <
typename T,
typename S,
typename VTraits>
106 using InvP = std::conditional_t<is_low_precision<S>::value, T, S>;
108 using Traits = VTraits;
110 using VertexType =
typename Traits::Vertex;
111 using State = get_State_or_t<Traits, VertexType>;
114 thrust::device_vector<VertexType *> x_device;
115 thrust::host_vector<VertexType *> x_host;
116 thrust::device_vector<State> backup_state;
120 std::unordered_map<size_t, size_t> global_to_local_map;
121 std::vector<size_t> local_to_global_map;
122 thrust::host_vector<size_t> local_to_hessian_offsets;
123 thrust::device_vector<size_t> hessian_ids;
124 managed_vector<size_t> block_ids;
125 managed_vector<uint8_t> active_state;
127 static constexpr size_t dim = Traits::dimension;
136 cudaStream_t stream)
override {
137 ops::apply_update<T, S>(
this, delta_x, jacobian_scales, stream);
144 T mu, cudaStream_t stream)
override {
145 ops::augment_block_diagonal<T, S>(
this, block_diagonal, scalar_diagonal, mu,
153 cudaStream_t stream)
override {
154 ops::apply_block_jacobi<T, S>(
this, z, r, block_diagonal, stream);
162 hessian_ids = local_to_hessian_offsets;
169 VertexType **
vertices() {
return x_device.data().get(); }
175 VertexType **
vertices = x_device.data().get();
177 const int num_vertices =
static_cast<int>(
count());
178 const int num_threads = num_vertices;
179 const int block_size = 256;
180 const auto num_blocks = (num_threads + block_size - 1) / block_size;
181 backup_state.resize(num_vertices);
183 backup_state_kernel<VertexType, State, Traits, T>
184 <<<num_blocks, block_size>>>(
vertices, backup_state.data().get(),
185 active_state.data().get(), num_vertices);
192 VertexType **
vertices = x_device.data().get();
194 const int num_vertices =
static_cast<int>(
count());
195 const int num_threads = num_vertices;
196 const int block_size = 256;
197 const auto num_blocks = (num_threads + block_size - 1) / block_size;
198 set_state_kernel<VertexType, State, Traits, T>
199 <<<num_blocks, block_size>>>(
vertices, backup_state.data().get(),
200 active_state.data().get(), num_vertices);
207 virtual size_t count()
const override {
return x_host.size(); }
215 x_host.reserve(size);
216 global_to_local_map.reserve(size);
217 local_to_global_map.reserve(size);
218 local_to_hessian_offsets.reserve(size);
219 block_ids.reserve(size);
220 active_state.reserve(size);
232 if (global_to_local_map.find(
id) == global_to_local_map.end()) {
233 std::cerr <<
"Vertex with id " <<
id <<
" not found." << std::endl;
237 const auto local_id = global_to_local_map[id];
238 const auto last_index = x_host.size() - 1;
241 std::swap(x_host[local_id], x_host[last_index]);
243 if (local_to_hessian_offsets.size() > 0) {
244 std::swap(local_to_hessian_offsets[local_id],
245 local_to_hessian_offsets[last_index]);
249 const auto last_global_id = local_to_global_map[last_index];
250 global_to_local_map[last_global_id] = local_id;
251 local_to_global_map[local_id] = last_global_id;
254 active_state[local_id] = active_state[last_index];
257 active_state.pop_back();
261 local_to_hessian_offsets.pop_back();
262 global_to_local_map.erase(
id);
263 local_to_global_map.pop_back();
264 block_ids.pop_back();
273 if (global_to_local_map.find(
id) == global_to_local_map.end()) {
274 std::cerr <<
"Vertex with id " <<
id <<
" not found." << std::endl;
278 const auto local_id = global_to_local_map[id];
279 x_host[local_id] = vertex;
290 const bool fixed =
false) {
291 x_host.push_back(vertex);
292 const auto local_id = x_host.size() - 1;
293 global_to_local_map.insert({id, local_id});
294 local_to_global_map.push_back(
id);
295 local_to_hessian_offsets.push_back(0);
296 block_ids.push_back(0);
299 active_state.push_back(
static_cast<uint8_t
>(fixed));
308 const auto local_id = global_to_local_map.at(
id);
310 active_state[local_id] =
static_cast<uint8_t
>(fixed);
319 const auto local_id = global_to_local_map.at(
id);
320 return (active_state[local_id] & 0x1) > 0;
329 const auto local_id = global_to_local_map.at(
id);
330 return is_vertex_active(active_state.data().get(), local_id);
338 return active_state.data().get();
346 return block_ids.data().get();
356 auto it = global_to_local_map.find(
id);
357 if (it != global_to_local_map.end()) {
358 return x_host[it->second];
360 std::cerr <<
"Vertex with id " <<
id <<
" not found." << std::endl;
371 return global_to_local_map.find(
id) != global_to_local_map.end();
374 const std::unordered_map<size_t, size_t> &get_global_map()
const override {
375 return global_to_local_map;
384 const size_t *get_hessian_ids()
const override {
385 return hessian_ids.data().get();
395 const size_t block_index) {
396 const auto local_id = global_to_local_map.at(global_id);
397 local_to_hessian_offsets[local_id] = hessian_column;
398 block_ids[local_id] = block_index;
407 backup_state.clear();
409 global_to_local_map.clear();
410 local_to_global_map.clear();
411 local_to_hessian_offsets.clear();
414 active_state.clear();
Represents a collection of optimizable variables to be processed together on the GPU.
Definition vertex.hpp:104
void clear()
Clears all vertices and resets the descriptor to an empty state.
Definition vertex.hpp:404
virtual void to_device() override
Prepares descriptor for GPU processing.
Definition vertex.hpp:160
bool exists(const size_t id) const
Checks if a vertex exists in the descriptor.
Definition vertex.hpp:370
virtual size_t count() const override
Returns the number of vertices in the descriptor (including fixed vertices).
Definition vertex.hpp:207
VertexType * get_vertex(const size_t id)
Retrieves a pointer to the vertex associated with the given ID.
Definition vertex.hpp:355
void apply_block_jacobi(T *z, const T *r, InvP *block_diagonal, cudaStream_t stream) override
Applies the block Jacobi preconditioner asynchronously.
Definition vertex.hpp:152
void replace_vertex(const size_t id, VertexType *vertex)
Replaces a vertex pointer in the descriptor for the given ID.
Definition vertex.hpp:272
void remove_vertex(const size_t id)
Removes a vertex from the descriptor by its ID.
Definition vertex.hpp:227
virtual void backup_parameters_async() override
Backs up the state of the vertices asynchronously.
Definition vertex.hpp:174
VertexType ** vertices()
Returns a pointer to the array of vertex pointers for the descriptor.
Definition vertex.hpp:169
virtual void restore_parameters_async() override
Restores the state of the vertices from the backup asynchronously.
Definition vertex.hpp:191
uint8_t * get_active_state() const override
Retrieves a pointer to the active state array for the vertices.
Definition vertex.hpp:337
void augment_block_diagonal_async(InvP *block_diagonal, InvP *scalar_diagonal, T mu, cudaStream_t stream) override
Adds damping to the Hessian block diagonal asynchronously.
Definition vertex.hpp:143
void add_vertex(const size_t id, VertexType *vertex, const bool fixed=false)
Adds a vertex to the descriptor.
Definition vertex.hpp:289
void set_hessian_column(const size_t global_id, const size_t hessian_column, const size_t block_index)
Sets the Hessian column and block index for a vertex.
Definition vertex.hpp:394
size_t dimension() const override
Returns the dimension of the vertex parameterization.
Definition vertex.hpp:382
void reserve(size_t size)
Reserves memory for the specified number of vertices. You should call this before constructing the gr...
Definition vertex.hpp:214
bool is_fixed(const size_t id) const override
Checks if a vertex is fixed.
Definition vertex.hpp:318
void apply_update_async(const T *delta_x, T *jacobian_scales, cudaStream_t stream) override
Applies an update to the vertex parameters asynchronously.
Definition vertex.hpp:135
const size_t * get_block_ids() const override
Retrieves a pointer to the block IDs array for the vertices.
Definition vertex.hpp:345
void set_fixed(const size_t id, const bool fixed)
Sets the fixed state of a vertex.
Definition vertex.hpp:307
bool is_active(const size_t id) const override
Checks if a vertex is active.
Definition vertex.hpp:328