Graphite
Loading...
Searching...
No Matches
loss.hpp
Go to the documentation of this file.
1
2#pragma once
3
4namespace graphite {
5
6template <typename T, int E> class Loss {
7public:
8 virtual __device__ __host__ ~Loss() {}
9
10 virtual __device__ __host__ T loss(const T &x) const = 0;
11
12 virtual __device__ __host__ T loss_derivative(const T &x) const = 0;
13};
14
15template <typename T, int E> class DefaultLoss final : public Loss<T, E> {
16public:
17 __device__ __host__ DefaultLoss() {}
18 __device__ __host__ DefaultLoss(const DefaultLoss &other) {}
19
20 __device__ __host__ T loss(const T &x) const override { return x; };
21
22 __device__ __host__ T loss_derivative(const T &x) const override {
23 return 1;
24 };
25};
26
27template <typename T, int E> class HuberLoss final : public Loss<T, E> {
28public:
29 T delta;
30
31 __device__ __host__ HuberLoss() : delta(100.0) {}
32 __device__ __host__ HuberLoss(const HuberLoss &other) : delta(other.delta) {}
33
34 __device__ __host__ HuberLoss(T delta) : delta(delta) {}
35
36 __device__ __host__ T loss(const T &x) const override {
37 if (x <= delta * delta) {
38 return x;
39 } else {
40 return 2 * std::sqrt(x) * delta - delta * delta;
41 }
42 }
43
44 __device__ __host__ T loss_derivative(const T &x) const override {
45 if (x <= delta * delta) {
46 return 1;
47 } else {
48 return delta / std::sqrt(x);
49 }
50 }
51};
52
53} // namespace graphite
Definition loss.hpp:15
Definition loss.hpp:27
Definition loss.hpp:6