/**
File:		MachineLearning/Models/GPModels/FgSparseGPLVMBaseModel.cpp

Author:		
Email:		
Site:       

Copyright (c) 2017 . All rights reserved.
*/

#include <NeMachineLearningPCH.h>
#include <MachineLearning/FgSparseGPLVMBaseModel.h>
#include <MachineLearning/FgAEPSparseGPR.h>

namespace NeuralEngine::MachineLearning::GPModels
{
	template class SparseGPLVMBaseModel<float>;
	template class SparseGPLVMBaseModel<double>;

	template<typename Scalar>
	SparseGPLVMBaseModel<Scalar>::SparseGPLVMBaseModel(const af::array& Y, int latentDimension, Scalar priorMean, Scalar priorVariance, int numInducing, LogLikType lType, XInit emethod)
		: GPLVMBaseModel(Y, latentDimension, priorMean, priorVariance, lType, emethod), ik(numInducing), gpLayer(nullptr)
	{
	}

	template<typename Scalar>
	SparseGPLVMBaseModel<Scalar>::SparseGPLVMBaseModel()
		: GPLVMBaseModel()
	{
	}

	template<typename Scalar>
	SparseGPLVMBaseModel<Scalar>::~SparseGPLVMBaseModel()
	{
		if (gpLayer) delete gpLayer;
	}

	template<typename Scalar>
	void SparseGPLVMBaseModel<Scalar>::PredictF(const af::array& testInputs, af::array& mf, af::array& vf)
	{
		GPBaseModel<Scalar>::PredictF(testInputs, mf, vf);

		gpLayer->ForwardPredictionPost(&testInputs, nullptr, mf, vf);
	}

	template<typename Scalar>
	void SparseGPLVMBaseModel<Scalar>::SampleY(const af::array inputs, int numSamples, af::array& outFunctions)
	{
		if (!bInit) Init();

		outFunctions = af::array(inputs.dims(0), iD, numSamples, m_dType);
		af::array tmpFunc(m_dType);

		for (int i = 0; i < numSamples; i++)
		{
			gpLayer->SampleFromPost(inputs, tmpFunc);
			outFunctions(af::span, af::span, i) = tmpFunc;
		}
	}

	template<typename Scalar>
	bool SparseGPLVMBaseModel<Scalar>::Init()
	{
		af::array mx, vx;
		if (!GetParent())
			bInit = GPLVMBaseModel::Init(mx);
		else
			PosteriorLatents(mx, vx);
			
		gpLayer->SetDataSize(iN, iD);

		if (mStyles)
		{
			// Init inter domain kernel
			InterDomainKernel<Scalar>* kInter = new InterDomainKernel<Scalar>();

			// Init the window kernel for number of styles
			TensorKernel<Scalar>* kWindow = new TensorKernel<Scalar>();
			kWindow->AddKernel(new ARDKernel<Scalar>(iq), af::seq(0, iq - 1));

			int iStart, iEnd = iq;
			for (auto style = mStyles->begin(); style != mStyles->end(); style++)
			{
				iStart = iEnd;  iEnd += style->second.GetNumSubstyles();
				kWindow->AddKernel(new StyleKernel<Scalar>(style->second.GetNumSubstyles()), af::seq(iStart, iEnd - 1));
			}

			// add window kernel to inter domain kernel
			//kInter->AddWindowKernel(kWindow);
			kInter->AddSubKernel(kWindow);

			// add subset kernel kernel to inter domain kernel
			//kInter->AddSubKernel(new ARDKernel<Scalar>(iEnd));
			kInter->AddWindowKernel(new ARDKernel<Scalar>(iEnd));
			
			// add inter domain kernel to GP layer
			gpLayer->SetKernel(kInter);

			//// Init the window kernel for number of styles
			//TensorKernel<Scalar>* kWindow = new TensorKernel<Scalar>();
			//kWindow->AddKernel(new ARDKernel<Scalar>(iq), af::seq(0, iq - 1));

			//int iStart, iEnd = iq;
			//for (auto style = mStyles->begin(); style != mStyles->end(); style++)
			//{
			//	iStart = iEnd;  iEnd += style->second.GetNumSubstyles();
			//	kWindow->AddKernel(new StyleKernel<Scalar>(style->second.GetNumSubstyles()), af::seq(iStart, iEnd - 1));
			//}
			//gpLayer->SetKernel(kInter);
			gpLayer->SetStyles(mStyles);

			//gpLayer->SetLatentDimension(iEnd);
		}

		// Init posterior
		//std::cout << "Init posterior through SGPR\n===========================\n" << std::endl;
		//AEP::SGPR<Scalar>* model = new AEP::SGPR<Scalar>(afY, mx, ik, 1.0);
		//model->FixKernelParameters(true);
		//model->FixInducing(true);
		//model->Optimise(OptimizerType::ADAM, 0.0, false, 100, iBatchSize);
		///*model->FixKernelParameters(false);
		//model->FixInducing(false);*/
		//gpLayer->FixKernelParameters(true);
		//gpLayer->FixInducing(true);
		//gpLayer->SetParameters(model->GetGPLayer()->GetParameters());
		//gpLayer->FixKernelParameters(false);
		//gpLayer->FixInducing(false);
		//delete model;
		//std::cout << "done.\n\nOptimizing GPLVM\n================\n" << std::endl;

		gpLayer->InitParameters(&mx);

		for (uint i = 0; i < GetNumChildren(); i++)
		{
			GPLVMBaseModel<Scalar>& child = dynamic_cast<GPLVMBaseModel<Scalar>&>(*GetChild(i));
			child.Init();
		}

		return bInit;
	}

	template<typename Scalar>
	int SparseGPLVMBaseModel<Scalar>::GetNumParameters()
	{
		int numParam = GPLVMBaseModel::GetNumParameters();

		numParam += gpLayer->GetNumParameters();

		return numParam;
	}

	template<typename Scalar>
	void SparseGPLVMBaseModel<Scalar>::SetParameters(const af::array& param)
	{
		int iStart = 0, iEnd = GPLVMBaseModel::GetNumParameters();
		GPLVMBaseModel::SetParameters(param(af::seq(iStart, af::end)));

		iStart = iEnd, iEnd += gpLayer->GetNumParameters();
		gpLayer->SetParameters(param(af::seq(iStart, iEnd - 1)));
	}

	template<typename Scalar>
	af::array SparseGPLVMBaseModel<Scalar>::GetParameters()
	{
		af::array param = af::constant(0.0f, SparseGPLVMBaseModel::GetNumParameters(), m_dType);

		int iStart = 0, iEnd = GPLVMBaseModel::GetNumParameters();
		param(af::seq(iStart, iEnd - 1)) = GPLVMBaseModel::GetParameters();

		iStart = iEnd; iEnd += gpLayer->GetNumParameters();
		param(af::seq(iStart, iEnd - 1)) = gpLayer->GetParameters();

		return param;
	}

	template<typename Scalar>
	void SparseGPLVMBaseModel<Scalar>::UpdateParameters()
	{
		GPLVMBaseModel<Scalar>::UpdateParameters();

		gpLayer->UpdateParameters();
	}

	template<typename Scalar>
	void SparseGPLVMBaseModel<Scalar>::FixKernelParameters(bool isfixed)
	{
		gpLayer->FixKernelParameters(isfixed);
	}

	template<typename Scalar>
	void SparseGPLVMBaseModel<Scalar>::FixInducing(bool isfixed)
	{
		gpLayer->FixInducing(isfixed);
	}
}