import torch
import math
from torch.nn.utils import clip_grad_norm_, clip_grad_value_
from typing import Union, Iterable, Tuple, Callable, List

inf = math.inf


def grad_ignore_norm(parameters, max_norm, norm_type=2):
    """
    Overview:
        Clip the gradient norm of an iterable of parameters.
    Arguments:
        - parameters (:obj:`Iterable`): an iterable of torch.Tensor
        - max_norm (:obj:`float`): the max norm of the gradients
        - norm_type (:obj:`float`): 2.0 means use norm2 to clip
    """
    if isinstance(parameters, torch.Tensor):
        parameters = [parameters]
    parameters = list(filter(lambda p: p.grad is not None, parameters))
    max_norm = float(max_norm)
    norm_type = float(norm_type)
    if norm_type == inf:
        total_norm = max(p.grad.data.abs().max() for p in parameters)
    else:
        total_norm = 0
        for p in parameters:
            param_norm = p.grad.data.norm(norm_type)
            total_norm += param_norm.item() ** norm_type
        total_norm = total_norm ** (1. / norm_type)
    clip_coef = max_norm / (total_norm + 1e-6)
    if clip_coef < 1:
        for p in parameters:
            p.grad.zero_()
    return total_norm


def grad_ignore_value(parameters, clip_value):
    """
    Overview:
        Clip the gradient value of an iterable of parameters.
    Arguments:
        - parameters (:obj:`Iterable`): an iterable of torch.Tensor
        - clip_value (:obj:`float`): the value to start clipping
    """
    if isinstance(parameters, torch.Tensor):
        parameters = [parameters]
    clip_value = float(clip_value)
    flag = False
    for p in filter(lambda p: p.grad is not None, parameters):
        val = p.grad.data.abs().max()
        if val >= clip_value:
            flag = True
            break
    if flag:
        for p in filter(lambda p: p.grad is not None, parameters):
            p.grad.data.zero_()


