/**
File:		MachineLearning/Models/GPModels/FgSparseDeepGPSSMBaseModel.cpp

Author:		
Email:		
Site:       

Copyright (c) 2022 . All rights reserved.
*/

#include <NeMachineLearningPCH.h>
#include <MachineLearning/FgSparseDeepGPSSMBaseModel.h>
#include <MachineLearning/FgAEPSparseDGPR.h>

namespace NeuralEngine::MachineLearning::GPModels
{
	template class SparseDeepGPSSMBaseModel<float>;
	template class SparseDeepGPSSMBaseModel<double>;

	template<typename Scalar>
	SparseDeepGPSSMBaseModel<Scalar>::SparseDeepGPSSMBaseModel(const af::array& Y, int latentDimension, HiddenLayerDescription description, 
		Scalar priorMean, Scalar priorVariance, af::array& xControl, PropagationMode probMode, LogLikType lType, XInit emethod)
		: GPSSBaseModel<Scalar>(Y, latentDimension, priorMean, priorVariance, xControl, probMode, lType, true, true, emethod), vDescription(), vNumPseudosPerLayer(), vSize(), 
		dynLayer(), gpEmissLayer()
	{
		vDescription.push_back(description);

		vSize.push_back(iq);
		vSize.push_back(description.GetNumHiddenDimensions());
		vSize.push_back(iD);

		iNumLayer = vSize.size() - 1;

		for (uint i = 0; i < iNumLayer; i++) vNumPseudosPerLayer.push_back(description.GetNumPseudoInputs());

		switch (lType)
		{
		case LogLikType::Probit:
			likLayer = new ProbitLikLayer<Scalar>(iN, iD);
			break;
		default:
			likLayer = new GaussLikLayer<Scalar>(iN, iD);
			break;
		}
	}

	template<typename Scalar>
	SparseDeepGPSSMBaseModel<Scalar>::SparseDeepGPSSMBaseModel(const af::array& Y, int latentDimension, std::vector<HiddenLayerDescription> descriptions, 
		Scalar priorMean, Scalar priorVariance, af::array& xControl, PropagationMode probMode, LogLikType lType, XInit emethod)
		: GPSSBaseModel<Scalar>(Y, latentDimension, priorMean, priorVariance, xControl, probMode, lType, true, true, emethod), vDescription(descriptions), vNumPseudosPerLayer(), vSize(), 
		dynLayer(), gpEmissLayer()
	{
		vSize.push_back(iq);
		for (auto description : descriptions)
		{
			vSize.push_back(description.GetNumHiddenDimensions());
		}
		vSize.push_back(iD);

		iNumLayer = vSize.size() - 1;

		for (uint i = 0; i < iNumLayer; i++)
			if (i == 0) vNumPseudosPerLayer.push_back(descriptions[i].GetNumPseudoInputs());
			else vNumPseudosPerLayer.push_back(descriptions[i - 1].GetNumPseudoInputs());

		switch (lType)
		{
		case LogLikType::Probit:
			likLayer = new ProbitLikLayer<Scalar>(iN, iD);
			break;
		default:
			likLayer = new GaussLikLayer<Scalar>(iN, iD);
			break;
		}
	}

	template<typename Scalar>
	SparseDeepGPSSMBaseModel<Scalar>::SparseDeepGPSSMBaseModel()
		: GPSSBaseModel<Scalar>(), iNumLayer(0), vNumPseudosPerLayer(), vSize(), dynLayer(), gpEmissLayer()
	{
	}

	template<typename Scalar>
	SparseDeepGPSSMBaseModel<Scalar>::~SparseDeepGPSSMBaseModel()
	{
		for (auto layer : gpEmissLayer) delete layer;
		gpEmissLayer.clear();

		if (dynLayer) delete dynLayer;
	}

