/**
File:		MachineLearning/Models/GPModels/FgSparseDeepGPBaseModel.cpp

Author:		
Email:		
Site:       

Copyright (c) 2017 . All rights reserved.
*/

#include <NeMachineLearningPCH.h>
#include <MachineLearning/FgSparseDeepGPBaseModel.h>

namespace NeuralEngine::MachineLearning::GPModels
{
	template class SparseDeepGPBaseModel<float>;
	template class SparseDeepGPBaseModel<double>;

	template<typename Scalar>
	SparseDeepGPBaseModel<Scalar>::SparseDeepGPBaseModel(const af::array& Y, const af::array& X, std::vector<HiddenLayerDescription> descriptions, LogLikType lType)
		: DeepGPBaseModel<Scalar>(Y, descriptions, lType), afX(X), iq(X.dims(1))
	{
		vSize.push_back(iq);
		for (auto description : descriptions)
		{
			vSize.push_back(description.GetNumHiddenDimensions());
		}
		vSize.push_back(iD);

		iNumLayer = vSize.size() - 1;

		for (uint i = 0; i < iNumLayer; i++)
			if (i == 0) vNumPseudosPerLayer.push_back(descriptions[i].GetNumPseudoInputs());
			else vNumPseudosPerLayer.push_back(descriptions[i - 1].GetNumPseudoInputs());

		switch (lType)
		{
		case LogLikType::Probit:
			likLayer = new ProbitLikLayer<Scalar>(iN, iD);
			break;
		default:
			likLayer = new GaussLikLayer<Scalar>(iN, iD);
			break;
		}
	}

	template<typename Scalar>
	SparseDeepGPBaseModel<Scalar>::SparseDeepGPBaseModel(const af::array& Y, const af::array& X, HiddenLayerDescription description, LogLikType lType)
		: DeepGPBaseModel<Scalar>(Y, description, lType), afX(X), iq(X.dims(1))
	{
		vSize.push_back(iq);
		vSize.push_back(description.GetNumHiddenDimensions());
		vSize.push_back(iD);

		iNumLayer = vSize.size() - 1;

		for (uint i = 0; i < iNumLayer; i++) vNumPseudosPerLayer.push_back(description.GetNumPseudoInputs());

		switch (lType)
		{
		case LogLikType::Probit:
			likLayer = new ProbitLikLayer<Scalar>(iN, iD);
			break;
		default:
			likLayer = new GaussLikLayer<Scalar>(iN, iD);
			break;
		}
	}
	
	template<typename Scalar>
	SparseDeepGPBaseModel<Scalar>::SparseDeepGPBaseModel()
		: DeepGPBaseModel<Scalar>(), iq(0), afX(m_dType)
	{
	}

	template<typename Scalar>
	SparseDeepGPBaseModel<Scalar>::~SparseDeepGPBaseModel()
	{
	}

	template<typename Scalar>
	bool SparseDeepGPBaseModel<Scalar>::Init()
	{
		DeepGPBaseModel::Init();

		for (int i = 0; i < iNumLayer; i++)
			if (i == 0) gpLayer[i]->InitParameters(&afX);
			else gpLayer[i]->InitParameters();

		return bInit;
	}

	template<typename Scalar>
	af::array SparseDeepGPBaseModel<Scalar>::GetTrainingInputs()
	{
		return afX;
	}

	template<typename Scalar>
	void SparseDeepGPBaseModel<Scalar>::FixInducing(bool isfixed)
	{
		for (int i = 0; i < iNumLayer; i++)
		{
			SparseGPBaseLayer<Scalar>& slayer = dynamic_cast<SparseGPBaseLayer<Scalar>&>(*gpLayer[i]);
			slayer.FixInducing(isfixed);
		}
	}
}