import math
import warnings
from typing import Callable, Iterable, Tuple

import torch
from torch import nn
from torch.optim import Optimizer

from transformers.utils.versions import require_version


class AdamW(Optimizer):
    """
    Implements Adam algorithm with weight decay fix as introduced in [Decoupled Weight Decay
    Regularization](https://arxiv.org/abs/1711.05101).

    Parameters:
        params (`Iterable[nn.parameter.Parameter]`):
            Iterable of parameters to optimize or dictionaries defining parameter groups.
        lr (`float`, *optional*, defaults to 0.001):
            The learning rate to use.
        betas (`Tuple[float,float]`, *optional*, defaults to `(0.9, 0.999)`):
            Adam's betas parameters (b1, b2).
        eps (`float`, *optional*, defaults to 1e-06):
            Adam's epsilon for numerical stability.
        weight_decay (`float`, *optional*, defaults to 0.0):
            Decoupled weight decay to apply.
        correct_bias (`bool`, *optional*, defaults to `True`):
            Whether or not to correct bias in Adam (for instance, in Bert TF repository they use `False`).
        no_deprecation_warning (`bool`, *optional*, defaults to `False`):
            A flag used to disable the deprecation warning (set to `True` to disable the warning).
    """

    def __init__(
            self,
            params: Iterable[nn.parameter.Parameter],
            lr: float = 1e-3,
            betas: Tuple[float, float] = (0.9, 0.999),
            eps: float = 1e-6,
            weight_decay: float = 0.0,
            correct_bias: bool = True,
            no_deprecation_warning: bool = False,
    ):
        if not no_deprecation_warning:
            warnings.warn(
                "This implementation of AdamW is deprecated and will be removed in a future version. Use the PyTorch"
                " implementation torch.optim.AdamW instead, or set `no_deprecation_warning=True` to disable this"
                " warning",
                FutureWarning,
            )
        require_version("torch>=1.5.0")  # add_ with alpha
        if lr < 0.0:
            raise ValueError(f"Invalid learning rate: {lr} - should be >= 0.0")
        if not 0.0 <= betas[0] < 1.0:
            raise ValueError(f"Invalid beta parameter: {betas[0]} - should be in [0.0, 1.0)")
        if not 0.0 <= betas[1] < 1.0:
            raise ValueError(f"Invalid beta parameter: {betas[1]} - should be in [0.0, 1.0)")
        if not 0.0 <= eps:
            raise ValueError(f"Invalid epsilon value: {eps} - should be >= 0.0")
        defaults = {"lr": lr, "betas": betas, "eps": eps, "weight_decay": weight_decay, "correct_bias": correct_bias}
        super().__init__(params, defaults)

    @torch.no_grad()
    def step(self, closure: Callable = None):
        """
        Performs a single optimization step.

        Arguments:
            closure (`Callable`, *optional*): A closure that reevaluates the model and returns the loss.
        """
        loss = None
        if closure is not None:
            loss = closure()

        for group in self.param_groups:

            if 'lora' in group:
                precon_type = group['precon_type']
                beta1, beta2 = group["betas"]

                # loop through lora_B, lora_A jointly
                for p_B, p_A in list(zip(group["params"], group["params"][1:]))[::2]:
                # for p_B, p_A in group['params']:
                    assert p_B.grad is not None and p_A.grad is not None
                    grad_B = p_B.grad
                    grad_A = p_A.grad

                    state_B = self.state[p_B]
                    state_A = self.state[p_A]

                    if "step" not in state_B:
                        state_B["step"] = 0
                    if "step" not in state_A:
                        state_A["step"] = 0

                    # State initialization
                    if "exp_avg" not in state_B:
                        # Exponential moving average of gradient values
                        state_B["exp_avg"] = torch.zeros_like(grad_B)
                        # Exponential moving average of squared gradient values
                        state_B["exp_avg_sq"] = torch.zeros_like(grad_B)
                    if "exp_avg" not in state_A:
                        # Exponential moving average of gradient values
                        state_A["exp_avg"] = torch.zeros_like(grad_A)
                        # Exponential moving average of squared gradient values
                        state_A["exp_avg_sq"] = torch.zeros_like(grad_A)

                    exp_avg_B, exp_avg_sq_B = state_B["exp_avg"], state_B["exp_avg_sq"]
                    exp_avg_A, exp_avg_sq_A = state_A["exp_avg"], state_A["exp_avg_sq"]

                    state_B["step"] += 1
                    state_A["step"] += 1

                    # ========== preconditioner for B ============
                    if precon_type == 'diag':
                        grad_cond = grad_B.data / (p_A.data.mm(p_A.data.t())).diag().clamp(min=1e-6)
                    elif precon_type == 'norm':
                        grad_cond = grad_B.data / (p_A.data * p_A.data).sum().clamp(min=1e-6)
                    else:
                        grad_cond = p_B.grad

                    exp_avg_B.mul_(beta1).add_(grad_cond, alpha=(1.0 - beta1))
                    exp_avg_sq_B.mul_(beta2).addcmul_(grad_cond, grad_cond, value=1.0 - beta2)
                    denom = exp_avg_sq_B.sqrt().add_(group["eps"])

                    step_size = group["lr"]
                    if group["correct_bias"]:  # No bias correction for Bert
                        bias_correction1 = 1.0 - beta1 ** state_B["step"]
                        bias_correction2 = 1.0 - beta2 ** state_B["step"]
                        step_size_B = step_size * math.sqrt(bias_correction2) / bias_correction1

                    # compute norm gradient
                    norm_grad_B = exp_avg_B / denom

                    # ========== preconditioner for A ===============
                    # preconditioner
                    if precon_type == 'diag':
                        grad_cond = 1 / (p_B.data.t().mm(p_B.data)).diag().clamp(min=1e-6).unsqueeze(1) * grad_A.data
                    elif precon_type == 'norm':
                        grad_cond = grad_A.data / (p_B.data * p_B.data).sum().clamp(min=1e-6)
                    else:
                        grad_cond = grad_A

                    exp_avg_A.mul_(beta1).add_(grad_cond, alpha=(1.0 - beta1))
                    exp_avg_sq_A.mul_(beta2).addcmul_(grad_cond, grad_cond, value=1.0 - beta2)
                    denom = exp_avg_sq_A.sqrt().add_(group["eps"])

                    step_size = group["lr"]
                    if group["correct_bias"]:  # No bias correction for Bert
                        bias_correction1 = 1.0 - beta1 ** state_A["step"]
                        bias_correction2 = 1.0 - beta2 ** state_A["step"]
                        step_size_A = step_size * math.sqrt(bias_correction2) / bias_correction1

                    # compute norm gradient
                    norm_grad_A = exp_avg_A / denom

                    # ====== update =====
                    p_B.add_(norm_grad_B, alpha=-step_size_B)
                    p_A.add_(norm_grad_A, alpha=-step_size_A)

                    if group["weight_decay"] > 0.0:
                        p_B.add_(p_B, alpha=(-group["lr"] * group["weight_decay"]))

                    if group["weight_decay"] > 0.0:
                        p_A.add_(p_A, alpha=(-group["lr"] * group["weight_decay"]))

            else:
                for p in group["params"]:
                    if p.grad is None:
                        continue
                    grad = p.grad
                    if grad.is_sparse:
                        raise RuntimeError("Adam does not support sparse gradients, please consider SparseAdam instead")

                    state = self.state[p]

                    if "step" not in state:
                        state["step"] = 0

                    # State initialization
                    if "exp_avg" not in state:
                        # Exponential moving average of gradient values
                        state["exp_avg"] = torch.zeros_like(grad)
                        # Exponential moving average of squared gradient values
                        state["exp_avg_sq"] = torch.zeros_like(grad)

                    exp_avg, exp_avg_sq = state["exp_avg"], state["exp_avg_sq"]
                    beta1, beta2 = group["betas"]

                    state["step"] += 1

                    # Decay the first and second moment running average coefficient
                    # In-place operations to update the averages at the same time
                    exp_avg.mul_(beta1).add_(grad, alpha=(1.0 - beta1))
                    exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1.0 - beta2)
                    denom = exp_avg_sq.sqrt().add_(group["eps"])

                    step_size = group["lr"]
                    if group["correct_bias"]:  # No bias correction for Bert
                        bias_correction1 = 1.0 - beta1 ** state["step"]
                        bias_correction2 = 1.0 - beta2 ** state["step"]
                        step_size = step_size * math.sqrt(bias_correction2) / bias_correction1

                    # compute norm gradient
                    norm_grad = exp_avg / denom

                    p.add_(norm_grad, alpha=-step_size)

                    # Just adding the square of the weights to the loss function is *not*
                    # the correct way of using L2 regularization/weight decay with Adam,
                    # since that will interact with the m and v parameters in strange ways.
                    #
                    # Instead we want to decay the weights in a manner that doesn't interact
                    # with the m/v parameters. This is equivalent to adding the square
                    # of the weights to the loss with plain (non-momentum) SGD.
                    # Add weight decay at the end (fixed version)
                    if group["weight_decay"] > 0.0:
                        p.add_(p, alpha=(-group["lr"] * group["weight_decay"]))

        return loss
