/**
File:		MachineLearning/Optimization/Unconstrained/L-BFGS.cpp

Author:		
Email:		
Site:       

Copyright (c) 2019 . All rights reserved.
*/

#include <NeMachineLearningPCH.h>
#include <MachineLearning/FgLBFGSBsolver.h>
#include <cmath>
#include <math.h>
#include <limits>
#include <iomanip>

namespace NeuralEngine::MachineLearning
{
	template class LBFGSBSolver<float, ArmijoBacktracking>;
	template class LBFGSBSolver<float, ArmijoBracketing>;
	template class LBFGSBSolver<float, MoreThuente>;
	template class LBFGSBSolver<float, StrongWolfeBacktracking>;
	template class LBFGSBSolver<float, StrongWolfeBracketing>;
	template class LBFGSBSolver<float, WolfeBacktracking>;
	template class LBFGSBSolver<float, WolfeBracketing>;

	template class LBFGSBSolver<double, ArmijoBacktracking>;
	template class LBFGSBSolver<double, ArmijoBracketing>;
	template class LBFGSBSolver<double, MoreThuente>;
	template class LBFGSBSolver<double, StrongWolfeBacktracking>;
	template class LBFGSBSolver<double, StrongWolfeBracketing>;
	template class LBFGSBSolver<double, WolfeBacktracking>;
	template class LBFGSBSolver<double, WolfeBracketing>;

	template<typename Scalar, LineSearchType LSType>
	LBFGSBSolver<Scalar, LSType>::LBFGSBSolver(int numberOfVariables)
		: BaseGradientOptimizationMethod<Scalar, LSType>(numberOfVariables),
		m_historySize(5), theta(1.0)
	{
	}

	template<typename Scalar, LineSearchType LSType>
	LBFGSBSolver<Scalar, LSType>::LBFGSBSolver(int numberOfVariables, std::function<Scalar(const af::array&, af::array&)> function)
		: BaseGradientOptimizationMethod<Scalar, LSType>(numberOfVariables, function),
		m_historySize(5), theta(1.0)
	{
	}

	template<typename Scalar, LineSearchType LSType>
	LBFGSBSolver<Scalar, LSType>::LBFGSBSolver(NonlinearObjectiveFunction<Scalar> * function)
		: BaseGradientOptimizationMethod<Scalar, LSType>(function),
		m_historySize(5), theta(1.0)
	{
	}

	template<typename Scalar, LineSearchType LSType>
	void LBFGSBSolver<Scalar, LSType>::SetHistorySize(const int hs)
	{
		m_historySize = hs;
	}

	template<typename Scalar, LineSearchType LSType>
	LBFGSBSolver<Scalar, LSType>::~LBFGSBSolver()
	{
	}

