/**
File:		MachineLearning/Util/CommonUtil<Scalar>.h

Author:		
Email:		
Site:       

Copyright (c) 2017 . All rights reserved.
*/

#include <NeMachineLearningPCH.h>
#include <MachineLearning/CommonUtil.h>
#include <cmath>
#include <math.h>
#include <iostream>
#include <string>
#include <fstream>

namespace NeuralEngine
{
	namespace MachineLearning
	{
		template class CommonUtil<float>;
		template class CommonUtil<double>;

		template<typename Scalar>
		Scalar CommonUtil<Scalar>::Euclidean(const af::array& a)
		{
			return std::sqrt(SquareEuclidean(a));
		}

		template<typename Scalar>
		af::array CommonUtil<Scalar>::Euclidean(const af::array& a, const af::array& b)
		{
			return af::sqrt(SquareEuclidean(a, b));
		}

		template<typename Scalar>
		af::array CommonUtil<Scalar>::SquareEuclidean(const af::array& a, const af::array& b)
		{
			af::array dist = a - b;
			return af::sum(dist * dist, 1);
		}

		template<typename Scalar>
		Scalar CommonUtil<Scalar>::SquareEuclidean(const af::array& a)
		{
			return af::sum<Scalar>(a * a);
		}

		template<typename Scalar>
		Scalar CommonUtil<Scalar>::Max(Scalar a, Scalar b, Scalar c)
		{
			if (a > b)
			{
				if (c > a)
					return c;
				return a;
			}
			else
			{
				if (c > b)
					return c;
				return b;
			}
		}

		template<typename Scalar>
		af::array CommonUtil<Scalar>::Covariance(const af::array& M)
		{
			return af::matmulTN(M, M) / (M.dims(0) - 1);
		}

		template<typename Scalar>
		af::array CommonUtil<Scalar>::CorrelationCoefficients(const af::array& inX, const af::array& inY)
		{
			af::array x = inX; 
			af::array y = inY;
			if (y.dims(0) > 0)
			{
				x = af::flat(x);
				y = af::flat(y);

				if (x.dims(0) != y.dims(0))
				{
					std::cout << "Corrcoef:XYmismatch. The lengths of X and Y must match." << std::endl;
					return 0;
				}
				x = af::join(0, x, y);
			}

			int n = x.dims(0);
			int m = x.dims(1);

			af::array r = Covariance(x);
			af::array d = af::sqrt(af::diag(r)); // sqrt first to avoid under/overflow
			d = af::matmulNT(d, d);
			r = r / d;

			// Fix up possible round-off problems, while preserving NaN: put exact 1 on the
			// diagonal, and limit off-diag to [-1,1].
			af::array t = af::where(af::abs(r) > 1);
			r(t) = r(t) / af::abs(r(t));
			r(af::seq(1, r.dims(0) - 1, m)) = af::sign(af::diag(r));

			return r;
		}

		template<typename Scalar>
		af::array CommonUtil<Scalar>::SquareDistance(const af::array& inX1, const af::array& inX2)
		{
			int ndata = inX1.dims(0); // number of rows
			int dimx = inX1.dims(1);  // number of collums
			int ncentres = inX2.dims(0);
			int dimc = inX2.dims(1);

			LogAssert(dimx == dimc, "Data dimension does not match dimension of centres!");

			/*af::array n2 = af::matmul(af::constant(1.0f, ncentres, 1), af::sum(af::pow(inX1, 2).T(), 0)).T() +
				af::matmul(af::constant(1.0f, ndata, 1), af::sum(af::pow(inX2, 2).T(), 0)) -
				2.0f * af::matmulNT(inX1, inX2);*/

			af::array n2 = af::tile(af::sum(af::pow(inX1, 2), 1), 1, ncentres) +
				af::tile(af::sum(af::pow(inX2, 2), 1).T(), ndata, 1) -
				2.0f * af::matmulNT(inX1, inX2);

			// Rounding errors occasionally cause negative entries in n2
			n2 = af::clamp(n2, 0.0f, af::Inf);

			return n2;
		}

		template<typename Scalar>
		af::array CommonUtil<Scalar>::UnscaledDistance(const af::array & inX1, const af::array & inX2)
		{
			return af::sqrt(SquareDistance(inX1, inX2));
		}

