/**
File:		MachineLearning/GPModels/SparseGPModels/SparseGPBaseLayer.cpp

Author:		
Email:		
Site:       

Copyright (c) 2017 . All rights reserved.
*/

#include <NeMachineLearningPCH.h>
#include <MachineLearning/FgSparseGPBaseLayer.h>

namespace NeuralEngine::MachineLearning::GPModels
{
	template class SparseGPBaseLayer<float>;
	template class SparseGPBaseLayer<double>;

	template<typename Scalar>
	SparseGPBaseLayer<Scalar>::SparseGPBaseLayer(int numPoints, int numPseudos, int outputDim, int inputDim)
		: GPBaseLayer<Scalar>(numPoints, outputDim, inputDim), ik(numPseudos), isFixedInducing(false)
	{
		T1 = af::constant(0.0f, ik, 1, iD, m_dType);
		T2 = af::constant(0.0f, ik, ik, iD, m_dType);
		T2_R = af::constant(0.0f, ik, ik, iD, m_dType);

		afMu = af::array(ik, 1, iD, m_dType);
		afSu = af::array(ik, ik, iD, m_dType);
		afInvSu = af::array(ik, ik, iD, m_dType);
		afInvSuMu = af::array(ik, ik, iD, m_dType);

		afKuu = af::array(ik, ik, m_dType);
		afInvKuu = af::array(ik, ik, m_dType);
		afKfu = af::array(iN, ik, m_dType);

		afXu = af::array(ik, iq, m_dType);
	}

	template<typename Scalar>
	SparseGPBaseLayer<Scalar>::~SparseGPBaseLayer()
	{
	}

	template<typename Scalar>
	void SparseGPBaseLayer<Scalar>::ForwardPredictionPost(const af::array* mx, const af::array* vx, af::array& mout, af::array& vout)
	{
		if (vx == nullptr) 
			ForwardPredictionDeterministicPost(*mx, &mout, &vout);
		else 
			ForwardPredictionRandomPost(*mx, *vx, mout, vout);
	}

	template<typename Scalar>
	void SparseGPBaseLayer<Scalar>::SampleFromPost(const af::array& inX, af::array& fsample)
	{
		af::array Lu(m_dType), u_sample(m_dType), qfu(m_dType), mf(m_dType), vf(m_dType), Lf(m_dType);
		af::array kff(m_dType), kfu(m_dType);

		af::cholesky(Lu, afSu);
		af::array epsilon = af::randn(ik, iD, m_dType);

		u_sample = afMu + moddims(sum(Lu * tile(moddims(epsilon, ik, 1, iD), 1, ik, 1), 1), ik, iD);

		//kernel->SetLogParameters(kernel->GetLogParameters() * 2.0);
		kernel->ComputeKernelMatrix(inX, inX, kff);
		kff += af::diag(af::constant(JITTER, inX.dims(0), m_dType), 0, false);
		kernel->ComputeKernelMatrix(inX, afXu, kfu);
		//kernel->SetLogParameters(kernel->GetLogParameters() / 2.0);

		qfu = matmul(kfu, afInvKuu);
		mf = matmul(qfu, u_sample);
		vf = kff - matmul(qfu, kfu.T());

		af::cholesky(Lf, vf);
		epsilon = af::randn(inX.dims(0), iD, m_dType);
		fsample = mf + matmul(Lf, epsilon);
	}

	template<typename Scalar>
	void SparseGPBaseLayer<Scalar>::ComputeKuu()
	{
		LogAssert(!afXu.isempty(), "Xu is empty. Init parameters first!");

		//kernel->SetLogParameters(2 * kernel->GetLogParameters());

		kernel->ComputeKernelMatrix(afXu, afXu, afKuu);
		afKuu += af::diag(af::constant(JITTER, ik, m_dType), 0, false);

		/*try { afInvKuu = af::inverse(afKuu); }
		catch (...)*/ { afInvKuu = CommonUtil<Scalar>::PDInverse(afKuu); }

		//kernel->SetLogParameters(kernel->GetLogParameters() / 2);
	}