	template<typename Scalar, LineSearchType LSType>
	bool LBFGSBSolver<Scalar, LSType>::Optimize(int* cycle)
	{
		af::array x0 = GetSolution();

		size_t DIM = x0.dims(0);
		size_t n = GetNumberOfVariables();
		theta = 1.0;

		W = af::constant((Scalar)0.0, n, 1, m_dtype);
		M = af::constant((Scalar)0.0, 1, 1, m_dtype);

		af::array yHistory;
		af::array sHistory;
		af::array x = x0.copy(), g = x0.copy();

		Scalar f = _function->Value(x);
		g = _function->Gradient(x);

		// conv. crit.
		Scalar optimality = GetOptimality(x, g);

		// print out some useful information, if specified
		if (_display)
		{
			std::cout << "Numerical Optimization via L-BFGS-B\n===================================\n\n";
			std::cout << "iter: " << iterations + 1 << "\t\t f(x): " << f << "\t\t optimality: " << optimality << std::endl;
		}

		if (cycle)
			*cycle = (iterations + 1);


		af::array x_old(m_dtype), g_old(m_dtype);
		while ((GetOptimality(x, g) > _tolerance) && (iterations < maxIterations))
		{
			Scalar f_old = f;
			x_old = x.copy();
			g_old = g.copy();

			// STEP 2: compute the cauchy point
			af::array CauchyPoint(m_dtype);
			af::array  c(m_dtype);
			GetGeneralizedCauchyPoint(x, g, CauchyPoint, c);

			// STEP 3: compute a search direction d_k by the primal method for the sub-problem
			af::array SubspaceMin;
			SubspaceMinimization(CauchyPoint, x, c, g, SubspaceMin);
			// STEP 4: perform linesearch and STEP 5: compute gradient
			Scalar alpha_init = 1.0;
			//linesearch->Linesearch(x, SubspaceMin - x, *_function, alpha_init);
			linesearch->Linesearch(*_function, f, x, g, alpha_init, SubspaceMin - x, x_old);

			//// update current guess and function information
			//x = x - rate * (x - SubspaceMin);
			//f = _function->Value(x);
			//g = _function->Gradient(x);

			// prepare for next iteration
			af::array newY = g - g_old;
			af::array newS = x - x_old;

			// STEP 6:
			Scalar test = matmulTN(newS, newY).scalar<Scalar>();
			if (abs(test) > 1e-7 * CommonUtil<Scalar>::SquareEuclidean(newY))
			{
				if (yHistory.dims(1) < m_historySize) 
				{
					yHistory = CommonUtil<Scalar>::Join(yHistory, newY, 1);
					sHistory = CommonUtil<Scalar>::Join(sHistory, newS, 1);
				}
				else 
				{
					yHistory.cols(0, m_historySize - 2) = yHistory.cols(1, m_historySize - 1);
					sHistory.cols(0, m_historySize - 2) = sHistory.cols(1, m_historySize - 1);
					yHistory.col(m_historySize - 1) = newY;
					sHistory.col(m_historySize - 1) = newS;
				}

				// STEP 7:
				theta = (matmulTN(newY, newY) / matmulTN(newY, newS)).scalar<Scalar>();
				W = CommonUtil<Scalar>::Join(yHistory, theta * sHistory, 1);
				af::array A = matmulTN(sHistory, yHistory);
				af::array L = af::lower(A);
				af::array diagIdx = CommonUtil<Scalar>::DiagIdx(A.dims(0));
				L(diagIdx) *= 0.0f;

				af::array D(m_dtype);
				if (A.dims(0) == 1) D = -A;
				else D = -tile(diag(A), 1, A.dims(1)) * identity(A.dims());

				af::array MM = CommonUtil<Scalar>::Join(CommonUtil<Scalar>::Join(D, L.T(), 1), CommonUtil<Scalar>::Join(L, theta * matmulTN(sHistory, sHistory), 1), 0);
				M = inverse(MM);
			}

			if (fabs(f_old - f) < 1e-8)
				// successive function values too similar
				break;

			++iterations;
			if (_display)
			{
				optimality = GetOptimality(x, g);
				std::cout << "iter: " << iterations + 1 << "\t\t f(x): " << f << "\t\t optimality: " << optimality << std::endl;
			}
			if (cycle)
				*cycle = (iterations + 1);
		}

		SetSolution(x);
		return true;
	}

	template<typename Scalar, LineSearchType LSType>
	Scalar LBFGSBSolver<Scalar, LSType>::GetOptimality(const af::array & x, const af::array & g)
	{
		af::array projected_g = x - g;

		projected_g = clamp(projected_g, _function->LowerBound(), _function->UpperBound());

		projected_g = projected_g - x;

		return (af::max(abs(projected_g))).scalar<Scalar>();
	}

	template<typename Scalar, LineSearchType LSType>
	std::vector<int> LBFGSBSolver<Scalar, LSType>::SortIndexes(const std::vector<std::pair<int, Scalar>>& v)
	{
		std::vector<int> idx(v.size());
		for (size_t i = 0; i != idx.size(); ++i)
			idx[i] = v[i].first;
		sort(idx.begin(), idx.end(), [&v](size_t i1, size_t i2) 
		{
			return v[i1].second < v[i2].second;
		});
		return idx;
	}

