Graphite
Loading...
Searching...
No Matches
dual.hpp
Go to the documentation of this file.
1
2#pragma once
3#include <cuda/std/cmath>
4#include <cuda/std/limits>
5
6namespace graphite {
7
8template <typename T, typename D> struct Dual {
9 T real;
10 D dual;
11
12 using DT = Dual<T, D>;
13
14 __host__ __device__ Dual() : real(0), dual(0) {}
15 __host__ __device__ Dual(T real, D dual) : real(real), dual(dual) {}
16 __host__ __device__ Dual(T real) : real(real), dual(0) {}
17
18 __host__ __device__ DT operator+(const DT &other) const {
19 return DT(real + other.real, dual + other.dual);
20 }
21
22 __host__ __device__ DT operator-(const DT &other) const {
23 return DT(real - other.real, dual - other.dual);
24 }
25
26 __host__ __device__ DT operator-() const { return DT(-real, -dual); }
27
28 __host__ __device__ DT operator*(const DT &other) const {
29 return DT(real * other.real, real * other.dual + dual * other.real);
30 }
31
32 __host__ __device__ DT operator/(const DT &other) const {
33 if (other.real == 0) {
34 // Handle division by zero case
35 return DT(std::numeric_limits<T>::infinity(),
36 std::numeric_limits<T>::infinity());
37 }
38 T denominator = other.real * other.real;
39 return DT((real * other.real) / denominator,
40 (dual * other.real - real * other.dual) / denominator);
41 }
42
43 __host__ __device__ DT &operator+=(const DT &other) {
44 real += other.real;
45 dual += other.dual;
46 return *this;
47 }
48
49 __host__ __device__ DT &operator-=(const DT &other) {
50 real -= other.real;
51 dual -= other.dual;
52 return *this;
53 }
54
55 __host__ __device__ DT &operator*=(const DT &other) {
56 T new_real = real * other.real;
57 dual = real * other.dual + dual * other.real;
58 real = new_real;
59 return *this;
60 }
61
62 __host__ __device__ DT &operator/=(const DT &other) {
63 if (other.real == 0) {
64 // Handle division by zero case
65 real = std::numeric_limits<T>::infinity();
66 dual = std::numeric_limits<T>::infinity();
67 return *this;
68 }
69 T denominator = other.real * other.real;
70 T new_real = (real * other.real) / denominator;
71 dual = (dual * other.real - real * other.dual) / denominator;
72 real = new_real;
73 return *this;
74 }
75
76 __host__ __device__ friend DT sin(const DT &x) {
77 return DT(std::sin(x.real), x.dual * std::cos(x.real));
78 }
79
80 __host__ __device__ friend DT cos(const DT &x) {
81 return DT(std::cos(x.real), -x.dual * std::sin(x.real));
82 }
83
84 __host__ __device__ friend DT atan(const DT &x) {
85 return DT(std::atan(x.real), x.dual / (1 + x.real * x.real));
86 }
87
88 __host__ __device__ friend DT acos(const DT &x) {
89 return DT(std::acos(x.real), -x.dual / std::sqrt(1 - x.real * x.real));
90 }
91
92 __host__ __device__ friend DT exp(const DT &x) {
93 T exp_real = std::exp(x.real);
94 return DT(exp_real, x.dual * exp_real);
95 }
96
97 __host__ __device__ friend DT log(const DT &x) {
98 return DT(std::log(x.real), x.dual / x.real);
99 }
100
101 __host__ __device__ friend DT sqrt(const DT &x) {
102 T sqrt_real = std::sqrt(x.real);
103 return DT(sqrt_real, x.dual / (2 * sqrt_real));
104 }
105
106 __host__ __device__ friend DT abs(const DT &x) {
107 if (x.real < T(0.0)) {
108 return -x;
109 }
110 return x;
111 }
112
113 __host__ __device__ bool operator<(const DT &other) const {
114 return real < other.real;
115 }
116
117 __host__ __device__ bool operator>(const DT &other) const {
118 return real > other.real;
119 }
120
121 __host__ __device__ bool operator<=(const DT &other) const {
122 return real <= other.real;
123 }
124
125 __host__ __device__ bool operator>=(const DT &other) const {
126 return real >= other.real;
127 }
128};
129
130} // namespace graphite
Definition dual.hpp:8