/**
File:		MachineLearning/GPModels/Models/Layers/LikLayers/FgILikelihoodLayer.h

Author:		
Email:		
Site:       

Copyright (c) 2017 . All rights reserved.
*/

#pragma once

#include <MachineLearning/FgILayer.h>

namespace NeuralEngine
{
	namespace MachineLearning
	{
		namespace GPModels
		{

			enum class LogLikType
			{
				Gaussian,
				Probit
			};

			////////////////////////////////////////////////////////////////////////////////////////////////////
			/// <summary>	Abstract class for different GP likelihood layers. </summary>
			///
			/// <remarks>	, 27.02.2018. </remarks>
			////////////////////////////////////////////////////////////////////////////////////////////////////
			template<typename Scalar>
			class NE_IMPEXP LikelihoodBaseLayer : public ILayer<Scalar>
			{
			public:

				////////////////////////////////////////////////////////////////////////////////////////////////////
				/// <summary>	Constructor. </summary>
				///
				/// <remarks>	, 26.04.2018. </remarks>
				///
				/// <param name="type">			The likelihood type. </param>
				/// <param name="numPoints">	Number of data points. </param>
				/// <param name="outputDim">	Dimension of data points. </param>
				////////////////////////////////////////////////////////////////////////////////////////////////////
				LikelihoodBaseLayer(LogLikType type, int numPoints, int outputDim);

				virtual ~LikelihoodBaseLayer();

				////////////////////////////////////////////////////////////////////////////////////////////////////
				/// <summary>	Calculates the logZ. </summary>
				///
				/// <remarks>	Hmetal T, 05/05/2018. </remarks>
				///
				/// <param name="mout"> 	[in,out] The mean. </param>
				/// <param name="vout"> 	The standart deviation. </param>
				/// <param name="y">		[in,out] Sample data vector. </param>
				/// <param name="alpha">	(Optional) Weight between alpha- and KL-divergence. </param>
				///
				/// <returns>	The calculated log z coordinate. </returns>
				////////////////////////////////////////////////////////////////////////////////////////////////////
				virtual Scalar ComputeLogZ(const af::array& mout, const af::array& vout, const af::array& y, Scalar alpha = 1.0, af::array* dlogZ_dm = nullptr,
					af::array* dlogZ_dv = nullptr, af::array* dlogZ_dm2 = nullptr) = 0;

				////////////////////////////////////////////////////////////////////////////////////////////////////
				/// <summary>	Calculates logZ gradients. </summary>
				///
				/// <remarks>	Hmetal T, 05/05/2018. </remarks>
				///
				/// <param name="mout">			[in,out] The mean. </param>
				/// <param name="vout">			[in,out] The standart deviation. </param>
				/// <param name="y">			[in,out] Sample data vector. </param>
				/// <param name="dlogZ_dm"> 	[in,out] (Optional) If non-null, derivative of logZ w.r.t mean. </param>
				/// <param name="dlogZ_dv"> 	[in,out] (Optional) If non-null, derivative of logZ w.r.t standart deviation. </param>
				/// <param name="dlogZ_dm2">	[in,out] (Optional) If non-null, derivative of logZ w.r.t mean^2. </param>
				////////////////////////////////////////////////////////////////////////////////////////////////////
				virtual void ComputeLogZGradients(const af::array& mout, const af::array& vout, const af::array& y, af::array* dlogZ_dm = nullptr,
					af::array* dlogZ_dv = nullptr, af::array* dlogZ_dm2 = nullptr, Scalar alpha = 1.0) = 0;

				virtual Scalar BackpropagationGradients(const af::array& mout, const af::array& vout, af::array& dmout, af::array& dvout, Scalar alpha = 1.0, Scalar scale = 1.0) = 0;

				virtual Scalar ComputeLogLikExp(const af::array& mout, const af::array& vout, const af::array& y) = 0;

				virtual void ComputeLogLikExpGradients(const af::array& mout, const af::array& vout, const af::array& y, af::array* de_dm = nullptr, af::array* de_dv = nullptr) = 0;

				virtual Scalar BackpropagationGradientsLogLikExp(const af::array& mout, const af::array& vout, af::array& dmout, af::array& dvout, af::array& y, Scalar scale = 1.0) = 0;

				virtual void ProbabilisticOutput(const af::array& mf, const af::array& vf, af::array& myOut, af::array& vyOut, Scalar alpha = 1.0f) = 0;

				virtual Scalar InitParameters() { return 0.0f; }

				////////////////////////////////////////////////////////////////////////////////////////////////////
				/// <summary>	Gets log likelihood type. </summary>
				///
				/// <remarks>	, 26.04.2018. </remarks>
				///
				/// <returns>	The log likelihood type. </returns>
				////////////////////////////////////////////////////////////////////////////////////////////////////
				LogLikType GetLogLikType();

				////////////////////////////////////////////////////////////////////////////////////////////////////
				/// <summary>	Gets number of parameters to be optimized. </summary>
				///
				/// <remarks>	, 26.06.2018. </remarks>
				///
				/// <returns>	The number parameters. </returns>
				////////////////////////////////////////////////////////////////////////////////////////////////////
				virtual int GetNumParameters();

				////////////////////////////////////////////////////////////////////////////////////////////////////
				/// <summary>	Sets the parameters for each optimization iteration. </summary>
				///
				/// <remarks>	, 26.06.2018. </remarks>
				///
				/// <param name="param">	The parameter. </param>
				////////////////////////////////////////////////////////////////////////////////////////////////////
				virtual void SetParameters(const af::array& param);

				////////////////////////////////////////////////////////////////////////////////////////////////////
				/// <summary>	Gets the parameters for each optimization iteration. </summary>
				///
				/// <remarks>	, 26.06.2018. </remarks>
				///
				/// <returns>	The parameters. </returns>
				////////////////////////////////////////////////////////////////////////////////////////////////////
				virtual af::array GetParameters();

				////////////////////////////////////////////////////////////////////////////////////////////////////
				/// <summary>	Set to fix the parameters or not for optimization. </summary>
				///
				/// <remarks>	Hmetal T, 17/12/2019. </remarks>
				///
				/// <param name="isfixed">	True if is fixed. </param>
				////////////////////////////////////////////////////////////////////////////////////////////////////
				virtual void FixParameters(bool isfixed);

				////////////////////////////////////////////////////////////////////////////////////////////////////
				/// <summary>	Updates the parameters. </summary>
				///
				/// <remarks>	, 26.06.2018. </remarks>
				////////////////////////////////////////////////////////////////////////////////////////////////////
				virtual void UpdateParameters() { }

			protected:
				////////////////////////////////////////////////////////////////////////////////////////////////////
				/// <summary>	Default constructor. </summary>
				///
				/// <remarks>	, 26.06.2018. </remarks>
				////////////////////////////////////////////////////////////////////////////////////////////////////
				LikelihoodBaseLayer() { }

				bool isFixedParam;
				bool bDimMod;

			private:
				LogLikType lltype;

				friend class boost::serialization::access;

				template<class Archive>
				void serialize(Archive& ar, unsigned int version)
				{
					ar& boost::serialization::base_object<ILayer<Scalar>>(*this);
					ar& BOOST_SERIALIZATION_NVP(lltype);
					ar& BOOST_SERIALIZATION_NVP(isFixedParam);
					ar& BOOST_SERIALIZATION_NVP(bDimMod);
				}
			};
		}
	}
}