		template<typename Scalar>
		af::array CommonUtil<Scalar>::ScaledDistance(const af::array & inX1, const af::array & inX2, const af::array & inLengtScale)
		{
			LogAssert(inX1.dims(1) == inLengtScale.dims(0) || inX1.dims(1) == inLengtScale.dims(1), "Lengthscale has not dimensionality of X.");

			af::array scale = inLengtScale;

			if (inLengtScale.dims(0) > inLengtScale.dims(1)) scale = inLengtScale.T();
			return UnscaledDistance(inX1 / tile(scale, inX1.dims(0), 1), inX2 / tile(scale, inX2.dims(0), 1));
		}

		template<typename Scalar>
		af::array CommonUtil<Scalar>::NormalPDF(const af::array& inX, const af::array& inMu, const af::array& inSigma)
		{
			af::dtype m_dType = CommonUtil<Scalar>::CheckDType();
			af::array mu, Sigma;

			if (inMu.isempty())
				mu = af::constant(0.0, inX.dims(), m_dType);
			else
				mu = inMu;

			if (inSigma.isempty())
				Sigma = af::constant(1.0, inX.dims(), m_dType);
			else
				Sigma = inSigma;

			Sigma(Sigma <= 0) = af::NaN;

			return af::exp(-0.5 * af::pow((inX - mu) / Sigma, 2)) / (sqrt(2 * af::Pi) * Sigma);
		}

		template<typename Scalar>
		af::array CommonUtil<Scalar>::LogNormalPDF(const af::array& inX, const af::array& inMu, const af::array& inSigma)
		{
			af::dtype m_dType = CommonUtil<Scalar>::CheckDType();
			af::array x, mu, Sigma;

			if (inMu.isempty())
				mu = af::constant(0.0, inX.dims(), m_dType);
			else
				mu = inMu;

			if (inSigma.isempty())
				Sigma = af::constant(1.0, inX.dims(), m_dType);
			else
				Sigma = inSigma;

			x = inX;

			Sigma(Sigma <= 0) = af::NaN;

			x(x <= 0) = af::Inf;

			return exp(-0.5 * pow((af::log(x) - mu) / Sigma, 2)) / (x * sqrt(2 * af::Pi) * Sigma);
		}

		template<typename Scalar>
		af::array CommonUtil<Scalar>::NormalCDF(const af::array& inX, const af::array& inMu, const af::array& inSigma)
		{
			af::dtype m_dType = CommonUtil<Scalar>::CheckDType();
			af::array x, mu, Sigma;

			if (inMu.isempty())
				mu = af::constant(0.0, inX.dims(), m_dType);
			else
				mu = inMu;

			if (inSigma.isempty())
				Sigma = af::constant(1.0, inX.dims(), m_dType);
			else
				Sigma = inSigma;

			x = inX;

			//x(x < 0) = 0;

			af::array z = (x - mu) / Sigma;
			return 0.5 * af::erfc(-z / sqrt(2.0));
		}

		template<typename Scalar>
		af::array CommonUtil<Scalar>::LogNormalCDF(const af::array& inX, const af::array& inMu, const af::array& inSigma)
		{
			af::dtype m_dType = CommonUtil<Scalar>::CheckDType();
			af::array x, mu, Sigma;

			if (inMu.isempty())
				mu = af::constant(0.0, inX.dims(), m_dType);
			else
				mu = inMu;

			if (inSigma.isempty())
				Sigma = af::constant(1.0, inX.dims(), m_dType);
			else
				Sigma = inSigma;

			x = inX;

			x(x < 0) = 0;

			af::array z = (log(x) - mu) / Sigma;
			return 0.5 * af::erfc(-z / sqrt(2.0));
		}

