/**
File:		MachineLearning/GPModels/SparseGPModels/PowerEP/FgPEPSparseGPLayer.cpp

Author:		
Email:		
Site:       

Copyright (c) 2017 . All rights reserved.
*/

#include <NeMachineLearningPCH.h>
#include <MachineLearning/FgPEPSparseGPLayer.h>

namespace NeuralEngine::MachineLearning::GPModels::PowerEP
{
	template class SGPLayer<float>;
	template class SGPLayer<double>;

	template<typename Scalar>
	SGPLayer<Scalar>::SGPLayer(int numPoints, int numPseudos, int outputDim, int inputDim)
		: SparseGPBaseLayer<Scalar>(numPoints, numPseudos, outputDim, inputDim)
	{
		/*T1 = af::constant(0.1f, iN, iD, ik, m_dType);
		T2 = af::constant(0.2f, iN, iD, ik, ik, m_dType);*/
	}

	template<typename Scalar>
	SGPLayer<Scalar>::~SGPLayer()
	{
	}

	template<typename Scalar>
	void SGPLayer<Scalar>::ForwardPredictionCavity(af::array& mout, af::array& vout, af::array& n, af::array& mx, af::array* vx, Scalar alpha)
	{
		if (vx == nullptr)
			ForwardPredictionDeterministicCavity(mout, vout, n, mx, alpha);
		else
			ForwardPredictionRandomCavity(mout, vout, n, mx, *vx, alpha);
	}

	/*void SGPLayer::ForwardPredictionPost(af::array& mout, af::array& vout, af::array* mx, af::array* vx = nullptr)
	{
	}*/

	template<typename Scalar>
	void SGPLayer<Scalar>::BackpropGradientsReg(af::array& m, af::array& v, af::array& dlogZ_dm, af::array& dlogZ_dv, af::array& x, std::map<std::string, af::array>& grad_hyper, std::map<std::string, af::array>& grad_cav, Scalar alpha)
	{
		int numSamples = x.dims(0);
		af::array kfu(m_dType), dGammaHat(m_dType), dBetaHat(m_dType), dvcav(m_dType), dmcav(m_dType);
		kernel->ComputeKernelMatrix(x, afXu, kfu);

		// compute grads wrt GammaHat and BetaHat
		dGammaHat = af::tile(dlogZ_dm, 1, 1, ik) * af::tile(af::moddims(kfu, numSamples, 1, ik), 1, iD, 1);
		dBetaHat = af::constant(0.0f, numSamples, iD, ik, ik);
		for (int n = 0; n < numSamples; n++)
			for (int d = 0; d < iD; d++)
				dBetaHat(n, d, af::span, af::span) = af::moddims(af::tile(dlogZ_dv(n, d), ik, ik) * af::matmulTN(kfu(n, af::span), kfu(n, af::span)), 1, 1, ik, ik);

		dvcav = af::constant(0.0f, numSamples, iD, ik, ik);
		for (int n = 0; n < numSamples; n++)
			for (int d = 0; d < iD; d++)
				dvcav(n, d, af::span, af::span) = af::moddims(af::matmul(afInvKuu, af::matmul(af::moddims(dBetaHat(n, d, af::span, af::span), ik, ik, 1, 1), 
					afInvKuu)), 1, 1, ik, ik);

		dmcav = af::sum(af::tile(af::moddims(afInvKuu, 1, 1, ik, ik), numSamples, iD) * af::tile(af::moddims(dGammaHat, numSamples, iD, 1, ik), 1, 1, ik, 1), 3);

		grad_cav.insert(std::pair<std::string, af::array>("dmcav", dmcav));
		grad_cav.insert(std::pair<std::string, af::array>("dvcav", dvcav));
	}

