/**
File:		MachineLearning/Models/GPModels/FgSparseDeepGPBaseModel.cpp

Author:		
Email:		
Site:       

Copyright (c) 2017 . All rights reserved.
*/

#include <NeMachineLearningPCH.h>
#include <MachineLearning/FgDeepGPBaseModel.h>

namespace NeuralEngine::MachineLearning::GPModels
{
	template class DeepGPBaseModel<float>;
	template class DeepGPBaseModel<double>;

	template<typename Scalar>
	DeepGPBaseModel<Scalar>::DeepGPBaseModel(const af::array& Y, HiddenLayerDescription description, LogLikType lType)
		: GPBaseModel<Scalar>(Y, lType, ModelType::DGPR), vNumPseudosPerLayer(), vSize(), gpLayer()
	{
	}

	template<typename Scalar>
	DeepGPBaseModel<Scalar>::DeepGPBaseModel(const af::array& Y, std::vector<HiddenLayerDescription> descriptions, LogLikType lType)
		: GPBaseModel<Scalar>(Y, lType, ModelType::DGPR), vNumPseudosPerLayer(), vSize(), gpLayer()
	{
	}

	template<typename Scalar>
	DeepGPBaseModel<Scalar>::DeepGPBaseModel()
		: GPBaseModel<Scalar>(), iNumLayer(2), gpLayer(), vNumPseudosPerLayer(), vSize()
	{
	}

	template<typename Scalar>
	DeepGPBaseModel<Scalar>::~DeepGPBaseModel()
	{
		if (!gpLayer.empty())
		{
			for (auto obj : gpLayer) delete obj;
			gpLayer.clear();
		}
	}

	template<typename Scalar>
	void DeepGPBaseModel<Scalar>::PredictF(const af::array& testInputs, af::array& mf, af::array& vf)
	{
		af::array mfTmp, vfTmp;

		GPBaseModel::PredictF(testInputs, mf, vf);

		for (uint i = 0; i < iNumLayer; i++)
		{
			//SparseGPBaseLayer<Scalar>& slayer = dynamic_cast<SparseGPBaseLayer<Scalar>&>(*gpLayer[i]);
			if (i == 0) gpLayer[i]->ForwardPredictionPost(&testInputs, nullptr, mf, vf);
			else
			{
				gpLayer[i]->ForwardPredictionPost(&mf, &vf, mfTmp, vfTmp);
				mf = mfTmp.copy();
				vf = vfTmp.copy();
			}
		}
	}
	template<typename Scalar>
	void DeepGPBaseModel<Scalar>::SampleY(const af::array inputs, int numSamples, af::array& outFunctions)
	{
		GPBaseModel::SampleY(inputs, numSamples, outFunctions);

		outFunctions = af::array(inputs.dims(0), iD, numSamples, m_dType);
		af::array tmpFunc, inputsTmp;

		for (uint i = 0; i < numSamples; i++)
		{
			inputsTmp = inputs.copy();
			for (uint k = 0; k < iNumLayer; i++)
			{
				gpLayer[k]->SampleFromPost(inputsTmp, tmpFunc);
				inputsTmp = tmpFunc.copy();
			}
			outFunctions(af::span, af::span, i) = tmpFunc;
		}
	}

	template<typename Scalar>
	int DeepGPBaseModel<Scalar>::GetNumParameters()
	{
		int numParam = GPBaseModel<Scalar>::GetNumParameters();

		for (auto layer : gpLayer)
			numParam += layer->GetNumParameters();
		return numParam;
	}

	template<typename Scalar>
	int DeepGPBaseModel<Scalar>::GetNumLayers()
	{
		return iNumLayer;
	}

	template<typename Scalar>
	void DeepGPBaseModel<Scalar>::SetParameters(const af::array& param)
	{
		int istart = 0, iend = 0;
		for (int i = iNumLayer - 1; i >= 0; i--)
		{
			iend += gpLayer[i]->GetNumParameters();
			gpLayer[i]->SetParameters(param(af::seq(istart, iend - 1)));
			istart = iend;
		}

		if (likLayer->GetNumParameters() != 0)
			GPBaseModel<Scalar>::SetParameters(param(af::seq(iend, af::end)));
	}

	template<typename Scalar>
	af::array DeepGPBaseModel<Scalar>::GetParameters()
	{
		af::array param = af::constant(0.0f, DeepGPBaseModel::GetNumParameters(), m_dType);

		int istart = 0, iend = 0;

		for (int i = iNumLayer - 1; i >= 0; i--)
		{
			iend += gpLayer[i]->GetNumParameters();
			param(af::seq(istart, iend - 1)) = gpLayer[i]->GetParameters();
			istart = iend;
		}

		iend += GPBaseModel<Scalar>::GetNumParameters();

		if (istart != iend)
			param(af::seq(istart, iend - 1)) = GPBaseModel::GetParameters();

		return param;
	}
	template<typename Scalar>
	void DeepGPBaseModel<Scalar>::UpdateParameters()
	{
		GPBaseModel::UpdateParameters();

		for (int i = 0; i < iNumLayer; i++)
			gpLayer[i]->UpdateParameters();
	}

	template<typename Scalar>
	std::vector<GPBaseLayer<Scalar>*> DeepGPBaseModel<Scalar>::GetGPLayers()
	{
		return gpLayer;
	}

	template<typename Scalar>
	void DeepGPBaseModel<Scalar>::FixKernelParameters(bool isfixed)
	{
		for (int i = 0; i < iNumLayer; i++)
		{
			gpLayer[i]->FixKernelParameters(isfixed);
		}
	}
}