/**
File:		MachineLearning/GPModels/Models/Layers/FgGaussLikelihoodLayer.cpp

Author:		
Email:		
Site:       

Copyright (c) 2017 . All rights reserved.
*/

#include <NeMachineLearningPCH.h>
#include <MachineLearning/FgGaussLikelihoodLayer.h>
#define _USE_MATH_DEFINES
#include <math.h>

namespace NeuralEngine::MachineLearning::GPModels
{
	template class GaussLikLayer<float>;
	template class GaussLikLayer<double>;

	template<typename Scalar>
	GaussLikLayer<Scalar>::GaussLikLayer(int numPoints, int outputDim)
		: LikelihoodBaseLayer<Scalar>(LogLikType::Gaussian, numPoints, outputDim), /*_sn(log(0.01))*/ _sn(0.0)
	{
	}

	template<typename Scalar>
	GaussLikLayer<Scalar>::~GaussLikLayer()
	{
	}

	template<typename Scalar>
	Scalar GaussLikLayer<Scalar>::ComputeLogZ(const af::array & mout, const af::array& vout, const af::array & y, Scalar alpha, af::array* dlogZ_dm, af::array* dlogZ_dv, af::array* dlogZ_dm2)
	{
		Scalar logZ = 0.0;
		Scalar sn2 = exp(2.0 * _sn);
		
		af::array snVout;

		af::array moutTmp = mout;

		if (y.dims(0) != mout.dims(0) && mout.dims(0) == 5 * y.dims(0))
		{
			moutTmp = moddims(mout, 5, mout.dims(0) / 5, mout.dims(1));
			snVout = moddims(vout, 5, vout.dims(0) / 5, vout.dims(1)) + constant(sn2 / alpha, moutTmp.dims(), m_dType);

			bDimMod = true;
		}
		else
		{
			snVout = vout + constant(sn2 / alpha, vout.dims(), m_dType);
		}

		if (y.dims(0) == mout.dims(0))
		{
			af::array logZtmp = -0.5 * (af::log(2.0 * M_PI * snVout) + af::pow(y - mout, 2) / snVout);
			logZtmp(af::isNaN(logZtmp)) = 0.0;
			logZ = af::sum<Scalar>(af::sum(logZtmp));
			logZ += (y.dims(0) * iD * (0.5 * log(2 * M_PI * sn2 / alpha) - 0.5 * alpha * log(2 * M_PI * sn2)));

			if (dlogZ_dm != nullptr)
				*dlogZ_dm = (y - mout) / snVout;
			if (dlogZ_dv != nullptr)
				*dlogZ_dv = (-0.5 / snVout + 0.5 * af::pow(y - mout, 2) / af::pow(snVout, 2));
			if (dlogZ_dm2 != nullptr)
				*dlogZ_dm2 = -1.0f / snVout;
		}
		else
		{
			af::array tileY = af::tile(af::moddims(y, 1, y.dims(0), y.dims(1)), moutTmp.dims(0));

			af::array logZtmp = -0.5 * (af::log(2.0 * M_PI * snVout) + af::pow(tileY - moutTmp, 2.0) / snVout);

			logZtmp(af::isNaN(logZtmp)) = 0.0;

			logZtmp += (0.5 * log(2.0 * M_PI * sn2 / alpha) - 0.5 * alpha * log(2.0 * M_PI * sn2));

			af::array logZ_max = af::max(logZtmp, 0);
			af::array exp_term = af::exp(logZtmp - af::tile(logZ_max, moutTmp.dims(0)));
			af::array sumexp = af::sum(exp_term, 0);
			af::array logZ_lse = logZ_max + af::log(sumexp);
			logZ_lse -= log(moutTmp.dims(0));
			logZ = af::sum<Scalar>(af::sum(logZ_lse));

			af::array dlogZ = exp_term / af::tile(sumexp, moutTmp.dims(0));

			if (bDimMod)
			{
				if (dlogZ_dm != nullptr)
				{
					*dlogZ_dm = dlogZ * (tileY - moutTmp) / snVout;
					*dlogZ_dm = af::moddims(*dlogZ_dm, 5 * moutTmp.dims(1), moutTmp.dims(2));
				}
				if (dlogZ_dv != nullptr)
				{
					*dlogZ_dv = dlogZ * (-0.5 / snVout + 0.5 * af::pow(tileY - moutTmp, 2) / af::pow(snVout, 2));
					*dlogZ_dv = af::moddims(*dlogZ_dv, 5 * moutTmp.dims(1), moutTmp.dims(2));
				}
				if (dlogZ_dm2 != nullptr)
				{
					*dlogZ_dm2 = dlogZ * (-1.0f / snVout);
				}
			}
			else
			{
				if (dlogZ_dm != nullptr)
					*dlogZ_dm = dlogZ * (tileY - moutTmp) / snVout;
				if (dlogZ_dv != nullptr)
					*dlogZ_dv = dlogZ * (-0.5 / snVout + 0.5 * af::pow(tileY - moutTmp, 2) / af::pow(snVout, 2));
				if (dlogZ_dm2 != nullptr)
					*dlogZ_dm2 = dlogZ * (-1.0f / snVout);
			}
		}
		return logZ;
	}

