/**
File:		MachineLearning/Models/GPModels/FgSparseGPBaseModel.cpp

Author:		
Email:		
Site:       

Copyright (c) 2017 . All rights reserved.
*/

#include <NeMachineLearningPCH.h>
#include <MachineLearning/FgSparseGPBaseModel.h>

namespace NeuralEngine::MachineLearning::GPModels
{
	template class SparseGPBaseModel<float>;
	template class SparseGPBaseModel<double>;

	template<typename Scalar>
	SparseGPBaseModel<Scalar>::SparseGPBaseModel(const af::array& Y, const af::array& X, int numInducing, LogLikType lType)
		: GPBaseModel<Scalar>(Y, lType), ik(numInducing), afX(X), iq(X.dims(1)), gpLayer(nullptr)
	{
	}

	template<typename Scalar>
	SparseGPBaseModel<Scalar>::SparseGPBaseModel()
		: GPBaseModel<Scalar>(), ik(0), afX(m_dType), iq(0), gpLayer(nullptr)
	{
	}

	template<typename Scalar>
	SparseGPBaseModel<Scalar>::~SparseGPBaseModel()
	{
		if (gpLayer != nullptr) delete gpLayer;
	}

	template<typename Scalar>
	void SparseGPBaseModel<Scalar>::PredictF(const af::array& testInputs, af::array& mf, af::array& vf)
	{
		GPBaseModel<Scalar>::PredictF(testInputs, mf, vf);

		gpLayer->ForwardPredictionPost(&testInputs, nullptr, mf, vf);
	}

	template<typename Scalar>
	void SparseGPBaseModel<Scalar>::SampleY(const af::array inputs, int numSamples, af::array& outFunctions)
	{
		if (!bInit) Init();

		outFunctions = af::array(inputs.dims(0), iD, numSamples, m_dType);
		af::array tmpFunc(m_dType);

		for (int i = 0; i < numSamples; i++)
		{
			gpLayer->SampleFromPost(inputs, tmpFunc);
			outFunctions(af::span, af::span, i) = tmpFunc;
		}
			
	}

	template<typename Scalar>
	af::array SparseGPBaseModel<Scalar>::GetTrainingInputs()
	{
		return afX;
	}

	template<typename Scalar>
	void SparseGPBaseModel<Scalar>::SetTrainingInputs(af::array& inputs)
	{
		afX = inputs;
	}

	template<typename Scalar>
	af::array SparseGPBaseModel<Scalar>::GetPseudoInputs()
	{
		return gpLayer->GetPseudoInputs();
	}

	template<typename Scalar>
	bool SparseGPBaseModel<Scalar>::Init()
	{
		GPBaseModel::Init();

		gpLayer->InitParameters(&afX);
		return bInit;
	}

	template<typename Scalar>
	int SparseGPBaseModel<Scalar>::GetNumParameters()
	{
		int numParam = GPBaseModel::GetNumParameters();

		numParam += gpLayer->GetNumParameters();

		return numParam;
	}

	template<typename Scalar>
	void SparseGPBaseModel<Scalar>::SetParameters(const af::array& param)
	{
		int iStart = 0, iEnd = gpLayer->GetNumParameters();
		gpLayer->SetParameters(param(af::seq(iStart, iEnd - 1)));

		GPBaseModel::SetParameters(param(af::seq(iEnd, af::end)));
	}

	template<typename Scalar>
	af::array SparseGPBaseModel<Scalar>::GetParameters()
	{
		m_dType = CommonUtil<Scalar>::CheckDType();
		af::array param = af::constant(0.0f, GetNumParameters(), m_dType);

		int iStart = 0, iEnd = gpLayer->GetNumParameters();
		param(af::seq(iStart, iEnd - 1)) = gpLayer->GetParameters();

		iStart = iEnd; iEnd += GPBaseModel::GetNumParameters();
		
		if (iStart != iEnd)
			param(af::seq(iStart, iEnd - 1)) = GPBaseModel::GetParameters();

		return param;
	}

	template<typename Scalar>
	void SparseGPBaseModel<Scalar>::UpdateParameters()
	{
		GPBaseModel::UpdateParameters();

		gpLayer->UpdateParameters();
	}

	template<typename Scalar>
	void SparseGPBaseModel<Scalar>::FixKernelParameters(bool isfixed)
	{
		gpLayer->FixKernelParameters(isfixed);
	}

	template<typename Scalar>
	void SparseGPBaseModel<Scalar>::FixInducing(bool isfixed)
	{
		gpLayer->FixInducing(isfixed);
	}

	template<typename Scalar>
	SparseGPBaseLayer<Scalar>* SparseGPBaseModel<Scalar>::GetGPLayer()
	{
		return gpLayer;
	}
}