8template <
typename T,
typename D>
struct Dual {
12 using DT = Dual<T, D>;
14 __host__ __device__
Dual() : real(0), dual(0) {}
15 __host__ __device__
Dual(T real, D dual) : real(real), dual(dual) {}
16 __host__ __device__
Dual(T real) : real(real), dual(0) {}
18 __host__ __device__ DT operator+(
const DT &other)
const {
19 return DT(real + other.real, dual + other.dual);
22 __host__ __device__ DT operator-(
const DT &other)
const {
23 return DT(real - other.real, dual - other.dual);
26 __host__ __device__ DT operator-()
const {
return DT(-real, -dual); }
28 __host__ __device__ DT operator*(
const DT &other)
const {
29 return DT(real * other.real, real * other.dual + dual * other.real);
32 __host__ __device__ DT operator/(
const DT &other)
const {
33 if (other.real == 0) {
35 return DT(std::numeric_limits<T>::infinity(),
36 std::numeric_limits<T>::infinity());
38 T denominator = other.real * other.real;
39 return DT((real * other.real) / denominator,
40 (dual * other.real - real * other.dual) / denominator);
43 __host__ __device__ DT &operator+=(
const DT &other) {
49 __host__ __device__ DT &operator-=(
const DT &other) {
55 __host__ __device__ DT &operator*=(
const DT &other) {
56 T new_real = real * other.real;
57 dual = real * other.dual + dual * other.real;
62 __host__ __device__ DT &operator/=(
const DT &other) {
63 if (other.real == 0) {
65 real = std::numeric_limits<T>::infinity();
66 dual = std::numeric_limits<T>::infinity();
69 T denominator = other.real * other.real;
70 T new_real = (real * other.real) / denominator;
71 dual = (dual * other.real - real * other.dual) / denominator;
76 __host__ __device__
friend DT sin(
const DT &x) {
77 return DT(std::sin(x.real), x.dual * std::cos(x.real));
80 __host__ __device__
friend DT cos(
const DT &x) {
81 return DT(std::cos(x.real), -x.dual * std::sin(x.real));
84 __host__ __device__
friend DT atan(
const DT &x) {
85 return DT(std::atan(x.real), x.dual / (1 + x.real * x.real));
88 __host__ __device__
friend DT acos(
const DT &x) {
89 return DT(std::acos(x.real), -x.dual / std::sqrt(1 - x.real * x.real));
92 __host__ __device__
friend DT exp(
const DT &x) {
93 T exp_real = std::exp(x.real);
94 return DT(exp_real, x.dual * exp_real);
97 __host__ __device__
friend DT log(
const DT &x) {
98 return DT(std::log(x.real), x.dual / x.real);
101 __host__ __device__
friend DT sqrt(
const DT &x) {
102 T sqrt_real = std::sqrt(x.real);
103 return DT(sqrt_real, x.dual / (2 * sqrt_real));
106 __host__ __device__
friend DT abs(
const DT &x) {
107 if (x.real < T(0.0)) {
113 __host__ __device__
bool operator<(
const DT &other)
const {
114 return real < other.real;
117 __host__ __device__
bool operator>(
const DT &other)
const {
118 return real > other.real;
121 __host__ __device__
bool operator<=(
const DT &other)
const {
122 return real <= other.real;
125 __host__ __device__
bool operator>=(
const DT &other)
const {
126 return real >= other.real;