/**
File:		MachineLearning/GPModels/Models/Layers/EmissionLayers/FgGaussEmissionLayer.cpp

Author:		
Email:		
Site:       

Copyright (c) 2022 . All rights reserved.
*/


#include <NeMachineLearningPCH.h>
#include <MachineLearning/FgGaussEmissionLayer.h>
#define _USE_MATH_DEFINES
#include <math.h>

namespace NeuralEngine::MachineLearning::GPModels
{
	template class GaussEmission<float>;
	template class GaussEmission<double>;
	
	template <typename Scalar>
	GaussEmission<Scalar>::GaussEmission(const af::array& y, int outputDim,
		int inputDim) : ILayer<Scalar>(LayerType::Emission, y.dims(0), outputDim),
	afY(y), afU(), afR(), iq(inputDim)
	{
	}
	
	template<typename Scalar>
	GaussEmission<Scalar>::~GaussEmission()
	= default;

	template<typename Scalar>
	GaussEmission<Scalar>::GaussEmission()
		: ILayer<Scalar>(), afY(), afU(), afR(), iq(0)
	{
	}

	template<typename Scalar>
	void GaussEmission<Scalar>::ProbabilisticOutput(const af::array & mf, const af::array & vf, af::array & myOut, af::array & vyOut, Scalar alpha)
	{
		int bN = mf.dims(0);

		af::array UtileL = af::tile(af::moddims(afU, iD, 1, iq), 1, iD, 1);
		af::array UtileR = af::tile(af::moddims(afU, 1, iD, iq), iD, 1, 1);

		myOut = af::matmulNT(afU, mf).T();
		vyOut = af::constant(0.0, iD, iD, bN, m_dType);

		
		for (auto i = 0; i < bN; i++)
		{
			af::array vi = af::tile(af::moddims(vf(i, af::span), 1, 1, iq), iD, iD, 1);
			vyOut(af::span, af::span, i) = af::diag(afR, 0, false) + af::sum(UtileL * vi * UtileR, 2);
		}

		vyOut = af::moddims(af::diag(vyOut), iD, bN).T();
	}

