/**
File:		MachineLearning/Kernel/FgLinearAccelerationKernel<Scalar>.cpp

Author:		Nick Taubert
Email:		nick.taubert@uni-tuebingen.de
Site:       http://www.compsens.uni-tuebingen.de/

Copyright (c) 2017 CompSens. All rights reserved.
*/

#include <NeEnginePCH.h>
#include <Core/NeLogger.h>
#include <MachineLearning/FgLinearAccelerationKernel.h>
#include <MachineLearning/CommonUtil.h>

namespace NeuralEngine
{
	namespace MachineLearning
	{
		template class LinearAccelerationKernel<float>;
		template class LinearAccelerationKernel<double>;

		template<typename Scalar>
		LinearAccelerationKernel<Scalar>::LinearAccelerationKernel()
			: IKernel<Scalar>(eLinearAccelerationKernel, 2), dVariance1(1.0f), dVariance2(1.0f)
		{
		}

		template<typename Scalar>
		LinearAccelerationKernel<Scalar>::~LinearAccelerationKernel() { }

		template<typename Scalar>
		void LinearAccelerationKernel<Scalar>::ComputeKernelMatrix(const af::array& inX1, const af::array& inX2, af::array& outMatrix)
		{
			int q = inX1.dims(1) / 2;

			outMatrix = dVariance1 * af::matmulNT(inX1(af::span, af::seq(0, q - 1)), inX2(af::span, af::seq(0, q - 1)))
				+ dVariance2 * af::matmulNT(inX1(af::span, af::seq(q, 2 * q - 1)), inX2(af::span, af::seq(q, 2 * q - 1)));
		}

		template<typename Scalar>
		void LinearAccelerationKernel<Scalar>::ComputeDiagonal(const af::array& inX, af::array& outDiagonal)
		{
			int q = inX.dims(1) / 2;
			outDiagonal = af::sum(inX(af::span, af::seq(0, q - 1)) * inX(af::span, af::seq(0, q - 1)), 1) * dVariance1
				+ af::sum(inX(af::span, af::seq(q, 2 * q - 1)) * inX(af::span, af::seq(q, 2 * q - 1)), 1) * dVariance2;
		}

		template<typename Scalar>
		void LinearAccelerationKernel<Scalar>::LogLikGradientX(const af::array& inX, const af::array& indL_dK, af::array& outdL_dX)
		{
			outdL_dX = af::constant(0.0f, inX.dims(), (m_dType));
			af::array dK_dX(m_dType);
			for (int q = 0; q < outdL_dX.dims(1); q++)
			{
				GradX(inX, inX, q, dK_dX);
				outdL_dX(af::span, q) = 2.0f * af::sum(indL_dK * dK_dX, 1) - af::diag(indL_dK) * af::diag(dK_dX);
			}
		}

		template<typename Scalar>
		void LinearAccelerationKernel<Scalar>::LogLikGradientX(const af::array& inXu, const af::array& indL_dKuu, const af::array& inX, const af::array& indL_dKuf, af::array& outdL_dXu, af::array& outdL_dX)
		{
			outdL_dXu = af::constant(0.0f, inXu.dims(), (m_dType));
			outdL_dX = af::constant(0.0f, inX.dims(), (m_dType));

			af::array dK_dX_u(m_dType); // overrider for dX and dXu
			for (int q = 0; q < outdL_dXu.dims(1); q++)
			{
				// dL_dKuu_dXu
				GradX(inXu, inXu, q, dK_dX_u);
				outdL_dXu(af::span, q) = 2 * af::sum(indL_dKuu * dK_dX_u, 1) - af::diag(indL_dKuu) * af::diag(dK_dX_u);

				// dL_dKuf_dXu
				GradX(inXu, inX, q, dK_dX_u);
				outdL_dXu(af::span, q) += af::sum(indL_dKuf * dK_dX_u, 1);

				// dL_dKuf_dX
				GradX(inX, inXu, q, dK_dX_u);
				outdL_dX(af::span, q) = af::sum(indL_dKuf.T() * dK_dX_u, 1);
			}
		}

