/**
File:		MachineLearning/GPModels/Models/GPModels/FgAEPSparseDGPR.cpp

Author:		
Email:		
Site:       

Copyright (c) 2018 . All rights reserved.
*/

#include <NeMachineLearningPCH.h>
#include <MachineLearning/FgAEPSparseDGPR.h>
#include <MachineLearning/FgAEPSparseGPLayer.h>

namespace NeuralEngine::MachineLearning::GPModels::AEP
{
	template class SDGPR<float>;
	template class SDGPR<double>;

	template<typename Scalar>
	SDGPR<Scalar>::SDGPR(const af::array & Y, const af::array & X, HiddenLayerDescription hiddenLayerdescription, Scalar alpha, LogLikType lType)
		: SparseDeepGPBaseModel<Scalar>(Y, X, hiddenLayerdescription, lType), dAlpha(alpha)
	{
		for (int i = 0; i < iNumLayer; i++)
			gpLayer.push_back(new SGPLayer<Scalar>(iN, vNumPseudosPerLayer[i], vSize[i + 1], vSize[i]));
	}

	template<typename Scalar>
	SDGPR<Scalar>::SDGPR(const af::array & Y, const af::array & X, std::vector<HiddenLayerDescription> hiddenLayerdescriptions, Scalar alpha, LogLikType lType)
		: SparseDeepGPBaseModel<Scalar>(Y, X, hiddenLayerdescriptions, lType), dAlpha(alpha)
	{
		for (int i = 0; i < iNumLayer; i++)
			gpLayer.push_back(new SGPLayer<Scalar>(iN, vNumPseudosPerLayer[i], vSize[i + 1], vSize[i]));
	}

	template<typename Scalar>
	SDGPR<Scalar>::SDGPR()
		: SparseDeepGPBaseModel<Scalar>(), dAlpha(0.0)
	{
	}

	template<typename Scalar>
	SDGPR<Scalar>::~SDGPR()
	{
	}

	template<typename Scalar>
	Scalar SDGPR<Scalar>::Function(const af::array & x, af::array& outGradient)
	{
		SetParameters(x);

		Scalar logZ, scaleLogZ, sgpContribution;

		af::array xBatch(m_dType), yBatch(m_dType), idx;

		std::map<std::string, af::array> GradInput;

		if (iBatchSize >= iN)
		{
			iBatchSize = iN;
			xBatch = afX;
			yBatch = afY;
		}
		else
		{
			af::setSeed(time(NULL));
			idx = af::round(af::randu(iN) * iN)(af::seq(iBatchSize));
			xBatch = afX(idx, af::span);
			yBatch = afY(idx, af::span);
		}

		//batchSize = afY.dims(0);
		scaleLogZ = -iN * 1.0 / iBatchSize / dAlpha;

		// propagate x cavity forward
		// compute likelihood contribution from each gp layer
		sgpContribution = 0;
		std::vector<af::array> mout, vout, psi1, psi2;
		af::array mi(m_dType), vi(m_dType), psi1i(m_dType), psi2i(m_dType);
		for (uint i = 0; i < iNumLayer; i++)
		{
			SGPLayer<Scalar>& slayer = dynamic_cast<SGPLayer<Scalar>&>(*gpLayer[i]);
			if (i == 0)
			{
				slayer.ForwardPredictionCavity(mi, vi, &psi1i, nullptr, xBatch, nullptr, dAlpha);
				mout.push_back(mi);
				vout.push_back(vi);
				psi1.push_back(psi1i);
				psi2.push_back(af::array());
			}
			else
			{
				slayer.ForwardPredictionCavity(mi, vi, &psi1i, &psi2i, mout[i - 1], &vout[i - 1], dAlpha);
				mout.push_back(mi);
				vout.push_back(vi);
				psi1.push_back(psi1i);
				psi2.push_back(psi2i);

				/*af_print(mi);
				af_print(vi);*/
			}
			sgpContribution += slayer.ComputePhi(dAlpha);
		}

		// compute log normalizer
		logZ = likLayer->ComputeLogZ(mout.back(), vout.back(), yBatch, dAlpha) * scaleLogZ;

		/// Gradient computation
		outGradient = af::constant(0.0f, GetNumParameters(), (m_dType));
		af::array dlogZ_dm(m_dType), dlogZ_dv(m_dType), dlogZ_dmi(m_dType), dlogZ_dvi(m_dType);
		std::map<std::string, af::array> gradHyper, gradCav;

		likLayer->ComputeLogZGradients(mout.back(), vout.back(), yBatch, &dlogZ_dm, &dlogZ_dv, nullptr, dAlpha);
		dlogZ_dmi = dlogZ_dm * scaleLogZ;
		dlogZ_dvi = dlogZ_dv * scaleLogZ;

		// Backpropagation
		int iStart = 0, iEnd = 0;
		for (int i = iNumLayer - 1; i >= 0; i--)
		{
			SGPLayer<Scalar>& slayer = dynamic_cast<SGPLayer<Scalar>&>(*gpLayer[i]);

			iEnd += slayer.GetNumParameters();
			if (i == 0)
				outGradient(af::seq(iStart, iEnd - 1)) = slayer.BackpropGradientsReg(mout[i], vout[i], dlogZ_dmi, dlogZ_dvi, xBatch, nullptr, dAlpha);
			else
			{
				outGradient(af::seq(iStart, iEnd - 1)) = slayer.BackpropGradientsMM(mout[i], vout[i], dlogZ_dmi, dlogZ_dvi, psi1[i], psi2[i], mout[i - 1], vout[i - 1], &GradInput, dAlpha);
				dlogZ_dmi = GradInput["dL_dmx"];
				dlogZ_dvi = GradInput["dL_dvx"];
			}
			iStart = iEnd;
		}

		iEnd += likLayer->GetNumParameters();

		if (iStart != iEnd)
			outGradient(af::seq(iStart, iEnd - 1)) = likLayer->BackpropagationGradients(mout.back(), vout.back(), dlogZ_dm, dlogZ_dv, dAlpha, scaleLogZ);

		outGradient /= iN;

		/*std::cout << logZ << std::endl;
		std::cout << sgpContribution << std::endl;*/

		return (logZ + sgpContribution) / iN;
	}
}