"""
From https://github.com/zhuchen03/MaxVA at commit e0de2436bd08d4af4322d015d3e96132e43e0d73

MAdam for large-batch BERT pretraining, compatible with
https://github.com/NVIDIA/DeepLearningExamples/tree/master/PyTorch/LanguageModeling/BERT.
"""

import math
import torch


class MAdam(torch.optim.Optimizer):

    def __init__(self, params, lr=1e-3, beta1=0.9, beta2_range=(0.5, 0.999), eps=1e-8,
                 weight_decay=0, amsgrad=False, moment_warmup=0, nesterov=False, adamw=True,
                 lamb=False, max_grad_norm=1.0):
        defaults = dict(lr=lr, beta1=beta1, beta2_range=beta2_range, eps=eps,
                        weight_decay=weight_decay, amsgrad=amsgrad,
                        moment_warmup=moment_warmup, nesterov=nesterov, adamw=adamw,
                        lamb=lamb)
        self.max_grad_norm = max_grad_norm
        self.last_step = 0
        super(MAdam, self).__init__(params, defaults)

    @property
    def update_size(self):
        if getattr(self, "update_size_", None) is not None:
            return None, None, self.update_size_
        else:
            return None, None, None

    def step(self, closure=None):
        """Performs a single optimization step.

        Arguments:
            closure (callable, optional): A closure that reevaluates the model
                and returns the loss.
        """
        loss = None
        if closure is not None:
            loss = closure()

        self.update_size_ = None

        total_grad_norm = 0
        grad_all_32 = []
        for group in self.param_groups:
            for p in group['params']:
                grad_all_32.append(p.grad.data.float())
                total_grad_norm += grad_all_32[-1].square().sum()
        total_grad_norm = torch.sqrt(total_grad_norm)
        clipped_ratio = self.max_grad_norm / max(self.max_grad_norm, total_grad_norm)

        gidx = 0
        for group in self.param_groups:
            if 'step' in group:
                group['step'] += 1
            else:
                group['step'] = 1

            for p in group['params']:
                if p.grad is None:
                    continue
                grad = grad_all_32[gidx] * clipped_ratio
                gidx += 1
                # grad = p.grad.data.float()
                if grad.is_sparse:
                    raise RuntimeError('Adam does not support sparse gradients, please consider SparseAdam instead')

                amsgrad = group['amsgrad']

                p_data_fp32 = p.data.float()

                if not group['adamw'] and group['weight_decay'] != 0:
                    grad.add_(group['weight_decay'], p_data_fp32)

                state = self.state[p]

                # State initialization
                if len(state) == 0:
                    #
                    # Exponential moving average of gradient values
                    state['exp_avg'] = torch.zeros_like(p_data_fp32)
                    # Exponential moving average of squared gradient values
                    state['exp_avg_sq'] = torch.zeros_like(p_data_fp32)
                    if amsgrad:
                        # Maintains max of all exp. moving avg. of sq. grad. values
                        state['max_exp_avg_sq'] = torch.zeros_like(p_data_fp32)
                    state['exp_avg_grad'] = torch.zeros_like(p_data_fp32)
                    state['total_w'] = torch.zeros_like(p_data_fp32)

                exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']
                if amsgrad:
                    max_exp_avg_sq = state['max_exp_avg_sq']
                beta1 = group['beta1']
                beta2_min, beta2_max = group['beta2_range']

                total_w = state['total_w']

                if group['step'] == 1 or beta2_max == beta2_min:
                    exp_avg.mul_(beta2_max).add_(1 - beta2_max, grad)
                    exp_avg_sq.mul_(beta2_max).addcmul_(1 - beta2_max, grad, grad)
                    state['total_w'] = 1 - beta2_max ** group['step']
                else:
                    # find the beta that maximize the variance
                    # beta is the multiplier for the new grad
                    exp_avg_sq_unbiased = exp_avg_sq / total_w
                    exp_avg_unbiased = exp_avg / total_w
                    moment_diff = exp_avg_sq_unbiased - exp_avg_unbiased ** 2
                    mean_diff_sq = (grad - exp_avg_unbiased) ** 2
                    sum_diff = mean_diff_sq + moment_diff
                    denominator = (mean_diff_sq - moment_diff).mul_(total_w).add_(sum_diff)

                    adv_beta = sum_diff.div_(denominator.add_(1e-16))

                    # clamp the range
                    adv_beta.clamp_(min=beta2_min, max=beta2_max)

                    adv_beta_comp = 1 - adv_beta
                    exp_avg.mul_(adv_beta).add_(adv_beta_comp * grad)
                    exp_avg_sq.mul_(adv_beta).add_(adv_beta_comp.mul(grad).mul_(grad))

                    state['total_w'] = state['total_w'] * adv_beta + adv_beta_comp

                if group['step'] <= group['moment_warmup']:
                    continue

                denom = (exp_avg_sq / state['total_w']).sqrt() + group['eps']

                if amsgrad:
                    torch.max(denom, max_exp_avg_sq, out=max_exp_avg_sq)
                    denom.copy_(max_exp_avg_sq)

                state['exp_avg_grad'].mul_(beta1).add_(grad, alpha=(1 - beta1))
                bias_correction0 = 1 - beta1 ** (group['step'] - group['moment_warmup'])

                if group['nesterov']:
                    exp_avg_grad = state['exp_avg_grad'] * beta1 + (1 - beta1) * grad
                else:
                    exp_avg_grad = state['exp_avg_grad']

                if group['lamb']:
                    if bias_correction0 < 1:
                        update_ = exp_avg_grad / denom / bias_correction0 + p_data_fp32 * group['weight_decay']
                    else:
                        update_ = exp_avg_grad / denom + p_data_fp32 * group['weight_decay']
                    trust_ratio = 1

                    if group['weight_decay'] > 0:
                        weight_norm = torch.norm(p_data_fp32)  # .clamp(0, 10)
                        update_norm = torch.norm(update_)
                        if weight_norm == 0 or update_norm == 0:
                            trust_ratio = 1
                        else:
                            trust_ratio = weight_norm / update_norm
                    p_data_fp32.add_(update_, alpha=-group['lr'] * trust_ratio)
                else:
                    step_size = group['lr'] / bias_correction0
                    if group['adamw'] and group['weight_decay'] > 0:
                        p_data_fp32.add_(- group['lr'] * group['weight_decay'], p_data_fp32)

                    if True:
                        update = - step_size * exp_avg_grad / denom
                        p_data_fp32.add_(update)
                        self.update_size_ = torch.mean(update.abs()).item()
                    else:
                        p_data_fp32.addcdiv_(-step_size, exp_avg_grad, denom)

                p.data.copy_(p_data_fp32)

        return loss

