/**
File:		MachineLearning/Models/GPModels/FgGPStateSpaceBaseModel.cpp

Author:		
Email:		
Site:       

Copyright (c) 2020 . All rights reserved.
*/

#include <NeMachineLearningPCH.h>
#include <MachineLearning/FgGPStateSpaceBaseModel.h>

namespace NeuralEngine::MachineLearning::GPModels
{
	template class GPSSBaseModel<float>;
	template class GPSSBaseModel<double>;

	template<typename Scalar>
	GPSSBaseModel<Scalar>::GPSSBaseModel(const af::array& Y, int latentDimension,
		Scalar priorMean, Scalar priorVariance, af::array& xControl, PropagationMode probMode, LogLikType lType, bool GPemission, bool controlToEmiss, XInit emethod)
		: GPBaseModel(Y, lType, ModelType::SSM), iq(latentDimension), dPriorMean(priorMean), dPriorVariance(priorVariance), dPriorX1(priorMean / priorVariance),
		dPriorX2(1.0 / priorVariance), eEmMethod(emethod), afXControl(), bGPemission(GPemission), bControlToEmiss(controlToEmiss), dSn(log(0.01)),
		afPriorMean(), afPriorVariance(), afPriorMeanCav(), afPriorVarianceCav(), afGradMean(), afGradVariance(), afGradMeanCav(), afGradVarianceCav(),
		afPriorX1(), afPriorX2(), pMode(probMode)
	{
		afFactorX1 = af::constant(0.0, iN, iq, m_dType);
		afFactorX2 = af::constant(0.0, iN, iq, m_dType);

		afPosteriorX1 = af::constant(0.0, iN, iq, m_dType);
		afPosteriorX2 = af::constant(0.0, iN, iq, m_dType);

		//afIndexes = af::seq(0, iN - 1);

		if (!xControl.isempty())
		{
			iDControlDyn = xControl.dims(1);
			afXControl = xControl;
			if (controlToEmiss)
				iDControlEmiss = xControl.dims(1);
			else
				iDControlEmiss = 0;
		}
		else
		{
			iDControlDyn = 0;
			iDControlEmiss = 0;
		}

		if (afSegments.isempty())
			afDynIndexes = af::seq(0, iN - 2);
		else
		{
			int iStart, iEnd;
			for(auto i = 0; i < afSegments.dims(0) - 1; i++)
			{
				iStart = afSegments(i).as(s32).scalar<int>();
				iEnd = afSegments(i + 1).as(s32).scalar<int>() - 2;
				afDynIndexes = CommonUtil<Scalar>::Join(afDynIndexes, af::seq(iStart, iEnd));
			}
			iStart = afSegments(af::end).as(s32).scalar<int>();
			afDynIndexes = CommonUtil<Scalar>::Join(afDynIndexes, af::seq(iStart, iN - 2));
		}
	}

	template<typename Scalar>
	GPSSBaseModel<Scalar>::GPSSBaseModel() : iq(0), dPriorMean(0), dPriorVariance(1), eEmMethod(XInit::pca)
	{
	}

	template<typename Scalar>
	GPSSBaseModel<Scalar>::~GPSSBaseModel()
	{
	}
	template<typename Scalar>
	void GPSSBaseModel<Scalar>::Optimise(OptimizerType method, Scalar tol, bool reinit_hypers, int maxiter, int mb_size, LineSearchType lsType, bool disp, int* cycle)
	{
		if (mb_size > 0)
		{
			srand((unsigned)time(0));
			int startIdx = rand() % (iN - mb_size);
			int endIdx = startIdx + mb_size;

			afIndexes = af::seq(startIdx, endIdx - 1);
			afDynIndexes = af::seq(startIdx, endIdx - 2);
		}

		GPBaseModel::Optimise(method, tol, reinit_hypers, maxiter, mb_size, lsType, disp, cycle);
	}

