﻿/**
File:		MachineLearning/GPModels/Models/Layers/FgProbitLikelihoodLayer.h

Author:		
Email:		
Site:       

Copyright (c) 2017 . All rights reserved.
*/

#pragma once

#include <MachineLearning/FgLikelihoodBaseLayer.h>

namespace NeuralEngine
{
	namespace MachineLearning
	{
		namespace GPModels
		{
			////////////////////////////////////////////////////////////////////////////////////////////////////
			/// <summary>	Likelihood estimation based on Probit distribution. </summary>
			///
			/// <remarks>	
			/// 	Approximating an unnormalized distribution means to replace it by a much simpler 
			/// 	parametric distribution. This is often needed for untractable conditionals or integrals.
			/// 	This can be done via  EP, PowerEP or VFE. The algorithms are based on minimization of
			/// 	Kullback-Leiber-Divergence and presuppose the distribution Q is restricted to belong to 
			/// 	a family of probability distributions that is closed under the product operation. This is 
			/// 	the exponential family:
			/// 	
			/// 		 Q(x ∣ θ) = exp⁡(η(θ) * T(x) − A(θ)),
			/// 		 
			/// 	where η(θ) are the natural parameters, T(x) the sufficient statistics and A(θ) the log-
			/// 	normalizer, also known as logZ. Exponential family facilitates the parameter updates in
			/// 	each iteration step, because it needes just the first and second derivatives of logZ.
			/// 	
			/// 	For a Gaussian distribution N(x|μ, σ^2) = 1/sqrt(2πσ^2) exp{−1/(2σ^2)(x−μ)^2} the ex-
			/// 	ponential family parameters are:
			/// 	
			/// 		η(θ) = (μ / σ^2, −1/(2σ^2))^T,
			/// 		T(x) = (x, x^2)^T,
			/// 		A(θ) = 1/2 log(π/−η_2 − η_1^2/(4η_2).
			///
			/// 	
			/// 	Hmetal T, 04/05/2018. 
			/// </remarks>
			////////////////////////////////////////////////////////////////////////////////////////////////////
			template<typename Scalar>
			class NE_IMPEXP ProbitLikLayer : public LikelihoodBaseLayer<Scalar>
			{
			public:

				////////////////////////////////////////////////////////////////////////////////////////////////////
				/// <summary>	Constructor. </summary>
				///
				/// <remarks>	, 23.04.2018. </remarks>
				///
				/// <param name="numPoints">	Number of Samples. </param>
				/// <param name="outputDim">	Number of Dimensions. </param>
				////////////////////////////////////////////////////////////////////////////////////////////////////
				ProbitLikLayer(int numPoints, int outputDim);

				////////////////////////////////////////////////////////////////////////////////////////////////////
				/// <summary>	Calculates the logZ. </summary>
				///
				/// <remarks>	Hmetal T, 05/05/2018. </remarks>
				///
				/// <param name="mout"> 	[in,out] The mean function. </param>
				/// <param name="vout"> 	The covarianve function. </param>
				/// <param name="y">		[in,out] Sample data vector. </param>
				/// <param name="alpha">	(Optional) Weight between alpha- and KL-divergence. </param>
				///
				/// <returns>	The calculated log z coordinate. </returns>
				////////////////////////////////////////////////////////////////////////////////////////////////////
				Scalar ComputeLogZ(const af::array& mout, const af::array& vout, const af::array& y, Scalar alpha = 1.0, af::array* dlogZ_dm = nullptr,
					af::array* dlogZ_dv = nullptr, af::array* dlogZ_dm2 = nullptr);

				////////////////////////////////////////////////////////////////////////////////////////////////////
				/// <summary>	Calculates the logZ gradients. </summary>
				///
				/// <remarks>	Hmetal T, 05/05/2018. </remarks>
				///
				/// <param name="mout">			[in,out] The mean. </param>
				/// <param name="vout">			[in,out] The standart deviation. </param>
				/// <param name="y">			[in,out] Sample data vector. </param>
				/// <param name="dlogZ_dm"> 	[in,out] (Optional) If non-null, derivative of logZ w.r.t mean. </param>
				/// <param name="dlogZ_dv"> 	[in,out] (Optional) If non-null, derivative of logZ w.r.t standart deviation. </param>
				/// <param name="dlogZ_dm2">	[in,out] (Optional) If non-null, derivative of logZ w.r.t mean^2. </param>
				////////////////////////////////////////////////////////////////////////////////////////////////////
				void ComputeLogZGradients(const af::array& mout, const af::array& vout, const af::array& y, af::array* dlogZ_dm = nullptr,
					af::array* dlogZ_dv = nullptr, af::array* dlogZ_dm2 = nullptr, Scalar alpha = 1.0);

				Scalar ComputeLogLikExp(const af::array& mout, const af::array& vout, const af::array& y);

				void ComputeLogLikExpGradients(const af::array& mout, const af::array& vout, const af::array& y, af::array* de_dm = nullptr, af::array* de_dv = nullptr);

				Scalar BackpropagationGradientsLogLikExp(const af::array& mout, const af::array& vout, af::array& dmout, af::array& dvout, af::array& y, Scalar scale = 1.0);

				Scalar BackpropagationGradients(const af::array& mout, const af::array& vout, af::array& dmout, af::array& dvout, Scalar alpha = 1.0, Scalar scale = 1.0);

				void ProbabilisticOutput(const af::array& mf, const af::array& vf, af::array& myOut, af::array& vyOut, Scalar alpha = 1.0f);

			protected:
				ProbitLikLayer() { }

				// Quadrature points
				af::array afGHx;
				af::array afGHw;
			private:

				friend class boost::serialization::access;

				template<class Archive>
				void serialize(Archive& ar, unsigned int version)
				{
					ar& boost::serialization::base_object<LikelihoodBaseLayer<Scalar>>(*this);
					//ar& _sn;
				}
			};
		}
	}
}