class LaMAdam(torch.optim.Optimizer):

    def __init__(self, params, lr=1e-3, momentum=0.9, beta=0.98, beta_min=0.5,
                    eps=1e-15, weight_decay=0, use_adamw=False, amsgrad=False):
        if not 0.0 <= lr:
            raise ValueError("Invalid learning rate: {}".format(lr))
        if not 0.0 <= eps:
            raise ValueError("Invalid epsilon value: {}".format(eps))
        if not 0.0 <= weight_decay:
            raise ValueError("Invalid weight_decay value: {}".format(weight_decay))

        defaults = dict(lr=lr, momentum=momentum, beta=beta, beta_min=beta_min,
                        eps=eps, weight_decay=weight_decay,
                        use_adamw=use_adamw, amsgrad=amsgrad)
        super(LaMAdam, self).__init__(params, defaults)

    @property
    def update_size(self):
        if getattr(self, "update_size_", None) is not None:
            return None, None, self.update_size_
        else:
            return None, None, None

    @torch.no_grad()
    def step(self, closure=None):
        """Performs a single optimization step.
        Arguments:
            closure (callable, optional): A closure that reevaluates the model
                and returns the loss.
        """
        loss = None
        if closure is not None:
            with torch.enable_grad():
                loss = closure()

        for group in self.param_groups:
            for p in group['params']:
                if p.grad is None:
                    continue
                grad = p.grad
                if grad.is_sparse:
                    raise RuntimeError('Adadelta does not support sparse gradients')
                state = self.state[p]

                # State initialization
                if len(state) == 0:
                    state['step'] = 0
                    state['update_lr_bc'] = 0.
                    state['update_est'] = torch.zeros_like(p, memory_format=torch.preserve_format)
                    state['g_sq_est'] = torch.zeros_like(p, memory_format=torch.preserve_format)
                    state['g_est'] = torch.zeros_like(p, memory_format=torch.preserve_format)
                    state['total_w'] = torch.zeros_like(p)

                    if group['amsgrad']:
                        state['max_sq_est'] = torch.zeros_like(p)

                update_est, g_sq_est = state['update_est'], state['g_sq_est']
                momentum, beta = group['momentum'], group['beta']

                state['step'] += 1

                if group['weight_decay'] != 0:
                    if group['use_adamw']:
                        p.data.add_(-group['weight_decay'] * group['lr'], p.data)
                    else:
                        grad.data.add_(p, alpha=group['weight_decay'])

                if state['step'] > 1 and group['beta_min'] != beta:
                    total_w = state['total_w']
                    g_sq_est_unbiased = g_sq_est / total_w
                    g_est_unbiased = state['g_est'] / total_w
                    moment_diff = g_sq_est_unbiased - g_est_unbiased ** 2
                    mean_diff_sq = (grad - g_est_unbiased) ** 2
                    sum_diff = mean_diff_sq + moment_diff
                    denominator = (mean_diff_sq - moment_diff).mul_(total_w).add_(sum_diff)

                    adv_beta = sum_diff.div_(denominator.add_(1e-16))
                    # clamp the range
                    adv_beta.clamp_(min=group['beta_min'], max=beta)

                    all_beta = adv_beta
                    all_beta_comp = 1 - all_beta

                    state['g_est'].mul_(all_beta).add_(all_beta_comp * grad)
                    g_sq_est.mul_(all_beta).add_(all_beta_comp.mul(grad).mul_(grad))
                    total_w.mul_(all_beta).add_(all_beta_comp)

                else:
                    g_sq_est.mul_(beta).addcmul_(grad, grad, value=1 - beta)
                    total_w = 1 - beta ** state['step']
                    if 'total_w' in state:
                        state['total_w'][:] = total_w
                        state['g_est'].mul_(beta).add_(1 - beta, grad)

                if getattr(group, 'amsgrad', False):
                    torch.max(state['max_sq_est'], g_sq_est, out=state['max_sq_est'])
                    g_sq_est = state['max_sq_est']

                denom = g_sq_est.div(total_w).sqrt_().add_(group['eps'])

                update_est.mul_(momentum).addcdiv_((1 - momentum) * group['lr'], grad, denom)
                state['update_lr_bc'] = state['update_lr_bc'] * momentum + group['lr'] * (1 - momentum)

                step_size = 0
                if state['update_lr_bc'] != 0:
                    step_size = group['lr'] / state['update_lr_bc']

                if True:
                    update = - step_size * update_est
                    p.add_(update)
                    self.update_size_ = torch.mean(update.abs()).item()
                else:
                    p.add_(-step_size, update_est)

        return loss