		template<typename Scalar>
		af::array CommonUtil<Scalar>::TriUpperIdx(int numRows, int dimension)
		{
			af::array idx_i = af::constant(0, (numRows * (numRows + 1)) / 2, 1);

			int cnt = 0, idxCnt = 0, idxRow = 0, rowCnt = 0;
			switch (dimension)
			{
			case 0:
				for (int i = 0; i < numRows; i++)
				{
					for (int j = 0; j < numRows; j++)
					{
						if (j >= i)
						{
							if (j == i)
							{	
								idxRow = idxCnt;
								rowCnt = 0;
							}
							idx_i(cnt) = idxRow + rowCnt * numRows;
							cnt++;
							rowCnt++;
						}
						idxCnt++;
					}
				}
				break;
			case 1:
				for (int i = 0; i < numRows; i++)
					for (int j = 0; j < numRows; j++)
					{
						if (j <= i)
						{
							idx_i(cnt) = idxCnt;
							cnt++;
						}
						idxCnt++;
					}
				break;
			}
			

			return idx_i;
		}

		template<typename Scalar>
		af::array CommonUtil<Scalar>::TriLowerIdx(int numRows)
		{
			af::array idx_i = af::constant(0, (numRows * (numRows + 1)) / 2, 1);

			int cnt = 0, idxCnt = 0;
			for (int i = 0; i < numRows; i++)
				for (int j = 0; j < numRows; j++)
				{
					if (j >= i)
					{
						idx_i(cnt) = idxCnt;
						cnt++;
					}
					idxCnt++;
				}

			return idx_i;
		}

		template<typename Scalar>
		af::array CommonUtil<Scalar>::DiagIdx(int numRows)
		{
			af::array idx_i = af::constant(0, numRows, 1);

			int cnt = 0, idxCnt = 0;
			for (int i = 0; i < numRows; i++)
				for (int j = 0; j < numRows; j++)
				{
					if (i == j)
					{
						idx_i(cnt) = idxCnt;
						cnt++;
					}
					idxCnt++;
				}

			return idx_i;
		}

		template<typename Scalar>
		af::array CommonUtil<Scalar>::LinSpace(Scalar start, Scalar stop, int num, bool endpoint)
		{
			af::dtype dType = CommonUtil<Scalar>::CheckDType();

			af::array linspaced(num, dType);

			if (num == 0) { return linspaced; }
			if (num == 1)
			{
				linspaced(0) = start;
				return linspaced;
			}

			if (endpoint)
			{
				Scalar delta = (stop - start) / ((Scalar)num - 1.0f);

				for (int i = 0; i < num - 1; ++i)
				{
					linspaced(i) = (start + delta * i);
				}
				linspaced(num - 1) = stop; // I want to ensure that start and end
										  // are exactly the same as the input
			}
			else
			{
				Scalar delta = (stop - start) / ((Scalar)num);

				for (int i = 0; i < num; ++i)
				{
					linspaced(i) = (start + delta * i);
				}
			}

			return linspaced;
		}

		template<typename Scalar>
		af::array CommonUtil<Scalar>::JitChol(const af::array & inA)
		{
			int maxTries = 10;
			Scalar jitter = 0.0;
			af::array UC;
			for (int i = 0; i < maxTries; i++)
			{
				// Try --- need to check A is positive definite
				if (jitter == 0.0f)
				{
					jitter = af::abs(af::mean(af::diag(inA))).scalar<Scalar>() * 1e-6;
					if (af::cholesky(UC, inA) == 0) break;
				}
				else
				{
					//System.Console.WriteLine("Matrix is not positive definite in jitChol, adding {0} jitter.", jitter);
					if (af::cholesky(UC, inA + jitter * af::identity(inA.dims(0), inA.dims(1))) == 0) break;
					else
					{
						jitter *= 10;
						std::cout << "+";
					}

				}
			}
			return UC;
		}

