import warnings
from dowel import logger
import math
import numpy as np
from scipy.sparse.linalg import LinearOperator
from scipy.sparse.linalg import eigsh
import torch
from torch.optim import Optimizer
from garage.np import unflatten_tensors

def _build_hessian_vector_product(func, params, reg_coeff=1e-5):
    """Computes Hessian-vector product using Pearlmutter's algorithm.
    `Pearlmutter, Barak A. "Fast exact multiplication by the Hessian." Neural
    computation 6.1 (1994): 147-160.`
    Args:
        func (callable): A function that returns a torch.Tensor. Hessian of
            the return value will be computed.
        params (list[torch.Tensor]): A list of function parameters.
        reg_coeff (float): A small value so that A -> A + reg*I.
    Returns:
        function: It can be called to get the final result.
    """
    param_shapes = [p.shape or torch.Size([1]) for p in params]
    f = func()
    f_grads = torch.autograd.grad(f, params, create_graph=True)

    def _eval(vector):
        """The evaluation function.
        Args:
            vector (torch.Tensor): The vector to be multiplied with
                Hessian.
        Returns:
            torch.Tensor: The product of Hessian of function f and v.
        """
        unflatten_vector = unflatten_tensors(vector, param_shapes)

        assert len(f_grads) == len(unflatten_vector)
        grad_vector_product = torch.sum(
            torch.stack(
                [torch.sum(g * x) for g, x in zip(f_grads, unflatten_vector)]))

        hvp = list(
            torch.autograd.grad(grad_vector_product, params,
                                retain_graph=True))
        for i, (hx, p) in enumerate(zip(hvp, params)):
            if hx is None:
                hvp[i] = torch.zeros_like(p)

        flat_output = torch.cat([h.reshape(-1) for h in hvp])
        return flat_output + reg_coeff * vector

    return _eval