	template<typename Scalar>
	void GaussLikLayer<Scalar>::ComputeLogZGradients(const af::array& mout, const af::array& vout, const af::array& y, af::array* dlogZ_dm, af::array* dlogZ_dv, af::array* dlogZ_dm2, Scalar alpha)
	{
		Scalar sn2 = exp(2.0 * _sn);
		//af::array snVout = vout + sn2 / alpha;

		af::array snVout;

		af::array moutTmp = mout;

		if (y.dims(0) != mout.dims(0) && mout.dims(0) == 5 * y.dims(0))
		{
			moutTmp = moddims(mout, 5, mout.dims(0) / 5, mout.dims(1));
			snVout = moddims(vout, 5, vout.dims(0) / 5, vout.dims(1)) + constant(sn2 / alpha, moutTmp.dims(), m_dType);

			bDimMod = true;
		}
		else
		{
			snVout = vout + constant(sn2 / alpha, vout.dims(), m_dType);
		}

		if (y.dims(0) == mout.dims(0))
		{
			if (dlogZ_dm != nullptr)
				*dlogZ_dm = (y - mout) / snVout;
			if (dlogZ_dv != nullptr)
				*dlogZ_dv = (-0.5 / snVout + 0.5 * af::pow(y - mout, 2) / af::pow(snVout, 2));
			if (dlogZ_dm2 != nullptr)
				*dlogZ_dm2 = -1.0f / snVout;
		}
		else
		{
			af::array tileY = af::tile(af::moddims(y, 1, y.dims(0), y.dims(1)), moutTmp.dims(0));

			af::array logZtmp = -0.5 * (af::log(2.0 * M_PI * snVout) + af::pow(tileY - moutTmp, 2.0) / snVout);
			logZtmp += (0.5 * log(2.0 * M_PI * sn2 / alpha) - 0.5 * alpha * log(2.0 * M_PI * sn2));

			af::array logZ_max = af::max(logZtmp, 0);
			af::array exp_term = af::exp(logZtmp - af::tile(logZ_max, moutTmp.dims(0)));
			af::array sumexp = af::sum(exp_term, 0);

			af::array dlogZ = exp_term / af::tile(sumexp, moutTmp.dims(0));

			if (bDimMod)
			{
				if (dlogZ_dm != nullptr)
				{
					*dlogZ_dm = dlogZ * (tileY - moutTmp) / snVout;
					*dlogZ_dm = af::moddims(*dlogZ_dm, 5 * moutTmp.dims(1), moutTmp.dims(2));
				}
				if (dlogZ_dv != nullptr)
				{
					*dlogZ_dv = dlogZ * (-0.5 / snVout + 0.5 * af::pow(tileY - moutTmp, 2) / af::pow(snVout, 2));
					*dlogZ_dv = af::moddims(*dlogZ_dv, 5 * moutTmp.dims(1), moutTmp.dims(2));
				}
				if (dlogZ_dm2 != nullptr)
				{
					*dlogZ_dm2 = dlogZ * (-1.0f / snVout);
				}
			}
			else
			{
				if (dlogZ_dm != nullptr)
					*dlogZ_dm = dlogZ * (tileY - moutTmp) / snVout;
				if (dlogZ_dv != nullptr)
					*dlogZ_dv = dlogZ * (-0.5 / snVout + 0.5 * af::pow(tileY - moutTmp, 2) / af::pow(snVout, 2));
				if (dlogZ_dm2 != nullptr)
					*dlogZ_dm2 = dlogZ * (-1.0f / snVout);
			}
		}
	}

