import torch
import math

from torch import Tensor
from typing import List, Optional, Tuple
import time
import os

import matplotlib.pyplot as plt
import seaborn as sns

from transformers.utils.versions import require_version

from .optimizer import LowBitOptimizer
from ..functional import vectorwise_dequant, vectorwise_quant, create_normal_map

__all__ = ["Adafactor"]


class Adafactor(LowBitOptimizer):
    def __init__(
        self,
        params,
        lr=None,
        eps=(1e-30, 1e-3),
        clip_threshold=1.0,
        decay_rate=-0.8,
        beta1=None,
        weight_decay=0.0,
        scale_parameter=True,
        relative_step=True,
        warmup_init=False,
        qconfig=None,
        is_adafactor_quantized=True, # apply quantization to low-rank approximation of second moments
        is_model_quantized=True,
        use_error_feedback=False,
        use_adaptive_qmap=True,
        use_sparse_dense=True,
        embedding_not_quantized=True,
        use_intermediate_as_sparse=True, # normally, intermediate values are dense but we use these values as sparse for full precision.
        steps_in_epoch=None,
        *,
        fused: Optional[bool] = False,
    ):
        use_first_moment = beta1 is not None

        require_version("torch>=1.5.0")  # add_ with alpha
        if lr is not None and relative_step:
            raise ValueError("Cannot combine manual `lr` and `relative_step=True` options")
        if warmup_init and not relative_step:
            raise ValueError("`warmup_init=True` requires `relative_step=True`")
        if not 0.0 <= weight_decay:
            raise ValueError("Invalid weight_decay value: {}".format(weight_decay))

        defaults = dict(
            lr=lr,
            eps=eps,
            clip_threshold=clip_threshold,
            decay_rate=decay_rate,
            beta1=beta1,
            weight_decay=weight_decay,
            scale_parameter=scale_parameter,
            relative_step=relative_step,
            warmup_init=warmup_init,
            use_first_moment=use_first_moment,
            is_adafactor_quantized=is_adafactor_quantized,
            is_model_quantized=is_model_quantized,
            use_error_feedback=use_error_feedback,
            use_adaptive_qmap=use_adaptive_qmap,
            use_sparse_dense=use_sparse_dense,
            fused=fused,
        )
        super().__init__(params, defaults, qconfig)
        self.steps_in_epoch = steps_in_epoch
        self.embedding_not_quantized = embedding_not_quantized
        self.use_intermediate_as_sparse = use_intermediate_as_sparse

    def __setstate__(self, state):
        super().__setstate__(state)
        for group in self.param_groups:
            group.setdefault("fused", None)
        state_values = list(self.state.values())
        step_is_tensor = (len(state_values) != 0) and torch.is_tensor(
            state_values[0]["step"]
        )
        if not step_is_tensor:
            for s in state_values:
                s["step"] = torch.tensor(float(s["step"]))

    def get_subqconfig(self, optimizer_state_name):
        if optimizer_state_name in ['exp_avg', 'model']:
            return self.qconfig.QUANT.M
        elif optimizer_state_name == 'exp_avg_sq': # we can make this better if we separate exp_avg_sq and exp_avg_sq_factored
            return self.qconfig.QUANT.SQM    # SQM
        else:
            raise ValueError(
                f""
            )

    def update_qstate_qmap(self, p, state_name, mu, sigma, min_weight=None, max_weight=None): # update qstate qmap for adaptive normal mapping
        state = self.state[p]
        field = f"{state_name}_qstate"
        
        md = self.get_qmetadata_by_state_name(state_name)
        qmap_key = (md['quant_type'], md['b'], md['signed'])
        self.qmaps[qmap_key] = create_normal_map(offset=0.995, use_adaptive_map=True, total_bits=md['b'], mu=mu, sigma=sigma, min_weight=min_weight, max_weight=max_weight)
        # self.qmaps[qmap_key] = create_pow_map(md['b'], md['signed'], 1)
        # self.qmaps[qmap_key] = create_dynamic_map(md['signed'], md['b'] - 1, md['b'] if md['signed'] else md['b'] - 1)
        self.qmaps[qmap_key] = self.qmaps[qmap_key].to(p.device)
        state[field]["qmap"] = self.qmaps[qmap_key]

    def find_range_except_outliers(self, p_tensor, r=0.5):
        """
        Find the r% largest and r% smallest values in a PyTorch tensor.

        Args:
        p_tensor (torch.Tensor): The input tensor.

        Returns:
        tuple: A tuple containing the r% largest and r% smallest values.
        """
        # Flatten the tensor to 1D
        p_tensor_flat = p_tensor.flatten()
        if p_tensor_flat.dtype in {torch.float16, torch.bfloat16}: # Does it affect to memory & performance?
            p_tensor_flat = p_tensor_flat.float()

        # Calculate the indices for r% largest and smallest
        k_lower = int(r * 0.01 * len(p_tensor_flat))
        k_upper = int((1 - r * 0.01) * len(p_tensor_flat))

        # Find the r% largest value
        upper_value = torch.kthvalue(p_tensor_flat, k_upper).values

        # Find the r% smallest value
        lower_value = torch.kthvalue(p_tensor_flat, k_lower).values

        return lower_value, upper_value

    def find_range_intermediate(self, p_tensor, r=0.25):
        # Flatten the tensor to 1D
        p_tensor_flat = p_tensor.flatten()
        if p_tensor_flat.dtype in {torch.float16, torch.bfloat16}: # Does it affect to memory & performance?
            p_tensor_flat = p_tensor_flat.float()

        k_lower = int((0.5 - r * 0.01) * len(p_tensor_flat))
        k_upper = int((0.5 +  r * 0.01) * len(p_tensor_flat))

        inter_upper_value = torch.kthvalue(p_tensor_flat, k_upper).values

        inter_lower_value = torch.kthvalue(p_tensor_flat, k_lower).values

        return (inter_lower_value, inter_upper_value)

    def sparse_dense_decomposition(self, p_tensor, lower_value, upper_value, intermediate_values=None):
        """
        Decompose a tensor into sparse and dense components while preserving gradient tracking.

        Args:
        p_tensor (torch.Tensor): The input tensor.
        lower_value (float): The threshold below which values are included in the sparse tensor.
        upper_value (float): The threshold above which values are included in the sparse tensor.

        Returns:
        tuple: A tuple containing the sparse tensor and the dense tensor.
        """
        # Create a mask for values that are either below the lower_value or above the upper_value
        if not intermediate_values:
            sparse_mask = (p_tensor < lower_value) | (p_tensor > upper_value)
        else:
            sparse_mask = (p_tensor < lower_value) | ((p_tensor > intermediate_values[0]) & (p_tensor < intermediate_values[1])) | (p_tensor > upper_value)

        # Create the sparse tensor by cloning and then zeroing out elements not in the sparse mask
        s_tensor = p_tensor.clone()
        s_tensor[~sparse_mask] = 0.0

        # Create the dense tensor by cloning and then zeroing out elements in the sparse mask
        d_tensor = p_tensor.clone()
        d_tensor[sparse_mask] = 0.0

        return s_tensor, d_tensor

    def plot_weight_distribution(self, p_data_tensor, quantized_values, step, dir_name='plots'):
        current_dir = os.getcwd()
        output_dir = os.path.join(current_dir, dir_name)
        if not os.path.exists(output_dir):
            os.makedirs(output_dir)

        p_data_np = p_data_tensor.to(torch.float32).detach().cpu().numpy()
        p_data_flattened = p_data_np.flatten()
        quantized_values = quantized_values.to(torch.float32).detach().cpu().numpy()

        # Set Seaborn style for better aesthetics
        # sns.set()

        # Create a histogram
        # g = sns.displot(p_data_flattened, kind='hist', log_scale=(False, True), color='darkblue')
        # g.set_titles('Distribution of Weights')
        # g.set_axis_labels('Weight Value', 'Density')
        # g.savefig(os.path.join(output_dir, 'weight_distribution.png'))
        plt.hist(p_data_flattened, bins=500, color='green', log=True, density=True) # density=False, log=True
        plt.title('Distribution of Weights')
        plt.xlabel('Weight Value')
        plt.ylabel('Density')

        # print(quantized_values)
        # Set the y-coordinate for the red points
        ymin, ymax = plt.ylim()
        y_coord_for_red_points = ymax * 1e-6  # 1% of the maximum y-value
        # Add quantized values as red points on the x-axis
        for q_value in quantized_values:
            plt.scatter([q_value], [y_coord_for_red_points], color='red', s=3)  # 's' is the size of the point
            # plt.axvline(x=q_value, color='red', linestyle='--', linewidth=0.5, ymin=ymin, ymax=y_coord_for_red_points)
            plt.vlines(q_value, ymin, y_coord_for_red_points, colors='red', linestyles='--', linewidth=0.5)

        # Display the plot
        plt.savefig(os.path.join(output_dir, 'weight_distribution_webglm_nf4_q_proj_embedding_no_quant_intermediate_sparse.png'))
        plt.close()

    @staticmethod
    def _get_lr(param_group, param_state):
        rel_step_sz = param_group["lr"]
        if param_group["relative_step"]:
            min_step = 1e-6 * param_state["step"] if param_group["warmup_init"] else 1e-2
            rel_step_sz = min(min_step, 1.0 / math.sqrt(param_state["step"]))
        param_scale = 1.0
        if param_group["scale_parameter"]:
            param_scale = max(param_group["eps"][1], param_state["RMS"])
        return param_scale * rel_step_sz

    @staticmethod
    def _get_options(param_group, param_shape):
        factored = len(param_shape) >= 2
        # use_first_moment = param_group["beta1"] is not None
        return factored
    
    @staticmethod
    def _rms(tensor):
        return tensor.norm(2) / (tensor.numel() ** 0.5)

    @staticmethod
    def _approx_sq_grad(exp_avg_sq_row, exp_avg_sq_col):
        # copy from fairseq's adafactor implementation:
        # https://github.com/huggingface/transformers/blob/8395f14de6068012787d83989c3627c3df6a252b/src/transformers/optimization.py#L505
        r_factor = (exp_avg_sq_row / exp_avg_sq_row.mean(dim=-1, keepdim=True)).rsqrt_().unsqueeze(-1)
        c_factor = exp_avg_sq_col.unsqueeze(-2).rsqrt()
        return torch.mul(r_factor, c_factor)

    @torch.no_grad()
    def step(self, closure=None):
        """Performs a single optimization step.

        Args:
            closure (Callable, optional): A closure that reevaluates the model
                and returns the loss.
        """

        loss = None
        if closure is not None:
            with torch.enable_grad():
                loss = closure()

        for group in self.param_groups:
            for idx, p in enumerate(group["params"]):
                if p.grad is None:
                    continue
                grad = p.grad
                if grad.is_sparse:
                    raise RuntimeError("Adafactor does not support sparse gradients")
                # Do we need to include the code below?
                if grad.dtype in {torch.float16, torch.bfloat16}:
                    grad = grad.float()

                state = self.state[p]
                # grad_shape = grad.shape

                factored = self._get_options(group, grad.shape)
                is_embedding_and_not_quantized = (idx == 0) and factored and self.embedding_not_quantized
                # State initialization
                if len(state) == 0:
                    # note(crcrpar): Deliberately host `step` on CPU if both capturable and fused are off.
                    # This is because kernel launches are costly on CUDA and XLA.
                    state["step"] = torch.tensor(0.0)
                    if group["use_first_moment"]:
                        # Exponential moving average of gradient values
                        state["exp_avg"] = torch.zeros((), dtype=torch.float, device=p.device)
                        # state["exp_avg"] = torch.zeros_like(grad)
                        self.init_qstate(p, "exp_avg")
                    # Exponential moving average of squared gradient values
                    if factored:
                        state["exp_avg_sq_row"] = torch.zeros(grad.shape[:-1], device=p.device)
                        state["exp_avg_sq_col"] = torch.zeros(grad.shape[:-2] + grad.shape[-1:], device=p.device)
                        # state["exp_avg_sq_row"] = torch.zeros(grad.shape[:-1]).to(grad)
                        # state["exp_avg_sq_col"] = torch.zeros(grad.shape[:-2] + grad.shape[-1:]).to(grad)
                    else:
                        if group["is_adafactor_quantized"]:
                            state["exp_avg_sq"] = torch.zeros((), dtype=torch.float, device=p.device)
                        else:
                            state["exp_avg_sq"] = torch.zeros_like(grad)
                    
                    if group["is_adafactor_quantized"]:
                        self.init_qstate(p, "exp_avg_sq")

                    if group["is_model_quantized"] and not is_embedding_and_not_quantized:
                        self.init_qstate(p, "model")

                        model_qmetadata = self.get_qmetadata_by_state_name("model")
                        if group["use_adaptive_qmap"]:
                            mu, sigma = torch.mean(p).item(), torch.std(p).item()
                            lower_value, upper_value = self.find_range_except_outliers(p)
                            state["lower_threshold"] = lower_value
                            state["upper_threshold"] = upper_value
                            if group["use_sparse_dense"] and self.use_intermediate_as_sparse:
                                intermediate_values = self.find_range_intermediate(p)
                                state["intermediate_lower_threshold"] = intermediate_values[0]
                                state["intermediate_upper_threshold"] = intermediate_values[1]
                            if group["use_sparse_dense"]: # this can be more memory-efficient later
                                if self.use_intermediate_as_sparse:
                                    s_tensor, d_tensor = self.sparse_dense_decomposition(p, lower_value, upper_value, intermediate_values)
                                else:
                                    s_tensor, d_tensor = self.sparse_dense_decomposition(p, lower_value, upper_value)
                                p.copy_(d_tensor)
                                state["sparse_model_state"] = s_tensor
                            else:
                                p.clamp_(lower_value, upper_value)
                            self.update_qstate_qmap(p, "model", mu, sigma)
                            # self.update_qstate_qmap(p, "model", mu, sigma, lower_value.item(), upper_value.item())
                            # self.update_qstate_qmap(p, "model", mu, sigma, min_weight, max_weight)
                        
                        qx, gen = vectorwise_quant(p, qmap=state["model_qstate"]["qmap"], shape=p.shape, **model_qmetadata)
                        state["model_qstate"]["overhead"].update(gen)
                        state["model_state"] = qx
                        # params.append(qx)
                        if group["use_error_feedback"]:
                            state["error_feedback"] = torch.zeros_like(grad)

                    state["RMS"] = 0
                # # do we need this part?
                else:
                    if group["use_first_moment"]:
                        state["exp_avg"] = state["exp_avg"].to(grad)
                    if factored:
                        state["exp_avg_sq_row"] = state["exp_avg_sq_row"].to(grad)
                        state["exp_avg_sq_col"] = state["exp_avg_sq_col"].to(grad)
                    else:
                        if not group["is_adafactor_quantized"]:
                            state["exp_avg_sq"] = state["exp_avg_sq"].to(grad)
                    if group["use_error_feedback"]:
                        state["error_feedback"] = state["error_feedback"].to(grad)

                if group["is_model_quantized"] and not is_embedding_and_not_quantized:
                    param = state["model_state"]
                    sparse_param = state["sparse_model_state"] if group["use_sparse_dense"] else None
                    lower_threshold = state["lower_threshold"] if group["use_adaptive_qmap"] else None
                    upper_threshold = state["upper_threshold"] if group["use_adaptive_qmap"] else None
                    intermediate_lower_threshold = state["intermediate_lower_threshold"] if self.use_intermediate_as_sparse else None
                    intermediate_upper_threshold = state["intermediate_upper_threshold"] if self.use_intermediate_as_sparse else None 
                else:
                    param = p
                    sparse_param, lower_threshold, upper_threshold, intermediate_lower_threshold, intermediate_upper_threshold = None, None, None, None, None

                if factored:
                    exp_avg_sq_row = state["exp_avg_sq_row"]
                    exp_avg_sq_col = state["exp_avg_sq_col"]
                    q_exp_avg_sq = None
                else:
                    exp_avg_sq_row = None
                    exp_avg_sq_col = None
                    q_exp_avg_sq = state["exp_avg_sq"]

                if group["use_first_moment"]:
                    q_exp_avg = state["exp_avg"]
                    exp_avg_q_enabled = self.override_q_enable[id(p)] if id(p) in self.override_q_enable else state["exp_avg_qstate"]["enable"]
                    exp_avg_q_overhead = state["exp_avg_qstate"]["overhead"]
                    exp_avg_qmap = state["exp_avg_qstate"]["qmap"]
                else:
                    q_exp_avg, exp_avg_q_enabled, exp_avg_q_overhead, exp_avg_qmap = None, None, None, None

                if group["is_adafactor_quantized"]:
                    exp_avg_sq_q_enabled = self.override_q_enable[id(p)] if id(p) in self.override_q_enable else state["exp_avg_sq_qstate"]["enable"]
                    exp_avg_sq_q_overhead = state["exp_avg_sq_qstate"]["overhead"]
                    exp_avg_sq_qmap = state["exp_avg_sq_qstate"]["qmap"]
                else:
                    exp_avg_sq_q_enabled, exp_avg_sq_q_overhead, exp_avg_sq_qmap = None, None, None
                
                if group["is_model_quantized"] and not is_embedding_and_not_quantized:
                    model_q_enabled = self.override_q_enable[id(p)] if id(p) in self.override_q_enable else state["model_qstate"]["enable"]
                    model_q_overhead = state["model_qstate"]["overhead"]
                    model_qmap = state["model_qstate"]["qmap"]
                    error_feedback = state["error_feedback"] if group["use_error_feedback"] else None
                else:
                    model_q_enabled, model_q_overhead, model_qmap, error_feedback = None, None, None, None

                # update step
                state["step"] += 1

                model_qmetadata=self.get_qmetadata_by_state_name("model")
                exp_avg_qmetadata=self.get_qmetadata_by_state_name("exp_avg")
                exp_avg_sq_qmetadata=self.get_qmetadata_by_state_name("exp_avg_sq")

                if group["is_model_quantized"] and not is_embedding_and_not_quantized:
                    model_q_overhead.update(model_qmetadata)
                    dequant_param = vectorwise_dequant(param, qmap=model_qmap, shape=p.shape, **model_q_overhead)
                    if group["use_sparse_dense"]:
                        dequant_param.add_(sparse_param)
                    model_q_overhead.clear()
                    p_data_fp32 = dequant_param
                    if dequant_param.dtype in {torch.float16, torch.bfloat16}: # Does it affect to memory & performance?
                        p_data_fp32 = p_data_fp32.float()
                else:
                    p_data_fp32 = param
                    if param.dtype in {torch.float16, torch.bfloat16}:
                        p_data_fp32 = p_data_fp32.float()

                state["RMS"] = self._rms(p_data_fp32)
                lr = self._get_lr(group, state)
                beta2t = 1.0 - math.pow(state["step"], group["decay_rate"])

                if factored:
                    # print("getting into factored!")
                    self._single_quantized_factored_update(
                        state,
                        p,
                        param,
                        sparse_param,
                        p_data_fp32,
                        grad,
                        group["use_first_moment"],
                        q_exp_avg,
                        q_exp_avg_sq,
                        exp_avg_sq_row,
                        exp_avg_sq_col,
                        model_q_enabled,
                        model_q_overhead,
                        model_qmap,
                        model_qmetadata,
                        error_feedback,
                        exp_avg_q_enabled,
                        exp_avg_q_overhead,
                        exp_avg_qmap,
                        exp_avg_qmetadata,
                        exp_avg_sq_q_enabled,
                        exp_avg_sq_q_overhead,
                        exp_avg_sq_qmap,
                        exp_avg_sq_qmetadata,
                        lower_threshold,
                        upper_threshold,
                        intermediate_lower_threshold,
                        intermediate_upper_threshold,
                        lr,
                        group["weight_decay"],
                        group["beta1"],
                        beta2t,
                        group["eps"],
                        group["clip_threshold"],
                        group["is_adafactor_quantized"],
                        group["is_model_quantized"],
                        group["use_error_feedback"],
                        group["use_adaptive_qmap"],
                        group["use_sparse_dense"],
                        is_embedding_and_not_quantized,
                        state["step"].item(),
                        idx
                    )
                else:
                    # print("getting into not factored!")

                    if group["use_first_moment"]:
                        # dequantize
                        if q_exp_avg.numel() <= 1:
                            q_exp_avg.data = exp_avg = torch.zeros_like(p, memory_format=torch.preserve_format)
                        elif exp_avg_q_enabled:
                            exp_avg_q_overhead.update(exp_avg_qmetadata)
                            exp_avg = vectorwise_dequant(q_exp_avg, qmap=exp_avg_qmap, shape=p.shape, **exp_avg_q_overhead)
                            exp_avg_q_overhead.clear()
                        else:
                            exp_avg = q_exp_avg
                    
                    if group["is_adafactor_quantized"]:
                        if q_exp_avg_sq.numel() <= 1:
                            q_exp_avg_sq.data = exp_avg_sq = torch.zeros_like(p, memory_format=torch.preserve_format)
                        elif exp_avg_sq_q_enabled:
                            exp_avg_sq_q_overhead.update(exp_avg_sq_qmetadata)
                            exp_avg_sq = vectorwise_dequant(q_exp_avg_sq, qmap=exp_avg_sq_qmap, shape=p.shape, **exp_avg_sq_q_overhead)
                            exp_avg_sq_q_overhead.clear()
                        else:
                            exp_avg_sq = q_exp_avg_sq
                    else:
                        exp_avg_sq = q_exp_avg_sq

                    update = (grad**2) + group["eps"][0]

                    # print(exp_avg_sq)
                    exp_avg_sq.mul_(beta2t).add_(update, alpha=(1.0 - beta2t))
                    update = exp_avg_sq.rsqrt().mul_(grad)

                    update.div_((self._rms(update) / group["clip_threshold"]).clamp_(min=1.0))
                    update.mul_(lr)

                    if group["use_first_moment"]:
                        exp_avg.mul_(group["beta1"]).add_(update, alpha=(1 - group["beta1"]))
                        update = exp_avg

                    if group["weight_decay"] != 0:
                        p_data_fp32.add_(p_data_fp32, alpha=(-group["weight_decay"] * lr))

                    p_data_fp32.add_(-update)
                    # self.plot_weight_distribution(p_data_tensor=p_data_fp32)
                    # mu, sigma = torch.mean(p_data_fp32).item(), torch.std(p_data_fp32).item()
                    # print('mean is {} & standard deviation is {}'.format(mu, sigma))

                    if group["is_model_quantized"]:
                        if group["use_error_feedback"]:
                            p_data_fp32_pre_quant = p_data_fp32.clone().detach()
                            p_data_fp32.add_(error_feedback)
                        if group["use_adaptive_qmap"]:
                            # print(p_data_fp32.shape)
                            # self.plot_weight_distribution(p_data_tensor=p_data_fp32)
                            mu, sigma = torch.mean(p_data_fp32).item(), torch.std(p_data_fp32).item()
                            if state["step"] % self.steps_in_epoch == 1:
                                lower_value, upper_value = self.find_range_except_outliers(p_data_fp32)
                                lower_threshold.copy_(lower_value)
                                upper_threshold.copy_(upper_value)
                                if group["use_sparse_dense"] and self.use_intermediate_as_sparse:
                                    intermediate_values = self.find_range_intermediate(p_data_fp32)
                                    intermediate_lower_threshold.copy_(intermediate_values[0])
                                    intermediate_upper_threshold.copy_(intermediate_values[1])
                            if group["use_sparse_dense"]: # this can be more memory-efficient later
                                if self.use_intermediate_as_sparse:
                                    s_tensor, d_tensor = self.sparse_dense_decomposition(p_data_fp32, lower_threshold, upper_threshold, (intermediate_lower_threshold, intermediate_upper_threshold))
                                else:
                                    s_tensor, d_tensor = self.sparse_dense_decomposition(p_data_fp32, lower_threshold, upper_threshold)
                                p_data_fp32.copy_(d_tensor)
                                sparse_param.copy_(s_tensor)
                            else:
                                p_data_fp32.clamp_(lower_threshold, upper_threshold)
                            # min_weight, max_weight = torch.min(p_data_fp32).item(), torch.max(p_data_fp32).item()
                            # mu, sigma = 0.0, torch.std(p_data_fp32).item()
                            # print('mean is {} & standard deviation is {}'.format(mu, sigma))
                            self.update_qstate_qmap(p, "model", mu, sigma)
                            # self.update_qstate_qmap(p, "model", mu, sigma, lower_value.item(), upper_value.item())
                            # self.update_qstate_qmap(p, "model", mu, sigma, min_weight, max_weight)
                            model_qmap = state["model_qstate"]["qmap"]

                        qx, gen = vectorwise_quant(p_data_fp32, qmap=model_qmap, shape=p.shape, **model_qmetadata)
                        param.copy_(qx)
                        model_q_overhead.update(gen)

                        model_q_overhead.update(model_qmetadata)
                        dequant_param = vectorwise_dequant(qx, qmap=model_qmap, shape=p.shape, **model_q_overhead)
                        if group["use_sparse_dense"]:
                            dequant_param.add_(sparse_param)
                        model_q_overhead.clear()
                        p.copy_(dequant_param)
                        model_q_overhead.update(gen)

                        if group["use_error_feedback"]:
                            error_feedback.copy_(p_data_fp32_pre_quant - dequant_param)
                    else:
                        if param.dtype in {torch.float16, torch.bfloat16}:
                            param.copy_(p_data_fp32)

                    if group["use_first_moment"]:
                        # quantize
                        if exp_avg_q_enabled:
                            qx, gen = vectorwise_quant(exp_avg, qmap=exp_avg_qmap, shape=p.shape, **exp_avg_qmetadata)
                            q_exp_avg.data = qx
                            exp_avg_q_overhead.update(gen)
                        else:
                            pass

                    if group["is_adafactor_quantized"]:
                        if exp_avg_sq_q_enabled:
                            qx, gen = vectorwise_quant(exp_avg_sq, qmap=exp_avg_sq_qmap, shape=p.shape, **exp_avg_sq_qmetadata)
                            q_exp_avg_sq.data = qx
                            exp_avg_sq_q_overhead.update(gen)
                        else:
                            pass

        return loss
    
    def _single_quantized_factored_update(
        self,
        state,
        param_with_grad,
        param,
        sparse_param,
        p_data_fp32,
        grad,
        use_first_moment,
        q_exp_avg,
        q_exp_avg_sq,
        exp_avg_sq_row,
        exp_avg_sq_col,
        model_q_enabled,
        model_q_overhead,
        model_qmap,
        model_qmetadata,
        error_feedback,
        exp_avg_q_enabled,
        exp_avg_q_overhead,
        exp_avg_qmap,
        exp_avg_qmetadata,
        exp_avg_sq_q_enabled,
        exp_avg_sq_q_overhead,
        exp_avg_sq_qmap,
        exp_avg_sq_qmetadata,
        lower_threshold,
        upper_threshold,
        intermediate_lower_threshold,
        intermediate_upper_threshold,
        lr,
        weight_decay,
        beta1,
        beta2t,
        eps,
        clip_threshold,
        is_adafactor_quantized,
        is_model_quantized,
        use_error_feedback,
        use_adaptive_qmap,
        use_sparse_dense,
        is_embedding_and_not_quantized,
        step,
        idx
    ):
        if use_first_moment:
            # dequantize
            if q_exp_avg.numel() <= 1:
                q_exp_avg.data = exp_avg = torch.zeros_like(param_with_grad, memory_format=torch.preserve_format)
            elif exp_avg_q_enabled:
                exp_avg_q_overhead = exp_avg_q_overhead
                exp_avg_q_overhead.update(exp_avg_qmetadata)
                exp_avg = vectorwise_dequant(q_exp_avg, qmap=exp_avg_qmap, shape=param_with_grad.shape, **exp_avg_q_overhead)
                exp_avg_q_overhead.clear()
            else:
                exp_avg = q_exp_avg

        update = (grad**2) + eps[0]

        exp_avg_sq_row.mul_(beta2t).add_(update.mean(dim=-1), alpha=(1.0 - beta2t))
        exp_avg_sq_col.mul_(beta2t).add_(update.mean(dim=-2), alpha=(1.0 - beta2t))

        update = self._approx_sq_grad(exp_avg_sq_row, exp_avg_sq_col)
        update.mul_(grad)

        update.div_((self._rms(update) / clip_threshold).clamp_(min=1.0)) 
        update.mul_(lr)

        if use_first_moment:
            exp_avg.mul_(beta1).add_(update, alpha=(1 - beta1))
            update = exp_avg

        if weight_decay != 0:
            p_data_fp32.add_(p_data_fp32, alpha=(-weight_decay * lr))

        p_data_fp32.add_(-update)
        # self.plot_weight_distribution(p_data_tensor=p_data_fp32)
        # mu, sigma = torch.mean(p_data_fp32).item(), torch.std(p_data_fp32).item()
        # print('mean is {} & standard deviation is {}'.format(mu, sigma))

        if is_model_quantized and not is_embedding_and_not_quantized:
            if use_error_feedback:
                p_data_fp32_pre_quant = p_data_fp32.clone().detach()
                p_data_fp32.add_(error_feedback)
            if use_adaptive_qmap:
                # print(p_data_fp32.shape)
                # self.plot_weight_distribution(p_data_tensor=p_data_fp32)
                mu, sigma = torch.mean(p_data_fp32).item(), torch.std(p_data_fp32).item()
                if step % self.steps_in_epoch == 1:
                    lower_value, upper_value = self.find_range_except_outliers(p_data_fp32)
                    lower_threshold.copy_(lower_value)
                    upper_threshold.copy_(upper_value)
                    if use_sparse_dense and self.use_intermediate_as_sparse:
                        intermediate_values = self.find_range_intermediate(p_data_fp32)
                        intermediate_lower_threshold.copy_(intermediate_values[0])
                        intermediate_upper_threshold.copy_(intermediate_values[1])
                if use_sparse_dense: # this can be more memory-efficient later
                    if self.use_intermediate_as_sparse:
                        s_tensor, d_tensor = self.sparse_dense_decomposition(p_data_fp32, lower_threshold, upper_threshold, (intermediate_lower_threshold, intermediate_upper_threshold))
                    else:
                        s_tensor, d_tensor = self.sparse_dense_decomposition(p_data_fp32, lower_threshold, upper_threshold)
                    p_data_fp32.copy_(d_tensor)
                    sparse_param.copy_(s_tensor)
                else:
                    p_data_fp32.clamp_(lower_threshold, upper_threshold)
                # min_weight, max_weight = torch.min(p_data_fp32).item(), torch.max(p_data_fp32).item()
                # print('mean is {} & standard deviation is {}'.format(mu, sigma))
                # mu, sigma = 0.0, torch.std(p_data_fp32).item()
                self.update_qstate_qmap(param_with_grad, "model", mu, sigma)
                # self.update_qstate_qmap(param_with_grad, "model", mu, sigma, lower_value.item(), upper_value.item())
                # self.update_qstate_qmap(param_with_grad, "model", mu, sigma, min_weight, max_weight)
                model_qmap = state["model_qstate"]["qmap"]

            qx, gen = vectorwise_quant(p_data_fp32, qmap=model_qmap, shape=param_with_grad.shape, **model_qmetadata)
            param.copy_(qx)
            model_q_overhead.update(gen)

            model_q_overhead.update(model_qmetadata)
            dequant_param = vectorwise_dequant(qx, qmap=model_qmap, shape=param_with_grad.shape, **model_q_overhead)
            if use_sparse_dense:
                dequant_param.add_(sparse_param)
            model_q_overhead.clear()
            param_with_grad.copy_(dequant_param)
            model_q_overhead.update(gen)

            if use_error_feedback:
                error_feedback.copy_(p_data_fp32_pre_quant - dequant_param)
                # print(error_feedback)
        else:
            if param.dtype in {torch.float16, torch.bfloat16}:
                param.copy_(p_data_fp32)

        if idx == 1:
            if step % self.steps_in_epoch == 1:
                # print(torch.argmax(gen['max1'])) -> 401
                quantized_values = gen['max1'][401] * model_qmap
                print('quantized values are {}'.format(quantized_values.to(torch.float32).detach().cpu().numpy()))
                if use_sparse_dense:
                    p_data_fp32.add_(sparse_param)
                self.plot_weight_distribution(p_data_tensor=p_data_fp32, quantized_values=quantized_values, step=step)

        if use_first_moment:
            # quantize
            if exp_avg_q_enabled:
                qx, gen = vectorwise_quant(exp_avg, qmap=exp_avg_qmap, shape=param_with_grad.shape, **exp_avg_qmetadata)
                q_exp_avg.data = qx
                exp_avg_q_overhead.update(gen)
            else:
                pass