/**
File:		MachineLearning/GPModels/Models/GPModels/FgAEPSparseGPR.cpp

Author:		
Email:		
Site:       

Copyright (c) 2018 . All rights reserved.
*/

#include <NeMachineLearningPCH.h>
#include <MachineLearning/FgAEPSparseGPR.h>
#include <MachineLearning/FgAEPSparseGPLayer.h>

namespace NeuralEngine::MachineLearning::GPModels::AEP
{
	template class SGPR<float>;
	template class SGPR<double>;

	template<typename Scalar>
	SGPR<Scalar>::SGPR(const af::array& Y, const af::array& X, int numInducing, Scalar alpha, LogLikType lType)
		: SparseGPBaseModel<Scalar>(Y, X, numInducing, lType), dAlpha(alpha)
	{
		gpLayer = new SGPLayer<Scalar>(iN, ik, iD, iq);
		//gpLayer->InitParameters(&afX);
	}

	template<typename Scalar>
	SGPR<Scalar>::SGPR()
		: SparseGPBaseModel<Scalar>(), dAlpha(0.0)
	{
		/*gpLayer = new SGPLayer<Scalar>(iN, ik, iD, iq);
		gpLayer->InitParameters(&afX);*/
	}

	template<typename Scalar>
	Scalar SGPR<Scalar>::Function(const af::array& x, af::array& outGradient)
	{
		SetParameters(x);

		Scalar logZ, scaleLogZ, sgpContribution;

		af::array mout(m_dType), vout(m_dType);
		SGPLayer<Scalar>& slayer = dynamic_cast<SGPLayer<Scalar>&>(*gpLayer);

		af::array xBatch(m_dType), yBatch(m_dType), idx;

		std::map<std::string, af::array> GradInput;

		if (iBatchSize >= iN)
		{
			iBatchSize = iN;
			xBatch = afX;
			yBatch = afY;
		}
		else
		{
			af::setSeed(time(NULL));
			idx = af::round(af::randu(iN) * iN)(af::seq(iBatchSize));
			xBatch = afX(idx, af::span);
			yBatch = afY(idx, af::span);
		}
	
		scaleLogZ = -iN * 1.0 / iBatchSize / dAlpha;

		// propagate x forward through cavity
		slayer.ForwardPredictionCavity(mout, vout, nullptr, nullptr, xBatch, nullptr, dAlpha);

		// compute log normalizer
		logZ = likLayer->ComputeLogZ(mout, vout, yBatch, dAlpha) * scaleLogZ;

		// likelihood contribution from gp layer
		sgpContribution = slayer.ComputePhi(dAlpha);

		/// Gradient computation
		int iStart = 0, iEnd = 0;
		outGradient = af::constant(0.0f, GetNumParameters(), (m_dType));

		af::array dlogZ_dm(m_dType), dlogZ_dv(m_dType), dlogZ_dm_scale(m_dType), dlogZ_dv_scale(m_dType);

		likLayer->ComputeLogZGradients(mout, vout, yBatch, &dlogZ_dm, &dlogZ_dv, nullptr, dAlpha);

		dlogZ_dm_scale = dlogZ_dm * scaleLogZ;
		dlogZ_dv_scale = dlogZ_dv * scaleLogZ;

		// Collecting gradients
		iEnd = gpLayer->GetNumParameters();
		outGradient(af::seq(iStart, iEnd - 1)) = slayer.BackpropGradientsReg(mout, vout, dlogZ_dm_scale, dlogZ_dv_scale, xBatch, nullptr, dAlpha);

		iStart = iEnd;
		iEnd += likLayer->GetNumParameters();
		if (iStart != iEnd)
			outGradient(af::seq(iStart, iEnd - 1)) = likLayer->BackpropagationGradients(mout, vout, dlogZ_dm, dlogZ_dv, dAlpha, scaleLogZ);

		outGradient /= iN;

		return (logZ + sgpContribution) / iN;
	}
}