	template<typename Scalar>
	bool SparseDeepGPSSMBaseModel<Scalar>::Init()
	{
		af::array mx, vx;
		if (!GetParent())
			bInit = GPSSBaseModel::Init(mx);
		else
			PosteriorLatents(mx, vx);

		// Init dynamic posterior
		af::array x = mx(af::seq(0, iN - 2), af::span);
		af::array y = mx(af::seq(1, iN - 1), af::span);

		if (iDControlDyn > 0) x = CommonUtil<Scalar>::Join(x, afXControl(af::seq(0, iN - 2), af::span));

		std::cout << "\nInit dynamic posterior through SGPR\n===========================\n" << std::endl;
		AEP::SGPR<Scalar>* model1 = new AEP::SGPR<Scalar>(y, x, vNumPseudosPerLayer[0], 1.0);
		//model1->FixKernelParameters(true);
		//model1->FixInducing(true);
		model1->Optimise(OptimizerType::ADAM, 0.0, false, 100, iBatchSize);
		//model1->FixKernelParameters(false);
		//model1->FixInducing(false);
		dynLayer->SetParameters(model1->GetGPLayer()->GetParameters());
		delete model1;

		vSize.back() = iD;
		for (int i = 0; i < iNumLayer; i++)
			gpEmissLayer[i]->SetDataSize(iN, vSize[i + 1]);

		/*if (mStyles)
		{
			TensorKernel<Scalar>* ktensor = new TensorKernel<Scalar>();
			ktensor->AddKernel(new ARDKernel<Scalar>(iq), af::seq(0, iq - 1));

			int iStart, iEnd = iq;
			for (auto style = mStyles->begin(); style != mStyles->end(); style++)
			{
				iStart = iEnd;  iEnd += style->second.GetNumSubstyles();
				ktensor->AddKernel(new LinearKernel<Scalar>(style->second.GetNumSubstyles()), af::seq(iStart, iEnd - 1));
			}

			gpLayer[0]->SetKernel(ktensor);
			gpLayer[0]->SetStyles(mStyles);
		}*/

		// Init posterior
		std::cout << "Init emission posterior through SDGPR\n=====================================\n" << std::endl;
		AEP::SDGPR<Scalar>* model = new AEP::SDGPR<Scalar>(afY, mx, vDescription);
		//model->FixKernelParameters(true);
		//model->FixInducing(true);
		model->Optimise(OptimizerType::ADAM, 0.0, false, 100, iBatchSize);
		//model->FixKernelParameters(false);
		//model->FixInducing(false);

		for (int i = 0; i < iNumLayer; i++)
			gpEmissLayer[i]->SetParameters(model->GetGPLayers()[i]->GetParameters());

		delete model;
		std::cout << "done.\n\nOptimizing SDGPSSM\n==================\n" << std::endl;

		/*for (int i = 0; i < iNumLayer; i++)
			if (i == 0) gpEmissLayer[i]->InitParameters(&mx);
			else gpEmissLayer[i]->InitParameters();*/

		for (uint i = 0; i < GetNumChildren(); i++)
		{
			GPLVMBaseModel<Scalar>& child = static_cast<GPLVMBaseModel<Scalar>&>(*GetChild(i));
			child.Init();
		}

		return bInit;
	}

	template<typename Scalar>
	void SparseDeepGPSSMBaseModel<Scalar>::PredictF(const af::array& testInputs, af::array& mf, af::array& vf)
	{
		af::array mfTmp, vfTmp;

		GPSSBaseModel::PredictF(testInputs, mf, vf);

		dynLayer->ForwardPredictionPost(&testInputs, nullptr, mf, vf);

		/*for (uint i = 0; i < iNumLayer; i++)
		{
			gpEmissLayer[i]->ForwardPredictionPost(&mf, &vf, mfTmp, vfTmp);
			mf = mfTmp.copy();
			vf = vfTmp.copy();
		}*/
	}

	template<typename Scalar>
	void SparseDeepGPSSMBaseModel<Scalar>::PredictForward(int numTimeSamples, af::array& my, af::array& vy, int numSamples, af::array* mx, af::array* vx)
	{
		switch (pMode)
		{
		case PropagationMode::MomentMatching:
			PredictForwardMM(numTimeSamples, my, vy, mx, vx);
			break;
		case PropagationMode::MonteCarlo:
			PredictForwardMC(numTimeSamples, my, vy, numSamples, mx);
		}
	}

