NeuralEngine
A Game Engine with embeded Machine Learning algorithms based on Gaussian Processes.
FgAEPSparseDGPSSM.h
1
11#pragma once
12
13#include <MachineLearning/FgSparseDeepGPSSMBaseModel.h>
14
15namespace NeuralEngine
16{
17 namespace MachineLearning
18 {
19 namespace GPModels
20 {
21 namespace AEP
22 {
23
65 template<typename Scalar>
66 class NE_IMPEXP SDGPSSM : public SparseDeepGPSSMBaseModel<Scalar>
67 {
68 public:
80 SDGPSSM(const af::array& Y, int latentDimension, HiddenLayerDescription description, Scalar alpha = 1.0,
81 Scalar priorMean = 0.0, Scalar priorVariance = 1.0, af::array xControl = af::array(),
82 PropagationMode probMode = PropagationMode::MomentMatching, LogLikType lType = LogLikType::Gaussian, XInit emethod = XInit::pca);
83
94 SDGPSSM(const af::array& Y, int latentDimension, std::vector<HiddenLayerDescription> descriptions, Scalar alpha = 1.0,
95 Scalar priorMean = 0.0, Scalar priorVariance = 1.0, af::array xControl = af::array(),
96 PropagationMode probMode = PropagationMode::MomentMatching, LogLikType lType = LogLikType::Gaussian, XInit emethod = XInit::pca);
97
105
111 virtual ~SDGPSSM();
112
123 virtual Scalar Function(const af::array& x, af::array& outGradient) override;
124
125 protected:
126
137 virtual void CavityLatents(af::array& mcav, af::array& vcav, af::array& cav1, af::array& cav2);
138
156 Scalar ComputeTiltedTransition(const af::array& mprob, const af::array& vprob, const af::array& mcav_t1, const af::array& vcav_t1,
157 Scalar scaleLogZDyn, af::array& dlogZ_dmProb, af::array& dlogZ_dvProb, af::array& dlogZ_dmt, af::array& dlogZ_dvt, Scalar& dlogZ_sn);
158
166 virtual af::array PosteriorGradientLatents();
167
178 virtual af::array CavityGradientLatents(const af::array& cav1, const af::array& cav2);
179
196 virtual af::array LogZGradientLatents(const af::array& cav1, const af::array& cav2, const af::array& dmcav_up, const af::array& dvcav_up,
197 const af::array& dmcav_prev, const af::array& dvcav_prev, const af::array& dmcav_next, const af::array& dvcav_next);
198
207
216
225
226 private:
228
229 friend class boost::serialization::access;
230
231 template<class Archive>
232 void serialize(Archive& ar, unsigned int version)
233 {
234 ar& boost::serialization::base_object<SparseDeepGPSSMBaseModel<Scalar>>(*this);
235 //ar& boost::serialization::make_nvp("SparseDeepGPLVMBaseModel", boost::serialization::base_object<SparseDeepGPLVMBaseModel<Scalar>>(*this));
236 ar& BOOST_SERIALIZATION_NVP(dAlpha);
237 }
238 };
239 }
240 }
241 }
242}
Sparse deep GPSSM via Approximated Expectation Propagation (AEP).
SDGPSSM(const af::array &Y, int latentDimension, std::vector< HiddenLayerDescription > descriptions, Scalar alpha=1.0, Scalar priorMean=0.0, Scalar priorVariance=1.0, af::array xControl=af::array(), PropagationMode probMode=PropagationMode::MomentMatching, LogLikType lType=LogLikType::Gaussian, XInit emethod=XInit::pca)
Constructor.
virtual Scalar ComputePhiPosteriorLatents()
Calculates the phi posterior.
virtual af::array LogZGradientLatents(const af::array &cav1, const af::array &cav2, const af::array &dmcav_up, const af::array &dvcav_up, const af::array &dmcav_prev, const af::array &dvcav_prev, const af::array &dmcav_next, const af::array &dvcav_next)
LogZ gradient w.r.t .
virtual af::array PosteriorGradientLatents()
Posterior gradient w.r.t .
Scalar ComputeTiltedTransition(const af::array &mprob, const af::array &vprob, const af::array &mcav_t1, const af::array &vcav_t1, Scalar scaleLogZDyn, af::array &dlogZ_dmProb, af::array &dlogZ_dvProb, af::array &dlogZ_dmt, af::array &dlogZ_dvt, Scalar &dlogZ_sn)
Calculates the tilted transition.
virtual Scalar Function(const af::array &x, af::array &outGradient) override
Cost function the given parameter inputs.
virtual void CavityLatents(af::array &mcav, af::array &vcav, af::array &cav1, af::array &cav2)
Computes the cavity distribution.
SDGPSSM(const af::array &Y, int latentDimension, HiddenLayerDescription description, Scalar alpha=1.0, Scalar priorMean=0.0, Scalar priorVariance=1.0, af::array xControl=af::array(), PropagationMode probMode=PropagationMode::MomentMatching, LogLikType lType=LogLikType::Gaussian, XInit emethod=XInit::pca)
Constructor.
virtual Scalar ComputePhiCavityLatents()
Calculates the phi cavity.
virtual af::array CavityGradientLatents(const af::array &cav1, const af::array &cav2)
Cavity gradient w.r.t .
virtual Scalar ComputePhiPriorLatents()
Calculates the phi prior.
Base class with abstract and basic function definitions. All deep GP models will be derived from this...