/**
File:		MachineLearning/Models/GPModels/FgSparseDeepGPBaseModel.h

Author:		
Email:		
Site:       

Copyright (c) 2017 . All rights reserved.
*/

#include <NeMachineLearningPCH.h>
#include <MachineLearning/FgSparseDeepGPLVMBaseModel.h>
#include <MachineLearning/FgAEPSparseDGPR.h>

namespace NeuralEngine::MachineLearning::GPModels
{
	template class SparseDeepGPLVMBaseModel<float>;
	template class SparseDeepGPLVMBaseModel<double>;

	template<typename Scalar>
	SparseDeepGPLVMBaseModel<Scalar>::SparseDeepGPLVMBaseModel(const af::array& Y, int latentDimension, HiddenLayerDescription description, 
		Scalar priorMean, Scalar priorVariance, LogLikType lType, XInit emethod)
		: GPLVMBaseModel(Y, latentDimension, priorMean, priorVariance, lType, emethod), vDescription(), vNumPseudosPerLayer(), vSize(), gpLayer()

	{
		vDescription.push_back(description);

		vSize.push_back(iq);
		vSize.push_back(description.GetNumHiddenDimensions());
		vSize.push_back(iD);

		iNumLayer = vSize.size() - 1;

		for (uint i = 0; i < iNumLayer; i++) vNumPseudosPerLayer.push_back(description.GetNumPseudoInputs());

		switch (lType)
		{
		case LogLikType::Probit:
			likLayer = new ProbitLikLayer<Scalar>(iN, iD);
			break;
		default:
			likLayer = new GaussLikLayer<Scalar>(iN, iD);
			break;
		}
	}

	template<typename Scalar>
	SparseDeepGPLVMBaseModel<Scalar>::SparseDeepGPLVMBaseModel(const af::array& Y, int latentDimension, std::vector<HiddenLayerDescription> descriptions, 
		Scalar priorMean, Scalar priorVariance, LogLikType lType, XInit emethod)
		: GPLVMBaseModel(Y, latentDimension, priorMean, priorVariance, lType, emethod), vDescription(descriptions), vNumPseudosPerLayer(), vSize(), gpLayer()
	{
		vSize.push_back(iq);
		for (auto description : descriptions)
		{
			vSize.push_back(description.GetNumHiddenDimensions());
		}
		vSize.push_back(iD);

		iNumLayer = vSize.size() - 1;

		for (uint i = 0; i < iNumLayer; i++)
			if (i == 0) vNumPseudosPerLayer.push_back(descriptions[i].GetNumPseudoInputs());
			else vNumPseudosPerLayer.push_back(descriptions[i - 1].GetNumPseudoInputs());

		switch (lType)
		{
		case LogLikType::Probit:
			likLayer = new ProbitLikLayer<Scalar>(iN, iD);
			break;
		default:
			likLayer = new GaussLikLayer<Scalar>(iN, iD);
			break;
		}
	}

	template<typename Scalar>
	SparseDeepGPLVMBaseModel<Scalar>::SparseDeepGPLVMBaseModel()
		: GPLVMBaseModel<Scalar>(), vNumPseudosPerLayer(), vSize(), gpLayer()
	{
	}

	template<typename Scalar>
	SparseDeepGPLVMBaseModel<Scalar>::~SparseDeepGPLVMBaseModel()
	{
		for (auto layer : gpLayer) delete layer;
		gpLayer.clear();
	}

