/**
File:		MachineLearning/Kernel/FgRBFKernel<Scalar>.cpp

Author:		Nick Taubert
Email:		nick.taubert@uni-tuebingen.de
Site:       http://www.compsens.uni-tuebingen.de/

Copyright (c) 2017 CompSens. All rights reserved.
*/

#include <NeEnginePCH.h>
#include <Core/NeLogger.h>
#include <cmath>
#include <MachineLearning/FgRBFKernel.h>
#include <MachineLearning/CommonUtil.h>

namespace NeuralEngine
{
	namespace MachineLearning
	{
		template class RBFKernel<float>;
		template class RBFKernel<double>;

		template<typename Scalar>
		RBFKernel<Scalar>::RBFKernel()
			: IKernel<Scalar>(eRBFKernel, 2), dVariance(1.0f), dInvScale(1.0f)
		{
		}

		template<typename Scalar>
		RBFKernel<Scalar>::~RBFKernel() { }

		template<typename Scalar>
		void RBFKernel<Scalar>::ComputeKernelMatrix(const af::array& inX1, const af::array& inX2, af::array& outMatrix)
		{
			outMatrix = dVariance * af::exp(-CommonUtil<Scalar>::SquareDistance(inX1, inX2) * 0.5f * (1.0 / dInvScale));
		}

		template<typename Scalar>
		void RBFKernel<Scalar>::ComputeDiagonal(const af::array& inX, af::array& outDiagonal)
		{
			outDiagonal = af::constant(dVariance, inX.dims(0), 1, (m_dType));
		}

		template<typename Scalar>
		void RBFKernel<Scalar>::LogLikGradientX(const af::array& inX, const af::array& indL_dK, af::array& outdL_dX)
		{
			outdL_dX = af::constant(0.0f, inX.dims(), (m_dType));
			af::array dK_dX;
			for (int q = 0; q < outdL_dX.dims(1); q++)
			{
				GradX(inX, inX, q, dK_dX);
				outdL_dX(af::span, q) = 2.0f * af::sum(indL_dK * dK_dX, 1) - af::diag(indL_dK) * af::diag(dK_dX);
			}
		}

		template<typename Scalar>
		void RBFKernel<Scalar>::LogLikGradientX(const af::array& inXu, const af::array& indL_dKuu, const af::array& inX, const  af::array& indL_dKuf, af::array& outdL_dXu, af::array& outdL_dX)
		{
			outdL_dXu = af::constant(0.0f, inXu.dims(), (m_dType));
			outdL_dX = af::constant(0.0f, inX.dims(), (m_dType));

			af::array dK_dX_u(m_dType); // overrider for dX and dXu
			for (int q = 0; q < outdL_dXu.dims(1); q++)
			{
				// dL_dKuu_dXu
				GradX(inXu, inXu, q, dK_dX_u);
				outdL_dXu(af::span, q) = 2 * af::sum(indL_dKuu * dK_dX_u, 1) - af::diag(indL_dKuu) * af::diag(dK_dX_u);

				// dL_dKuf_dXu
				GradX(inXu, inX, q, dK_dX_u);
				outdL_dXu(af::span, q) += af::sum(indL_dKuf * dK_dX_u, 1);

				// dL_dKuf_dX
				GradX(inX, inXu, q, dK_dX_u);
				outdL_dX(af::span, q) = af::sum(indL_dKuf.T() * dK_dX_u, 1);
			}
		}

		template<typename Scalar>
		void RBFKernel<Scalar>::LogLikGradientParam(const af::array& inX1, const af::array& inX2, const af::array& indL_dK, af::array& outdL_dParam)
		{
			af::array K(m_dType), n2(m_dType), tmp(m_dType);
			ComputeKernelMatrix(inX1, inX2, K);
			n2 = CommonUtil<Scalar>::SquareDistance(inX1, inX2);
			//tmp = -0.5 * K * n2;
			tmp = K * n2;

			outdL_dParam = af::array(1, 2, (m_dType));
			outdL_dParam(1) = -0.5 * af::sum(af::sum(indL_dK * tmp)) / dInvScale;

			//tmp = K / dVariance;
			outdL_dParam(0) = af::sum(af::sum(indL_dK * K)) / dVariance;
		}

