import torch
import math

from torch import Tensor
from typing import List, Optional, Tuple
import time

from transformers.utils.versions import require_version

from .optimizer import LowBitOptimizer
from ..functional import vectorwise_dequant, vectorwise_quant

__all__ = ["Adafactor"]


class Adafactor(LowBitOptimizer):
    def __init__(
        self,
        params,
        lr=None,
        eps=(1e-30, 1e-3),
        clip_threshold=1.0,
        decay_rate=-0.8,
        beta1=None,
        weight_decay=0.0,
        scale_parameter=True,
        relative_step=True,
        warmup_init=False,
        qconfig=None,
        is_adafactor_quantized=True, # apply quantization to low-rank approximation of second moments
        is_model_quantized=True,
        use_error_feedback=False,
        *,
        fused: Optional[bool] = False,
    ):
        use_first_moment = beta1 is not None

        require_version("torch>=1.5.0")  # add_ with alpha
        if lr is not None and relative_step:
            raise ValueError("Cannot combine manual `lr` and `relative_step=True` options")
        if warmup_init and not relative_step:
            raise ValueError("`warmup_init=True` requires `relative_step=True`")
        if not 0.0 <= weight_decay:
            raise ValueError("Invalid weight_decay value: {}".format(weight_decay))

        defaults = dict(
            lr=lr,
            eps=eps,
            clip_threshold=clip_threshold,
            decay_rate=decay_rate,
            beta1=beta1,
            weight_decay=weight_decay,
            scale_parameter=scale_parameter,
            relative_step=relative_step,
            warmup_init=warmup_init,
            use_first_moment=use_first_moment,
            is_adafactor_quantized=is_adafactor_quantized,
            is_model_quantized=is_model_quantized,
            use_error_feedback=use_error_feedback,
            fused=fused,
        )
        super().__init__(params, defaults, qconfig)

    def __setstate__(self, state):
        super().__setstate__(state)
        for group in self.param_groups:
            group.setdefault("fused", None)
        state_values = list(self.state.values())
        step_is_tensor = (len(state_values) != 0) and torch.is_tensor(
            state_values[0]["step"]
        )
        if not step_is_tensor:
            for s in state_values:
                s["step"] = torch.tensor(float(s["step"]))

    def get_subqconfig(self, optimizer_state_name):
        if optimizer_state_name in ['exp_avg', 'model']:
            return self.qconfig.QUANT.M
        elif optimizer_state_name == 'exp_avg_sq': # we can make this better if we separate exp_avg_sq and exp_avg_sq_factored
            return self.qconfig.QUANT.SQM    # SQM
        else:
            raise ValueError(
                f""
            )

    @staticmethod
    def _get_lr(param_group, param_state):
        rel_step_sz = param_group["lr"]
        if param_group["relative_step"]:
            min_step = 1e-6 * param_state["step"] if param_group["warmup_init"] else 1e-2
            rel_step_sz = min(min_step, 1.0 / math.sqrt(param_state["step"]))
        param_scale = 1.0
        if param_group["scale_parameter"]:
            param_scale = max(param_group["eps"][1], param_state["RMS"])
        return param_scale * rel_step_sz

    @staticmethod
    def _get_options(param_group, param_shape):
        factored = len(param_shape) >= 2
        # use_first_moment = param_group["beta1"] is not None
        return factored
    
    @staticmethod
    def _rms(tensor):
        return tensor.norm(2) / (tensor.numel() ** 0.5)

    @staticmethod
    def _approx_sq_grad(exp_avg_sq_row, exp_avg_sq_col):
        # copy from fairseq's adafactor implementation:
        # https://github.com/huggingface/transformers/blob/8395f14de6068012787d83989c3627c3df6a252b/src/transformers/optimization.py#L505
        r_factor = (exp_avg_sq_row / exp_avg_sq_row.mean(dim=-1, keepdim=True)).rsqrt_().unsqueeze(-1)
        c_factor = exp_avg_sq_col.unsqueeze(-2).rsqrt()
        return torch.mul(r_factor, c_factor)

    @torch.no_grad()
    def step(self, closure=None):
        """Performs a single optimization step.

        Args:
            closure (Callable, optional): A closure that reevaluates the model
                and returns the loss.
        """

        loss = None
        if closure is not None:
            with torch.enable_grad():
                loss = closure()

        for group in self.param_groups:
            for p in group["params"]:
                if p.grad is None:
                    continue
                grad = p.grad
                if grad.is_sparse:
                    raise RuntimeError("Adafactor does not support sparse gradients")
                # Do we need to include the code below?
                if grad.dtype in {torch.float16, torch.bfloat16}:
                    grad = grad.float()

                state = self.state[p]
                # grad_shape = grad.shape

                factored = self._get_options(group, grad.shape)
                # State initialization
                if len(state) == 0:
                    # note(crcrpar): Deliberately host `step` on CPU if both capturable and fused are off.
                    # This is because kernel launches are costly on CUDA and XLA.
                    state["step"] = torch.tensor(0.0)
                    if group["use_first_moment"]:
                        # Exponential moving average of gradient values
                        state["exp_avg"] = torch.zeros((), dtype=torch.float, device=p.device)
                        # state["exp_avg"] = torch.zeros_like(grad)
                        self.init_qstate(p, "exp_avg")
                    # Exponential moving average of squared gradient values
                    if factored:
                        state["exp_avg_sq_row"] = torch.zeros(grad.shape[:-1], device=p.device)
                        state["exp_avg_sq_col"] = torch.zeros(grad.shape[:-2] + grad.shape[-1:], device=p.device)
                        # state["exp_avg_sq_row"] = torch.zeros(grad.shape[:-1]).to(grad)
                        # state["exp_avg_sq_col"] = torch.zeros(grad.shape[:-2] + grad.shape[-1:]).to(grad)
                    else:
                        if group["is_adafactor_quantized"]:
                            state["exp_avg_sq"] = torch.zeros((), dtype=torch.float, device=p.device)
                        else:
                            state["exp_avg_sq"] = torch.zeros_like(grad)
                    
                    if group["is_adafactor_quantized"]:
                        self.init_qstate(p, "exp_avg_sq")

                    if group["is_model_quantized"]:
                        self.init_qstate(p, "model")

                        model_qmetadata = self.get_qmetadata_by_state_name("model")
                        qx, gen = vectorwise_quant(p, qmap=state["model_qstate"]["qmap"], shape=p.shape, **model_qmetadata)
                        state["model_qstate"]["overhead"].update(gen)
                        state["model_state"] = qx
                        # params.append(qx)
                        if group["use_error_feedback"]:
                            state["error_feedback"] = torch.zeros_like(grad)

                    state["RMS"] = 0
                # # do we need this part?
                else:
                    if group["use_first_moment"]:
                        state["exp_avg"] = state["exp_avg"].to(grad)
                    if factored:
                        state["exp_avg_sq_row"] = state["exp_avg_sq_row"].to(grad)
                        state["exp_avg_sq_col"] = state["exp_avg_sq_col"].to(grad)
                    else:
                        if not group["is_adafactor_quantized"]:
                            state["exp_avg_sq"] = state["exp_avg_sq"].to(grad)
                    if group["use_error_feedback"]:
                        state["error_feedback"] = state["error_feedback"].to(grad)
                        # print(state["error_feedback"])

                if group["is_model_quantized"]:
                    param = state["model_state"]
                else:
                    param = p

                if factored:
                    exp_avg_sq_row = state["exp_avg_sq_row"]
                    exp_avg_sq_col = state["exp_avg_sq_col"]
                    q_exp_avg_sq = None
                else:
                    exp_avg_sq_row = None
                    exp_avg_sq_col = None
                    q_exp_avg_sq = state["exp_avg_sq"]

                if group["use_first_moment"]:
                    q_exp_avg = state["exp_avg"]
                    exp_avg_q_enabled = self.override_q_enable[id(p)] if id(p) in self.override_q_enable else state["exp_avg_qstate"]["enable"]
                    exp_avg_q_overhead = state["exp_avg_qstate"]["overhead"]
                    exp_avg_qmap = state["exp_avg_qstate"]["qmap"]
                else:
                    q_exp_avg, exp_avg_q_enabled, exp_avg_q_overhead, exp_avg_qmap = None, None, None, None

                if group["is_adafactor_quantized"]:
                    exp_avg_sq_q_enabled = self.override_q_enable[id(p)] if id(p) in self.override_q_enable else state["exp_avg_sq_qstate"]["enable"]
                    exp_avg_sq_q_overhead = state["exp_avg_sq_qstate"]["overhead"]
                    exp_avg_sq_qmap = state["exp_avg_sq_qstate"]["qmap"]
                else:
                    exp_avg_sq_q_enabled, exp_avg_sq_q_overhead, exp_avg_sq_qmap = None, None, None
                
                if group["is_model_quantized"]:
                    model_q_enabled = self.override_q_enable[id(p)] if id(p) in self.override_q_enable else state["model_qstate"]["enable"]
                    model_q_overhead = state["model_qstate"]["overhead"]
                    model_qmap = state["model_qstate"]["qmap"]
                    error_feedback = state["error_feedback"] if group["use_error_feedback"] else None
                else:
                    model_q_enabled, model_q_overhead, model_qmap, error_feedback = None, None, None, None

                # update step
                state["step"] += 1

                model_qmetadata=self.get_qmetadata_by_state_name("model")
                exp_avg_qmetadata=self.get_qmetadata_by_state_name("exp_avg")
                exp_avg_sq_qmetadata=self.get_qmetadata_by_state_name("exp_avg_sq")

                if group["is_model_quantized"]:
                    model_q_overhead.update(model_qmetadata)
                    dequant_param = vectorwise_dequant(param, qmap=model_qmap, shape=p.shape, **model_q_overhead)
                    model_q_overhead.clear()
                    p_data_fp32 = dequant_param
                else:
                    p_data_fp32 = param
                    if param.dtype in {torch.float16, torch.bfloat16}:
                        p_data_fp32 = p_data_fp32.float()

                state["RMS"] = self._rms(p_data_fp32)
                lr = self._get_lr(group, state)
                beta2t = 1.0 - math.pow(state["step"], group["decay_rate"])

                if factored:
                    # print("getting into factored!")
                    _single_quantized_factored_update(
                        p,
                        param,
                        p_data_fp32,
                        grad,
                        group["use_first_moment"],
                        q_exp_avg,
                        q_exp_avg_sq,
                        exp_avg_sq_row,
                        exp_avg_sq_col,
                        model_q_enabled,
                        model_q_overhead,
                        model_qmap,
                        model_qmetadata,
                        error_feedback,
                        exp_avg_q_enabled,
                        exp_avg_q_overhead,
                        exp_avg_qmap,
                        exp_avg_qmetadata,
                        exp_avg_sq_q_enabled,
                        exp_avg_sq_q_overhead,
                        exp_avg_sq_qmap,
                        exp_avg_sq_qmetadata,
                        lr,
                        group["weight_decay"],
                        group["beta1"],
                        beta2t,
                        group["eps"],
                        group["clip_threshold"],
                        group["is_adafactor_quantized"],
                        group["is_model_quantized"],
                        group["use_error_feedback"],
                        state["step"].item()
                    )
                else:
                    # print("getting into not factored!")

                    if group["use_first_moment"]:
                        # dequantize
                        if q_exp_avg.numel() <= 1:
                            q_exp_avg.data = exp_avg = torch.zeros_like(p, memory_format=torch.preserve_format)
                        elif exp_avg_q_enabled:
                            exp_avg_q_overhead.update(exp_avg_qmetadata)
                            exp_avg = vectorwise_dequant(q_exp_avg, qmap=exp_avg_qmap, shape=p.shape, **exp_avg_q_overhead)
                            exp_avg_q_overhead.clear()
                        else:
                            exp_avg = q_exp_avg
                    
                    if group["is_adafactor_quantized"]:
                        if q_exp_avg_sq.numel() <= 1:
                            q_exp_avg_sq.data = exp_avg_sq = torch.zeros_like(p, memory_format=torch.preserve_format)
                        elif exp_avg_sq_q_enabled:
                            exp_avg_sq_q_overhead.update(exp_avg_sq_qmetadata)
                            exp_avg_sq = vectorwise_dequant(q_exp_avg_sq, qmap=exp_avg_sq_qmap, shape=p.shape, **exp_avg_sq_q_overhead)
                            exp_avg_sq_q_overhead.clear()
                        else:
                            exp_avg_sq = q_exp_avg_sq
                    else:
                        exp_avg_sq = q_exp_avg_sq

                    update = (grad**2) + group["eps"][0]

                    # print(exp_avg_sq)
                    exp_avg_sq.mul_(beta2t).add_(update, alpha=(1.0 - beta2t))
                    update = exp_avg_sq.rsqrt().mul_(grad)

                    update.div_((self._rms(update) / group["clip_threshold"]).clamp_(min=1.0))
                    update.mul_(lr)

                    if group["use_first_moment"]:
                        exp_avg.mul_(group["beta1"]).add_(update, alpha=(1 - group["beta1"]))
                        update = exp_avg

                    if group["weight_decay"] != 0:
                        p_data_fp32.add_(p_data_fp32, alpha=(-group["weight_decay"] * lr))

                    p_data_fp32.add_(-update)

                    if group["is_model_quantized"]:
                        if group["use_error_feedback"]:
                            p_data_fp32_pre_quant = p_data_fp32.clone().detach()
                            p_data_fp32.add_(error_feedback)

                        qx, gen = vectorwise_quant(p_data_fp32, qmap=model_qmap, shape=p.shape, **model_qmetadata)
                        param.copy_(qx)
                        model_q_overhead.update(gen)

                        model_q_overhead.update(model_qmetadata)
                        dequant_param = vectorwise_dequant(qx, qmap=model_qmap, shape=p.shape, **model_q_overhead)
                        model_q_overhead.clear()
                        p.copy_(dequant_param)
                        model_q_overhead.update(gen)

                        if group["use_error_feedback"]:
                            error_feedback.copy_(p_data_fp32_pre_quant - dequant_param)
                    else:
                        if param.dtype in {torch.float16, torch.bfloat16}:
                            param.copy_(p_data_fp32)

                    if group["use_first_moment"]:
                        # quantize
                        if exp_avg_q_enabled:
                            qx, gen = vectorwise_quant(exp_avg, qmap=exp_avg_qmap, shape=p.shape, **exp_avg_qmetadata)
                            q_exp_avg.data = qx
                            exp_avg_q_overhead.update(gen)
                        else:
                            pass

                    if group["is_adafactor_quantized"]:
                        if exp_avg_sq_q_enabled:
                            qx, gen = vectorwise_quant(exp_avg_sq, qmap=exp_avg_sq_qmap, shape=p.shape, **exp_avg_sq_qmetadata)
                            q_exp_avg_sq.data = qx
                            exp_avg_sq_q_overhead.update(gen)
                        else:
                            pass

        return loss