	template<typename Scalar>
	bool SparseDeepGPLVMBaseModel<Scalar>::Init()
	{
		af::array mx, vx;
		if (!GetParent())
			bInit = GPLVMBaseModel::Init(mx);
		else
			PosteriorLatents(mx, vx);

		vSize.back() = iD;
		for (int i = 0; i < iNumLayer; i++)
			gpLayer[i]->SetDataSize(iN, vSize[i + 1]);

		if (mStyles)
		{
			// Init inter domain kernel
			InterDomainKernel<Scalar>* kInter = new InterDomainKernel<Scalar>();

			// Init the window kernel for number of styles
			TensorKernel<Scalar>* kWindow = new TensorKernel<Scalar>();
			kWindow->AddKernel(new ARDKernel<Scalar>(iq), af::seq(0, iq - 1));

			int iStart, iEnd = iq;
			for (auto style = mStyles->begin(); style != mStyles->end(); style++)
			{
				iStart = iEnd;  iEnd += style->second.GetNumSubstyles();
				kWindow->AddKernel(new StyleKernel<Scalar>(style->second.GetNumSubstyles()), af::seq(iStart, iEnd - 1));
			}

			// add window kernel to inter domain kernel
			//kInter->AddWindowKernel(kWindow);
			kInter->AddSubKernel(kWindow);

			// add subset kernel kernel to inter domain kernel
			//kInter->AddSubKernel(new ARDKernel<Scalar>(iEnd));
			kInter->AddWindowKernel(new ARDKernel<Scalar>(iEnd));

			// add inter domain kernel to GP layer
			gpLayer[0]->SetKernel(kInter);

			gpLayer[0]->SetStyles(mStyles);

			//gpLayer[0]->SetLatentDimension(iEnd);
		}

		for (int i = 0; i < iNumLayer; i++)
			if (i == 0) gpLayer[i]->InitParameters(&mx);
			else gpLayer[i]->InitParameters();

		//// Init posterior
		//std::cout << "Init posterior through SDGPR\n============================\n" << std::endl;
		//AEP::SDGPR<Scalar>* model = new AEP::SDGPR<Scalar>(afY, mx, vDescription);
		//if (mStyles)
		//{
		//	// Init inter domain kernel
		//	InterDomainKernel<Scalar>* kInter = new InterDomainKernel<Scalar>();

		//	// Init the window kernel for number of styles
		//	TensorKernel<Scalar>* kWindow = new TensorKernel<Scalar>();
		//	kWindow->AddKernel(new ARDKernel<Scalar>(iq), af::seq(0, iq - 1));

		//	int iStart, iEnd = iq;
		//	for (auto style = mStyles->begin(); style != mStyles->end(); style++)
		//	{
		//		iStart = iEnd;  iEnd += style->second.GetNumSubstyles();
		//		kWindow->AddKernel(new StyleKernel<Scalar>(style->second.GetNumSubstyles()), af::seq(iStart, iEnd - 1));
		//	}

		//	// add window kernel to inter domain kernel
		//	//kInter->AddWindowKernel(kWindow);
		//	kInter->AddSubKernel(kWindow);

		//	// add subset kernel kernel to inter domain kernel
		//	//kInter->AddSubKernel(new ARDKernel<Scalar>(iEnd));
		//	kInter->AddWindowKernel(new ARDKernel<Scalar>(iEnd));

		//	model->GetGPLayers()[0]->SetKernel(kInter);
		//	model->GetGPLayers()[0]->SetStyles(mStyles);
		//	model->GetGPLayers()[0]->SetLatentDimension(iq);
		//}
		//model->FixKernelParameters(true);
		//model->FixInducing(true);
		/*model->Optimise(OptimizerType::ADAM, 0.0, false, 100, iBatchSize);
		model->FixKernelParameters(false);
		model->FixInducing(false);

		for (int i = 0; i < iNumLayer; i++)
			gpLayer[i]->SetParameters(model->GetGPLayers()[i]->GetParameters());

		delete model;
		std::cout << "done.\n\nOptimizing SDGPLVM\n==================\n" << std::endl;*/


		for (uint i = 0; i < GetNumChildren(); i++)
		{
			GPLVMBaseModel<Scalar>& child = dynamic_cast<GPLVMBaseModel<Scalar>&>(*GetChild(i));
			child.Init();
		} 

		return bInit;
	}

