/**
File:		MachineLearning/Models/GPModels/FgGPBaseModel<Scalar>.cpp

Author:		
Email:		
Site:       

Copyright (c) 2017 . All rights reserved.
*/

#include <NeMachineLearningPCH.h>
#include <MachineLearning/FgGPBaseModel.h>
#include <MachineLearning/FgLBFGSsolver.h>
#include <MachineLearning/FgLBFGSBsolver.h>

namespace NeuralEngine::MachineLearning::GPModels
{
	template class GPBaseModel<double>;
	template class GPBaseModel<float>;

	template<typename Scalar>
	GPBaseModel<Scalar>::GPBaseModel(const af::array & Y, LogLikType lType, ModelType mtype)
		: IModel<Scalar>(Y.dims(0), Y.dims(1), mtype), afY(Y), afBias(af::constant(0, 1, Y.dims(1))), bInit(false), afSegments(/*af::constant(0, 1, m_dType)*/)
	{
		switch (lType)
		{
		case LogLikType::Probit:
			likLayer = new ProbitLikLayer<Scalar>(iN, iD);
			break;
		default:
			likLayer = new GaussLikLayer<Scalar>(iN, iD);
			break;
		}
	}

	template<typename Scalar>
	GPBaseModel<Scalar>::GPBaseModel()
		: IModel<Scalar>(0, 0, ModelType::NONE), bInit(false), afY(), afBias(), afSegments()
	{
	}

	template<typename Scalar>
	GPBaseModel<Scalar>::~GPBaseModel()
	{
		if (likLayer != nullptr) delete likLayer;
	}