	template<typename Scalar>
	void SparseGPBaseLayer<Scalar>::ComputeKfu(const af::array& inX)
	{
		LogAssert(!afXu.isempty(), "Xu is empty. Init parameters first!");
		LogAssert(!inX.isempty(), "X is empty. Init parameters first!");

		//kernel->SetLogParameters(2 * kernel->GetLogParameters());

		kernel->ComputeKernelMatrix(inX, afXu, afKfu);

		//kernel->SetLogParameters(kernel->GetLogParameters() / 2);
	}

	template<typename Scalar>
	af::array SparseGPBaseLayer<Scalar>::GetPseudoInputs()
	{
		return afXu;
	}

	template<typename Scalar>
	int SparseGPBaseLayer<Scalar>::GetNumParameters()
	{
		int numParam = GPBaseLayer<Scalar>::GetNumParameters();
		numParam += ik * iD;					// T1
		numParam += (ik * (ik + 1) / 2) * iD;	// T2_R
		if (!isFixedInducing) numParam += ik * iq;	// Xu
		return numParam;
	}

	template<typename Scalar>
	void SparseGPBaseLayer<Scalar>::SetParameters(const af::array& param)
	{
		int istart = 0, iend = 0;

		// Base layer call -> intrinsic kernel parameter update
		iend = GPBaseLayer::GetNumParameters();
		GPBaseLayer::SetParameters(param(af::seq(istart, iend - 1)));

		// T1
		istart = iend;
		iend += ik * iD;
		T1 = af::moddims(param(af::seq(istart, iend - 1)), ik, 1, iD);

		af::array triInx, diagIdx, paramT2_R(m_dType), Rd(m_dType);

		istart = iend;
		iend += (ik * (ik + 1) / 2) * iD;
		paramT2_R = af::moddims(param(af::seq(istart, iend - 1)), iD, ik * (ik + 1) / 2);

		triInx = CommonUtil<Scalar>::TriUpperIdx(ik);
		diagIdx = CommonUtil<Scalar>::DiagIdx(ik);
		for (int d = 0; d < iD; d++)
		{
			Rd = af::constant(0.0f, ik, ik, (m_dType));
			Rd(triInx) = paramT2_R(d, af::span).T();
			Rd(diagIdx) = exp(Rd(diagIdx));
			T2_R(af::span, af::span, d) = Rd.copy();
			T2(af::span, af::span, d) = matmulTN(Rd, Rd);
		}

		// afXu
		istart = iend;
		iend += ik * iq;
		if (!isFixedInducing)
		{
			afXu = af::moddims(param(af::seq(istart, iend - 1)), ik, iq);

			if (mStyles)
			{
				int sStart = iq, sEnd = iq;
				for (auto style = mStyles->begin(); style != mStyles->end(); style++)
				{
					sStart = sEnd; sEnd += style->second.GetNumInducingSubstyles();

					afXu = CommonUtil<Scalar>::Join(afXu, style->second.GetInducingStyleExpanded(), 1);
				}
			}
		}

		UpdateParameters();
	}