		template<typename Scalar>
		Scalar CommonUtil<Scalar>::LogDet(af::array inA)
		{
			af::array tmp;
			// using the fact: logdet(A) = tr(log(A))
			if (inA.dims(0) == 1 && inA.dims(1) == 1)
			{
				return log(inA(0).scalar<Scalar>());
			}
			else
			{
				tmp = af::log(inA);
				tmp(af::isNaN(tmp) || af::isInf(tmp)) = 0.0;
				return af::sum<Scalar>(af::diag(tmp/*af::log(inA)*/));
			}

			//data(af::isNaN(data)) = 0.0;

			//int maxTries = 10;
			//af::array jitter;

			//for (int i = 0; i < maxTries; i++)
			//{
			//	try
			//	{
			//		// Try --- need to check A is positive definite
			//		if (jitter == 0f)
			//		{
			//			jitter = ILMath.abs(ILMath.mean(ILMath.diag(A))) * 1e-6;
			//			UC = ILMath.chol(A);
			//			break;
			//		}
			//		else
			//		{
			//			//System.Console.WriteLine("Matrix is not positive definite in jitChol, adding {0} jitter.", jitter);
			//			UC = ILMath.chol(A + jitter * ILMath.eye<double>(A.Size[0], A.Size[1]));
			//			break;
			//		}
			//	}
			//	catch (ILNumerics.Exceptions.ILArgumentException e)
			//	{
			//		if (System.Text.RegularExpressions.Regex.IsMatch(e.Message, "not positive definite",
			//			System.Text.RegularExpressions.RegexOptions.IgnoreCase))
			//		{
			//			System.Console.Write("+");
			//			jitter *= 10;
			//		}
			//	}
			//}
			//return UC;
		}

		template<typename Scalar>
		void CommonUtil<Scalar>::MergeMaps(std::map<std::string, af::array>& lhs, const std::map<std::string, af::array>& rhs)
		{
			std::map<std::string, af::array>::iterator lhsItr = lhs.begin();
			std::map<std::string, af::array>::const_iterator rhsItr = rhs.begin();

			while (lhsItr != lhs.end() && rhsItr != rhs.end()) 
			{
				/* If the rhs value is less than the lhs value, then insert it into the
				lhs map and skip past it. */
				if (rhsItr->first < lhsItr->first) 
				{
					lhs.insert(lhsItr, *rhsItr); // Use lhsItr as a hint.
					++rhsItr;
				}
				/* Otherwise, if the values are equal, overwrite the lhs value and move both
				iterators forward. */
				else if (rhsItr->first == lhsItr->first) 
				{
					lhsItr->second = rhsItr->second;
					++lhsItr; ++rhsItr;
				}
				/* Otherwise the rhs value is bigger, so skip past the lhs value. */
				else
					++lhsItr;

			}

			/* At this point we've exhausted one of the two ranges.  Add what's left of the
			rhs values to the lhs map, since we know there are no duplicates there. */
			lhs.insert(rhsItr, rhs.end());
		}

		template<typename Scalar>
		af::array CommonUtil<Scalar>::SortRows(af::array inA)
		{

			af::array result;
			af::array index;
			af::array condition;
			bool begin = false;

			int start, end;

			// basic row sort
			af::sort(result, index, inA(af::span, 0));
			inA(af::span, af::span) = inA(index, af::span);

			// recursive sortrows for remaining columns if
			// some elements of actual one are equal
			start = end = 0;
			for (int j = 1; j < inA.dims(0); j++)
			{
				condition = inA(j, 0) == inA(j - 1, 0);

				if (condition.as(f32).scalar<float>() > 0)
				{
					if (!begin)
					{
						start = j - 1;
						end = j;
						begin = true;
					}
					else end++;
				}
				else
				{
					if (begin)
					{
						if (inA.dims(1) != 1)
							inA(af::seq(start, end), af::seq(1, inA.dims(1) - 1)) = SortRows(inA(af::seq(start, end), af::seq(1, inA.dims(1) - 1)));
						begin = false;
					}
				}
				if (begin && j == inA.dims(0) - 1)
				{
					if (inA.dims(1) != 1)
						inA(af::seq(start, end), af::seq(1, inA.dims(1) - 1)) = SortRows(inA(af::seq(start, end), af::seq(1, inA.dims(1) - 1)));
					begin = false;
				}
			}

			return inA;
		}