	template<typename Scalar>
	Scalar GaussLikLayer<Scalar>::ComputeLogLikExp(const af::array & mout, const af::array & vout, const af::array & y)
	{
		Scalar sn2 = exp(2.0 * _sn);
		Scalar term1 = -0.5 * log(2 * M_PI * sn2);
		af::array term2 = -0.5 / sn2 * (y ^ 2 - 2 * y * mout + mout ^ 2 + vout);
		Scalar dsn = af::sum<Scalar>(term1 + term2);
		return dsn;
	}

	template<typename Scalar>
	void GaussLikLayer<Scalar>::ComputeLogLikExpGradients(const af::array & mout, const af::array & vout, const af::array & y, af::array * de_dm, af::array * de_dv)
	{
		Scalar sn2 = exp(2.0 * _sn);
		if (de_dm != nullptr) *de_dm = 1.0 / sn2 * (y - mout);
		if (de_dv != nullptr) *de_dv = -0.5 / tile(sn2, vout.dims());
	}

	template<typename Scalar>
	Scalar GaussLikLayer<Scalar>::BackpropagationGradients(const af::array & mout, const af::array & vout, af::array & dmout, af::array & dvout, Scalar alpha, Scalar scale)
	{
		Scalar sn2 = exp(2.0 * _sn);
		int numSamples;

		if (mout.dims(0) > mout.dims(1))
			if (bDimMod)
				numSamples = mout.dims(0) / 5;
			else
				numSamples = mout.dims(0);
		else
			numSamples = mout.dims(1);

		Scalar dsn = (numSamples * iD * (1.0f - alpha) + af::sum<Scalar>(sum(dvout)) * 2.0f * sn2 / alpha) * scale;

		af::array dsnArr = af::constant(dsn, 1, m_dType);

		return dsn;
	}

	template<typename Scalar>
	Scalar GaussLikLayer<Scalar>::BackpropagationGradientsLogLikExp(const af::array & mout, const af::array & vout, af::array & dmout, af::array & dvout, af::array & y, Scalar scale)
	{
		Scalar sn2 = exp(2.0 * _sn);
		Scalar term1 = -1;
		af::array term2 = 1 / sn2 * (y^2 - 2 * y * mout + mout^2 + vout);
		Scalar dsn = scale * af::sum<Scalar>(term1 + term2);
		return dsn;
	}

	template<typename Scalar>
	void GaussLikLayer<Scalar>::ProbabilisticOutput(const af::array& mf, const af::array& vf, af::array& myOut, af::array& vyOut, Scalar alpha)
	{
		myOut = mf;
		vyOut = vf + exp(2.0 * _sn) / alpha;
	}

	template<typename Scalar>
	int GaussLikLayer<Scalar>::GetNumParameters()
	{
		int numParams = LikelihoodBaseLayer<Scalar>::GetNumParameters();
		if (!isFixedParam) numParams += 1;
		return numParams;
	}

	template<typename Scalar>
	void GaussLikLayer<Scalar>::SetParameters(const af::array & param)
	{
		int istart = 0, iend = 0;
		iend = LikelihoodBaseLayer::GetNumParameters();
		LikelihoodBaseLayer<Scalar>::SetParameters(param(af::seq(istart, iend - 1)));
		if (!isFixedParam) _sn = param(iend).scalar<Scalar>();
	}

	template<typename Scalar>
	af::array GaussLikLayer<Scalar>::GetParameters()
	{
		m_dType = CommonUtil<Scalar>::CheckDType();
		af::array param = af::constant(0.0f, 1, GetNumParameters(), m_dType);
		int istart = 0, iend = 0;
		iend = LikelihoodBaseLayer<Scalar>::GetNumParameters();
		param(af::seq(istart, iend - 1)) = LikelihoodBaseLayer::GetParameters();
		if (!isFixedParam) param(iend) = _sn;
		return param;
	}

	template<typename Scalar>
	void GaussLikLayer<Scalar>::UpdateParameters()
	{
		LikelihoodBaseLayer<Scalar>::UpdateParameters();
	}
}