	template<typename Scalar>
	af::array SparseGPBaseLayer<Scalar>::GetParameters()
	{
		m_dType = CommonUtil<Scalar>::CheckDType();
		af::array param = af::constant(0.0f, GetNumParameters(), (m_dType));

		int iStart = 0, iEnd = 0;
		iEnd = GPBaseLayer::GetNumParameters();
		param(af::seq(iStart, iEnd - 1)) = GPBaseLayer::GetParameters();

		iStart = iEnd;
		iEnd += ik * iD;
		
		/*std::cout << T1.type() << std::endl;
		std::cout << T2_R.type() << std::endl;*/

		param(af::seq(iStart, iEnd - 1)) = af::flat(T1);

		af::array triInx, diagIdx, paramT2_R(m_dType), Rd(m_dType);

		triInx = CommonUtil<Scalar>::TriUpperIdx(ik);
		diagIdx = CommonUtil<Scalar>::DiagIdx(ik);

		paramT2_R = af::constant(0.0f, iD, ik * (ik + 1) / 2, (m_dType));
		for (int d = 0; d < iD; d++)
		{
			Rd = T2_R(af::span, af::span, d).copy();
			Rd(diagIdx) = log(Rd(diagIdx));
			paramT2_R(d, af::span) = Rd(triInx).copy();
		}

		iStart = iEnd;
		iEnd += (ik * (ik + 1) / 2) * iD;
		param(af::seq(iStart, iEnd - 1)) = af::flat(paramT2_R);

		iStart = iEnd;
		iEnd += ik * iq;
		if (!isFixedInducing) param(af::seq(iStart, iEnd - 1)) = af::flat(afXu(af::span, af::seq(0, iq - 1)));

		return param;
	}

	template<typename Scalar>
	void SparseGPBaseLayer<Scalar>::FixInducing(bool isfixed)
	{
		isFixedInducing = isfixed;
	}

	template<typename Scalar>
	void SparseGPBaseLayer<Scalar>::UpdateParameters()
	{
		GPBaseLayer<Scalar>::UpdateParameters();

		ComputeKuu();
		
		// computing posterior distribution
		afInvSuMu = T1;

		afInvSu = tile(afInvKuu, 1, 1, iD) + T2;

		afSu = constant(0.0, afInvSu.dims(), m_dType);
		for (int d = 0; d < iD; d++)
		{
			try { afSu(af::span, af::span, d) = af::inverse(afInvSu(af::span, af::span, d)); }
			catch (...) { afSu(af::span, af::span, d) = CommonUtil<Scalar>::PDInverse(afInvSu(af::span, af::span, d)); }
		}

		afMu = matmul(afSu, afInvSuMu);
	}