	template<typename Scalar>
	void SGPLayer<Scalar>::UpdateFactor(af::array & n, std::map<std::string, af::array> grad_cav, Scalar alpha, Scalar decay)
	{
		int numSamples = n.dims(0);
		af::array munew(m_dType), inner(m_dType), Sunew(m_dType), Suinvnew(m_dType), SuinvMunew(m_dType), 
			t1_frac(m_dType), t2_frac(m_dType), t1_old(m_dType), t2_old(m_dType), t1_new(m_dType), t2_new(m_dType);

		af::array dmcav = grad_cav["dmcav"];
		af::array dvcav = grad_cav["dvcav"];

		// perform PowerEP update
		munew = afMuHat + af::sum(afSuHat * af::tile(af::moddims(dmcav, numSamples, iD, 1, ik), 1, 1, ik, 1), 3);

		Sunew = af::constant(0.0f, numSamples, iD, ik, ik);
		Suinvnew = af::constant(0.0f, numSamples, iD, ik, ik);
		for (int n = 0; n < numSamples; n++)
		{
			for (int d = 0; d < iD; d++)
			{
				inner = af::matmulTN(af::moddims(dmcav(n, d, af::span), 1, ik), af::moddims(dmcav(n, d, af::span), 1, ik)) - 2 * af::moddims(dvcav(n, d, af::span, af::span), ik, ik, 1, 1);
				inner = af::matmul(inner, af::moddims(afSuHat(n, d, af::span, af::span), ik, ik, 1, 1));
				Sunew(n, d, af::span, af::span) = af::moddims(af::matmul(af::moddims(afSuHat(n, d, af::span, af::span), ik, ik, 1, 1), inner), 1, 1, ik, ik);
				Sunew(n, d, af::span, af::span) = afSuHat(n, d, af::span, af::span) - Sunew(n, d, af::span, af::span);

				Suinvnew(n, d, af::span, af::span) = af::moddims(af::inverse(af::moddims(Sunew(n, d, af::span, af::span), ik, ik, 1, 1)), 1, 1, ik, ik);
			}
		}

		SuinvMunew = af::sum(Suinvnew * af::tile(af::moddims(munew, numSamples, iD, 1, ik), 1, 1, ik, 1), 3);

		t2_frac = Suinvnew - afInvSuHat;
		t1_frac = SuinvMunew - afInvSuMuHat;
		t1_old = T1(n, af::span, af::span);
		t2_old = T2(n, af::span, af::span, af::span);
		t1_new = ((Scalar)1.0 - alpha) * t1_old + t1_frac;
		t2_new = ((Scalar)1.0 - alpha) * t2_old + t2_frac;

		if (numSamples == 1)
		{
			// sequential update
			T1(n, af::span, af::span) = t1_new;
			T2(n, af::span, af::span, af::span) = t2_new;

			afMu = munew(0, af::span, af::span);
			afSu = Sunew(0, af::span, af::span, af::span);

			afInvSuMu = SuinvMunew(0, af::span, af::span);
			afInvSu = Suinvnew(0, af::span, af::span, af::span);
		}
		else
		{
			// parallel update
			T1(n, af::span, af::span) = decay * t1_old + (1 - decay) * t1_new;
			T2(n, af::span, af::span, af::span) = decay * t2_old + (1 - decay) * t2_new;
			UpdateParameters();
		}
	}