def _dispatch_sqrt(x: float):  # float annotation is needed because of torchscript type inference
    if not torch.jit.is_scripting() and isinstance(x, torch.Tensor):
        return x.sqrt()
    else:
        return math.sqrt(x)


def _approx_sq_grad(exp_avg_sq_row, exp_avg_sq_col):
        # copy from fairseq's adafactor implementation:
        # https://github.com/huggingface/transformers/blob/8395f14de6068012787d83989c3627c3df6a252b/src/transformers/optimization.py#L505
        r_factor = (exp_avg_sq_row / exp_avg_sq_row.mean(dim=-1, keepdim=True)).rsqrt_().unsqueeze(-1)
        c_factor = exp_avg_sq_col.unsqueeze(-2).rsqrt()
        return torch.mul(r_factor, c_factor)
    

def _rms(tensor):
    return tensor.norm(2) / (tensor.numel() ** 0.5)


def _get_lr(lr, relative_step, step, warmup_init, scale_parameter, eps, rms):
    rel_step_sz = lr
    if relative_step:
        min_step = 1e-6 * step if warmup_init else 1e-2
        rel_step_sz = min(min_step, 1.0 / math.sqrt(step))
    param_scale = 1.0
    if scale_parameter:
        param_scale = max(eps[1], rms)
    return param_scale * rel_step_sz


