/**
File:		MachineLearning/GPModels/Models/GPModels/FgSparseGPSSMBaseModel.cpp

Author:		
Email:		
Site:       

Copyright (c) 2020 . All rights reserved.
*/

#include <NeMachineLearningPCH.h>
#include <MachineLearning/FgSparseGPSSMBaseModel.h>

namespace NeuralEngine::MachineLearning::GPModels
{
	template class SparseGPSSMBaseModel<double>;
	template class SparseGPSSMBaseModel<float>;

	template<typename Scalar>
	SparseGPSSMBaseModel<Scalar>::SparseGPSSMBaseModel(const af::array& Y, int latentDimension, int numInducing, Scalar priorMean,
		Scalar priorVariance, af::array& xControl, PropagationMode probMode, LogLikType lType, bool GPemission, bool controlToEmiss, XInit emethod)
		: GPSSBaseModel(Y, latentDimension, priorMean, priorVariance, xControl, probMode, lType, GPemission, controlToEmiss, emethod),
		ik(numInducing), dynLayer(nullptr), gpEmissLayer(nullptr), gaussEmissLayer(nullptr)
	{
	}

	template<typename Scalar>
	SparseGPSSMBaseModel<Scalar>::~SparseGPSSMBaseModel()
	{
		if (dynLayer) delete dynLayer;
		if (gpEmissLayer) delete gpEmissLayer;
		if (gaussEmissLayer) delete gaussEmissLayer;
	}

	template<typename Scalar>
	SparseGPSSMBaseModel<Scalar>::SparseGPSSMBaseModel()
		: GPSSBaseModel(), ik(0), dynLayer(nullptr), gpEmissLayer(nullptr), gaussEmissLayer(nullptr)
	{
	}

	template<typename Scalar>
	bool SparseGPSSMBaseModel<Scalar>::Init()
	{
		af::array mx;
		bInit = GPSSBaseModel::Init(mx);

		// Init dynamic posterior
		af::array x = mx(af::seq(0, iN - 2), af::span);
		af::array y = mx(af::seq(1, iN - 1), af::span);

		if (iDControlDyn > 0) x = CommonUtil<Scalar>::Join(x, afXControl(af::seq(0, iN - 2), af::span), 1);

		std::cout << "\nInit dynamic posterior through SGPR\n===================================\n" << std::endl;
		AEP::SGPR<Scalar>* model1 = new AEP::SGPR<Scalar>(y, x, ik, 0.1);
		//model1->FixKernelParameters(true);
		//model1->FixInducing(true);
		model1->Optimise(OptimizerType::ADAM, 0.0, false, 100, iBatchSize);
		//model1->FixKernelParameters(false);
		//model1->FixInducing(false);
		dynLayer->SetParameters(model1->GetGPLayer()->GetParameters());
		delete model1;

		// Init emission posterior
		if (iDControlEmiss > 0) x = CommonUtil<Scalar>::Join(mx, afXControl, 1);
		else x = mx;

		if (bGPemission)
		{
			std::cout << "\nInit emission posterior through SGPR\n===========================\n" << std::endl;
			AEP::SGPR<Scalar>* model2 = new AEP::SGPR<Scalar>(afY, x, ik, 0.1);
			//model2->FixKernelParameters(true);
			//model2->FixInducing(true);
			model2->Optimise(OptimizerType::ADAM, 0.0, false, 100, iBatchSize);
			//model2->FixKernelParameters(false);
			//model2->FixInducing(false);
			gpEmissLayer->SetParameters(model2->GetGPLayer()->GetParameters());
			delete model2;

			std::cout << "done.\n\nOptimizing SGPSSM\n================\n" << std::endl;
		}
		else
			gaussEmissLayer->InitParameters();
		
		/*dynLayer->InitParameters();
		gaussEmissLayer->InitParameters();*/

		UpdateParametersInternal();

		return bInit;
	}

	template<typename Scalar>
	void SparseGPSSMBaseModel<Scalar>::PredictF(const af::array& testInputs, af::array& mf, af::array& vf)
	{
		LogAssert(testInputs.dims(1) - iDControlDyn - iq == 0, "Control points not found");

		GPBaseModel<Scalar>::PredictF(testInputs, mf, vf);

		dynLayer->ForwardPredictionPost(&testInputs, nullptr, mf, vf);
	}

	template<typename Scalar>
	void SparseGPSSMBaseModel<Scalar>::PredictY(const af::array& testInputs, af::array& my, af::array& vy)
	{
		af::array mf, vf, mg, vg;
		PredictF(testInputs, mf, vf);
		if (iDControlEmiss > 0)
		{
			mf = CommonUtil<Scalar>::Join(mf, testInputs.cols(iD, iD + iDControlEmiss - 1), 1);
			vf = CommonUtil<Scalar>::Join(vf, af::constant(0.0, testInputs.dims(0), iDControlEmiss, m_dType), 1);
		}
		if (bGPemission)
		{
			gpEmissLayer->ForwardPredictionPost(&mf, &vf, mg, vg);
			likLayer->ProbabilisticOutput(mg, vg, my, vy);
		}
		else
		{
			gaussEmissLayer->ProbabilisticOutput(mf, vf, my, vy);
		}
	}