	template<typename Scalar, LineSearchType LSType>
	void LBFGSBSolver<Scalar, LSType>::GetGeneralizedCauchyPoint(const af::array & x, const af::array & g, af::array & x_cauchy, af::array & c)
	{
		const int DIM = x.dims(0);

		// Given x,l,u,g, and B = \theta I-WMW
		// {all t_i} = { (idx,value), ... }
		// TODO: use "std::set" ?
		std::vector<std::pair<int, Scalar> > SetOfT;

		// the feasible set is implicitly given by "SetOfT - {t_i==0}"
		af::array d = af::constant(0.0, DIM, m_dtype); //-g.copy();
		// n operations
		af::array condition;
		for (int j = 0; j < DIM; j++) 
		{
			if (g(j).scalar<Scalar>() == 0) 
				SetOfT.push_back(std::make_pair(j, std::numeric_limits<Scalar>::max()));
			else 
			{
				Scalar tmp = 0;
				condition = g(j) < 0;
				tmp = (condition.as(m_dtype) * ((x(j) - _function->UpperBound()(j)) / g(j)) 
					+ (!condition).as(m_dtype) * ((x(j) - _function->LowerBound()(j)) / g(j))).scalar<Scalar>();
				/*if (g(j) < 0) {
					tmp = (x(j) - _function.UpperBound()(j)) / g(j);
				}
				else {
					tmp = (x(j) - _function.LowerBound()(j)) / g(j);
				}*/
				SetOfT.push_back(std::make_pair(j, tmp));
				//if (tmp == 0) d(j) = 0;
				d(j) = -1 * g(j);
			}
		}
		// sortedindices [1,0,2] means the minimal element is on the 1-st entry
		std::vector<int> sortedIndices = SortIndexes(SetOfT);
		x_cauchy = x.copy();
		// Initialize
		// p :=     W^Scalar*p
		af::array p = matmulTN(W, d);						// (2mn operations)
															// c :=     0
		c = af::constant((Scalar)0.0, W.dims(1), m_dtype);
		// f' :=    g^Scalar*d = -d^Td
		Scalar f_prime = -matmulTN(d, d).scalar<Scalar>();	// (n operations)

		// f'' :=   \theta*d^Scalar*d-d^Scalar*W*M*W^Scalar*d = -\theta*f' - p^Scalar*M*p
		Scalar f_doubleprime = (Scalar)(-1.0 * theta) * f_prime - matmulTN(p, matmul(M, p)).scalar<Scalar>(); // (O(m^2) operations)
		f_doubleprime = std::max<Scalar>(std::numeric_limits<Scalar>::epsilon(), f_doubleprime);

		Scalar f_dp_orig = f_doubleprime;
		// \delta t_min :=  -f'/f''
		Scalar dt_min = -f_prime / f_doubleprime;
		// t_old :=     0
		Scalar t_old = 0;
		// b :=     argmin {t_i , t_i >0}
		int i = 0;
		for (int j = 0; j < DIM; j++) 
		{
			i = j;
			if (SetOfT[sortedIndices[j]].second != 0)
				break;
		}
		int b = sortedIndices[i];
		// see below
		// t                    :=  min{t_i : i in F}
		Scalar t = SetOfT[b].second;
		// \delta Scalar             :=  t - 0
		Scalar dt = t - t_old;
		// examination of subsequent segments
		while ((dt_min >= dt) && (i < DIM)) 
		{
			condition = d(b) > 0;
			x_cauchy(b) = (condition).as(m_dtype) * _function->UpperBound()(b).as(m_dtype) + (!condition).as(m_dtype) * x_cauchy(b);
			condition = d(b) < 0;
			x_cauchy(b) = (condition).as(m_dtype) * _function->LowerBound()(b).as(m_dtype) + (!condition).as(m_dtype) * x_cauchy(b);

			/*if (d(b).scalar<Scalar>() > 0)
				x_cauchy(b) = _function->UpperBound()(b);
			else if (d(b).scalar<Scalar>() < 0)
				x_cauchy(b) = _function->LowerBound()(b);*/

			// z_b = x_p^{cp} - x_b
			Scalar zb = (x_cauchy(b) - x(b)).scalar<Scalar>();

			// c   :=  c +\delta t*p
			c += dt * p;

			// cache
			af::array wbt = W.row(b);
			f_prime += dt * f_doubleprime + (g(b) * g(b) + theta * g(b) * zb - g(b) *
				matmul(wbt, matmul(M, c))).scalar<Scalar>();

			f_doubleprime += -1.0 * theta * (g(b) * g(b)
				- 2.0 * (g(b) * (matmul(wbt, (matmul(M, p)))))
				- g(b) * g(b) * matmul(wbt, matmul(M, wbt))).scalar<Scalar>();
			f_doubleprime = std::max<Scalar>(std::numeric_limits<Scalar>::epsilon() * f_dp_orig, f_doubleprime);

			p += tile(g(b), p.dims(0)) * wbt.T();
			d(b) = 0;
			dt_min = -f_prime / f_doubleprime;
			t_old = t;
			++i;
			if (i < DIM) 
			{
				b = sortedIndices[i];
				t = SetOfT[b].second;
				dt = t - t_old;
			}
		}
		dt_min = std::max<Scalar>(dt_min, (Scalar)0.0);
		t_old += dt_min;

		for (int ii = i; ii < x_cauchy.dims(0); ii++)
			x_cauchy(sortedIndices[ii]) = x(sortedIndices[ii]) + t_old * d(sortedIndices[ii]);

		c += dt_min * p;
	}