	template<typename Scalar>
	bool GPSSBaseModel<Scalar>::Init(af::array& mx)
	{
		if (GetNumChildren() > 0)
		{
			for (uint i = 0; i < GetNumChildren(); i++)
			{
				af::array y;
				GPLVMBaseModel<Scalar>& child = dynamic_cast<GPLVMBaseModel<Scalar>&>(*GetChild(i));
				child.Init(y);
				afY = CommonUtil<Scalar>::Join(afY, y, 1);
			}

			iN = afY.dims(0);
			iD = afY.dims(1);

			if (iBatchSize >= iN || iBatchSize == 0)
			{
				iBatchSize = iN;
				afIndexes = af::seq(0, iN - 1);
			}
			else
			{
				af::setSeed(time(NULL));
				afIndexes = af::round(af::randu(iN) * iN)(af::seq(iBatchSize));
			}

			for (uint i = 0; i < GetNumChildren(); i++)
			{
				GPLVMBaseModel<Scalar>& child = dynamic_cast<GPLVMBaseModel<Scalar>&>(*GetChild(i));
				child.SetIndexes(afIndexes);
				child.SetBatchSize(iBatchSize);
			}
		}

		bInit = GPBaseModel::Init();

		int numNeighbours = 10;
		IEmbed* embed = nullptr;

		af::array vx(m_dType);

		if (iD > iq)
		{
			switch (eEmMethod)
			{
			case XInit::pca:
				embed = new PCA();
				mx = embed->Compute(afY, iq);
				break;
			case XInit::isomap:
				embed = new Isomap(numNeighbours);
				mx = embed->Compute(afY, iq);
				break;
			case XInit::lle:
				embed = new LLE(numNeighbours);
				mx = embed->Compute(afY, iq);
				break;
			default:
				std::cout << "Embedding method not implemented." << std::endl;
				break;
			}
		}
		else
			mx = afY.copy();

		UpdateDynamicIndexes();

		if (embed != nullptr) delete embed;

		vx = af::constant(0.1, mx.dims(), m_dType);
		// natural parameters computation
		afFactorX2 = 1.0 / vx;
		afFactorX1 = afFactorX2 * mx / 3.0;
		afFactorX2 = /*af::exp(2.0 * (af::log(*/afFactorX2 / 3.0/*) / 2.0))*/;

		/*afFactorX1 = CommonUtil<Scalar>::ReadTXT("../../resources/factor01.txt");
		afFactorX2 = af::exp(2.0 * CommonUtil<Scalar>::ReadTXT("../../resources/factor02.txt"));*/

		if (GetParent())
		{
			afPriorMean = af::constant(dPriorMean, iN, iq, m_dType);
			afPriorVariance = af::constant(dPriorVariance, iN, iq, m_dType);

			afPriorMeanCav = af::constant(dPriorMean, iN, iq, m_dType);
			afPriorVarianceCav = af::constant(dPriorVariance, iN, iq, m_dType);

			afPriorX1 = afPriorMean / afPriorVariance;
			afPriorX2 = 1.0 / afPriorVariance;

			afPriorX1Cav = afPriorMean / afPriorVariance;		//!< prior /f$x_1/f$
			afPriorX2Cav = 1.0 / afPriorVariance;
		}

		return bInit;
	}

	template<typename Scalar>
	void GPSSBaseModel<Scalar>::PosteriorLatents(af::array& mx, af::array& vx)
	{
		vx = 1.0 / afPosteriorX2;
		mx = afPosteriorX1 / afPosteriorX2;
	}

	template<typename Scalar>
	af::array GPSSBaseModel<Scalar>::PosteriorGradientLatents(const af::array& dmx, const af::array& dvx)
	{
		af::array gradX1, gradX2;
		af::array gradX = af::constant(0.0, iN * iq * 2, m_dType);
		af::array scaleX = af::constant(3.0, afIndexes.dims(0), m_dType);

		af::array gradX1all = af::constant(0.0, afPosteriorX1.dims(), m_dType);
		af::array gradX2all = af::constant(0.0, afPosteriorX2.dims(), m_dType);

		int iStart = 0, iEnd = iN * iq;

		gradX1 = dmx / afPosteriorX2(afIndexes, af::span);
		gradX2 = -dmx * afPosteriorX1(afIndexes, af::span) / af::pow(afPosteriorX2(afIndexes, af::span), 2.0) - dvx 
			/ af::pow(afPosteriorX2(afIndexes, af::span), 2.0);

		if (afSegments.isempty())
		{
			scaleX(af::where(afIndexes == 0.0)) = 2.0;
			scaleX(af::where(afIndexes == iN - 1)) = 2.0;
		}
		else
		{
			scaleX(afSegments, af::span) = 2.0 * afFactorX1(0, af::span);
			scaleX(afSegments, af::span) = 2.0 * afFactorX2(0, af::span);
			scaleX(afSegments(af::seq(1, af::end)) - 1, af::span) = 2.0 * afFactorX1(0, af::span);
			scaleX(afSegments(af::seq(1, af::end)) - 1, af::span) = 2.0 * afFactorX2(0, af::span);
			scaleX(af::end, af::span) = 2.0 * afFactorX1(af::end, af::span);
			scaleX(af::end, af::span) = 2.0 * afFactorX2(af::end, af::span);
		}

		gradX1 *= scaleX;
		gradX2 *= scaleX;

		gradX2 *= 2.0 * afFactorX2(afIndexes, af::span);

		gradX1all(afIndexes, af::span) = gradX1;
		gradX2all(afIndexes, af::span) = gradX2;

		gradX(af::seq(iStart, iEnd - 1)) = af::flat(gradX1all);
		iStart = iEnd; iEnd += iN * iq;
		gradX(af::seq(iStart, iEnd - 1)) = af::flat(gradX2all);

		return gradX;
	}