	template<typename Scalar>
	void SparseDeepGPLVMBaseModel<Scalar>::PredictF(const af::array& testInputs, af::array& mf, af::array& vf)
	{
		af::array mfTmp, vfTmp;

		GPLVMBaseModel::PredictF(testInputs, mf, vf);

		for (uint i = 0; i < iNumLayer; i++)
		{
			//SparseGPBaseLayer<Scalar>& slayer = dynamic_cast<SparseGPBaseLayer<Scalar>&>(*gpLayer[i]);
			if (i == 0) gpLayer[i]->ForwardPredictionPost(&testInputs, nullptr, mf, vf);
			else
			{
				gpLayer[i]->ForwardPredictionPost(&mf, &vf, mfTmp, vfTmp);
				mf = mfTmp.copy();
				vf = vfTmp.copy();
			}
		}
	}

	template<typename Scalar>
	void SparseDeepGPLVMBaseModel<Scalar>::SampleY(const af::array inputs, int numSamples, af::array& outFunctions)
	{
		GPLVMBaseModel::SampleY(inputs, numSamples, outFunctions);

		outFunctions = af::array(inputs.dims(0), iD, numSamples, m_dType);
		af::array tmpFunc, inputsTmp;

		for (uint i = 0; i < numSamples; i++)
		{
			inputsTmp = inputs.copy();
			for (uint k = 0; k < iNumLayer; i++)
			{
				gpLayer[k]->SampleFromPost(inputsTmp, tmpFunc);
				inputsTmp = tmpFunc.copy();
			}
			outFunctions(af::span, af::span, i) = tmpFunc;
		}
	}

	template<typename Scalar>
	int SparseDeepGPLVMBaseModel<Scalar>::GetNumParameters()
	{
		int numParam = GPLVMBaseModel::GetNumParameters();

		for (auto layer : gpLayer)
			numParam += layer->GetNumParameters();

		return numParam;
	}

	template<typename Scalar>
	int SparseDeepGPLVMBaseModel<Scalar>::GetNumGPLayerParameters()
	{
		int numParam = 0;
		for (auto layer : gpLayer)
			numParam += layer->GetNumParameters();

		return numParam;
	}

	template<typename Scalar>
	void SparseDeepGPLVMBaseModel<Scalar>::SetParameters(const af::array& param)
	{
		int iStart = 0, iEnd = GPLVMBaseModel::GetNumParameters();
		GPLVMBaseModel::SetParameters(param(af::seq(iStart, af::end)));

		iStart = iEnd;
		for (int i = iNumLayer - 1; i >= 0; i--)
		{
			iEnd += gpLayer[i]->GetNumParameters();
			gpLayer[i]->SetParameters(param(af::seq(iStart, iEnd - 1)));
			iStart = iEnd;
		}
	}

	template<typename Scalar>
	af::array SparseDeepGPLVMBaseModel<Scalar>::GetParameters()
	{
		af::array param = af::constant(0.0f, SparseDeepGPLVMBaseModel::GetNumParameters(), m_dType);

		int iStart = 0, iEnd = GPLVMBaseModel::GetNumParameters();
		param(af::seq(iStart, iEnd - 1)) = GPLVMBaseModel::GetParameters();

		iStart = iEnd;
		for (int i = iNumLayer - 1; i >= 0; i--)
		{
			iEnd += gpLayer[i]->GetNumParameters();
			param(af::seq(iStart, iEnd - 1)) = gpLayer[i]->GetParameters();
			iStart = iEnd;
		}

		return param;
	}

	template<typename Scalar>
	void SparseDeepGPLVMBaseModel<Scalar>::UpdateParameters()
	{
		GPLVMBaseModel<Scalar>::UpdateParameters();

		for (int i = 0; i < iNumLayer; i++)
			gpLayer[i]->UpdateParameters();
	}

	template<typename Scalar>
	int SparseDeepGPLVMBaseModel<Scalar>::GetNumLayers()
	{
		return iNumLayer;
	}

	template<typename Scalar>
	std::vector<SparseGPBaseLayer<Scalar>*> SparseDeepGPLVMBaseModel<Scalar>::GetGPLayers()
	{
		return gpLayer;
	}

	template<typename Scalar>
	void SparseDeepGPLVMBaseModel<Scalar>::FixKernelParameters(bool isfixed)
	{
		for (int i = 0; i < iNumLayer; i++)
		{
			gpLayer[i]->FixKernelParameters(isfixed);
		}
	}
}