	template<typename Scalar>
	void SparseGPBaseLayer<Scalar>::InitParameters(af::array* X)
	{
		if (X == nullptr)
		{
			af::array param(m_dType);
			afXu = tile(CommonUtil<Scalar>::LinSpace(-1.0f, 1.0f, ik), 1, iq);

			param = af::log(af::randu(kernel->GetNumParameter(), m_dType) * 0.1 + af::constant(1.0, kernel->GetNumParameter(), m_dType));
			param(0) = log(0.5);

			/*param = af::constant(log(0.2), kernel->GetNumParameter(), m_dType);
			param(0) = log(0.5);*/

			/*param = af::constant(0.7849934609, kernel->GetNumParameter(), m_dType);
			param(0) = -0.6931471806;

			afXu = CommonUtil<Scalar>::ReadTXT("../../resources/zu_dynamic.txt");*/

			kernel->SetLogParameters(param);
		}
		else
		{
			af::array idx, means(m_dType), clusters(m_dType), xTrain(m_dType),
				X1(m_dType), distances(m_dType), afMed(m_dType), param(m_dType);
			Scalar med;

			xTrain = X->copy();
			if (iN < 10000/* && !mStyles*/)
			{
				KMeans<Scalar>::Compute(means, clusters, (*X)(af::span, af::seq(0, iq - 1)), ik, 1000);

				if (mStyles)
				{
					af::setSeed(NULL);
					idx = af::round(af::randu(iN) * iN)(af::seq(ik));

					for (auto style = mStyles->begin(); style != mStyles->end(); style++)
					{
						style->second.SetInducingStyleIndex(style->second.GetStyleIndex()(idx));
						means = CommonUtil<Scalar>::Join(means, style->second.GetInducingStyleExpanded(), 1);
					}
				}
			}
			else
			{
				af::setSeed(NULL);
				idx = af::round(af::randu(iN) * iN)(af::seq(ik));
				means = xTrain(idx, af::span);

				if (mStyles)
					for (auto style = mStyles->begin(); style != mStyles->end(); style++)
						style->second.SetInducingStyleIndex(style->second.GetStyleIndex()(idx));
			}

			afXu = means;

			//afXu = tile(CommonUtil<Scalar>::LinSpace(-1.0f, 1.0f, ik), 1, iq);

			if (iN < 1000)
				X1 = xTrain.copy();
			else
			{
				af::setSeed(NULL);
				idx = af::round(af::randu(iN) * iN)(af::seq(1000));
				X1 = xTrain(idx, af::seq(0, iq - 1));
			}

			distances = af::sqrt(CommonUtil<Scalar>::SquareDistance(X1, X1));
			idx = CommonUtil<Scalar>::TriUpperIdx(X1.dims(0));
			med = af::median<Scalar>(distances(idx));

			kernel->InitParameters(med);
		}

		af::array Su(m_dType), mu(m_dType), invSu(m_dType), alpha(m_dType);
		af::array triInx, diagIdx, paramT2_R(m_dType), Rd(m_dType);

		diagIdx = CommonUtil<Scalar>::DiagIdx(ik);
		triInx = CommonUtil<Scalar>::TriUpperIdx(ik);
		for (int d = 0; d < iD; d++)
		{
			// alpha = 0.5 * af::randu(ik, m_dType);
			alpha = af::constant(0.01, ik, m_dType);
			mu = CommonUtil<Scalar>::LinSpace(-1.0f, 1.0f, ik);

			Su = af::diag(alpha, 0, false);
			invSu = af::diag(1.0f / alpha, 0, false);

			/*Su = af::diag(constant(0.01, ik), 0, false);
			invSu = af::diag(1 / (constant(0.01, ik)), 0, false);*/

			T1(af::span, 0, d) = af::matmul(invSu, mu);
			T2(af::span, af::span, d) = invSu;

			af::cholesky(Rd, invSu);
			//Rd(diagIdx) = log(Rd(diagIdx));
			T2_R(af::span, af::span, d) = Rd;
		}

		/*T1 = CommonUtil<Scalar>::ReadTXT("../../resources/eta2_dynamic.txt").T();

		T2_R = af::constant(0.0, ik, ik, m_dType);
		T2_R(triInx) = CommonUtil<Scalar>::ReadTXT("../../resources/eta1_R_dynamic.txt").T();
		T2_R(diagIdx) = af::exp(T2_R(diagIdx));

		T2 = matmulTN(T2_R, T2_R);*/

		UpdateParameters();
	}

	template<typename Scalar>
	void SparseGPBaseLayer<Scalar>::ReinitParameters()
	{
		GPBaseLayer::ReinitParameters();

		T1 = af::constant(0.0f, ik, 1, iD, m_dType);
		T2 = af::constant(0.0f, ik, ik, iD, m_dType);
		T2_R = af::constant(0.0f, ik, ik, iD, m_dType);

		afMu = af::array(ik, 1, iD, m_dType);
		afSu = af::array(ik, ik, iD, m_dType);
		afInvSu = af::array(ik, ik, iD, m_dType);
		afInvSuMu = af::array(ik, ik, iD, m_dType);

		afKuu = af::array(ik, ik, m_dType);
		afInvKuu = af::array(ik, ik, m_dType);
		afKfu = af::array(iN, ik, m_dType);

		afXu = af::array(ik, iq, m_dType);
	}

	template<typename Scalar>
	void SparseGPBaseLayer<Scalar>::ForwardPredictionDeterministicPost(const af::array& mx, af::array* mout, af::array* vout)
	{
	}

	template<typename Scalar>
	void SparseGPBaseLayer<Scalar>::ForwardPredictionRandomPost(const af::array & mx, const af::array & vx, af::array & mout, af::array & vout, PropagationMode mode)
	{
	}
}