	template<typename Scalar>
	int GPSSBaseModel<Scalar>::GetNumParameters()
	{
		int numParam = 0;

		for (uint i = 0; i < GetNumChildren(); i++)
		{
			GPLVMBaseModel<Scalar>& child = static_cast<GPLVMBaseModel<Scalar>&>(*GetChild(i));
			numParam += child.GetNumParameters();
		}

		/*int */numParam += GPBaseModel::GetNumParameters();

		numParam += iN * iq * 2; // number of latents to be optimized

		numParam += 1; // scale

		return numParam;
	}

	template<typename Scalar>
	void GPSSBaseModel<Scalar>::SetParameters(const af::array& param)
	{
		int iStart = 0, iEnd = 0;
		for (uint i = 0; i < GetNumChildren(); i++)
		{
			GPLVMBaseModel<Scalar>& child = dynamic_cast<GPLVMBaseModel<Scalar>&>(*GetChild(i));
			iStart = iEnd;
			iEnd += child.GetNumParameters();
			child.SetParameters(param(af::seq(iStart, iEnd - 1)));
		}

		iStart = iEnd; iEnd += GPBaseModel::GetNumParameters();
		if (iStart != iEnd)
			GPBaseModel<Scalar>::SetParameters(param(af::seq(iStart, iEnd - 1)));

		iStart = iEnd; iEnd += iN * iq;
		afFactorX1/*(afIndexes, af::span)*/ = af::moddims(param(af::seq(iStart, iEnd - 1)), iN, iq);
		iStart = iEnd; iEnd += iN * iq;
		afFactorX2/*(afIndexes, af::span)*/ = af::moddims(af::exp(2.0 * param(af::seq(iStart, iEnd - 1))), iN, iq);

		dSn = param(iEnd).scalar<Scalar>();

		UpdateParametersInternal();
	}

	template<typename Scalar>
	af::array GPSSBaseModel<Scalar>::GetParameters()
	{
		af::array param = af::constant(0.0f, GPSSBaseModel::GetNumParameters(), m_dType);

		int iStart = 0, iEnd = 0;
		for (uint i = 0; i < GetNumChildren(); i++)
		{
			GPLVMBaseModel<Scalar>& child = dynamic_cast<GPLVMBaseModel<Scalar>&>(*GetChild(i));
			iStart = iEnd;
			iEnd += child.GetNumParameters();
			param(af::seq(iStart, iEnd - 1)) = child.GetParameters();
		}

		iStart = iEnd; iEnd += GPBaseModel::GetNumParameters();
		if (iStart != iEnd)
			param(af::seq(iStart, iEnd - 1)) = GPBaseModel::GetParameters();

		iStart = iEnd; iEnd += iN * iq;
		param(af::seq(iStart, iEnd - 1)) = af::flat(afFactorX1/*(afIndexes, af::span)*/);
		iStart = iEnd; iEnd += iN * iq;
		param(af::seq(iStart, iEnd - 1)) = af::flat(af::log(afFactorX2/*(afIndexes, af::span)*/) / 2.0);

		param(iEnd) = dSn;

		return param;
	}

	template<typename Scalar>
	void GPSSBaseModel<Scalar>::UpdateDynamicIndexes()
	{
		if (afSegments.isempty())
			afDynIndexes = af::seq(0, iN - 2);
		else
		{
			afDynIndexes = af::array();
			int iStart, iEnd;
			for (auto i = 0; i < afSegments.dims(0) - 1; i++)
			{
				iStart = afSegments(i).as(s32).scalar<int>();
				iEnd = afSegments(i + 1).as(s32).scalar<int>() - 2;
				afDynIndexes = CommonUtil<Scalar>::Join(afDynIndexes, af::seq(iStart, iEnd));
			}
			iStart = afSegments(af::end).as(s32).scalar<int>();
			afDynIndexes = CommonUtil<Scalar>::Join(afDynIndexes, af::seq(iStart, iN - 2));
		}
	}