	template<typename Scalar>
	void SparseDeepGPSSMBaseModel<Scalar>::PredictForwardMM(int numTimeSamples, af::array& my, af::array& vy, af::array* mx, af::array* vx)
	{
		int T = numTimeSamples;

		af::array post_m, post_v, mt, vt, mft, vft, myt, vyt, mftmp, vftmp;

		my = af::constant(0.0, T, iD, m_dType);
		vy = af::constant(0.0, T, iD, m_dType);

		if (mx != nullptr) *mx = af::constant(0.0, T, iq);
		if (vx != nullptr) *vx = af::constant(0.0, T, iq);

		PosteriorLatents(post_m, post_v);

		af::array mtm1 = post_m(af::end, af::span);
		af::array vtm1 = post_v(af::end, af::span);
		for (uint t = 0; t < T; t++)
		{
			if (iDControlDyn > 0)
			{
				mtm1 = CommonUtil<Scalar>::Join(mtm1, afXControl(t, af::span));
				vtm1 = CommonUtil<Scalar>::Join(vtm1, af::constant(0.0, 1, iDControlDyn, m_dType));
			}
			dynLayer->ForwardPredictionPost(&mtm1, &vtm1, mt, vt);

			if (iDControlEmiss > 0)
			{
				mt = CommonUtil<Scalar>::Join(mt, afXControl(t, af::span));
				vt = CommonUtil<Scalar>::Join(vt, af::constant(0.0, 1, iDControlEmiss, m_dType));
			}

			mftmp = mt.copy();
			vftmp = vt.copy();

			for (uint i = 0; i < iNumLayer; i++)
			{
				gpEmissLayer[i]->ForwardPredictionPost(&mftmp, &vftmp, mft, vft);
				mftmp = mft.copy();
				vftmp = vft.copy();
			}

			likLayer->ProbabilisticOutput(mft, vft, myt, vyt);

			if (mx != nullptr) mx->row(t) = mt;
			if (vx != nullptr) vx->row(t) = vt;

			my(t, af::span) = myt;
			vy(t, af::span) = vyt;
		}
	}

	template<typename Scalar>
	void SparseDeepGPSSMBaseModel<Scalar>::PredictForwardMC(int numTimeSamples, af::array& my, af::array& vy, int numSamples, af::array* mx)
	{
		int T = numTimeSamples;

		af::array post_m, post_v, mt, vt, xcSamples, mftmp, vftmp, mft, vft, myt, vyt;

		if (mx != nullptr) *mx = af::constant(0.0, T, numSamples, iq, m_dType);
		my = af::constant(0.0, T, iD, numSamples, m_dType);
		vy = af::constant(0.0, T, iD, numSamples, m_dType);

		PosteriorLatents(post_m, post_v);

		af::array mtm1 = post_m(af::end, af::span);
		af::array vtm1 = post_v(af::end, af::span);
		af::array eps = af::randn(numSamples, iq, m_dType);

		af::array xSamples = eps * af::tile(af::sqrt(vtm1), numSamples) + af::tile(mtm1, numSamples);
		for (uint t = 0; t < T; t++)
		{
			if (iDControlDyn > 0)
				xcSamples = CommonUtil<Scalar>::Join(xSamples, af::tile(afXControl(t, af::span), numSamples), 1);
			else
				xcSamples = xSamples;

			dynLayer->ForwardPredictionPost(&xcSamples, nullptr, mt, vt);
			eps = af::randn(numSamples, iq, m_dType);

			xSamples = eps * af::sqrt(vt) + mt;

			if (iDControlEmiss > 0)
				xcSamples = CommonUtil<Scalar>::Join(mt, afXControl(t, af::span));
			else
				xcSamples = xSamples;

			for (uint i = 0; i < iNumLayer; i++)
			{
				if (i == 0) gpEmissLayer[i]->ForwardPredictionPost(&xcSamples, nullptr, mft, vft);
				else gpEmissLayer[i]->ForwardPredictionPost(&mftmp, &vftmp, mft, vft);
				mftmp = mft.copy();
				vftmp = vft.copy();
			}

			likLayer->ProbabilisticOutput(mft, vft, myt, vyt);

			if (mx != nullptr) (*mx)(t, af::span, af::span) = af::moddims(xSamples.T(), 1, iq, numSamples);

			my(t, af::span, af::span) = af::moddims(myt.T(), 1, iD, numSamples);
			vy(t, af::span, af::span) = af::moddims(vyt.T(), 1, iD, numSamples);
		}
	}