	template<typename Scalar>
	Scalar GaussEmission<Scalar>::ComputeLogZEmission(const af::array& mx, const af::array& vx, const af::array& y, const Scalar scale, std::map<std::string, af::array>* outGradInput, af::array& outHyperGrad, Scalar alpha)
	{
		af::array dL_dmx, dL_dvx, dL_dR, dL_dU;
		Scalar logZ;
		int bN = y.dims(0);

		//af::array C = af::constant(1.0 / (2 * 2), 3, 2, m_dType);
		//af::array R = af::constant(0.1, 3, m_dType);
		//R(1) = 0.01;

		//C(0, 0) = 0.2;
		//C(1, 1) = 0.3;

		//af::array vxt = af::constant(0.1, 200, 2, m_dType);
		//af::array mxt = af::constant(0.2, 200, 2, m_dType);
		//af::array yt = af::constant(0.4, 200, 3, m_dType);
		//af::array tmp = af::constant(0.0, 3, 3, 200, m_dType);
		//af::array tmpinv = af::constant(0.0, 3, 3, 200, m_dType);
		//af::array UtmpL = af::tile(af::moddims(C, 3, 1, 2), 1, 3, 1);
		//af::array UtmpR = af::tile(af::moddims(C, 1, 3, 2), 3, 1, 1);
		//af::array VinvYt = af::constant(0.0, 200, 3, m_dType);
		//af::array Ydifft = yt - matmulNT(C, mxt).T();
		//for (auto i = 0; i < vxt.dims(0); i++)
		//{
		//	af::array vi = af::tile(af::moddims(vxt(i, af::span), 1, 1, 2), 3, 3, 1);
		//	tmp(af::span, af::span, i) = af::diag(R / alpha, 0, false) + af::sum(UtmpL * vi * UtmpR, 2);
		//	tmpinv(af::span, af::span, i) = af::inverse(tmp(af::span, af::span, i));
		//	VinvYt(i, af::span) = CommonUtil<Scalar>::SolveQR(tmp(af::span, af::span, i), Ydifft(i, af::span).T()).T();
		//}

		//af_print(af::moddims(af::diag(tmpinv), 3, bN).T());

		///*af::array UVURt = tmp / tile(R.T(), 3, 1, 200);
		//af::array It = af::tile(af::identity(3, 3, m_dType), 1, 1, 200);
		//af::array ICVCRt = It + alpha * UVURt;
		//Scalar logDet_ = 0.0;
		//for (auto i = 0; i < ICVCRt.dims(2); i++)
		//	logDet_ += log(af::det<Scalar>(ICVCRt(af::span, af::span, i)));

		//Scalar Rlogdet_term_ = -0.5 * bN * alpha * af::sum<Scalar>(af::log(R));*/

		///*af::array dL_dRt = (-0.5 * af::sum(af::diag(tmpinv), 2) + 0.5 * af::sum(af::pow(VinvYt, 2), 0).T()) / alpha;
		//dL_dRt += 0.5 * bN * (1.0 - alpha) / R;
		//dL_dRt *= 2.0 * R;*/
		//af::array dSigma_ = -0.5 * tmpinv + 0.5 * tile(moddims(VinvYt.T(), 3, 1, bN), 1, 3, 1) * tile(moddims(VinvYt.T(), 1, 3, bN), 3, 1, 1);
		//dL_dvx = af::constant(0.0, bN, 2, m_dType);
		//af::array UtmpL_ = af::tile(af::moddims(C, 3, 1, 2), 1, 3, 1);
		//af::array UtmpR_ = af::tile(af::moddims(C, 1, 3, 2), 3, 1, 1);
		//af_print(UtmpL_);
		//for (auto i = 0; i < vx.dims(0); i++)
		//{
		//	//af_print(dL_dvx(i, af::span));
		//	dL_dvx(i, af::span) = moddims(sum(sum(UtmpL * tile(dSigma_(af::span, af::span, i), 1, 1, 2) * UtmpR, 0), 1), 1, 2, 1);
		//}
		//af_print(dL_dvx);
		////np.einsum('nab,da,db->nd', dSigma, C.T, C.T)

		// sum out iq for all N points
		af::array UtileL = af::tile(af::moddims(afU, iD, 1, iq), 1, iD, 1);
		af::array UtileR = af::tile(af::moddims(afU, 1, iD, iq), iD, 1, 1);
		af::array Vy = af::constant(0.0, iD, iD, bN, m_dType);
		af::array Vyinv = af::constant(0.0, iD, iD, bN, m_dType);
		af::array UVU = af::constant(0.0, iD, iD, bN, m_dType);
		af::array VinvY = af::constant(0.0, bN, iD, m_dType);
		af::array Ydiff = y - matmulNT(afU, mx).T();

		/*af_print(y);
		af_print(mx);*/

		for (auto i = 0; i < vx.dims(0); i++)
		{
			af::array vi = af::tile(af::moddims(vx(i, af::span), 1, 1, iq), iD, iD, 1);
			Vy(af::span, af::span, i) = af::diag(afR / alpha, 0, false) + af::sum(UtileL * vi * UtileR, 2);
			Vyinv(af::span, af::span, i) = af::inverse(Vy(af::span, af::span, i));
			UVU(af::span, af::span, i) = af::sum(UtileL * vi * UtileR, 2);
			VinvY(i, af::span) = CommonUtil<Scalar>::SolveQR(Vy(af::span, af::span, i), Ydiff(i, af::span).T()).T();
		}

		/*af_print(Vy);
		af_print(Ydiff);
		af_print(Vyinv);
		af_print(UVU);
		af_print(VinvY);*/

		af::array UVUR = UVU / tile(afR.T(), iD, 1, bN);
		af::array I = af::tile(af::identity(iD, iD, m_dType), 1, 1, bN);
		af::array ICVCR = I + alpha * UVUR;
		Scalar logDet = 0.0;
		for (auto i = 0; i < ICVCR.dims(2); i++)
			logDet += log(af::det<Scalar>(ICVCR(af::span, af::span, i)));
		
		Scalar quad_term = -0.5 * af::sum<Scalar>(af::sum(Ydiff * VinvY));
		Scalar Vlogdet_term = -0.5 * logDet;
		Scalar const_term = -bN * iD * 0.5 * alpha * log(2.0 * M_PI);
		Scalar Rlogdet_term = -0.5 * bN * alpha * af::sum<Scalar>(af::log(afR));

		logZ = const_term + Rlogdet_term + Vlogdet_term + quad_term;

		dL_dR = (-0.5 * af::sum(af::diag(Vyinv), 2) + 0.5 * af::sum(af::pow(VinvY, 2), 0).T()) / alpha;
		dL_dR += 0.5 * bN * (1.0 - alpha) / afR;
		dL_dR *= 2.0 * afR;

		af::array dSigma = -0.5 * Vyinv + 0.5 * af::tile(af::moddims(VinvY.T(), iD, 1, bN), 1, iD, 1) * af::tile(af::moddims(VinvY.T(), 1, iD, bN), iD, 1, 1);
		af::array dmu = VinvY;
		af::array dU1 = af::matmulTN(dmu, mx);
		af::array dU2 = 2.0 * af::sum(af::tile(af::moddims(vx.T(), 1, iq, bN), iD, 1, 1) * af::matmul(dSigma, afU), 2);
		dL_dU = dU1 + dU2;

		dL_dmx = matmul(dmu, afU);
		dL_dvx = af::constant(0.0, bN, iq, m_dType);
		for (auto i = 0; i < dL_dvx.dims(0); i++)
		{
			af::array dSigmai = af::tile(dSigma(af::span, af::span, i), 1, 1, iq);
			dL_dvx(i, af::span) = af::moddims(af::sum(af::sum(UtileL * dSigmai * UtileR, 0), 1), 1, iq, 1);
		}

		// Collecting gradients
		int iStart = 0, iEnd =  iD * iq;
		outHyperGrad = af::constant(0.0, GetNumParameters(), m_dType);
		outHyperGrad(af::seq(iStart, iEnd - 1)) = af::flat(dL_dU) * scale;

		iStart = iEnd; iEnd += iD;
		outHyperGrad(af::seq(iStart, iEnd - 1)) = dL_dR * scale;

		if (outGradInput != nullptr)
		{
			outGradInput->clear();
			outGradInput->insert(std::pair<std::string, af::array>("dL_dmx", dL_dmx * scale));
			outGradInput->insert(std::pair<std::string, af::array>("dL_dvx", dL_dvx * scale));
		}

		return logZ * scale;
	}

