NeuralEngine
A Game Engine with embeded Machine Learning algorithms based on Gaussian Processes.
FgAdamSolver.h
1
11#pragma once
12
13#include <MachineLearning/BaseGradientOptimizationMethod.h>
14
15namespace NeuralEngine
16{
17 namespace MachineLearning
18 {
73 template<typename Scalar, LineSearchType LSType = MoreThuente>
74 class NE_IMPEXP AdamSolver : public BaseGradientOptimizationMethod<Scalar, LSType>
75 {
76 public:
77
87 AdamSolver(int numberOfVariables);
88
100 AdamSolver(int numberOfVariables,
101 std::function<Scalar(const af::array&, af::array&)> function);
102
111
118
126 void SetBeta1(Scalar beta1);
127
135 void SetBeta2(Scalar beta2);
136
144 void SetAlpha(Scalar alpha);
145
153 void SetEpsilon(Scalar epsilon);
154
162 void SetDecay(Scalar decay);
163
172
181
190
199
208
209 protected:
210
221 virtual bool Optimize(int* cycle = nullptr) override;
222
223 private:
224 Scalar min_step; // The minimum step length allowed in the line search.
225 Scalar max_step; // The maximum step length allowed in the line search.
226
227 Scalar sAlpha; // learning rate
228 Scalar sBeta1; // exponential decay rate for the first moment estimates (e.g. 0.9)
229 Scalar sBeta2; // exponential decay rate for the second-moment estimates (e.g. 0.999).
230 Scalar sEpsilon; // small number to prevent any division by zero in the implementation
231 Scalar sDecay;
232 Scalar delta;
233 };
234 }
235}
AdamSolver(int numberOfVariables, std::function< Scalar(const af::array &, af::array &)> function)
Creates a new instance of the L-BFGS optimization algorithm.
Scalar GetAlpha()
Gets the learning rate.
Scalar GetBeta1()
Gets decay rate for the first moment estimates.
void SetDecay(Scalar decay)
Sets initial decay rate.
void SetBeta2(Scalar beta2)
Sets decay rate for the second-moment estimates.
void SetEpsilon(Scalar epsilon)
Sets an epsilon to avoid division by zero.
void SetAlpha(Scalar alpha)
Sets the learning rate.
Scalar GetEpsilon()
Gets the epsilon.
AdamSolver(int numberOfVariables)
Creates a new instance of the L-BFGS optimization algorithm.
virtual bool Optimize(int *cycle=nullptr) override
Implements the actual optimization algorithm. This method should try to minimize the objective functi...
Scalar GetBeta2()
Gets decay rate for the second-moment estimates.
void SetBeta1(Scalar beta1)
Sets decay rate for the first moment estimates.
Scalar GetDecay()
Gets the initial decay.
AdamSolver(NonlinearObjectiveFunction< Scalar > *function)
Creates a new instance of the L-BFGS optimization algorithm.