/**
File:		MachineLearning/Util/FgGaussHermiteQuadrature.cpp

Author:		
Email:		
Site:       

Copyright (c) 2017 . All rights reserved.
*/

#include <NeMachineLearningPCH.h>
#include <MachineLearning/FgGaussHermiteQuadrature.h>
#define _USE_MATH_DEFINES
#include <math.h>

namespace NeuralEngine::MachineLearning
{
	template class GaussHermiteQuadrature<double>;
	template class GaussHermiteQuadrature<float>;

	template<typename Scalar>
	void GaussHermiteQuadrature<Scalar>::Compute(int n, af::array & x, af::array & w)
	{
		LogAssert(n > 0, "n must be positive.");

		af::dtype m_dType = CommonUtil<Scalar>::CheckDType();

		af::array v(m_dType), u(m_dType), vt(m_dType);

		x = af::constant(0.0, n, m_dType);
		w = af::constant(0.0, n, m_dType);

		// first approximation of roots.We use the fact that the companion
		// matrix is symmetric in this case in order to obtain better zeros.
		af::array c = af::constant(0.0, n + 1, m_dType);
		c(af::end) = 1.0;
		af::array m = HermCompanion(c);

		cv::Mat mat, eval;
		if (m_dType == f64)
			mat = AfCv::ArrayToMat(m, CV_64F);
		else
			mat = AfCv::ArrayToMat(m);
		cv::eigen(mat, eval);

		x = af::sort(AfCv::MatToArray(eval));

		// improve roots by one application of Newton
		af::array dy = NormedHermite(x, n);
		af::array df = NormedHermite(x, n - 1) * sqrt(2 * n);
		x -= dy / df;

		// compute the weights.We scale the factor to avoid possible numerical
		// overflow.
		af::array fm = NormedHermite(x, n - 1);
		fm /= af::tile(af::max(af::abs(fm)), fm.dims(0));
		w = 1.0 / (fm * fm);

		// for Hermite we can also symmetrize
		w = (w + af::flip(w, 0)) / 2;
		x = (x - af::flip(x, 0)) / 2;

		// scale w to get the right value
		w *= sqrt(M_PI) / af::tile(af::sum(w), w.dims(0));
	}

	template<typename Scalar>
	af::array GaussHermiteQuadrature<Scalar>::HermCompanion(const af::array & c)
	{
		af::dtype m_dType = CommonUtil<Scalar>::CheckDType();

		LogAssert(c.dims(0) >= 2, "Series must have maximum degree of at least 1.");

		if (c.dims(0) == 2) return 0.5 * c(0) / c(1);

		int n = c.dims(0) - 1;

		af::array cRev = c(n, 0, -1);
		af::array scl = af::constant(0.0, n, m_dType);
		af::array mat = af::constant(0.0, n, n, m_dType);

		af::array tIdx = af::seq(1, (n * n) - 1, n + 1);
		af::array bIdx = af::seq(n, (n * n) - 1, n + 1);

		af::array sclTmp = af::join(0.0, af::constant(1.0, 1, m_dType), 1.0 / af::sqrt(2.0 * af::seq(n - 1, 1, -1)).as(m_dType));

		// Multiply Accomulation
		Scalar t = 1.0;
		for (uint i = 0; i < n; i++)
		{
			t *= sclTmp(i).scalar<Scalar>();
			scl(i) = t;
		}
		scl = scl(af::seq(n - 1, 0, -1));
		mat(tIdx) = af::sqrt(0.5 * (af::range(af::dim4(n - 1), 0, m_dType) + 1));
		mat(bIdx) = mat(tIdx);

		return mat.T();
	}

	template<typename Scalar>
	af::array GaussHermiteQuadrature<Scalar>::NormedHermite(const af::array& x, int n)
	{
		af::dtype m_dType = CommonUtil<Scalar>::CheckDType();

		if (n == 0)
			return af::constant(1.0 / sqrt(sqrt(M_PI)), x.dims(), m_dType);

		af::array c0 = af::constant(0.0, x.dims(), m_dType);
		af::array c1 = af::constant(1.0 / sqrt(sqrt(M_PI)), x.dims(), m_dType);
		Scalar nd = (Scalar)n;
		for (int i = 0; i < n - 1; i++)
		{
			af::array tmp = c0;
			c0 = -c1 * sqrt((nd - 1.0) / nd);
			c1 = tmp + c1 * x * sqrt(2.0 / nd);
			nd = nd - 1.0;
		}
		return c0 + c1 * x * sqrt(2.0);
	}
}