/**
File:		MachineLearning/GPModels/Models/Layers/GPLayers/FgGPBaseLayer<Scalar>.cpp

Author:		
Email:		
Site:       

Copyright (c) 2017 . All rights reserved.
*/

#include <NeMachineLearningPCH.h>
#include <MachineLearning/FgGPBaseLayer.h>

namespace NeuralEngine::MachineLearning::GPModels
{
	template class GPBaseLayer<float>;
	template class GPBaseLayer<double>;

	template<typename Scalar>
	GPBaseLayer<Scalar>::GPBaseLayer(int numPoints, int outputDim, int inputDim)
		: ILayer<Scalar>(LayerType::GP, numPoints, outputDim), iq(inputDim), JITTER(1e-5), isFixedHypers(false), mStyles(nullptr)
	{
		/*CompoundKernel* initKern = new CompoundKernel();
		initKern->AddKernel(new RBFKernel(), af::seq(iq));
		initKern->AddKernel(new WhiteKernel(), af::seq(iq));
		kernel = initKern;*/

		kernel = new ARDKernel<Scalar>(iq);
		m_dType = CommonUtil<Scalar>::CheckDType();
		//kernel = new LinearKernel<Scalar>(inputDim);
	}

	template<typename Scalar>
	GPBaseLayer<Scalar>::~GPBaseLayer()
	{
		if (kernel != nullptr) delete kernel;
	}

	template<typename Scalar>
	IKernel<Scalar>* GPBaseLayer<Scalar>::GetKernel()
	{
		return kernel;
	}

	template<typename Scalar>
	void GPBaseLayer<Scalar>::SetKernel(IKernel<Scalar>* kern)
	{
		if (kernel != nullptr) delete kernel;
		kernel = kern;
	}

	template<typename Scalar>
	void GPBaseLayer<Scalar>::InitParameters(af::array* X)
	{
		if (X == nullptr)
		{

		}
		else
		{

		}
	}

	template<typename Scalar>
	void GPBaseLayer<Scalar>::ForwardPredictionPost(const af::array* mx, const af::array* vx, af::array& mout, af::array& vout)
	{
	}

	template<typename Scalar>
	void GPBaseLayer<Scalar>::UpdateParameters()
	{
	}

	template<typename Scalar>
	void GPBaseLayer<Scalar>::SetDataSize(int length, int dimension)
	{
		ILayer::SetDataSize(length, dimension);

		ReinitParameters();
	}

	template<typename Scalar>
	void GPBaseLayer<Scalar>::SetStyles(std::map<std::string, Style<Scalar>>* styles)
	{
		mStyles = styles;
	}

	template<typename Scalar>
	void GPBaseLayer<Scalar>::SetLatentDimension(int q)
	{
		iq = q;
		ReinitParameters();
	}

	template<typename Scalar>
	void GPBaseLayer<Scalar>::ReinitParameters()
	{
	}

	template<typename Scalar>
	int GPBaseLayer<Scalar>::GetNumParameters()
	{
		int numParam = 0;
		if (mStyles)
		{
			for (auto style = mStyles->begin(); style != mStyles->end(); style++)
				numParam += style->second.GetNumInducingParameters();
		}
		(!isFixedHypers) ? numParam += kernel->GetNumParameter() : numParam += 0;

		return numParam;
	}

	template<typename Scalar>
	af::array GPBaseLayer<Scalar>::GetParameters()
	{
		af::array param;

		(!isFixedHypers) ? param = kernel->GetLogParameters() : param = af::array();

		if (mStyles)
		{
			for (auto style = mStyles->begin(); style != mStyles->end(); style++)
			{
				param = CommonUtil<Scalar>::Join(param, style->second.GetInducingParameters());
			}
		}

		return param;
	}

	template<typename Scalar>
	void GPBaseLayer<Scalar>::SetParameters(const af::array& param)
	{
		int iStart = 0, iEnd = 0; 
		if (!isFixedHypers) 
		{
			iEnd = kernel->GetNumParameter();
			kernel->SetLogParameters(param(af::seq(iStart, iEnd - 1)));
		}

		if (mStyles)
		{
			for (auto style = mStyles->begin(); style != mStyles->end(); style++)
			{
				iStart = iEnd; iEnd += style->second.GetNumInducingParameters();
				style->second.SetInducingParameters(param(af::seq(iStart, iEnd - 1)));
			}
		}
	}

	template<typename Scalar>
	void GPBaseLayer<Scalar>::FixKernelParameters(bool isfixed)
	{
		isFixedHypers = isfixed;
	}

	template<typename Scalar>
	void GPBaseLayer<Scalar>::SampleFromPost(const af::array& inX, af::array& outfsample)
	{
	}
}