	template<typename Scalar>
	void GPBaseModel<Scalar>::Optimise(OptimizerType method, Scalar tol, bool reinit_hypers, int maxiter, int mb_size, LineSearchType lsType, bool disp, int* cycle)
	{
		if (mb_size != 0)
			iBatchSize = mb_size;

		if (!bInit) Init();

		bool isInstance = false;

		switch (lsType)
		{
		case ArmijoBacktracking:
		{
			BaseGradientOptimizationMethod<Scalar, ArmijoBacktracking>* optimizer1 = nullptr;
			switch (method)
			{
			case L_BFGS:
				optimizer1 = new LBFGSSolver<Scalar, ArmijoBacktracking>(GetNumParameters(), std::bind(&IModel<Scalar>::Function, this, std::placeholders::_1, std::placeholders::_2));
				isInstance = true;
				break;
			case L_BFGS_B:
				optimizer1 = new LBFGSBSolver<Scalar, ArmijoBacktracking>(GetNumParameters(), std::bind(&IModel<Scalar>::Function, this, std::placeholders::_1, std::placeholders::_2));
				isInstance = true;
				break;
			case ADAM:
				optimizer1 = new AdamSolver<Scalar, ArmijoBacktracking>(GetNumParameters(), std::bind(&IModel<Scalar>::Function, this, std::placeholders::_1, std::placeholders::_2));
				isInstance = true;
				break;
			case NADAM:
				optimizer1 = new NadamSolver<Scalar, ArmijoBacktracking>(GetNumParameters(), std::bind(&IModel<Scalar>::Function, this, std::placeholders::_1, std::placeholders::_2));
				isInstance = true;
				break;
			case ADAMAX:
				optimizer1 = new AdaMaxSolver<Scalar, ArmijoBacktracking>(GetNumParameters(), std::bind(&IModel<Scalar>::Function, this, std::placeholders::_1, std::placeholders::_2));
				isInstance = true;
				break;
			default:
				std::cout << "\nSolver not supported." << std::endl;
				break;
			}

			if (isInstance)
			{
				optimizer1->SetMaxIterations(maxiter);
				optimizer1->SetTolerance(tol);
				optimizer1->Minimize(GetParameters(), cycle);

				delete optimizer1;
			}
			break;
		}

		case ArmijoBracketing:
		{
			BaseGradientOptimizationMethod<Scalar, ArmijoBracketing>* optimizer2 = nullptr;
			switch (method)
			{
			case L_BFGS:
				optimizer2 = new LBFGSSolver<Scalar, ArmijoBracketing>(GetNumParameters(), std::bind(&IModel<Scalar>::Function, this, std::placeholders::_1, std::placeholders::_2));
				isInstance = true;
				break;
			case L_BFGS_B:
				optimizer2 = new LBFGSBSolver<Scalar, ArmijoBracketing>(GetNumParameters(), std::bind(&IModel<Scalar>::Function, this, std::placeholders::_1, std::placeholders::_2));
				isInstance = true;
				break;
			case ADAM:
				optimizer2 = new AdamSolver<Scalar, ArmijoBracketing>(GetNumParameters(), std::bind(&IModel<Scalar>::Function, this, std::placeholders::_1, std::placeholders::_2));
				isInstance = true;
				break;
			case NADAM:
				optimizer2 = new NadamSolver<Scalar, ArmijoBracketing>(GetNumParameters(), std::bind(&IModel<Scalar>::Function, this, std::placeholders::_1, std::placeholders::_2));
				isInstance = true;
				break;
			case ADAMAX:
				optimizer2 = new AdaMaxSolver<Scalar, ArmijoBracketing>(GetNumParameters(), std::bind(&IModel<Scalar>::Function, this, std::placeholders::_1, std::placeholders::_2));
				isInstance = true;
				break;
			default:
				std::cout << "\nSolver not supported." << std::endl;
				break;
			}

			if (isInstance)
			{
				optimizer2->SetMaxIterations(maxiter);
				optimizer2->SetTolerance(tol);
				optimizer2->Minimize(GetParameters());

				delete optimizer2;
			}
			break;
		}

		case MoreThuente:
		{
			BaseGradientOptimizationMethod<Scalar, MoreThuente>* optimizer3 = nullptr;
			switch (method)
			{
			case L_BFGS:
				optimizer3 = new LBFGSSolver<Scalar, MoreThuente>(GetNumParameters(), std::bind(&IModel<Scalar>::Function, this, std::placeholders::_1, std::placeholders::_2));
				isInstance = true;
				break;
			case L_BFGS_B:
				optimizer3 = new LBFGSBSolver<Scalar, MoreThuente>(GetNumParameters(), std::bind(&IModel<Scalar>::Function, this, std::placeholders::_1, std::placeholders::_2));
				isInstance = true;
				break;
			case ADAM:
				optimizer3 = new AdamSolver<Scalar, MoreThuente>(GetNumParameters(), std::bind(&IModel<Scalar>::Function, this, std::placeholders::_1, std::placeholders::_2));
				isInstance = true;
				break;
			case NADAM:
				optimizer3 = new NadamSolver<Scalar, MoreThuente>(GetNumParameters(), std::bind(&IModel<Scalar>::Function, this, std::placeholders::_1, std::placeholders::_2));
				isInstance = true;
				break;
			case ADAMAX:
				optimizer3 = new AdaMaxSolver<Scalar, MoreThuente>(GetNumParameters(), std::bind(&IModel<Scalar>::Function, this, std::placeholders::_1, std::placeholders::_2));
				isInstance = true;
				break;
			default:
				std::cout << "\nSolver not supported." << std::endl;
				break;
			}

			if (isInstance)
			{
				optimizer3->SetMaxIterations(maxiter);
				optimizer3->SetTolerance(tol);
				optimizer3->Minimize(GetParameters());

				delete optimizer3;
			}
			break;
		}

		case StrongWolfeBacktracking:
		{
			BaseGradientOptimizationMethod<Scalar, StrongWolfeBacktracking>* optimizer4 = nullptr;
			switch (method)
			{
			case L_BFGS:
				optimizer4 = new LBFGSSolver<Scalar, StrongWolfeBacktracking>(GetNumParameters(), std::bind(&IModel<Scalar>::Function, this, std::placeholders::_1, std::placeholders::_2));
				isInstance = true;
				break;
			case L_BFGS_B:
				optimizer4 = new LBFGSBSolver<Scalar, StrongWolfeBacktracking>(GetNumParameters(), std::bind(&IModel<Scalar>::Function, this, std::placeholders::_1, std::placeholders::_2));
				isInstance = true;
				break;
			case ADAM:
				optimizer4 = new AdamSolver<Scalar, StrongWolfeBacktracking>(GetNumParameters(), std::bind(&IModel<Scalar>::Function, this, std::placeholders::_1, std::placeholders::_2));
				isInstance = true;
				break;
			case NADAM:
				optimizer4 = new NadamSolver<Scalar, StrongWolfeBacktracking>(GetNumParameters(), std::bind(&IModel<Scalar>::Function, this, std::placeholders::_1, std::placeholders::_2));
				isInstance = true;
				break;
			case ADAMAX:
				optimizer4 = new AdaMaxSolver<Scalar, StrongWolfeBacktracking>(GetNumParameters(), std::bind(&IModel<Scalar>::Function, this, std::placeholders::_1, std::placeholders::_2));
				isInstance = true;
				break;
			default:
				std::cout << "\nSolver not supported." << std::endl;
				break;
			}

			if (isInstance)
			{
				optimizer4->SetMaxIterations(maxiter);
				optimizer4->SetTolerance(tol);
				optimizer4->Minimize(GetParameters());

				delete optimizer4;
			}
			break;
		}

		case StrongWolfeBracketing:
		{
			BaseGradientOptimizationMethod<Scalar, StrongWolfeBracketing>* optimizer5 = nullptr;
			switch (method)
			{
			case L_BFGS:
				optimizer5 = new LBFGSSolver<Scalar, StrongWolfeBracketing>(GetNumParameters(), std::bind(&IModel<Scalar>::Function, this, std::placeholders::_1, std::placeholders::_2));
				isInstance = true;
				break;
			case L_BFGS_B:
				optimizer5 = new LBFGSBSolver<Scalar, StrongWolfeBracketing>(GetNumParameters(), std::bind(&IModel<Scalar>::Function, this, std::placeholders::_1, std::placeholders::_2));
				isInstance = true;
				break;
			case ADAM:
				optimizer5 = new AdamSolver<Scalar, StrongWolfeBracketing>(GetNumParameters(), std::bind(&IModel<Scalar>::Function, this, std::placeholders::_1, std::placeholders::_2));
				isInstance = true;
				break;
			case NADAM:
				optimizer5 = new NadamSolver<Scalar, StrongWolfeBracketing>(GetNumParameters(), std::bind(&IModel<Scalar>::Function, this, std::placeholders::_1, std::placeholders::_2));
				isInstance = true;
				break;
			case ADAMAX:
				optimizer5 = new AdaMaxSolver<Scalar, StrongWolfeBracketing>(GetNumParameters(), std::bind(&IModel<Scalar>::Function, this, std::placeholders::_1, std::placeholders::_2));
				isInstance = true;
				break;
			default:
				std::cout << "\nSolver not supported." << std::endl;
				break;
			}

			if (isInstance)
			{
				optimizer5->SetMaxIterations(maxiter);
				optimizer5->SetTolerance(tol);
				optimizer5->Minimize(GetParameters());

				delete optimizer5;
			}
			break;
		}

		case WolfeBacktracking:
		{
			BaseGradientOptimizationMethod<Scalar, WolfeBacktracking>* optimizer6 = nullptr;
			switch (method)
			{
			case L_BFGS:
				optimizer6 = new LBFGSSolver<Scalar, WolfeBacktracking>(GetNumParameters(), std::bind(&IModel<Scalar>::Function, this, std::placeholders::_1, std::placeholders::_2));
				isInstance = true;
				break;
			case L_BFGS_B:
				optimizer6 = new LBFGSBSolver<Scalar, WolfeBacktracking>(GetNumParameters(), std::bind(&IModel<Scalar>::Function, this, std::placeholders::_1, std::placeholders::_2));
				isInstance = true;
				break;
			case ADAM:
				optimizer6 = new AdamSolver<Scalar, WolfeBacktracking>(GetNumParameters(), std::bind(&IModel<Scalar>::Function, this, std::placeholders::_1, std::placeholders::_2));
				isInstance = true;
				break;
			case NADAM:
				optimizer6 = new NadamSolver<Scalar, WolfeBacktracking>(GetNumParameters(), std::bind(&IModel<Scalar>::Function, this, std::placeholders::_1, std::placeholders::_2));
				isInstance = true;
				break;
			case ADAMAX:
				optimizer6 = new AdaMaxSolver<Scalar, WolfeBacktracking>(GetNumParameters(), std::bind(&IModel<Scalar>::Function, this, std::placeholders::_1, std::placeholders::_2));
				isInstance = true;
				break;
			default:
				std::cout << "\nSolver not supported." << std::endl;
				break;
			}

			if (isInstance)
			{
				optimizer6->SetMaxIterations(maxiter);
				optimizer6->SetTolerance(tol);
				optimizer6->Minimize(GetParameters());

				delete optimizer6;
			}
			break;
		}

		case WolfeBracketing:
		{
			BaseGradientOptimizationMethod<Scalar, WolfeBracketing>* optimizer7 = nullptr;
			switch (method)
			{
			case L_BFGS:
				optimizer7 = new LBFGSSolver<Scalar, WolfeBracketing>(GetNumParameters(), std::bind(&IModel<Scalar>::Function, this, std::placeholders::_1, std::placeholders::_2));
				isInstance = true;
				break;
			case L_BFGS_B:
				optimizer7 = new LBFGSBSolver<Scalar, WolfeBracketing>(GetNumParameters(), std::bind(&IModel<Scalar>::Function, this, std::placeholders::_1, std::placeholders::_2));
				isInstance = true;
				break;
			case ADAM:
				optimizer7 = new AdamSolver<Scalar, WolfeBracketing>(GetNumParameters(), std::bind(&IModel<Scalar>::Function, this, std::placeholders::_1, std::placeholders::_2));
				isInstance = true;
				break;
			case NADAM:
				optimizer7 = new NadamSolver<Scalar, WolfeBracketing>(GetNumParameters(), std::bind(&IModel<Scalar>::Function, this, std::placeholders::_1, std::placeholders::_2));
				isInstance = true;
				break;
			case ADAMAX:
				optimizer7 = new AdaMaxSolver<Scalar, WolfeBracketing>(GetNumParameters(), std::bind(&IModel<Scalar>::Function, this, std::placeholders::_1, std::placeholders::_2));
				isInstance = true;
				break;
			default:
				std::cout << "\nSolver not supported." << std::endl;
				break;
			}

			if (isInstance)
			{
				optimizer7->SetMaxIterations(maxiter);
				optimizer7->SetTolerance(tol);
				optimizer7->Minimize(GetParameters());

				delete optimizer7;
			}
			break;
		}
		}
	}