class HSODMOptimizer(Optimizer):
    def __init__(self,
                 params,
                 homogeneous_param,
                 hvp_reg_coeff=1e-5):
        super().__init__(params, {})
        self._hvp_reg_coeff = hvp_reg_coeff
        self._max_backtracks = 25
        self._backtrack_ratio = 0.8
        self._accept_violation = False

        # for negative curvature
        self.homogeneous_param = homogeneous_param

    def hsodm_step(self, itr, f_loss, f_constraint):
        params = []
        grads = []
        for group in self.param_groups:
            for p in group['params']:
                if p.grad is not None:
                    params.append(p)
                    grads.append(p.grad.reshape(-1))
        g = torch.cat(grads)

        f_Ax = _build_hessian_vector_product(f_constraint, params,
                                             self._hvp_reg_coeff)
        h_Ax = _build_hessian_vector_product(f_loss, params,
                                             self._hvp_reg_coeff)
        print('* norm g is:')
        print(g.norm())
        print('------------')

        N = g.size(dim=0)
        g_numpy = g.detach().numpy()

        # identify the regularization parameter
        # TODO: use the adaptive strategy
        delta = self.homogeneous_param.delta
        self._max_constraint_value = 0.1

        # compute the leftmost eigenvector of the augumented matrix
        def mv(A, b, v):
            """
            Q = [A b; b' delta]
            Args:
                A: Augmented basic matrix
                b: vector
                v: the n+1 dimensional vector
            Returns:
                Q * v
            """
            v_first = v[0:N]
            v_first = torch.from_numpy(v_first)

            Fv = A(v_first).detach().numpy()
            v_first = v_first.detach().numpy()

            term1 = (Fv + b * v[N]).reshape((N, 1))
            term2 = (np.dot(b, v_first) - delta * v[N]).reshape((1, 1))
            res = np.concatenate((term1, term2), axis=0)
            return res

        if self.homogeneous_param.order == 1:
            # compute Q = [F g; g -δ]
            Augment_MV = LinearOperator((N + 1, N + 1), matvec=lambda v: mv(f_Ax, g_numpy, v))
            _, eigenvec = eigsh(Augment_MV, k=1, which='SA', tol=1e-5, return_eigenvectors=True)
        else:
            # use second-order augmented direction
            # compute generalized eigenvalue problem
            # [H g; g -δ] v = λ*[F 0; 0 0] v
            Fv = LinearOperator((N + 1, N + 1), matvec=lambda v: mv(f_Ax, np.zeros_like(g_numpy), v))
            Hv = LinearOperator((N + 1, N + 1), matvec=lambda v: mv(h_Ax, g_numpy, v))
            _, eigenvec = eigsh(Hv, M=Fv, k=1, which='SA', tol=1e-5, return_eigenvectors=True)

        eigenvec = eigenvec.reshape((N + 1,))

        # compute the direction
        if eigenvec[-1] != 0:
            homo_direction = eigenvec[0:N] / eigenvec[-1]
        else:
            homo_direction = eigenvec[0:N]

        # compute the norm of the homo direction
        norm_homo = np.linalg.norm(homo_direction)

        # baseline version, constant stepsize
        if self.homogeneous_param.line_search == 'const':
            
            # make stepsize a function of # itr
            # for Ant/HalfCheetah/Hopper/Walker
            # ini = self.homogeneous_param.Delta
            # self.homogeneous_param.Delta = 0.8 * 0.01 + ini * (1 + math.cos(math.pi * itr / 200) )
            print('homogeneous stepsize is:')
            print(self.homogeneous_param.Delta)

            if norm_homo > self.homogeneous_param.Delta:
                steps = homo_direction / norm_homo * self.homogeneous_param.Delta
            else:
                steps = homo_direction

            steps = torch.from_numpy(steps)
            param_shapes = [p.shape or torch.Size([1]) for p in params]
            steps = unflatten_tensors(steps, param_shapes)
            assert len(steps) == len(params)
            prev_params = [p.clone() for p in params]

            for step, prev_param, param in zip(steps, prev_params, params):
                step = step.type(torch.float32)
                new_param = prev_param.data + step
                param.data = new_param.data

        elif self.homogeneous_param.line_search == 'backtrack':

            if norm_homo > self.homogeneous_param.Delta:
                steps = homo_direction / norm_homo * self.homogeneous_param.Delta
            else:
                steps = homo_direction
            
            inner_iteration = self._backtracking_line_search(
                params,
                torch.from_numpy(steps).type(torch.float32),
                f_loss,
                f_constraint
            )
            print(f"backtrack {inner_iteration}")
        else:
            raise ValueError(f"unknown linear search {self.homogeneous_param.line_search}")

    """
    code snippet from https://github.com/rlworkgroup/garage/blob/master/src/garage/torch/optimizers/conjugate_gradient_optimizer.py
    """

    def _backtracking_line_search(
        self,
        params,
        descent_step,
        f_loss,
        f_constraint,
        backtrack_ratio=0.8,
        max_backtracks=25,
    ):
        prev_params = [p.clone() for p in params]
        ratio_list = backtrack_ratio ** torch.arange(max_backtracks, dtype=torch.float32)
        loss_before = f_loss()

        param_shapes = [p.shape or torch.Size([1]) for p in params]
        descent_step = unflatten_tensors(descent_step, param_shapes)
        assert len(descent_step) == len(params)
        innerstep = 0
        for ratio in ratio_list:
            innerstep += 1
            for step, prev_param, param in zip(descent_step, prev_params,
                                               params):
                step = ratio * step
                new_param = prev_param.data + step
                param.data = new_param.data

            loss = f_loss()
            constraint_val = f_constraint()
            if (loss < loss_before
                and constraint_val <= self._max_constraint_value):
                break

        print(constraint_val, self._max_constraint_value)
        if ((torch.isnan(loss) or torch.isnan(constraint_val)
             or loss >= loss_before
             or constraint_val >= self._max_constraint_value)
            and not self._accept_violation):
            print('Line search condition violated. Rejecting the step!')
            if torch.isnan(loss):
                print('Violated because loss is NaN')
            if torch.isnan(constraint_val):
                print('Violated because constraint is NaN')
            if loss >= loss_before:
                print('Violated because loss not improving')
            if constraint_val >= self._max_constraint_value:
                print('Violated because constraint is violated')
            for prev, cur in zip(prev_params, params):
                cur.data = prev.data
        return innerstep