	template<typename Scalar>
	int SparseDeepGPSSMBaseModel<Scalar>::GetNumParameters()
	{
		int numParam = GPSSBaseModel::GetNumParameters();

		numParam += dynLayer->GetNumParameters();
		for (auto layer : gpEmissLayer)
			numParam += layer->GetNumParameters();

		return numParam;
	}

	template<typename Scalar>
	int SparseDeepGPSSMBaseModel<Scalar>::GetNumGPLayerParameters()
	{
		int numParam = 0;
		for (auto layer : gpEmissLayer)
			numParam += layer->GetNumParameters();

		return numParam;
	}

	template<typename Scalar>
	void SparseDeepGPSSMBaseModel<Scalar>::SetParameters(const af::array& param)
	{
		int iStart = 0, iEnd = GPSSBaseModel::GetNumParameters();
		GPSSBaseModel::SetParameters(param(af::seq(iStart, iEnd - 1)));

		iStart = iEnd, iEnd += dynLayer->GetNumParameters();
		dynLayer->SetParameters(param(af::seq(iStart, iEnd - 1)));

		iStart = iEnd;
		for (int i = iNumLayer - 1; i >= 0; i--)
		{
			iEnd += gpEmissLayer[i]->GetNumParameters();
			gpEmissLayer[i]->SetParameters(param(af::seq(iStart, iEnd - 1)));
			iStart = iEnd;
		}
	}

	template<typename Scalar>
	af::array SparseDeepGPSSMBaseModel<Scalar>::GetParameters()
	{
		af::array param = af::constant(0.0f, SparseDeepGPSSMBaseModel::GetNumParameters(), m_dType);

		int iStart = 0, iEnd = GPSSBaseModel::GetNumParameters();
		param(af::seq(iStart, iEnd - 1)) = GPSSBaseModel::GetParameters();

		iStart = iEnd; iEnd += dynLayer->GetNumParameters();
		param(af::seq(iStart, iEnd - 1)) = dynLayer->GetParameters();

		iStart = iEnd;
		for (int i = iNumLayer - 1; i >= 0; i--)
		{
			iEnd += gpEmissLayer[i]->GetNumParameters();
			param(af::seq(iStart, iEnd - 1)) = gpEmissLayer[i]->GetParameters();
			iStart = iEnd;
		}

		return param;
	}

	template<typename Scalar>
	void SparseDeepGPSSMBaseModel<Scalar>::UpdateParameters()
	{
		GPSSBaseModel<Scalar>::UpdateParameters();

		dynLayer->UpdateParameters();
		for (auto layer : gpEmissLayer)
			layer->UpdateParameters();

		UpdateParametersInternal();
	}

	template<typename Scalar>
	int SparseDeepGPSSMBaseModel<Scalar>::GetNumLayers()
	{
		return iNumLayer;
	}

	template<typename Scalar>
	std::vector<SparseGPBaseLayer<Scalar>*> SparseDeepGPSSMBaseModel<Scalar>::GetGPLayers()
	{
		return gpEmissLayer;
	}

	template<typename Scalar>
	void SparseDeepGPSSMBaseModel<Scalar>::FixKernelParameters(bool isfixed)
	{
		dynLayer->FixKernelParameters(isfixed);
		for (auto layer : gpEmissLayer)
			layer->FixKernelParameters(isfixed);
	}
}