#*
# Authors: Anonymous
# This file is part of OASIS library.
#
# This file is based on the AdaHessian repository
# https://github.com/amirgholami/adahessian
#*

import torch
from torch.optim import Optimizer
import math
import copy

from IPython.core.debugger import set_trace

class AdaAdaHessian(Optimizer):
    r'''Implements AdaAdaHessian algorithm.

    It has been proposed in "citation".

    Args:
        params (iterable): iterable of parameters to optimize or dicts defining
            parameter groups
        beta (float, optional): coefficient for updating the Hessian diagonal
            estimate
        alpha (float, optional): lower bound on the Hessian diagonal elements
        zeta (float, optional): damping for eta choice
        gamma (float, optional): multiplier of theta
        lr_damping (float, optional): multiplier for lr in the update formula
        lr (float, optional): starting value for eta
        eps (float, optional): eps for numerical stability
        warmstart_samples_fbatch (int, optional): number of samples used for the
            initial Hessian diagonal estimation on the first batch
    '''

    def __init__(self, params, beta=0.999, alpha=1e-3, zeta=0.5, gamma=1.0, lr_damping=1.0, lr=1e-2,
                 weight_decay=5e-4, warmstart_samples_fbatch = 1, eps=1e-8):
        if not 0.0 <= beta < 1.0:
            raise ValueError("Invalid beta value: {}".format(beta))
        if not 0.0 <= alpha:
            raise ValueError("Invalid alpha value: {}".format(alpha))
        if not 0.0 <= zeta:
            raise ValueError("Invalid zeta value: {}".format(zeta))
        if not 0.0 <= gamma:
            raise ValueError("Invalid gamma value: {}".format(gamma))
        if not 0.0 < lr_damping:
            raise ValueError("Invalid eta damping value: {}".format(lr))
        if not 0.0 < lr:
            raise ValueError("Invalid eta value: {}".format(lr))
        if not 0.0 <= weight_decay:
            raise ValueError("Invalid weight decay coefficient value: {}".format(weight_decay))
        if not 0 < warmstart_samples_fbatch:
            raise ValueError("Invalid number of samples for first batch: {}".format(warmstart_samples_fbatch))
        if not 0.0 < eps:
            raise ValueError("Invalid eps value: {}".format(eps))

        defaults = dict(beta=beta, alpha=alpha, zeta=zeta, gamma=gamma, lr_damping=lr_damping,
                        lr=lr, weight_decay=weight_decay, eps=eps)
        super(AdaAdaHessian, self).__init__(params, defaults)

        # counter used in accumulate_h_diag method for warmstarting
        self.warmstarted_fbatch = False
        self.warmstart_counter_idbatch = 0
        self.warmstart_counter_fbatch = 0
        self.warmstart_samples_fbatch = warmstart_samples_fbatch

    def get_trace(self, params, grads, n_samples=1):
        r'''Compute the Hessian vector product with a random vector v, at the current
        gradient point, i.e., compute the gradient of <gradsH, v>.

        Args:
            params (iterable): a list of torch variables
            grads (iterable): a list of gradients
            n_samples (int, optional): numer of samples used to obtain
                Hessian diagonal estimate
        '''

        # check backward was called with create_graph set to True, else you can't differentiate grads
        for i, grad in enumerate(grads):
            if grad.grad_fn is None:
                raise RuntimeError('Gradient tensor {:} does not have grad_fn. When calling\n'.format(i) +
                           '\t\t\t  loss.backward(), make sure the option create_graph is\n' +
                           '\t\t\t  set to True.')

        vs = [2 * torch.randint_like(p, high=2, device='cuda') - 1 for p in params]

        # * is interpreted as component-wise multiplication
        hvs = torch.autograd.grad(grads,
                                  params,
                                  grad_outputs=vs,
                                  only_inputs=True,
                                  retain_graph=True)
        vhvs = [v * hv for (v, hv) in zip(vs, hvs)]

        # averaging samples to get a better estimate
        counter = 1
        for i in range(n_samples - 1):
            vs = [2 * torch.randint_like(p, high=2, device='cuda') - 1 for p in params]

            hvs_next = torch.autograd.grad(grads,
                                  params,
                                  grad_outputs=vs,
                                  only_inputs=True,
                                  retain_graph=True)
            vhvs_next = [v * hv_next for (v, hv_next) in zip(vs, hvs_next)]

            vhvs = [vhv * counter / (counter + 1) + vhv_next / (counter + 1) for (vhv, vhv_next) in zip(vhvs, vhvs_next)]
            counter += 1

        return vhvs

    def accumulate_h_diag(self):
        r'''Accumulating diagonal estimates from different batches

        '''
        params = []
        grads = []

        for group in self.param_groups:
            for p in group['params']:
                if p.grad is not None:
                    params.append(p)
                    grads.append(p.grad)

            if 'exp_h_diag_avg' not in group:
                group['exp_h_diag_avg'] = self.get_trace(params, grads)
            else:
                group['exp_h_diag_avg'] = [
                    prev_h_diag.mul_(self.warmstart_counter_idbatch / (self.warmstart_counter_idbatch + 1)).add_(
                    h_diag.mul_(1 / (self.warmstart_counter_idbatch + 1))) for
                    (prev_h_diag, h_diag) in zip(group['exp_h_diag_avg'], self.get_trace(params, grads))
                ]

        self.warmstart_counter_idbatch += 1

    def compute_dif_norms(self, prev_optimizer):
        r'''Weighted norm for torch.tensor's

        Args:
            prev_optimizer (Optimizer object): optimizer with a state, corresponding to a previous
                step in the model training, bvia which we can take previous params and grads values
        '''
        # obviously, relies on the correct passing of the prev_optimizer!
        for group, prev_group in zip(self.param_groups, prev_optimizer.param_groups):
            # views are just in case, since matmul behavious is dependent on the
            # dimensions of the arguments
            grad_dif_norm = 0
            param_dif_norm = 0

            for (d, p, p_prev) in zip(group['exp_h_diag_avg'],
                                      group['params'],
                                      prev_group['params']):

                grad_dif_norm += torch.sum(((p.grad.data - p_prev.grad.data) ** 2) / torch.clamp(torch.abs(d), min=group['alpha']))
                param_dif_norm += torch.sum(((p.data - p_prev.data) ** 2) * torch.clamp(torch.abs(d), min=group['alpha']))

            group['grad_dif_norm'] = torch.sqrt(grad_dif_norm)
            group['param_dif_norm'] = torch.sqrt(param_dif_norm)

    def update_damping(self, rho):
        for group in self.param_groups:
            group['lr_damping'] *= rho

    def step(self, closure=None):
        r'''Performs a single optimization step.

        Arguments:
            closure (callable, optional): a closure that reevaluates the model
                and returns the loss.
        '''
        loss = None
        if closure is not None:
            with torch.enable_grad():
                loss = closure()

        # need to pass lists to get_trace() to get rid of undesirable generator behaviour
        params = []
        grads = []

        for group in self.param_groups:
            for p in group['params']:
                if p.grad is not None:
                    params.append(p)
                    grads.append(p.grad)

        # getting diagonal estimates from the first batch
        if not self.warmstarted_fbatch:
            h_diags = self.get_trace(params, grads, n_samples = self.warmstart_samples_fbatch)
            self.warmstarted_fbatch = True
        else:
            h_diags = self.get_trace(params, grads, n_samples = self.warmstart_samples_fbatch)

        for group in self.param_groups:
            # if tree for different warmstart options
            if 'first_step' not in group:
                if 'exp_h_diag_avg' not in group:
                    group['exp_h_diag_avg'] = h_diags
                else:
                    group['exp_h_diag_avg'] = [prev_h_diag.mul_(group['beta']).add_(h_diag.mul_(1 - group['beta'])) for
                                           (prev_h_diag, h_diag) in zip(group['exp_h_diag_avg'], h_diags)]
                group['theta'] = float('Inf')
                group['first_step'] = True
            else:
                # exponential moving average of h_diag
                group['exp_h_diag_avg'] = [prev_h_diag.mul_(group['beta']).add_(h_diag.mul_(1 - group['beta'])) for
                                           (prev_h_diag, h_diag) in zip(group['exp_h_diag_avg'], h_diags)]
                group['first_step'] = False

            # updating eta if we can
            if ('param_dif_norm' in group) and ('grad_dif_norm' in group):
                eta_prev = group['lr']
                group['lr'] = min(math.sqrt(1 + group['theta'] / group['gamma']) * group['lr'],
                                   group['zeta'] * group['param_dif_norm'] / group['grad_dif_norm'])

                group['theta'] = group['lr'] / eta_prev

            for (p, grad, h_diag) in zip(params, grads, group['exp_h_diag_avg']):

                with torch.no_grad():
                    p.data.add_(
                        torch.mul(
                            (grad.detach_() / torch.clamp(torch.abs(h_diag), min=group['alpha'])).add_(
                                torch.mul(
                                    p.data,
                                    group['weight_decay']
                                )
                            ),
                            -group['lr_damping'] * group['lr']
                        )
                    )

        return loss