class Adam(torch.optim.Adam):
    """
    Overview:
        Rewrited Adam optimizer to support more features.
    Interfaces:
        ``__init__``, ``step``, ``_state_init``, ``get_grad``
    """

    def __init__(
        self,
        params: Iterable,
        lr: float = 1e-3,
        betas: Tuple[float, float] = (0.9, 0.999),
        eps: float = 1e-8,
        weight_decay: float = 0,
        amsgrad: bool = False,
        optim_type: str = 'adam',
        grad_clip_type: str = None,
        clip_value: Union[float, None] = None,
        clip_coef: float = 5,
        clip_norm_type: float = 2.0,
        clip_momentum_timestep: int = 100,
        grad_norm_type: str = None,
        grad_ignore_type: str = None,
        ignore_value: Union[float, None] = None,
        ignore_coef: float = 5,
        ignore_norm_type: float = 2.0,
        ignore_momentum_timestep: int = 100,
    ):
        """
        Overview:
            init method of refactored Adam class
        Arguments:
            - params (:obj:`iterable`):  – an iterable of torch.Tensor s or dict s. \
                Specifies what Tensors should be optimized
            - lr (:obj:`float`): learning rate, default set to 1e-3
            - betas (:obj:`Tuple[float, float]`): coefficients used for computing running averages of gradient and its\
                square, default set to (0.9, 0.999))
            - eps (:obj:`float`): term added to the denominator to improve numerical stability, default set to 1e-8
            - weight_decay (:obj:`float`): weight decay coefficient, deault set to 0
            - amsgrad (:obj:`bool`): whether to use the AMSGrad variant of this algorithm from the paper\
                On the Convergence of Adam and Beyond <https://arxiv.org/abs/1904.09237>
            - optim_type (:obj:str): support ["adam", "adamw"]
            - grad_clip_type (:obj:`str`): support [None, 'clip_momentum', 'clip_value', 'clip_norm', \
                'clip_momentum_norm']
            - clip_value (:obj:`float`): the value to start clipping
            - clip_coef (:obj:`float`): the cliping coefficient
            - clip_norm_type (:obj:`float`): 2.0 means use norm2 to clip
            - clip_momentum_timestep (:obj:`int`): after how many step should we start the momentum clipping
            - grad_ignore_type (:obj:`str`): support [None, 'ignore_momentum', 'ignore_value', 'ignore_norm', \
                'ignore_momentum_norm']
            - ignore_value (:obj:`float`): the value to start ignoring
            - ignore_coef (:obj:`float`): the ignoreing coefficient
            - ignore_norm_type (:obj:`float`): 2.0 means use norm2 to ignore
            - ignore_momentum_timestep (:obj:`int`): after how many step should we start the momentum ignoring

        """

        self._support_type = {
            'optim': ['adam', 'adamw'],
            'grad_clip': [None, 'clip_momentum', 'clip_value', 'clip_norm', 'clip_momentum_norm'],
            'grad_norm': [None],
            'grad_ignore': [None, 'ignore_momentum', 'ignore_value', 'ignore_norm', 'ignore_momentum_norm'],
        }

        assert optim_type in self._support_type['optim']
        assert grad_clip_type in self._support_type['grad_clip']
        assert grad_norm_type in self._support_type['grad_norm']
        assert grad_ignore_type in self._support_type['grad_ignore']
        if grad_clip_type:
            assert clip_value is not None
        if grad_ignore_type:
            assert ignore_value is not None

        self._optim_type = optim_type
        self._grad_clip_type = grad_clip_type
        self._grad_norm_type = grad_norm_type
        self._grad_ignore_type = grad_ignore_type
        self._clip_value = clip_value
        self._clip_norm_type = clip_norm_type
        self._clip_coef = clip_coef
        self._ignore_value = ignore_value
        self._ignore_norm_type = ignore_norm_type
        self._ignore_coef = ignore_coef
        self._clip_momentum_timestep = clip_momentum_timestep
        self._ignore_momentum_timestep = ignore_momentum_timestep

        if self._optim_type == 'adamw':
            self._weight_decay = weight_decay
            super(Adam, self).__init__(params, lr=lr, betas=betas, eps=eps, weight_decay=0, amsgrad=amsgrad)
        elif self._optim_type == 'adam':
            super(Adam, self).__init__(params, lr=lr, betas=betas, eps=eps, weight_decay=weight_decay, amsgrad=amsgrad)
        else:
            raise NotImplementedError(
                "optimizer type {} is not implemented, support type is {}".format(
                    self._optim_type, self._support_type['optim']
                )
            )

    def _state_init(self, p, amsgrad):
        """
        Overview:
            Initialize the state of the optimizer
        Arguments:
            - p (:obj:`torch.Tensor`): the parameter to be optimized
            - amsgrad (:obj:`bool`): whether to use the AMSGrad variant of this algorithm from the paper\
                On the Convergence of Adam and Beyond <https://arxiv.org/abs/1904.09237>
        """
        state = self.state[p]
        state['thre_exp_avg_sq'] = torch.zeros_like(p.data, device=p.data.device)
        # others
        if torch.__version__ < "1.12.0":
            state['step'] = 0
            # TODO
            # wait torch upgrad to 1.4, 1.3.1 didn't support memory format state['step'] = 0
        else:
            state['step'] = torch.zeros((1,), dtype=torch.float, device=p.device) \
                if self.defaults['capturable'] else torch.tensor(0.)

        state['exp_avg'] = torch.zeros_like(p.data)
        # Exponential moving average of squared gradient values
        state['exp_avg_sq'] = torch.zeros_like(p.data)
        if amsgrad:
            # Maintains max of all exp. moving avg. of sq. grad. values
            state['max_exp_avg_sq'] = torch.zeros_like(p.data)

    def step(self, closure: Union[Callable, None] = None):
        """
        Overview:
            Performs a single optimization step
        Arguments:
            - closure (:obj:`callable`): A closure that reevaluates the model and returns the loss, default set to None
        """
        # clipping
        new_params = [
            t for group in self.param_groups for t in group['params'] if t.requires_grad and t.grad is not None
        ]
        if self._grad_clip_type == 'clip_value':
            clip_grad_value_(new_params, self._clip_value)
        elif self._grad_clip_type == 'clip_norm':
            clip_grad_norm_(new_params, self._clip_value, self._clip_norm_type)
        elif self._grad_clip_type == 'clip_momentum':
            '''
            This is the implimentation mimic the clip used in OPENAI, quote:
                'Gradients are additionally clipped per parameter to be within between ±5√v
                 where v is the running estimate of the second moment of the (unclipped) gradient'
            '''
            for group in self.param_groups:
                for p in group['params']:
                    if p.grad is None:
                        continue
                    state = self.state[p]
                    if len(state) == 0:
                        self._state_init(p, group['amsgrad'])
                    grad = p.grad.data
                    # should we use same beta group?
                    beta1, beta2 = group['betas']
                    bias_correction2 = 1 - beta2 ** state['step']
                    state['thre_exp_avg_sq'].mul_(beta2).addcmul_(1 - beta2, grad, grad)
                    if state['step'] >= self._clip_momentum_timestep:  # initial value is inaccurate
                        flag = grad.abs(
                        ) > (state['thre_exp_avg_sq'].sqrt() / math.sqrt(bias_correction2)) * self._clip_coef
                        grad.mul_(~flag).add_(
                            ((state['thre_exp_avg_sq'].sqrt() / math.sqrt(bias_correction2)) *
                             self._clip_coef).mul_(flag)
                        )
        elif self._grad_clip_type == 'clip_momentum_norm':
            # might have multi param_group, we should calculate each group differently.
            for group in self.param_groups:
                total_norm = 0
                total_momentum_norm = 0
                step = inf
                for p in group['params']:
                    if p.grad is None:
                        continue
                    state = self.state[p]
                    if len(state) == 0:
                        self._state_init(p, group['amsgrad'])
                    grad = p.grad.data
                    # should we use same beta group?
                    beta1, beta2 = group['betas']
                    bias_correction2 = 1 - beta2 ** state['step']
                    state['thre_exp_avg_sq'].mul_(beta2).addcmul_(1 - beta2, grad, grad)
                    # sum total_norm
                    param_norm = grad.norm(self._clip_norm_type)
                    total_norm += param_norm.item() ** self._clip_norm_type

                    # sum momentum_norm
                    momentum = ((state['thre_exp_avg_sq'].sqrt() / math.sqrt(bias_correction2)) *
                                self._clip_coef).norm(self._clip_norm_type)
                    total_momentum_norm += momentum.item() ** self._clip_norm_type
                    step = min(step, state['step'])
                if step > self._clip_momentum_timestep:
                    total_norm = total_norm ** (1. / self._clip_norm_type)
                    total_momentum_norm = total_momentum_norm ** (1. / self._clip_norm_type)
                    clip_coef = total_momentum_norm / (total_norm + 1e-6)
                    if clip_coef < 1:
                        for p in group['params']:
                            p.grad.data.mul_(clip_coef)

        if self._grad_ignore_type == 'ignore_value':
            grad_ignore_value(new_params, self._ignore_value)
        elif self._grad_ignore_type == 'ignore_norm':
            grad_ignore_norm(new_params, self._ignore_value, self._ignore_norm_type)
        elif self._grad_ignore_type == 'ignore_momentum':
            flag = False
            for group in self.param_groups:
                for p in group['params']:
                    if p.grad is None:
                        continue
                    state = self.state[p]
                    if len(state) == 0:
                        self._state_init(p, group['amsgrad'])
                    grad = p.grad.data
                    # should we use same beta group?
                    beta1, beta2 = group['betas']
                    bias_correction2 = 1 - beta2 ** state['step']
                    state['thre_exp_avg_sq'].mul_(beta2).addcmul_(1 - beta2, grad, grad)
                    if state['step'] >= self._ignore_momentum_timestep:  # initial value is inaccurate
                        if grad.abs() > (state['thre_exp_avg_sq'].sqrt() /
                                         math.sqrt(bias_correction2)) * self._ignore_coef:
                            flag = True
                            break
                else:
                    continue
                break

            if flag:
                for group in self.param_groups:
                    for p in group['params']:
                        if p.grad is None:
                            continue
                        p.grad.zero_()
        elif self._grad_ignore_type == 'ignore_momentum_norm':
            # might have multi param_group, we should calculate each group differently.
            step = inf
            for group in self.param_groups:
                total_norm = 0
                total_momentum_norm = 0
                for p in group['params']:
                    if p.grad is None:
                        continue
                    state = self.state[p]
                    if len(state) == 0:
                        self._state_init(p, group['amsgrad'])
                    grad = p.grad.data
                    # should we use same beta group?
                    beta1, beta2 = group['betas']
                    bias_correction2 = 1 - beta2 ** state['step']
                    state['thre_exp_avg_sq'].mul_(beta2).addcmul_(1 - beta2, grad, grad)
                    # sum total_norm
                    param_norm = grad.norm(self._ignore_norm_type)
                    total_norm += param_norm.item() ** self._ignore_norm_type

                    # sum momentum_norm
                    momentum = ((state['thre_exp_avg_sq'].sqrt() / math.sqrt(bias_correction2)) *
                                self._ignore_coef).norm(self._ignore_norm_type)
                    total_momentum_norm += momentum.item() ** self._ignore_norm_type
                    step = min(step, state['step'])

                if step > self._ignore_momentum_timestep:
                    total_norm = total_norm ** (1. / self._ignore_norm_type)
                    total_momentum_norm = total_momentum_norm ** (1. / self._ignore_norm_type)
                    ignore_coef = total_momentum_norm / (total_norm + 1e-6)
                    if ignore_coef < 1:
                        for p in group['params']:
                            p.grad.zero_()

        # Adam optim type
        if self._optim_type == 'adamw':
            for group in self.param_groups:
                for p in group['params']:
                    if p.grad is None:
                        continue
                    p.data = p.data.add(-self._weight_decay * group['lr'], p.data)
            return super().step(closure=closure)
        elif self._optim_type == 'adam':
            return super().step(closure=closure)

    def get_grad(self) -> float:
        total_norm = 0.
        params = [t for group in self.param_groups for t in group['params'] if t.requires_grad and t.grad is not None]
        for p in params:
            param_norm = p.grad.data.norm(self._clip_norm_type)
            total_norm += param_norm.item() ** self._clip_norm_type
        return total_norm