	template<typename Scalar>
	void SparseGPSSMBaseModel<Scalar>::PredictForward(int numTimeSamples, af::array& my, af::array& vy, af::array* mx, af::array* vx)
	{
		int T = numTimeSamples;

		af::array post_m, post_v, mt, vt, mft, vft, myt, vyt;

		my = af::constant(0.0, T, iD, m_dType);
		vy = af::constant(0.0, T, iD, m_dType);

		if (mx != nullptr) *mx = af::constant(0.0, T, iq);
		if (vx != nullptr) *vx = af::constant(0.0, T, iq);

		PosteriorLatents(post_m, post_v);

		af::array mtm1 = post_m(af::end, af::span);
		af::array vtm1 = post_v(af::end, af::span);
		for (uint t = 0; t < T; t++)
		{
			if (iDControlDyn > 0)
			{
				mtm1 = CommonUtil<Scalar>::Join(mtm1, afXControl(t, af::span));
				vtm1 = CommonUtil<Scalar>::Join(vtm1, af::constant(0.0, 1, iDControlDyn, m_dType));
			}
			dynLayer->ForwardPredictionPost(&mtm1, &vtm1, mt, vt);

			if (iDControlEmiss > 0)
			{
				mt = CommonUtil<Scalar>::Join(mt, afXControl(t, af::span));
				vt = CommonUtil<Scalar>::Join(vt, af::constant(0.0, 1, iDControlEmiss, m_dType));
			}

			gpEmissLayer->ForwardPredictionPost(&mt, &vt, mft, vft);
			likLayer->ProbabilisticOutput(mft, vft, myt, vyt);

			if (mx != nullptr) mx->row(t) = mt;
			if (vx != nullptr) vx->row(t) = vt;

			my(t, af::span) = myt;
			vy(t, af::span) = vyt;
		}
	}

	template<typename Scalar>
	void SparseGPSSMBaseModel<Scalar>::PosteriorData(af::array& my, af::array& vy)
	{
		af::array mx, vx, mf, vf;
		PosteriorLatents(mx, vx);
		if (iDControlEmiss > 0)
		{
			mx = CommonUtil<Scalar>::Join(mx, afXControl);
			vx = CommonUtil<Scalar>::Join(vx, af::constant(0.0, 1, iDControlEmiss, m_dType));
		}

		gpEmissLayer->ForwardPredictionPost(&mx, &vx, mf, vf);
		likLayer->ProbabilisticOutput(mf, vf, my, vy);
	}

	template<typename Scalar>
	int SparseGPSSMBaseModel<Scalar>::GetNumParameters()
	{
		int numParam = GPSSBaseModel::GetNumParameters();

		numParam += dynLayer->GetNumParameters();
		if (bGPemission) numParam += gpEmissLayer->GetNumParameters();
		else numParam += gaussEmissLayer->GetNumParameters();

		return numParam;
	}

	template<typename Scalar>
	void SparseGPSSMBaseModel<Scalar>::SetParameters(const af::array& param)
	{
		int iStart = 0, iEnd = GPSSBaseModel::GetNumParameters();
		GPSSBaseModel::SetParameters(param(af::seq(iStart, iEnd - 1)));

		iStart = iEnd, iEnd += dynLayer->GetNumParameters();
		dynLayer->SetParameters(param(af::seq(iStart, iEnd - 1)));

		if (bGPemission)
		{
			iStart = iEnd, iEnd += gpEmissLayer->GetNumParameters();
			gpEmissLayer->SetParameters(param(af::seq(iStart, iEnd - 1)));
		}
		else
		{
			iStart = iEnd, iEnd += gaussEmissLayer->GetNumParameters();
			gaussEmissLayer->SetParameters(param(af::seq(iStart, iEnd - 1)));
		}
	}

	template<typename Scalar>
	af::array SparseGPSSMBaseModel<Scalar>::GetParameters()
	{
		af::array param = af::constant(0.0f, SparseGPSSMBaseModel::GetNumParameters(), m_dType);

		int iStart = 0, iEnd = GPSSBaseModel::GetNumParameters();
		param(af::seq(iStart, iEnd - 1)) = GPSSBaseModel::GetParameters();

		iStart = iEnd; iEnd += dynLayer->GetNumParameters();
		param(af::seq(iStart, iEnd - 1)) = dynLayer->GetParameters();

		if (bGPemission)
		{
			iStart = iEnd; iEnd += gpEmissLayer->GetNumParameters();
			param(af::seq(iStart, iEnd - 1)) = gpEmissLayer->GetParameters();
		}
		else
		{
			iStart = iEnd; iEnd += gaussEmissLayer->GetNumParameters();
			param(af::seq(iStart, iEnd - 1)) = gaussEmissLayer->GetParameters();
		}

		return param;
	}

	template<typename Scalar>
	void SparseGPSSMBaseModel<Scalar>::UpdateParameters()
	{
		GPSSBaseModel<Scalar>::UpdateParameters();

		dynLayer->UpdateParameters();
		if (bGPemission) gpEmissLayer->UpdateParameters();
		else gaussEmissLayer->UpdateParameters();
	}

	template<typename Scalar>
	void SparseGPSSMBaseModel<Scalar>::FixKernelParameters(bool isfixed)
	{
		dynLayer->FixKernelParameters(isfixed);
		if (bGPemission) gpEmissLayer->FixKernelParameters(isfixed);
	}

	template<typename Scalar>
	void SparseGPSSMBaseModel<Scalar>::FixInducing(bool isfixed)
	{
		dynLayer->FixInducing(isfixed);
		if (bGPemission) gpEmissLayer->FixInducing(isfixed);
	}
}