import torch
import math

from torch import Tensor
from typing import List, Optional

from .optimizer import LowBitOptimizer
from ..config import get_config

__all__ = ["AdamW"]


class AdamW(LowBitOptimizer):
    def __init__(
        self,
        params,
        lr=1e-3,
        betas=(0.9, 0.999),
        eps=1e-8,
        weight_decay=1e-2,
        qconfig=None,
        use_first_moment=True,
        factor_second_moment= True,
        amsgrad=False,
        *,
        maximize: bool = False,
        foreach: Optional[bool] = None,
        capturable: bool = False,
        differentiable: bool = False,
        fused: Optional[bool] = None,
    ):
        if not 0.0 <= lr:
            raise ValueError("Invalid learning rate: {}".format(lr))
        if not 0.0 <= eps:
            raise ValueError("Invalid epsilon value: {}".format(eps))
        if not 0.0 <= betas[0] < 1.0:
            raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0]))
        if not 0.0 <= betas[1] < 1.0:
            raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1]))
        if not 0.0 <= weight_decay:
            raise ValueError("Invalid weight_decay value: {}".format(weight_decay))
        defaults = dict(
            lr=lr,
            betas=betas,
            eps=eps,
            weight_decay=weight_decay,
            amsgrad=amsgrad,
            foreach=foreach,
            maximize=maximize,
            capturable=capturable,
            differentiable=differentiable,
            fused=fused,
            use_first_moment=use_first_moment,
            factor_second_moment=factor_second_moment,
        )
        if qconfig is None:
            qconfig = get_config(None)
        super().__init__(params, defaults, qconfig)
        self.qstate_name_list = ['exp_avg', 'exp_avg_sq']

    def __setstate__(self, state):
        super().__setstate__(state)
        for group in self.param_groups:
            group.setdefault("amsgrad", False)
            group.setdefault("maximize", False)
            group.setdefault("foreach", None)
            group.setdefault("capturable", False)
            group.setdefault("differentiable", False)
            group.setdefault("fused", None)
        state_values = list(self.state.values())
        step_is_tensor = (len(state_values) != 0) and torch.is_tensor(
            state_values[0]["step"]
        )
        if not step_is_tensor:
            for s in state_values:
                s["step"] = torch.tensor(float(s["step"]))

    def get_subqconfig(self, optimizer_state_name):
        if optimizer_state_name == 'exp_avg':
            return self.qconfig.QUANT.M
        elif optimizer_state_name == 'exp_avg_sq':
            return self.qconfig.QUANT.SQM
        else:
            raise ValueError(
                f""
            )

    @staticmethod
    def _get_options(param_group, param_shape):
        factored = len(param_shape) >= 2 and param_group["factor_second_moment"]
        use_first_moment = param_group["use_first_moment"]
        return factored, use_first_moment
    
    @staticmethod
    def _approx_sq_grad(exp_avg_sq_row, exp_avg_sq_col):
        # copy from fairseq's adafactor implementation:
        # https://github.com/huggingface/transformers/blob/8395f14de6068012787d83989c3627c3df6a252b/src/transformers/optimization.py#L505
        r_factor = (exp_avg_sq_row / exp_avg_sq_row.mean(dim=-1, keepdim=True)).unsqueeze(-1)
        c_factor = exp_avg_sq_col.unsqueeze(-2)
        return torch.mul(r_factor, c_factor)

    @torch.no_grad()
    def step(self, closure=None):
        """Performs a single optimization step.

        Args:
            closure (Callable, optional): A closure that reevaluates the model
                and returns the loss.
        """

        loss = None
        if closure is not None:
            with torch.enable_grad():
                loss = closure()

        for group in self.param_groups:
            beta1, beta2 = group["betas"]
            lr = group["lr"]
            weight_decay = group["weight_decay"]
            eps = group["eps"]

            for p in group["params"]:
                if p.grad is None:
                    continue
                grad = p.grad.data
                if grad.dtype in {torch.float16, torch.bfloat16}:
                    grad = grad.float()
                if p.grad.is_sparse:
                    raise RuntimeError("AdamW does not support sparse gradients")

                state = self.state[p]
                grad_shape = p.grad.shape

                factored, use_first_moment = self._get_options(group, grad_shape)
                # State initialization
                if len(state) == 0:
                    # note(crcrpar): Deliberately host `step` on CPU if both capturable and fused are off.
                    # This is because kernel launches are costly on CUDA and XLA.
                    state["step"] = 0
                    # Exponential moving average of gradient values
                    if use_first_moment:
                        state["exp_avg"] = torch.tensor(0.0)
                    # Exponential moving average of squared gradient values
                    if factored:
                        state["exp_avg_sq_row"] = torch.zeros(grad_shape[:-1]).to(grad)
                        state["exp_avg_sq_col"] = torch.zeros(grad_shape[:-2] + grad_shape[-1:]).to(grad)
                    else:
                        state["exp_avg_sq"] = torch.tensor(0.0)
                    # quantization state
                    self.init_qstate(p)

                # take out optimizer state
                param = p
                # dequantize
                if use_first_moment:
                    exp_avg = state["exp_avg"]
                    if exp_avg.numel() <= 1:
                        exp_avg.data = torch.zeros_like(param, memory_format=torch.preserve_format)
                    else:
                        hat_exp_avg = self.dequantize(param, 'exp_avg', exp_avg)
                        if hat_exp_avg is not None:
                            exp_avg.data = hat_exp_avg
                        del hat_exp_avg
                else:
                    exp_avg = grad
                if factored:
                    exp_avg_sq_row = state["exp_avg_sq_row"]
                    exp_avg_sq_col = state["exp_avg_sq_col"]
                else:
                    exp_avg_sq = state["exp_avg_sq"]
                    if exp_avg_sq.numel() <= 1:
                        exp_avg_sq.data = torch.zeros_like(param, memory_format=torch.preserve_format)
                    else:
                        hat_exp_avg_sq = self.dequantize(param, 'exp_avg_sq', exp_avg_sq)
                        if hat_exp_avg_sq is not None:
                            exp_avg_sq.data = hat_exp_avg_sq
                        del hat_exp_avg_sq

                # update
                state["step"] += 1
                # Perform stepweight decay
                param.mul_(1 - lr * weight_decay)

                # Decay the first and second moment running average coefficient
                if use_first_moment:
                    exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1)
                if factored:
                    update = (grad ** 2)
                    exp_avg_sq_row.mul_(beta2).add_(update.mean(dim=-1), alpha=1 - beta2)
                    exp_avg_sq_col.mul_(beta2).add_(update.mean(dim=-2), alpha=1 - beta2)
                    exp_avg_sq = self._approx_sq_grad(exp_avg_sq_row, exp_avg_sq_col)
                else:
                    exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2)

                step = state["step"]
                bias_correction1 = 1 - beta1 ** step
                bias_correction2 = 1 - beta2 ** step
                step_size = lr / bias_correction1
                bias_correction2_sqrt = _dispatch_sqrt(bias_correction2)

                denom = (exp_avg_sq.sqrt() / bias_correction2_sqrt).add_(eps)
                param.addcdiv_(exp_avg, denom, value=-step_size)

                # take in optimizer state
                if use_first_moment:
                    q_exp_avg = self.quantize(param, 'exp_avg', exp_avg)
                    if q_exp_avg is not None:
                        exp_avg.data = q_exp_avg
                if not factored:
                    q_exp_avg_sq = self.quantize(param, 'exp_avg_sq', exp_avg_sq)
                    if q_exp_avg_sq is not None:
                        exp_avg_sq.data = q_exp_avg_sq

        return loss


def _dispatch_sqrt(x: float):  # float annotation is needed because of torchscript type inference
    if not torch.jit.is_scripting() and isinstance(x, torch.Tensor):
        return x.sqrt()
    else:
        return math.sqrt(x)