		template<typename Scalar>
		void RBFKernel<Scalar>::LogLikGradientCompundKfu(const af::array & indL_dKfu, const af::array & inX, const af::array & inXu, 
			af::array * outdL_dParam, af::array* outdL_dXu, const af::array* dlogZ_dv, af::array * outdL_dX)
		{
			af::array Lpsi1(m_dType), Zmu(m_dType), kfu(m_dType);

			int iN = inX.dims(0);
			int ik = inXu.dims(0);
			int iq = inXu.dims(1);

			ComputeKernelMatrix(inX, inXu, kfu);

			*outdL_dParam = constant(0.0f, 1, GetNumParameter(), (m_dType));

			Scalar dInvScale2 = std::pow(dInvScale, 2);

			Lpsi1 = indL_dKfu * kfu;
			Zmu = tile(moddims(inXu, 1, ik, iq), iN, 1, 1) - tile(moddims(inX, iN, 1, iq), 1, ik, 1);
			
			(*outdL_dParam)(0) = sum(sum(Lpsi1)) / dVariance;
			if (dlogZ_dv)
				(*outdL_dParam)(0) += (sum(sum(*dlogZ_dv)));

			(*outdL_dParam)(1) = sum(sum(sum(tile(Lpsi1, 1, 1, iq) * af::pow(Zmu, 2) / std::pow(dInvScale, 3))));
			*outdL_dXu = moddims(-sum(tile(Lpsi1, 1, 1, iq) * Zmu / dInvScale2, 0), ik, iq);

			if (outdL_dX != nullptr)
				*outdL_dX = moddims(-sum(tile(Lpsi1, 1, 1, iq) * Zmu / dInvScale2, 1), iN, iq);
		}

		template<typename Scalar>
		void RBFKernel<Scalar>::LogGradientCompoundKuu(const af::array& inXu, const af::array& inM, af::array* outdL_dParam, af::array* outdL_dXu)
		{
			af::array kuu(m_dType), Ml(m_dType), Xl(m_dType), Xbar(m_dType), Mbar1(m_dType), Mbar2(m_dType);

			int ik = inXu.dims(0);
			int iq = inXu.dims(1);

			ComputeKernelMatrix(inXu, inXu, kuu);

			*outdL_dParam = constant(0.0f, 1, GetNumParameter(), (m_dType));

			(*outdL_dParam)(0) = sum(sum(inM * kuu));

			Ml = 0.5 * inM * kuu;
			Xl = inXu / dInvScale;

			(*outdL_dParam)(1) = sum(matmulTN(constant(1.0f, Ml.dims(0)), matmulTN(Ml, af::pow(Xl, 2))) + matmulTN(constant(1.0f, Ml.dims(0)), matmul(Ml, af::pow(Xl, 2)))
				- 2.0 * matmulTN(constant(1.0f, Xl.dims(0)), (Xl * matmul(Ml, Xl))));

			Xbar = inXu / dInvScale;
			Mbar1 = -inM.T() * kuu;
			Mbar2 = -inM * kuu;
			*outdL_dXu = (Xbar * tile(matmulTN(constant(1.0f, Mbar1.dims(0)), Mbar1), iq, 1).T() - matmul(Mbar1, Xbar))
				+ (Xbar * tile(matmulTN(constant(1.0f, Mbar2.dims(0)), Mbar2), iq, 1).T() - matmul(Mbar2, Xbar));
		}