	template<typename Scalar>
	void GPSSBaseModel<Scalar>::GetLatents(af::array& mx, af::array& vx)
	{
		vx = 1.0 / afFactorX2(afIndexes, af::span);
		mx = afFactorX1(afIndexes, af::span) / afFactorX2(afIndexes, af::span);
	}

	template<typename Scalar>
	void GPSSBaseModel<Scalar>::UpdateParametersInternal()
	{
		GPBaseModel<Scalar>::UpdateParameters();

		if (GetNumChildren() > 0)
		{
			afY = af::array();
			af::array mx, vx;
			for (uint i = 0; i < GetNumChildren(); i++)
			{
				GPSSBaseModel<Scalar>& child = dynamic_cast<GPSSBaseModel<Scalar>&>(*GetChild(i));
				child.PosteriorLatents(mx, vx);
				afY = CommonUtil<Scalar>::Join(afY, mx, 1);
			}
		}

		afPosteriorX1 = 3.0 * afFactorX1;
		afPosteriorX2 = 3.0 * afFactorX2;

		if (afSegments.isempty())
		{
			afPosteriorX1(0, af::span) = 2.0 * afFactorX1(0, af::span);
			afPosteriorX2(0, af::span) = 2.0 * afFactorX2(0, af::span);
			afPosteriorX1(af::end, af::span) = 2.0 * afFactorX1(af::end, af::span);
			afPosteriorX2(af::end, af::span) = 2.0 * afFactorX2(af::end, af::span);


			afPosteriorX1(0, af::span) += dPriorX1;
			afPosteriorX2(0, af::span) += dPriorX2;
		}
		else
		{
			afPosteriorX1(afSegments, af::span) = 2.0 * afFactorX1(afSegments, af::span);
			afPosteriorX2(afSegments, af::span) = 2.0 * afFactorX2(afSegments, af::span);
			afPosteriorX1(afSegments(af::seq(1,af::end)) - 1, af::span) = 2.0 * afFactorX1(afSegments(af::seq(1,af::end)) - 1, af::span);
			afPosteriorX2(afSegments(af::seq(1,af::end)) - 1, af::span) = 2.0 * afFactorX2(afSegments(af::seq(1,af::end)) - 1, af::span);
			afPosteriorX1(af::end, af::span) = 2.0 * afFactorX1(af::end, af::span);
			afPosteriorX2(af::end, af::span) = 2.0 * afFactorX2(af::end, af::span);

			afPosteriorX1(afSegments, af::span) += dPriorX1;
			afPosteriorX2(afSegments, af::span) += dPriorX2;
		}
	}

	template<typename Scalar>
	void GPSSBaseModel<Scalar>::AddWindowData(af::array data)
	{
		afY = CommonUtil<Scalar>::Join(afY, data);
		iN = (int)afY.dims(0);
		iBatchSize = iN;
		afIndexes = af::seq(0, iN - 1);

		int numNeighbours = 10;
		IEmbed* embed = nullptr;

		af::array vx, mx;

		if (iD > iq)
		{
			switch (eEmMethod)
			{
			case XInit::pca:
				embed = new PCA();
				mx = embed->Compute(data, iq);
				break;
			case XInit::isomap:
				embed = new Isomap(numNeighbours);
				mx = embed->Compute(data, iq);
				break;
			case XInit::lle:
				embed = new LLE(numNeighbours);
				mx = embed->Compute(data, iq);
				break;
			default:
				std::cout << "Embedding method not implemented." << std::endl;
				break;
			}
		}
		else
			mx = data.copy();

		UpdateDynamicIndexes();

		if (embed != nullptr) delete embed;

		vx = af::constant(0.1, mx.dims(), m_dType);
		// natural parameters computation

		af::array fPartX2 = 1.0 / vx;
		af::array fPartX1 = fPartX2 * mx / 3.0;
		fPartX2 = fPartX2 / 3.0;
		afFactorX2 = CommonUtil<Scalar>::Join(afFactorX2, fPartX2);
		afFactorX1 = CommonUtil<Scalar>::Join(afFactorX1, fPartX1);

		UpdateParametersInternal();
	}
}