def _single_quantized_factored_update(
    param_with_grad,
    param,
    p_data_fp32,
    grad,
    use_first_moment,
    q_exp_avg,
    q_exp_avg_sq,
    exp_avg_sq_row,
    exp_avg_sq_col,
    model_q_enabled,
    model_q_overhead,
    model_qmap,
    model_qmetadata,
    error_feedback,
    exp_avg_q_enabled,
    exp_avg_q_overhead,
    exp_avg_qmap,
    exp_avg_qmetadata,
    exp_avg_sq_q_enabled,
    exp_avg_sq_q_overhead,
    exp_avg_sq_qmap,
    exp_avg_sq_qmetadata,
    lr,
    weight_decay,
    beta1,
    beta2t,
    eps,
    clip_threshold,
    is_adafactor_quantized,
    is_model_quantized,
    use_error_feedback,
    step,
):
    if use_first_moment:
        # dequantize
        if q_exp_avg.numel() <= 1:
            q_exp_avg.data = exp_avg = torch.zeros_like(param_with_grad, memory_format=torch.preserve_format)
        elif exp_avg_q_enabled:
            exp_avg_q_overhead = exp_avg_q_overhead
            exp_avg_q_overhead.update(exp_avg_qmetadata)
            exp_avg = vectorwise_dequant(q_exp_avg, qmap=exp_avg_qmap, shape=param_with_grad.shape, **exp_avg_q_overhead)
            exp_avg_q_overhead.clear()
        else:
            exp_avg = q_exp_avg

    update = (grad**2) + eps[0]

    exp_avg_sq_row.mul_(beta2t).add_(update.mean(dim=-1), alpha=(1.0 - beta2t))
    exp_avg_sq_col.mul_(beta2t).add_(update.mean(dim=-2), alpha=(1.0 - beta2t))

    update = _approx_sq_grad(exp_avg_sq_row, exp_avg_sq_col)
    update.mul_(grad)

    update.div_((_rms(update) / clip_threshold).clamp_(min=1.0)) 
    update.mul_(lr)

    if use_first_moment:
        exp_avg.mul_(beta1).add_(update, alpha=(1 - beta1))
        update = exp_avg

    if weight_decay != 0:
        p_data_fp32.add_(p_data_fp32, alpha=(-weight_decay * lr))

    p_data_fp32.add_(-update)

    if is_model_quantized:
        if use_error_feedback:
            p_data_fp32_pre_quant = p_data_fp32.clone().detach()
            p_data_fp32.add_(error_feedback)
            # print(p_data_fp32_pre_quant==p_data_fp32)

        qx, gen = vectorwise_quant(p_data_fp32, qmap=model_qmap, shape=param_with_grad.shape, **model_qmetadata)
        param.copy_(qx)
        model_q_overhead.update(gen)

        model_q_overhead.update(model_qmetadata)
        dequant_param = vectorwise_dequant(qx, qmap=model_qmap, shape=param_with_grad.shape, **model_q_overhead)
        model_q_overhead.clear()
        param_with_grad.copy_(dequant_param)
        model_q_overhead.update(gen)

        if use_error_feedback:
            error_feedback.copy_(p_data_fp32_pre_quant - dequant_param)
            # print(error_feedback)
    else:
        if param.dtype in {torch.float16, torch.bfloat16}:
            param.copy_(p_data_fp32)

    if use_first_moment:
        # quantize
        if exp_avg_q_enabled:
            qx, gen = vectorwise_quant(exp_avg, qmap=exp_avg_qmap, shape=param_with_grad.shape, **exp_avg_qmetadata)
            q_exp_avg.data = qx
            exp_avg_q_overhead.update(gen)
        else:
            pass