		template<typename Scalar>
		void LinearAccelerationKernel<Scalar>::LogLikGradientParam(const af::array& inX1, const af::array& inX2, const af::array& indL_dK, af::array& outdL_dParam)
		{
			int q = inX1.dims(1) / 2;

			outdL_dParam = af::array(1, 2, (m_dType));
			outdL_dParam(0) = af::sum(af::sum(indL_dK * af::matmulNT(inX1(af::span, af::seq(0, q - 1)), inX2(af::span, af::seq(0, q - 1))))) * dVariance1;
			outdL_dParam(1) = af::sum(af::sum(indL_dK * af::matmulNT(inX1(af::span, af::seq(q, 2 * q - 1)), inX2(af::span, af::seq(q, 2 * q - 1))))) * dVariance2;
		}

		template<typename Scalar>
		void LinearAccelerationKernel<Scalar>::GradX(const af::array& inX1, const af::array& inX2, int q, af::array& outdK_dX)
		{
			int numData = inX1.dims(0);
			int numData2 = inX2.dims(0);

			if (q < inX1.dims(1) / 2)
				if (numData == numData2)
				{
					outdK_dX = af::tile(inX1(af::span, q).T(), numData, 1);
					outdK_dX += af::diag(af::diag(outdK_dX)) * dVariance1;
				}
				else
					outdK_dX = af::tile(inX2(af::span, q).T(), numData, 1) * dVariance1;
			else
				if (numData == numData2)
				{
					outdK_dX = af::tile(inX1(af::span, q).T(), numData, 1);
					outdK_dX += af::diag(af::diag(outdK_dX)) * dVariance2;
				}
				else
					outdK_dX = af::tile(inX2(af::span, q).T(), numData, 1) * dVariance2;
		}

		template<typename Scalar>
		void LinearAccelerationKernel<Scalar>::DiagGradX(const af::array& inX, af::array& outDiagdK_dX)
		{
			int q = inX.dims(1);
			outDiagdK_dX = af::constant(0.0f, inX.dims(), (m_dType));


			outDiagdK_dX(af::span, af::seq(0, q - 1)) = 2.0f * dVariance1 * inX(af::span, af::seq(0, q - 1));
			outDiagdK_dX(af::span, af::seq(q, 2 * q - 1)) = 2.0f * dVariance2 * inX(af::span, af::seq(q, 2 * q - 1));
		}

		template<typename Scalar>
		void LinearAccelerationKernel<Scalar>::DiagGradParam(const af::array& inX, const af::array& inCovDiag, af::array& outDiagdK_dParam)
		{
			int q = inX.dims(1) / 2;

			outDiagdK_dParam = af::array(1, 2, (m_dType));

			for (int i = 0; i < inX.dims(0); i++)
			{
				outDiagdK_dParam(0) = af::matmulNT(inX(af::span, af::seq(0, q - 1)), inX(af::span, af::seq(0, q - 1))) * inCovDiag(i) / dVariance1;
				outDiagdK_dParam(1) = af::matmulNT(inX(af::span, af::seq(q, 2 * q - 1)), inX(af::span, af::seq(q, 2 * q - 1))) * inCovDiag(i) / dVariance2;
			}
		}

		template<typename Scalar>
		void LinearAccelerationKernel<Scalar>::SetParameters(const af::array & param)
		{
			dVariance1 = param(0).scalar<Scalar>();
			dVariance2 = param(1).scalar<Scalar>();
		}

		template<typename Scalar>
		af::array LinearAccelerationKernel<Scalar>::GetParameters()
		{
			af::array param = constant(0.0f, 1, GetNumParameter(), (m_dType));
			param(0) = dVariance1;
			param(1) = dVariance2;
			return param;
		}

		template<typename Scalar>
		void LinearAccelerationKernel<Scalar>::Psi1Derivative(const af::array & inPsi1, const af::array & indL_dpsi1, const af::array & inZ, 
			const af::array & inMu, const af::array & inSu, af::array & outdL_dParam, af::array & outdL_dXu, af::array * outdL_dX)
		{
		}
	}
}