Graphite  0.5.0
GPU-accelerated graph optimization framework
Loading...
Searching...
No Matches
state.hpp
1
2#pragma once
4
5namespace graphite {
6
7template <typename, typename = void>
8struct has_type_alias_State : std::false_type {};
9
10template <typename T>
11struct has_type_alias_State<T, std::void_t<typename T::State>>
12 : std::true_type {};
13
14template <typename T, typename Fallback, typename = void> struct get_State_or {
15 using type = Fallback;
16};
17
18// Specialization: use T::State if it exists
19template <typename T, typename Fallback>
20struct get_State_or<T, Fallback, std::void_t<typename T::State>> {
21 using type = typename T::State;
22};
23
24// Helper alias
25template <typename T, typename Fallback>
26using get_State_or_t = typename get_State_or<T, Fallback>::type;
27
28namespace ops {
29
30template <typename VertexType, typename State, typename Traits, typename T>
31__global__ void backup_state_kernel(VertexType **vertices, State *dst,
32 const uint8_t *active_state,
33 const size_t num_vertices) {
34
35 const size_t vertex_id = get_thread_id();
36
37 if (vertex_id >= num_vertices || !is_vertex_active(active_state, vertex_id))
38 return;
40 dst[vertex_id] = Traits::get_state(*vertices[vertex_id]);
41 } else {
42 dst[vertex_id] = *vertices[vertex_id];
43 }
44}
45
46template <typename VertexType, typename State, typename Traits, typename T>
47__global__ void set_state_kernel(VertexType **vertices, const State *src,
48 const uint8_t *active_state,
49 const size_t num_vertices) {
50
51 const size_t vertex_id = get_thread_id();
52
53 if (vertex_id >= num_vertices || !is_vertex_active(active_state, vertex_id))
54 return;
55
56 if constexpr (has_type_alias_State<Traits>::value) {
57 Traits::set_state(*vertices[vertex_id], src[vertex_id]);
58 } else {
59 *vertices[vertex_id] = src[vertex_id];
60 }
61}
62
63} // namespace ops
64
65} // namespace graphite
The top-level namespace for Graphite.
Definition eigen_solver.cpp:4
Definition state.hpp:14
Definition state.hpp:8