	template<typename Scalar>
	int GaussEmission<Scalar>::GetNumParameters()
	{
		int numParam = iD * iq;
		numParam += iD;
		return numParam;
	}
	
	template<typename Scalar>
	void GaussEmission<Scalar>::SetParameters(const af::array& param)
	{
		int istart = 0, iend = iD * iq;

		afU = af::moddims(param(af::seq(istart, iend - 1)), iD, iq);

		istart = iend; iend += iD;
		afR = af::exp(2.0 * param(af::seq(istart, iend - 1)));
	}
	
	template<typename Scalar>
	af::array GaussEmission<Scalar>::GetParameters()
	{
		m_dType = CommonUtil<Scalar>::CheckDType();
		af::array param = af::constant(0.0f, GetNumParameters(), (m_dType));

		int istart = 0, iend = iD * iq;
		param(af::seq(istart, iend - 1)) = af::flat(afU);

		istart = iend; iend += iD;
		param(af::seq(istart, iend - 1)) = 0.5 * af::log(afR);

		return param;
	}
	
	template<typename Scalar>
	void GaussEmission<Scalar>::UpdateParameters()
	{
	}
	
	template<typename Scalar>
	void GaussEmission<Scalar>::InitParameters()
	{
		afU = af::constant(1.0 / (iD * iq), iD, iq, m_dType);
        afR = af::constant(exp(2.0 * log(0.01)), iD, m_dType);
	}
}