/**
File:		MachineLearning/GPModels/Backconstraints/FgKernelBasedRegression.cpp

Author:		
Email:		
Site:       

Copyright (c) 2020 . All rights reserved.
*/

#include <NeMachineLearningPCH.h>
#include <MachineLearning/FgPeriodicTopologicalConstraint.h>

namespace NeuralEngine::MachineLearning::GPModels
{
	template class PTC<float>;
	template class PTC<double>;

	template<typename Scalar>
	PTC<Scalar>::PTC()
		: IBackconstraint<Scalar>(BackConstType::ptc), afA(), pSegments(), iNumSegments(0)
	{
	}

	template<typename Scalar>
	PTC<Scalar>::~PTC()
	{
		//delete[] pSegments;
	}

	template<typename Scalar>
	void PTC<Scalar>::Init(const af::array& Y, const af::array& X, const af::array& segments)
	{
		LogAssert(X.dims(1) == 2, "Dimension missmatch! PTC needs exact 2 latent dimensions!");

		iNumSegments = segments.dims(0);
		iN = Y.dims(0);
		iq = X.dims(1);

		PCA* pca = new PCA();

		//pSegments = new int[iNumSegments];
		for (auto i = 0; i < iNumSegments; i++)
			pSegments.push_back((int)segments(i).scalar<Scalar>());

		af::array curY, curX;
		afA = af::constant(0.0, iNumSegments * 2, m_dType);	// offset and step size

		for (int i = 0; i < iNumSegments - 1; i++)
		{
			curY = Y(af::seq(pSegments[i], pSegments[i + 1] - 1), af::span);
			curX = pca->Compute(curY, iq);
			afA(iNumSegments + i) = 2.0 * af::Pi / curX.dims(0); // computing step size*/

			//_Y[ILMath.r(_segments[i], _segments[i + 1] - 1), ILMath.full] = curY;
		}

		curY = Y(af::seq(pSegments[iNumSegments - 1], Y.dims(0) - 1), af::span);
		curX = pca->Compute(curY, iq);
		afA(af::end) = 2.0 * af::Pi / curX.dims(0);

		//_Y[ILMath.r(_segments[ILMath.end], _Y.Size[0] - 1), ILMath.full] = curY;
		
		delete pca;
	}

	template<typename Scalar>
	int PTC<Scalar>::GetNumParameters()
	{
		return afA.dims(0);
	}

	template<typename Scalar>
	void PTC<Scalar>::SetParameters(const af::array& param)
	{
		afA = param;
	}

	template<typename Scalar>
	af::array PTC<Scalar>::GetParameters()
	{
		return afA;
	}

	template<typename Scalar>
	af::array PTC<Scalar>::GetConstraintX()
	{
		af::array X = af::constant(0.0, iN, iq, m_dType);

		Scalar theta_0, delta;
		for (int i = 0; i < iNumSegments - 1; i++)
		{
			/*theta_0 = afA(i).scalar<Scalar>();
			delta = afA(iNumSegments + i).scalar<Scalar>();*/
			for (int j = pSegments[i]; j < pSegments[i + 1]; j++)
			{
				int j_1 = j - pSegments[i];
				X(j, 0) = af::cos(afA(i) + (j_1) * afA(iNumSegments + i));
				X(j, 1) = af::sin(afA(i) + (j_1) * afA(iNumSegments + i)); // extracting the phase and setting to x_j
			}
		}

		/*theta_0 = afA(iNumSegments - 1).scalar<Scalar>();
		delta = afA(iNumSegments + iNumSegments - 1).scalar<Scalar>();*/
		for (int j = pSegments[iNumSegments - 1]; j < iN; j++)
		{
			int j_1 = j - pSegments[iNumSegments - 1];
			X(j, 0) = af::cos(afA(iNumSegments - 1) + (j_1) * afA(iNumSegments + iNumSegments - 1));
			X(j, 1) = af::sin(afA(iNumSegments - 1) + (j_1) * afA(iNumSegments + iNumSegments - 1)); // extracting the phase and setting to x_j
		}
		
		return X;
	}

	template<typename Scalar>
	af::array PTC<Scalar>::BackconstraintGradient(const af::array& gX)
	{
		af::array dX_dA = af::constant(0.0, iN, iq, iNumSegments * 2, m_dType);
		af::array dL_dA = af::constant(0.0, iNumSegments * 2, m_dType);

		Scalar theta_0, delta;
		for (int n = 0; n < iNumSegments - 1; n++)
		{
			/*theta_0 = afA(n).scalar<Scalar>();
			delta = afA(iNumSegments + n).scalar<Scalar>();*/
			for (int j = pSegments[n]; j < pSegments[n + 1]; j++)
			{
				int j_1 = j - pSegments[n];

				dX_dA(j, 0, n) = -af::sin(afA(n) + (j_1) * afA(iNumSegments + n));
				dX_dA(j, 1, n) = af::cos(afA(n) + (j_1) * afA(iNumSegments + n));

				dX_dA(j, 0, iNumSegments + n) = -af::sin(afA(n) + (j_1)* afA(iNumSegments + n));
				dX_dA(j, 1, iNumSegments + n) = af::cos(afA(n) + (j_1)* afA(iNumSegments + n));
			}
		}

		/*theta_0 = afA(iNumSegments - 1).scalar<Scalar>();
		delta = afA(iNumSegments + iNumSegments - 1).scalar<Scalar>();*/
		for (int j = pSegments[iNumSegments - 1]; j < iN; j++)
		{
			int j_1 = j - pSegments[iNumSegments - 1];

			dX_dA(j, 0, iNumSegments - 1) = -af::sin(afA(iNumSegments - 1) + (j_1)* afA(iNumSegments + iNumSegments - 1));
			dX_dA(j, 1, iNumSegments - 1) = af::cos(afA(iNumSegments - 1) + (j_1)* afA(iNumSegments + iNumSegments - 1));

			dX_dA(j, 0, iNumSegments + iNumSegments - 1) = -af::sin(afA(iNumSegments - 1) + (j_1)* afA(iNumSegments + iNumSegments - 1));
			dX_dA(j, 1, iNumSegments + iNumSegments - 1) = af::cos(afA(iNumSegments - 1) + (j_1)* afA(iNumSegments + iNumSegments - 1));
		}


		for (int n = 0; n < iNumSegments; n++)
		{
			dL_dA(n) = af::sum(af::diag(af::matmulTN(gX, dX_dA(af::span, af::span, n))));
			dL_dA(iNumSegments + n) = af::sum(af::diag(af::matmulTN(gX, dX_dA(af::span, af::span, iNumSegments + n))));
		}

		return dL_dA;
	}
}