	template<typename Scalar>
	bool GPBaseModel<Scalar>::Init()
	{
		//afBias = af::mean(afY);

		/*afY = afY - af::tile(af::mean(afY), afY.dims(0));
		afY /= af::tile(af::stdev(afY), afY.dims(0));
		afY(af::isNaN(afY)) = 0.0;*/
		//af_print(afBias);
		//afY -= tile(afBias, iN, 1);
		likLayer->InitParameters();

		bInit = true;
		return bInit;
	}

	template<typename Scalar>
	void GPBaseModel<Scalar>::PredictF(const af::array& testInputs, af::array& mf, af::array& vf)
	{
		if (!bInit) Init();
	}

	template<typename Scalar>
	void GPBaseModel<Scalar>::PredictY(const af::array& testInputs, af::array& my, af::array& vy)
	{
		af::array mf, vf;
		PredictF(testInputs, mf, vf);
		likLayer->ProbabilisticOutput(mf, vf, my, vy);
	}

	template<typename Scalar>
	void GPBaseModel<Scalar>::SampleY(const af::array inputs, int numSamples, af::array& outFunctions)
	{
		if (!bInit) Init();
	}

	template<typename Scalar>
	void GPBaseModel<Scalar>::AddData(const af::array Ytrain)
	{
		if (afSegments.isempty() && !afY.isempty()) afSegments = af::constant(0.0, 1);
		afSegments = CommonUtil<Scalar>::Join(afSegments, af::constant(afY.dims(0), 1));
		afY = CommonUtil<Scalar>::Join(afY, Ytrain);
		iN = (int)afY.dims(0);
		iBatchSize = iN;
		afIndexes = af::seq(0, iN - 1);
	}

