/**
File:		MachineLearning/GPModels/Models/Layers/FgProbitLikelihoodLayer.cpp

Author:		
Email:		
Site:       

Copyright (c) 2017 . All rights reserved.
*/


#include <NeMachineLearningPCH.h>
#include <MachineLearning/FgProbitLikelihoodLayer.h>
#include <MachineLearning/FgGaussHermiteQuadrature.h>
#define _USE_MATH_DEFINES
#include <math.h>

int GH_DEGREE = 10;

namespace NeuralEngine::MachineLearning::GPModels
{
	template class ProbitLikLayer<float>;
	template class ProbitLikLayer<double>;

	template<typename Scalar>
	ProbitLikLayer<Scalar>::ProbitLikLayer(int numPoints, int outputDim)
		: LikelihoodBaseLayer<Scalar>(LogLikType::Probit, numPoints, outputDim), afGHx(), afGHw()
	{
	}

	template<typename Scalar>
	Scalar ProbitLikLayer<Scalar>::ComputeLogZ(const af::array& mout, const af::array& vout, const af::array& y, Scalar alpha, af::array* dlogZ_dm, af::array* dlogZ_dv, af::array* dlogZ_dm2)
	{
		Scalar logZ;
		Scalar eps = std::numeric_limits<Scalar>::epsilon();
		if (alpha == 1.0)
		{
			af::array t = y * mout / af::sqrt(1 + vout);
			af::array Z = 0.5 * (1.0 + af::erf(t / sqrt(2.0)));
			logZ = af::sum<Scalar>(af::sum(af::log(Z + eps)));
		}
		else
		{
			if (afGHx.isempty() || afGHw.isempty())
				GaussHermiteQuadrature<Scalar>::Compute(GH_DEGREE, afGHx, afGHw);

			af::array ghx = af::tile(afGHx, 1, mout.dims(0), mout.dims(1));
			af::array ghw = af::tile(afGHw, 1, mout.dims(0), mout.dims(1));
	
			af::array ts = ghx * af::sqrt(2.0 * af::tile(af::moddims(vout, 1, vout.dims(0), vout.dims(1)), GH_DEGREE))
				+ af::tile(af::moddims(mout, 1, mout.dims(0), mout.dims(1)), GH_DEGREE);
			af::array pdfs = 0.5 * (1.0 + af::erf(af::tile(af::moddims(y, 1, y.dims(0), y.dims(1)), GH_DEGREE) * ts / sqrt(2.0))) + eps;
			af::array Ztilted = af::sum(af::pow(pdfs, alpha) * ghw, 0) / sqrt(M_PI);
			logZ = af::sum<Scalar>(af::sum(log(Ztilted)));
		}

		return logZ;
	}

	template<typename Scalar>
	void ProbitLikLayer<Scalar>::ComputeLogZGradients(const af::array& mout, const af::array& vout, const af::array& y, af::array* dlogZ_dm, af::array* dlogZ_dv, af::array* dlogZ_dm2, Scalar alpha)
	{
		Scalar eps = std::numeric_limits<Scalar>::epsilon();
		if (alpha == 1.0)
		{
			af::array t = y * mout / af::sqrt(1 + vout);
			af::array Z = 0.5 * (1.0 + af::erf(t / sqrt(2.0)));

			af::array dlogZ_dt = 1.0 / (Z + eps);
			dlogZ_dt /= sqrt(2.0 * M_PI) * af::exp(-af::pow(t, 2.0) / 2.0);

			if (dlogZ_dm != nullptr)
			{
				af::array dt_dm = y / sqrt(1 + vout);
				*dlogZ_dm = dlogZ_dt * dt_dm;
			}
			if (dlogZ_dv != nullptr)
			{
				af::array dt_dv = -0.5 * y * mout / pow(1.0 + vout, 1.5);
				*dlogZ_dv = dlogZ_dt * dt_dv;
			}
			if (dlogZ_dm2 != nullptr)
			{
				af::array beta = dlogZ_dm / y;
				*dlogZ_dm2 = -(af::pow(beta, 2) + mout * y * beta / (1.0 + vout));
			}
		}
		else
		{
			af::array ghx = af::tile(afGHx, 1, mout.dims(0), mout.dims(1));
			af::array ghw = af::tile(afGHw, 1, mout.dims(0), mout.dims(1));

			af::array ts = ghx * af::sqrt(2.0 * af::tile(af::moddims(vout, 1, vout.dims(0), vout.dims(1)), GH_DEGREE))
				+ af::tile(af::moddims(mout, 1, mout.dims(0), mout.dims(1)), GH_DEGREE);
			af::array pdfs = 0.5 * (1.0 + af::erf(af::tile(af::moddims(y, 1, y.dims(0), y.dims(1)), GH_DEGREE) * ts / sqrt(2.0))) + eps;
			af::array Ztilted = af::sum(af::pow(pdfs, alpha) * ghw, 0) / sqrt(M_PI);

			af::array a = af::pow(pdfs, alpha - 1.0) * af::exp(-af::pow(ts, 2.0) / 2.0);
			//af_print(af::tile(af::moddims(y, 1, y.dims(0), y.dims(1)), GH_DEGREE));
			if (dlogZ_dm != nullptr)
			{
				af::array dZdm = af::sum(ghw * a, 0) * /*af::tile(*/af::moddims(y, 1, y.dims(0), y.dims(1))/*, GH_DEGREE)*/ * alpha / M_PI / sqrt(2.0);
				*dlogZ_dm = moddims(dZdm / Ztilted, mout.dims(0), mout.dims(1)) + eps;
			}
			if (dlogZ_dv != nullptr)
			{
				af::array dZdv = af::sum(ghw * (a * ghx), 0) */*af::tile(*/af::moddims(y, 1, y.dims(0), y.dims(1))/*, GH_DEGREE)*/ 
					* alpha / M_PI / sqrt(2.0) / af::sqrt(2.0 * moddims(vout, 1, vout.dims(0), vout.dims(1)));
				*dlogZ_dv = moddims(dZdv / Ztilted, vout.dims(0), vout.dims(1)) + eps;
			}
			if (dlogZ_dm2 != nullptr)
			{
				af::array dZdm = af::sum(ghw * a, 0) * /*af::tile(*/af::moddims(y, 1, y.dims(0), y.dims(1))/*, GH_DEGREE)*/ * alpha / M_PI / sqrt(2.0);
				af::array b = (alpha - 1.0) * af::pow(pdfs, alpha - 2.0) * af::exp(-af::pow(ts, 2.0)) / sqrt(2.0 * M_PI)
					- af::pow(pdfs, alpha - 1.0) * af::tile(af::moddims(y, 1, y.dims(0), y.dims(1)), GH_DEGREE) * ts * af::exp(-af::pow(ts, 2.0) / 2);
				af::array dZdm2 = af::sum(ghw * b, 0) * alpha / M_PI / sqrt(2.0);
				*dlogZ_dm2 = -moddims(af::pow(dZdm, 2.0) / af::pow(Ztilted, 2.0) + dZdm2 / Ztilted, mout.dims(0), mout.dims(1)) + eps;
			}
		}
	}