		template<typename Scalar>
		af::array CommonUtil<Scalar>::Join(const af::array & inA, const af::array & inB, int dimension)
		{
			af::dtype m_dType = CommonUtil<Scalar>::CheckDType();

			af::array joint;

			if (inA.isempty())
			{
				joint = inB.copy();
			}
			else if (inB.isempty())
			{
				joint = inA.copy();
			}
			else
			{
				int A_Dim0 = inA.dims(0);
				int A_Dim1 = inA.dims(1);
				int B_Dim0 = inB.dims(0);
				int B_Dim1 = inB.dims(1);
				switch (dimension)
				{
				case 0:
					//if (A_Dim1 != B_Dim1) LogError("Length missmatch for join dimension");
					joint = af::constant(0.0, A_Dim0 + B_Dim0, A_Dim1, m_dType);
					joint.rows(0, A_Dim0 - 1) = inA.copy();
					joint.rows(A_Dim0, A_Dim0 + B_Dim0 - 1) = inB.copy();
					break;
				case 1:
					//if (A_Dim0 != B_Dim0) LogError("Length missmatch for join dimension");
					joint = af::constant(0.0, A_Dim0, A_Dim1 + B_Dim1, m_dType);

					joint.cols(0, A_Dim1 - 1) = inA.copy();
					joint.cols(A_Dim1, A_Dim1 + B_Dim1 - 1) = inB.copy();
					break;
				}
			}
			//af_print(joint);
			return joint;
		}

		template<typename Scalar>
		af::array CommonUtil<Scalar>::SolveQR(const af::array & A, const af::array & b)
		{
			af::array Q, r, tau, y;

			af::qr(Q, r, tau, A);
			y = af::matmulTN(Q, b);
			
			return af::solve(r, y, AF_MAT_UPPER);
		}

		template<typename Scalar>
		af::array CommonUtil<Scalar>::PDInverse(const af::array& inA)
		{
			af::dtype m_dType = CommonUtil<Scalar>::CheckDType();
			
			af::array v;
			af::array u;
			af::array vt;
			af::array idx;

			af::svd(u, v, vt, inA);

			return af::matmul(vt.T(), af::matmul(af::diag(1.0/v, 0, false).T(), u.T()));
		}

		template<typename Scalar>
		af::array CommonUtil<Scalar>::ReadTXT(std::string filename, char delimiter)
		{
			af::dtype m_dType = CommonUtil<Scalar>::CheckDType();

			std::ifstream  myfile(filename);
			std::string  line;
			std::vector<std::string> tokens;
			std::vector<std::vector<std::string>> vdata;
			int cnt;

			af::array data;

			if (myfile.is_open())
			{
				std::string token;
				std::stringstream iss;
				while (getline(myfile, line))
				{
					iss << line;
					cnt = 0;
					while (getline(iss, token, delimiter))
					{
						if (token != "")
						{
							tokens.push_back(token);
							cnt++;
						}
					}
					vdata.push_back(tokens);
					tokens.clear();
					iss.clear();
				}
				myfile.close();

				data = af::constant(0.0, vdata.size(), cnt, m_dType);
				for (uint i = 0; i < vdata.size(); i++)
					for (uint j = 0; j < cnt; j++)
						data(i, j) = atof(vdata[i][j].c_str());
			}

			return data;
		}

		template<typename Scalar>
		bool CommonUtil<Scalar>::WriteTXT(const af::array& data, std::string filename, char delimiter)
		{
			std::ofstream myfile(filename);
			bool success = false;

			if (myfile.is_open())
			{
				for (uint i = 0; i < data.dims(0); i++)
				{
					for (uint j = 0; j < data.dims(1); j++)
					{
						myfile << data(i, j).scalar<Scalar>();
						if (j != data.dims(1) - 1)
							myfile << delimiter;
						else
							myfile << "\n";
					}
				}
				myfile.close();
				success = true;
			}
			return success;
		}

		template<typename Scalar>
		bool CommonUtil<Scalar>::IsEqual(const af::array& a, const af::array& b)
		{
			bool issymmetric = true;
			auto n = a.dims(0);
			auto m = a.dims(1);

			if (a.dims() != b.dims()) return false;

			gfor(af::seq i, n)
			{
				for (int j = 0; j < m; j++)
				{
					issymmetric = ((a(i, j) == b(i, j)).as(f32).scalar<float>() > 0.0);
				}
			}
			return issymmetric;
		}

		template<typename Scalar>
		af::dtype CommonUtil<Scalar>::CheckDType()
		{
			af::dtype _type;
			if (is_Scalar<Scalar>::value) _type = f64;
			else if (is_float<Scalar>::value) _type = f32;
			//else LogError("Typename must be floating point");

			return _type;
		}
	}
}