	template<typename Scalar>
	af::array GPBaseModel<Scalar>::GetTrainingData()
	{
		return afY;
	}

	template<typename Scalar>
	void GPBaseModel<Scalar>::SetTrainingData(af::array& data)
	{
		afY = data;
	}

	template<typename Scalar>
	int GPBaseModel<Scalar>::GetNumParameters()
	{
		int numParam;

		numParam = likLayer->GetNumParameters();
		
		return numParam;
	}

	template<typename Scalar>
	void GPBaseModel<Scalar>::SetParameters(const af::array & param)
	{
		int iStart = 0, iEnd = likLayer->GetNumParameters();
		if (iStart != iEnd)
			likLayer->SetParameters(param(af::seq(iStart, iEnd - 1)));
	}

	template<typename Scalar>
	af::array GPBaseModel<Scalar>::GetParameters()
	{
		m_dType = CommonUtil<Scalar>::CheckDType();
		af::array param = af::constant(0.0f, GPBaseModel::GetNumParameters(), m_dType);
		int iStart = 0, iEnd = likLayer->GetNumParameters();
		if (iStart != iEnd)
			param(af::seq(iStart, iEnd - 1)) = likLayer->GetParameters();

		return param;
	}

	template<typename Scalar>
	void GPBaseModel<Scalar>::FixLikelihoodParameters(bool isfixed)
	{
		likLayer->FixParameters(isfixed);
	}

	template<typename Scalar>
	void GPBaseModel<Scalar>::SetSegments(af::array segments)
	{
	}

	template<typename Scalar>
	af::array GPBaseModel<Scalar>::GetSegments()
	{
		return afSegments;
	}

	template<typename Scalar>
	void GPBaseModel<Scalar>::UpdateParameters()
	{
		if (!bInit) Init();

		likLayer->UpdateParameters();
	}
}