	template<typename Scalar>
	void SGPLayer<Scalar>::ForwardPredictionDeterministicCavity(af::array& mout, af::array& vout, af::array& idx, af::array& mx, Scalar alpha)
	{
		int numSamples = idx.dims(0);
		af::array kff(m_dType), kfu(m_dType), gammaHat(m_dType), betaHat(m_dType), kfuBetaHatkuf(m_dType), tmp(m_dType);
		ComputeCavity(idx, afMuHat, afSuHat, afInvSuMuHat, afInvSuHat, alpha);

		gammaHat = af::sum(af::tile(af::moddims(afInvKuu, 1, 1, ik, ik), numSamples, iD) * af::tile(af::moddims(afMuHat, numSamples, iD, 1, ik), 1, 1, ik, 1), 3);

		betaHat = af::constant(0.0f, numSamples, iD, ik, ik);
		for (int n = 0; n < numSamples; n++)
			for (int d = 0; d < iD; d++)
				betaHat(n, d, af::span, af::span) = af::moddims(af::matmul(afInvKuu, af::matmul(af::moddims(afSuHat(n, d, af::span, af::span), ik, ik, 1, 1), afInvKuu)), 1, 1, ik, ik);

		betaHat = betaHat - af::tile(af::moddims(afInvKuu, 1, 1, ik, ik), numSamples, iD, 1, 1);

		//kff = af::constant(1.0f, numSamples);
		kernel->ComputeDiagonal(mx, kff);
		kernel->ComputeKernelMatrix(mx, afXu, kfu);

		mout = af::sum(af::tile(af::moddims(kfu, numSamples, 1, ik), 1, iD, 1) * gammaHat, 2);

		kfuBetaHatkuf = af::constant(0.0f, numSamples, iD);
		for (int n = 0; n < numSamples; n++)
			for (int d = 0; d < iD; d++)
				kfuBetaHatkuf(n, d) = af::matmulNT(af::matmul(kfu(n, af::span), af::moddims(betaHat(n, d, af::span, af::span), ik, ik, 1, 1)), kfu(n, af::span));

		vout = tile(kff, 1, iD) + kfuBetaHatkuf;
	}

	template<typename Scalar>
	void SGPLayer<Scalar>::ForwardPredictionRandomCavity(af::array& mout, af::array& vout, af::array& idx, af::array& mx, af::array& vx, Scalar alpha)
	{
	}

	template<typename Scalar>
	void SGPLayer<Scalar>::ComputeCavity(af::array& idx, af::array& muhat, af::array& Suhat, af::array& T1uHat, af::array& T2uHat, Scalar alpha)
	{
		int numSamples = idx.dims(0);
		af::array t1n = T1(idx, af::span, af::span);
		af::array t2n = T2(idx, af::span, af::span, af::span);

		T1uHat = af::tile(af::moddims(afInvSuMu, 1, iD, ik), t1n.dims(0), 1, 1) - alpha * t1n;
		T2uHat = af::tile(af::moddims(afInvSu, 1, iD, ik, ik), t1n.dims(0), 1, 1, 1) - alpha * t2n;

		Suhat = af::constant(0.0f, numSamples, iD, ik, ik);
		for (int nn = 0; nn < T2uHat.dims(0); nn++)
			for (int d = 0; d < iD; d++)
				Suhat(nn, d, af::span, af::span) = af::moddims(af::inverse(af::moddims(T2uHat(nn, d, af::span, af::span), ik, ik, 1, 1)), 1, 1, ik, ik);

		muhat = af::sum(Suhat * af::tile(af::moddims(T1uHat, numSamples, iD, 1, ik), 1, 1, ik, 1), 3);
	}

	template<typename Scalar>
	void SGPLayer<Scalar>::UpdateParameters()
	{
		SparseGPBaseLayer::UpdateParameters();

		afInvSu = af::tile(af::moddims(afInvKuu, 1, 1, ik, ik), 1, iD, 1, 1) + sum(T2, 0);
		afInvSuMu = sum(T1, 0);

		for (int i = 0; i < iD; i++)
			afSu(0, i, af::span, af::span) = af::moddims(af::inverse(af::moddims(afInvSu(0, i, af::span, af::span), ik, ik, 1)), 1, 1, ik, ik);

		afMu = af::sum(afSu * af::tile(af::moddims(afInvSuMu, 1, iD, 1, ik), 1, 1, ik, 1), 2);
	}

	template<typename Scalar>
	void SGPLayer<Scalar>::InitParameters(af::array* X)
	{
		if (X == nullptr)
		{
			afXu = af::seq(ik, (m_dType));
			afXu = af::moddims(afXu, ik, iq);
		}
		else
		{

		}

		UpdateParameters();
	}

	template<typename Scalar>
	void SGPLayer<Scalar>::ForwardPredictionDeterministicPost(af::array & mx, af::array * mout, af::array * vout)
	{
	}
}