		template<typename Scalar>
		void RBFKernel<Scalar>::GradX(const af::array& inX1, const af::array& inX2, int q, af::array& outdK_dX)
		{
			int numData = inX1.dims(0);
			int numData2 = inX2.dims(0);

			af::array K1 = af::tile(inX1(af::span, q), 1, numData2);
			af::array K2 = af::tile(inX2(af::span, q).T(), numData, 1);

			af::array K(m_dType);
			ComputeKernelMatrix(inX1, inX2, K);

			if (K1.dims(0) == K.dims(0))
				outdK_dX = -((K1 - K2) * K) / dInvScale^2;
			else
				outdK_dX = -((K1 - K2) * K.T()) / dInvScale^2;
		}

		template<typename Scalar>
		void RBFKernel<Scalar>::DiagGradX(const af::array& inX, af::array& outDiagdK_dX)
		{
			outDiagdK_dX = af::constant(0.0f, inX.dims(), (m_dType));
		}

		template<typename Scalar>
		void RBFKernel<Scalar>::DiagGradParam(const af::array& inX, const af::array& inCovDiag, af::array& outDiagdK_dParam)
		{
			outDiagdK_dParam = af::array(1, 2, (m_dType));

			outDiagdK_dParam(0) = 0.0;
			outDiagdK_dParam(1) = af::sum(inCovDiag);
		}

		template<typename Scalar>
		void RBFKernel<Scalar>::SetParameters(const af::array & param)
		{
			dVariance = param(0).scalar<Scalar>();
			dInvScale = param(1).scalar<Scalar>();
		}

		template<typename Scalar>
		void RBFKernel<Scalar>::SetLogParameters(const af::array & param)
		{
			dVariance = af::exp(param(0)).scalar<Scalar>();
			dInvScale = af::exp(param(1)).scalar<Scalar>();
		}

		template<typename Scalar>
		af::array RBFKernel<Scalar>::GetParameters()
		{
			af::array param = constant(0.0f, 1, GetNumParameter(), (m_dType));
			param(0) = dVariance;
			param(1) = dInvScale;
			return param;
		}

		template<typename Scalar>
		af::array RBFKernel<Scalar>::GetLogParameters()
		{
			af::array param = constant(0.0f, 1, GetNumParameter(), (m_dType));
			param(0) = dVariance;
			param(1) = dInvScale;
			return af::log(param);
		}

		////////////////////////////////////////////////////////////////////////////////////////////////////
		/// PSI statistics
		////////////////////////////////////////////////////////////////////////////////////////////////////
		template<typename Scalar>
		void RBFKernel<Scalar>::Psi1Derivative(const af::array& inPsi1, const af::array& indL_dpsi1, const af::array& inZ, const af::array& inMu,
			const af::array& inSu, af::array& outdL_dParam, af::array& outdL_dXu, af::array* outdL_dX)
		{
			af::array Lpsi1(m_dType), Zmu(m_dType);

			int iN = inMu.dims(0);
			int ik = inZ.dims(0);
			int iq = inMu.dims(1);

			outdL_dParam = constant(0.0f, 1, GetNumParameter(), (m_dType));

			Scalar dInvScale2 = std::pow(dInvScale, 2);
			
			Lpsi1 = indL_dpsi1 * inPsi1;
			Zmu = tile(moddims(inZ, 1, ik, iq), iN, 1, 1) - tile(moddims(inMu, iN, 1, iq), 1, ik, 1);

			outdL_dParam(0) = sum(sum(Lpsi1)) / dVariance;
			outdL_dParam(1) = sum(sum(sum(tile(Lpsi1, 1, 1, iq) * af::pow(Zmu, 2) / std::pow(dInvScale, 3) / std::pow(dInvScale, 3))));
			outdL_dXu = moddims(-sum(tile(Lpsi1, 1, 1, iq) * Zmu / std::pow(dInvScale, 2), 0), ik, iq);

			if (outdL_dX != nullptr)
				*outdL_dX = moddims(-sum(tile(Lpsi1, 1, 1, iq) * Zmu / std::pow(dInvScale, 2), 1), iN, iq);
			//af_print(outdL_dParam);
		}
	}
}