	template<typename Scalar, LineSearchType LSType>
	Scalar LBFGSBSolver<Scalar, LSType>::FindAlpha(af::array & x_cp, af::array & du, std::vector<int>& FreeVariables)
	{
		Scalar alphastar = 1;
		const unsigned int n = FreeVariables.size();
		LogAssert(du.dims(0) == n, "Dimension mismatch");
		for (unsigned int i = 0; i < n; i++) {
			if (du(i).scalar<Scalar>() > 0)
				alphastar = std::min<Scalar>(alphastar, ((_function->UpperBound()(FreeVariables[i]) - x_cp(FreeVariables[i])) / du(i)).scalar<Scalar>());
			else
				alphastar = std::min<Scalar>(alphastar, ((_function->LowerBound()(FreeVariables[i]) - x_cp(FreeVariables[i])) / du(i)).scalar<Scalar>());
		}
		return alphastar;
	}

	template<typename Scalar, LineSearchType LSType>
	void LBFGSBSolver<Scalar, LSType>::SubspaceMinimization(af::array & x_cauchy, af::array & x, af::array & c, af::array & g, af::array & SubspaceMin)
	{
		Scalar theta_inverse = 1 / theta;
		std::vector<int> FreeVariablesIndex;

		af::array condition;
		for (int i = 0; i < x_cauchy.dims(0); i++) 
		{
			condition = (x_cauchy(i) != _function->UpperBound()(i)) && (x_cauchy(i) != _function->LowerBound()(i));
			if (condition.as(m_dtype).scalar<Scalar>() > 0)
				FreeVariablesIndex.push_back(i);
		}

		const int FreeVarCount = FreeVariablesIndex.size();
		af::array WZ = af::constant((Scalar)0.0, W.dims(1), FreeVarCount, m_dtype);
		for (int i = 0; i < FreeVarCount; i++)
			WZ.col(i) = W.row(FreeVariablesIndex[i]);

		af::array rr = (g + theta * (x_cauchy - x) - matmul(W, matmul(M, c)));

		// r=r(FreeVariables);
		af::array r = af::constant((Scalar)0.0, FreeVarCount, 1, m_dtype);
		for (int i = 0; i < FreeVarCount; i++)
			r.row(i) = rr.row(FreeVariablesIndex[i]);

		// STEP 2: "v = w^T*Z*r" and STEP 3: "v = M*v"
		af::array v = matmul(M, matmul(WZ, r));

		// STEP 4: N = 1/theta*W^T*Z*(W^T*Z)^T
		af::array N = theta_inverse * matmulNT(WZ, WZ);;
		// N = I - MN
		N = af::identity(N.dims(0), N.dims(0)) - matmul(M, N);

		// STEP: 5
		// v = N^{-1}*v
		if (v.dims(0) > 0)
			v = CommonUtil<Scalar>::SolveQR(N, v);
			//v = matmul(inverse(N), v);

		// STEP: 6
		// HERE IS A MISTAKE IN THE ORIGINAL PAPER!
		af::array du = -theta_inverse * r - theta_inverse * theta_inverse * matmulTN(WZ, v);

		// STEP: 7
		Scalar alpha_star = FindAlpha(x_cauchy, du, FreeVariablesIndex);

		// STEP: 8
		af::array dStar = alpha_star * du;
		SubspaceMin = x_cauchy.copy();
		for (int i = 0; i < FreeVarCount; i++) {
			SubspaceMin(FreeVariablesIndex[i]) = SubspaceMin(FreeVariablesIndex[i]) + dStar(i);
		}
	}
}