	template<typename Scalar>
	Scalar ProbitLikLayer<Scalar>::ComputeLogLikExp(const af::array & mout, const af::array & vout, const af::array & y)
	{
		if (afGHx.isempty() || afGHw.isempty())
			GaussHermiteQuadrature<Scalar>::Compute(GH_DEGREE, afGHx, afGHw);

		af::array ghx = af::tile(afGHx, 1, mout.dims(0), mout.dims(1));
		af::array ghw = af::tile(afGHw, 1, mout.dims(0), mout.dims(1)) / sqrt(M_PI);

		af::array tiled_m = af::tile(af::moddims(mout, 1, mout.dims(0), mout.dims(1)), GH_DEGREE);
		af::array tiled_v = af::tile(af::moddims(vout, 1, vout.dims(0), vout.dims(1)), GH_DEGREE);
		
		af::array ts = ghx * af::sqrt(2.0 * tiled_v) + tiled_m;
		af::array logcdfs = CommonUtil<Scalar>::LogNormalCDF(ts * af::tile(af::moddims(y, 1, y.dims(0), y.dims(1)), GH_DEGREE));
		af::array products = ghw * logcdfs;
		
		return af::sum<Scalar>(af::sum(af::sum(products)));
	}

	template<typename Scalar>
	void ProbitLikLayer<Scalar>::ComputeLogLikExpGradients(const af::array & mout, const af::array & vout, const af::array & y, af::array * de_dm, af::array * de_dv)
	{
		if (afGHx.isempty() || afGHw.isempty())
			GaussHermiteQuadrature<Scalar>::Compute(GH_DEGREE, afGHx, afGHw);

		af::array ghx = af::tile(afGHx, 1, mout.dims(0), mout.dims(1));
		af::array ghw = af::tile(afGHw, 1, mout.dims(0), mout.dims(1)) / sqrt(M_PI);

		af::array tiled_m = af::tile(af::moddims(mout, 1, mout.dims(0), mout.dims(1)), GH_DEGREE);
		af::array tiled_v = af::tile(af::moddims(vout, 1, vout.dims(0), vout.dims(1)), GH_DEGREE);
		af::array tiledY = af::tile(af::moddims(y, 1, y.dims(0), y.dims(1)), GH_DEGREE);

		af::array ts = ghx * af::sqrt(2.0 * tiled_v) + tiled_m;

		af::array pdfs = CommonUtil<Scalar>::NormalPDF(ts * tiledY);
		af::array cdfs = CommonUtil<Scalar>::NormalCDF(ts * tiledY);
		af::array grad_cdfs = tiledY * ghw * pdfs / cdfs;
		Scalar dts_dm = 1.0;
		af::array dts_dv = 0.5 * ghx * sqrt(2 / tiled_v);

		if (de_dm != nullptr)
			*de_dm = af::sum(grad_cdfs * dts_dm, 0);
		if (de_dm != nullptr)
			*de_dv = af::sum(grad_cdfs * dts_dv, 0);
	}

	template<typename Scalar>
	Scalar ProbitLikLayer<Scalar>::BackpropagationGradientsLogLikExp(const af::array & mout, const af::array & vout, af::array & dmout, af::array & dvout, af::array & y, Scalar scale)
	{
		return 0.0;
	}

	template<typename Scalar>
	Scalar ProbitLikLayer<Scalar>::BackpropagationGradients(const af::array & mout, const af::array & vout, af::array & dmout, af::array & dvout, Scalar alpha, Scalar scale)
	{
		return 0.0;
	}

	template<typename Scalar>
	void ProbitLikLayer<Scalar>::ProbabilisticOutput(const af::array& mf, const af::array& vf, af::array& myOut, af::array& vyOut, Scalar alpha)
	{
	}
}