import torch
from torch import uint8, int32, float16, nn, Tensor
import copy
from enum import Enum
from typing import Union      

from .utils import is_divisible, encode_safetensor_type, decode_safetensor_type, quantize_full_to_int3, compensator_dequantize, quantize_full_to_int4
from .optimize import optimize_weights_proximal
from .bitpack import BitPack

_META_TYPE = {
    "scale": torch.Tensor,
    "zero": torch.Tensor,
    "zero_scale": torch.Tensor,
    "compute_dtype": torch.dtype,
    "quant_zero": bool,
    "quant_scale": bool,
    "view_as_float": bool,
    "unpack_view_dtype": torch.dtype,
    "packing": str,
    "axis": int,
    "group_size": int,
    "nbits": int,
    "shape": torch.Size,
    "channel_wise": bool,
    "optimize": bool,
    "round_zero": bool,

    "ranks": dict,
    "compensator_dtype": str,
    "compensator_quant_gs": int,
}



# Main BTMoE Quantizer
class Quantizer:
    SUPPORTED_BITS = [8,4,3,2]
    optimize_weights = optimize_weights_proximal

    bit_to_packing = {
        8: "8bit_u8",
        4: "4bit_u8",
        3: "3bit_32",
        2: "2bit_u8",
        "4bit_32": "4bit_32",
    }

    pack = {
        "8bit_u8": BitPack.pack_8bit_u8,
        "4bit_u8": BitPack.pack_4bit_u8,
        "3bit_32": BitPack.pack_3bit_32,
        "2bit_u8": BitPack.pack_2bit_u8,
        "4bit_32": BitPack.pack_4bit_32
    }

    unpack = {
        "8bit_u8": BitPack.unpack_8bit_u8,
        "4bit_u8": BitPack.unpack_4bit_u8,
        "3bit_32": BitPack.unpack_3bit_32,
        "2bit_u8": BitPack.unpack_2bit_u8,
        "4bit_32": BitPack.unpack_4bit_32
    }

    unpack_view_dtype = {
        "8bit_u8": uint8,
        "4bit_u8": uint8,
        "3bit_32": int32,
        "2bit_u8": uint8,
        "4bit_32": int32
    }

    @classmethod
    def quantize(
        cls,
        tensor: Tensor,
        nbits: float = 4,
        channel_wise: bool = True,
        group_size: int = 64,
        optimize: bool = False,
        round_zero: bool = False,
        axis: int = 0,
        bitpack: bool = True,
        compute_dtype: Union[torch.dtype, None] = None,
        view_as_float: bool = False,
        device: str = "cuda",
    ) -> tuple:
        assert nbits in Quantizer.SUPPORTED_BITS, (
            "nbits=" + str(nbits) + " not supported."
        )
        assert axis in [0, 1], "axis should be either 0 or 1"
        if group_size is not None:
            assert is_divisible(tensor.numel(), group_size), (
                "group_size should be divisble by the total tensor dimensions. shape: "
                + str(tensor.shape)
                + ", group_size: "
                + str(group_size)
            )

        W = tensor.float()
        shape = W.shape

        # Reshape for grouping
        if (group_size is not None) and channel_wise:
            W = (
                W.reshape([-1, group_size])
                if (axis == 1)
                else W.reshape([group_size, -1])
            )

        # Get min/max values
        if not channel_wise:
            _min, _max = W.min(), W.max()
            optimize = False
        else:
            _min = W.min(axis=axis, keepdim=True)[0]
            _max = W.max(axis=axis, keepdim=True)[0]

        max_v = round(2**nbits - 1)
        min_v = 0
        min_max = [min_v, max_v]

        # Note: here we work with the inverse of the scale to avoid division and quantize instead via W*scale + zero, the scale is inverted later on.
        scale = (max_v / (_max - _min)).clamp(
            max=2e4
        )  # clamp to avoid half-precision problems
        zero = -_min * scale

        # Round zero as in: https://github.com/casper-hansen/AutoAWQ/blob/main/awq/quantize/quantizer.py#L42C9-L42C14
        if round_zero:
            zero = torch.round(zero)

        # Fine-tune weights
        if optimize:
            W_q, scale, zero = Quantizer.optimize_weights(
                tensor=W,
                scale=scale,
                zero=zero,
                min_max=min_max,
                axis=axis,
                device=device,
            )
        else:
            W_q = torch.round(W * scale + zero).clamp(min_max[0], min_max[1])

        # Store meta-data (we invert the scale for dequantization)
        meta = {
            "nbits": nbits,
            "group_size": group_size,
            "shape": shape,
            "scale": 1.0 / scale,
            "zero": zero,
            "axis": axis,
            "packing": Quantizer.bit_to_packing[nbits],
        }
        meta["unpack_view_dtype"] = Quantizer.unpack_view_dtype[meta["packing"]]

        # Pack bits
        meta["view_as_float"] = view_as_float
        if bitpack:
            W_q = Quantizer.pack[meta["packing"]](W_q)
            if view_as_float:
                W_q = W_q.view(
                    torch.float32 if compute_dtype is None else compute_dtype
                )  # store quantized weights as compute_dtype
        else:
            W_q = W_q.to(tensor.dtype)
            meta["packing"] = None

        # cleanup
        del W, _min, _max
        torch.cuda.empty_cache()

        return W_q, meta

    # Main dequantization: bit_unpacking > (W_q - z)*s > reshape
    @classmethod
    def dequantize(cls, W_q: Tensor, meta: dict) -> Tensor:
        compute_dtype = meta["compute_dtype"] if ("compute_dtype" in meta) else float16
        if meta["packing"]:
            if meta["view_as_float"]:
                W_q = W_q.view(meta["unpack_view_dtype"])
            W_r = Quantizer.unpack[meta["packing"]](W_q, dtype=compute_dtype)
            if meta["nbits"] == 3:
                W_r = W_r[
                    : meta["group_size"]
                    if meta["axis"] == 0
                    else meta["shape"][0] * meta["shape"][1] // meta["group_size"]
                ]
        else:
            W_r = W_q.to(compute_dtype)
        W_r = ((W_r - meta["zero"]) * meta["scale"]).reshape(meta["shape"])
        return W_r

    @classmethod
    def to_inplace(cls, W_q: Tensor, meta: dict, device) -> tuple:
        compute_dtype = meta["compute_dtype"] if ("compute_dtype" in meta) else float16
        if W_q is not None:
            W_q = W_q.to(device).contiguous()
        for key in meta:
            if type(meta[key]) == torch.Tensor:
                meta[key] = (
                    (
                        meta[key].to(compute_dtype)
                        if torch.is_floating_point(meta[key])
                        else meta[key]
                    )
                    .to(device)
                    .contiguous()
                )
        return W_q, meta

    @classmethod
    def to_ooplace(cls, W_q: Tensor, meta: dict, device) -> tuple:
        compute_dtype = meta["compute_dtype"] if ("compute_dtype" in meta) else float16
        if W_q is not None:
            W_q_c = W_q.to(device).contiguous()
        else:
            W_q_c = None
        meta_c = {}
        for key in meta:
            if type(meta[key]) == torch.Tensor:
                meta_c[key] = (
                    (
                        meta[key].to(compute_dtype)
                        if torch.is_floating_point(meta[key])
                        else meta[key]
                    )
                    .to(device)
                    .contiguous()
                )
            else:
                meta_c[key] = meta[key]
        return W_q_c, meta_c

    @classmethod
    def cuda(cls, W_q: Tensor, meta: dict, device) -> tuple:
        return Quantizer.to_inplace(W_q, meta, device=device)

    @classmethod
    def cpu(cls, W_q: Tensor, meta: dict) -> tuple:
        return Quantizer.to_ooplace(W_q, meta, device="cpu")




class BTMoEBackend(Enum):
    # Name of the forward functions
    PYTORCH = "forward_pytorch_backprop"
    PYTORCH_COMPILE = "forward_pytorch_backprop_compile"


    # Alias for backward compatibility
    PYTORCH_BACKPROP = "forward_pytorch_backprop"
    PYTORCH_BACKPROP_COMPILE = "forward_pytorch_backprop_compile"


    PYTORCH_FORWARD = "forward_pytorch"
    PYTORCH_FORWARD_COMPILE = "forward_pytorch_compile"




# No cache: less memory, slower
class BTMoEMatmulNoCacheDeq(torch.autograd.Function):
    @staticmethod
    def forward(x: Tensor, dequantize, bias: Tensor):
        print("call forward HQQMatmulNoCacheDeq")
        out = torch.matmul(x, dequantize().t())
        if bias is not None:
            out += bias
        return out

    @staticmethod
    def setup_context(ctx, inputs, outputs):
        x, dequantize, bias = inputs
        ctx.save_for_backward(x, bias)
        ctx.dequantize = dequantize

    @staticmethod
    def backward(ctx, grad_output):
        x, bias = ctx.saved_tensors

        grad_input = grad_weight = grad_bias = None

        if ctx.needs_input_grad[0]:
            grad_input = torch.matmul(grad_output, ctx.dequantize())

        # weight grad for frozen quantized weights not defined
        # if ctx.needs_input_grad[1]:
        #   grad_weight = torch.matmul(grad_output.t(), x)

        if bias is not None and ctx.needs_input_grad[2]:
            grad_bias = grad_output.sum(0)

        return grad_input, grad_weight, grad_bias



class BTMoEMatmulNoCacheMul(torch.autograd.Function): # default
    @staticmethod
    def forward(x, matmul, bias):
        out = matmul(x, transpose=True)
        if bias is not None:
            out += bias
        return out

    @staticmethod
    def setup_context(ctx, inputs, outputs):
        x, matmul, bias = inputs
        ctx.save_for_backward(x, bias)
        ctx.matmul = matmul

    @staticmethod
    def backward(ctx, grad_output):
        x, bias = ctx.saved_tensors

        grad_input = grad_weight = grad_bias = None

        if ctx.needs_input_grad[0]:
            grad_input = ctx.matmul(grad_output, transpose=False)

        # weight grad for frozen quantized weights not defined
        # if ctx.needs_input_grad[1]:
        #   grad_weight = torch.matmul(grad_output.t(), x)

        if bias is not None and ctx.needs_input_grad[2]:
            grad_bias = grad_output.sum(0)

        return grad_input, grad_weight, grad_bias


# Cache dequantized tensor: Faster but needs more memory
class BTMoEMatmulCachedDeq(torch.autograd.Function):
    @staticmethod
    def forward(ctx, x, hqq_layer, bias):
        print("call forward at HQQMatmulCachedDeq")
        weight_tmp = hqq_layer.dequantize()
        out = torch.matmul(x, weight_tmp.t())
        if bias is not None:
            out += bias

        ctx.save_for_backward(x, bias, weight_tmp)
        return out

    @staticmethod
    def backward(ctx, grad_output):
        x, bias, weight_tmp = ctx.saved_tensors

        grad_input = grad_weight = grad_bias = None

        if ctx.needs_input_grad[0]:
            grad_input = torch.matmul(grad_output, weight_tmp)

        del weight_tmp

        if bias is not None and ctx.needs_input_grad[2]:
            grad_bias = grad_output.sum(0)

        return grad_input, grad_weight, grad_bias


# Main linear layer
class BTMoELinear(nn.Module):
    # Default backend
    backend = BTMoEBackend.PYTORCH

    def __init__(
        self,
        linear_layer: Union[nn.Module, None],
        compress_config: dict,
        del_orig: bool = True,
        compute_dtype: torch.dtype = float16,
        device: str = "cuda",
        initialize: bool = True,
    ):
        super().__init__()
        self.ready = False
        self.in_gpu = False
        self.bias = None
        self.device = device
        self.compute_dtype = compute_dtype
        self.compress_config = copy.deepcopy(compress_config)
        self.del_orig = del_orig
        self.offload_meta = (
            self.compress_config.pop("offload_meta")
            if (self.compress_config is not None)
            else None
        )
        # Remove non-compress keys from config to avoid passing unexpected kwargs
        # These flags are consumed at model patching level, not by BTMoELinear.compress
        if self.compress_config is not None:
            self.compress_config.pop("skip_attention", None)
            self.compress_config.pop("skip_shared_expert", None)
        # self.iters = iters
        # self.rank = rank
        # self.lorc_dtype = lorc_dtype

        self.set_backend(BTMoELinear.backend)

        self.linear_layer = linear_layer
        self.W_q = None
        self.meta = None
        self.encoded_state_dict = (
            True  # This makes state_dict compatible with safetensors
        )

        self.U = None
        self.V = None

        if linear_layer is not None:
            self.orig_shape = self.linear_layer.weight.data.shape
            self.name = linear_layer.name
            # print(self.name)

        if initialize:
            self.initialize()

    def pop_UV_quantized(self):
        if hasattr(self, 'UV_quantized'):
            UV_quantized = self.UV_quantized
            del UV_quantized
            return self.UV_quantized
        else:
            return None

    def initialize(self):
        if self.linear_layer is not None:
            self.compress(self.linear_layer.weight.data, **self.compress_config)
            self.bias = (
                None
                if (self.linear_layer.bias is None)
                else self.linear_layer.bias.to(
                    device=self.device, dtype=self.compute_dtype
                )
            )

        if self.del_orig:
            del self.linear_layer
        torch.cuda.empty_cache()

    @classmethod
    def from_weights(
        cls,
        weight: Tensor,
        bias: Union[Tensor, None],
        quant_config: dict,
        compute_dtype: torch.dtype = float16,
        device: str = "cuda",
        del_orig: bool = True,
    ):
        dummy_linear = torch.nn.Linear(1, 1, bias=False)
        dummy_linear.weight.data = weight
        dummy_linear.bias = bias

        return cls(
            dummy_linear,
            quant_config=quant_config,
            compute_dtype=compute_dtype,
            device=device,
            del_orig=del_orig,
        )

    def extra_repr(self) -> str:
        out = ""
        if hasattr(self, "meta"):
            if self.meta is not None:
                in_features, out_features = self.meta["shape"][::-1]
                out = f"in_features={in_features}, out_features={out_features}, bias={self.bias is not None}"
        return out

    # Set backends
    @classmethod
    def set_backend(cls, backend: BTMoEBackend):
        BTMoELinear.backend = backend
        cls.forward = getattr(cls, backend.value)

    # TODO: rewrite this mess
    def cuda(self, device):
        self.meta["compute_dtype"] = self.compute_dtype

        if type(self.W_q) == nn.parameter.Parameter:
            self.W_q.data, self.meta = Quantizer.cuda(self.W_q.data, self.meta, device)
        else:
            self.W_q, self.meta = Quantizer.cuda(self.W_q, self.meta, device)

        if self.meta["quant_zero"]:
            if "zero_q" in self.meta:
                self.meta["zero_q"], self.meta["meta_zero"] = Quantizer.cuda(
                    self.meta["zero_q"], self.meta["meta_zero"], device
                )
            else:
                _, self.meta["meta_zero"] = Quantizer.cuda(
                    None, self.meta["meta_zero"], device
                )
        elif "zero" in self.meta:
            self.meta["zero"] = self.meta["zero"].to(device)

        if self.meta["quant_scale"]:
            if "scale_q" in self.meta:
                self.meta["scale_q"], self.meta["meta_scale"] = Quantizer.cuda(
                    self.meta["scale_q"], self.meta["meta_scale"], device
                )
            else:
                _, self.meta["meta_scale"] = Quantizer.cuda(
                    None, self.meta["meta_scale"], device
                )
        elif "scale" in self.meta:
            self.meta["scale"] = self.meta["scale"].to(device)

        # #Use zero/scale with streams for dequantization is faster than packing in "zero_scale"
        # for key in ["zero", "zero_q", "scale", "scale_q"]:
        #     if((key in self.meta) and self.offload_meta):
        #         self.meta[key] = self.meta[key].contiguous().cpu().pin_memory()

        if self.offload_meta:
            if "zero_scale" not in self.meta:
                if self.meta["quant_scale"] and self.meta["quant_zero"]:
                    self.meta["zero_scale"] = torch.stack(
                        (self.meta["zero_q"], self.meta["scale_q"])
                    )
                    del self.meta["scale_q"], self.meta["zero_q"]
                else:
                    self.meta["zero_scale"] = torch.stack(
                        (self.meta["zero"], self.meta["scale"])
                    ).to(self.compute_dtype)
                    del self.meta["scale"], self.meta["zero"]

            self.meta["zero_scale"] = (
                self.meta["zero_scale"].contiguous().cpu().pin_memory()
            )

        if self.bias is not None:
            self.bias = self.bias.to(device=device, dtype=self.compute_dtype)

        self.W_q = nn.Parameter(self.W_q, requires_grad=False)
        self.device = device
        self.in_gpu = True

        torch.cuda.empty_cache()

        return self

    def to(self, *args, **kwargs):
        # TODO: later
        return self


    def type(self, dst_type):
        # TODO: later
        return self

    def half(self, *args, **kwargs):
        return self

    def bfloat16(self, *args, **kwargs):
        # TODO: later
        return self

    def float(self, *args, **kwargs):
        # TODO: later
        return self

    def double(self, *args, **kwargs):
        return self

    def cpu(self):
        # TODO: later
        return self

    # state_dict is encoded by default for safetensors support. You can get the raw dict by setting self.encoded_state_dict=False. \
    # Note: you can't change the state once it's done
    def state_dict(self, *args, **kwargs):  # nn.Module override compatible
        if (
            self.compress_config["scale_quant_params"]
            or self.compress_config["zero_quant_params"]
        ) and self.encoded_state_dict:
            raise Exception(
                "Unsupported serialization for quantized scale/zero and self.encoded_state_dict=True"
            )
            # TODO: add support for quantized zero/scale case (quant_config and zero/scale)

        _encode_type = (
            encode_safetensor_type if (self.encoded_state_dict) else lambda z: z
        )

        # Core data
        state = {"W_q": self.W_q} | {k: _encode_type(v) for k, v in self.meta.items()}
        if self.bias is not None:
            state["bias"] = self.bias
        state["offload_meta"] = _encode_type(self.offload_meta)

        # Encoding flag
        if self.encoded_state_dict:
            state["encoded_state_dict"] = _encode_type(self.encoded_state_dict)

        # Quant config
        state["stores_quant_config"] = _encode_type(True)
        for k in self.compress_config["weight_quant_params"]:
            state[k] = _encode_type(self.compress_config["weight_quant_params"][k])

        if "destination" in kwargs and "prefix" in kwargs:
            for key, value in state.items():
                kwargs["destination"][kwargs["prefix"] + key] = value

        # compensator config
        self.compress_config["compensator_params"].pop("ranks",None)
        for k in self.compress_config["compensator_params"]:
            state[k] = _encode_type(self.compress_config["compensator_params"][k])

        return state


    def load_state_dict(self, state_dict, strict=True, assign=False):
        if "encoded_state_dict" in state_dict:
            encoded_state_dict = True
            state_dict.pop("encoded_state_dict")
        else:
            encoded_state_dict = False

        _decode_type = (
            decode_safetensor_type if (encoded_state_dict) else lambda z, w: z
        )

        # Quant-config
        if state_dict.pop(
            "stores_quant_config", False
        ):  # check for backward compatibility
            self.compress_config = {
                "weight_quant_params": {
                    k: _decode_type(state_dict[k], _META_TYPE[k])
                    for k in [
                        "nbits",
                        "channel_wise",
                        "group_size",
                        "optimize",
                        "round_zero",
                        "axis",
                        "view_as_float",
                    ]
                }
            }
            # TODO: scale/zero quant use-case
            self.compress_config["scale_quant_params"] = state_dict.pop(
                "scale_quant_params", None
            )
            self.compress_config["zero_quant_params"] = state_dict.pop(
                "zero_quant_params", None
            )
        self.compress_config["compensator_params"] = {}
        self.compress_config["compensator_params"]["sparse_rank"] = state_dict.pop(
                "sparse_rank", None
            )
        self.compress_config["compensator_params"]["iter"] = state_dict.pop(
                "iter", None
            )
        self.compress_config["compensator_params"]["dense_rank"] = state_dict.pop(
                "dense_rank", None
            )
        self.compress_config["compensator_params"]["rank_strategy"] = state_dict.pop(
                "rank_strategy", None
            )
        self.compress_config["compensator_params"]["compensator_dtype"] = state_dict.pop(
                "compensator_dtype", None
            )
        self.compress_config["compensator_params"]["compensator_quant_gs"] = state_dict.pop(
                "compensator_quant_gs", None
            )
            
        # W_q/ bias
        self.W_q = state_dict.pop("W_q")
        self.bias = state_dict.pop("bias", None)

        # Meta
        self.offload_meta = _decode_type(state_dict.pop("offload_meta", False), bool)
        if "meta" in state_dict:
            self.meta = state_dict["meta"]  # Backward compatibility
        else:
            self.meta = {
                k: _decode_type(v, _META_TYPE[k]) for k, v in state_dict.items()
            }  # safetensors version

        # Meta-data offloading
        if self.offload_meta is None:
            self.offload_meta = False
        for key in ["zero", "zero_q", "scale", "scale_q", "zero_scale"]:
            if key in self.meta and self.offload_meta:
                self.meta[key] = self.meta[key].cpu().contiguous().pin_memory()

        # Float view settings
        if "unpack_view_dtype" not in self.meta:
            self.meta["unpack_view_dtype"] = Quantizer.unpack_view_dtype[
                self.meta["packing"]
            ]

        if "view_as_float" not in self.meta:
            self.meta["view_as_float"] = False

        if "meta_scale" in self.meta:
            if "view_as_float" not in self.meta["meta_scale"]:
                self.meta["meta_scale"]["view_as_float"] = False

        if "meta_zero" in self.meta:
            if "view_as_float" not in self.meta["meta_zero"]:
                self.meta["meta_zero"]["view_as_float"] = False

        # Check GPU
        self.cuda(self.device)
        self.ready = True

        # Set in_features/out_features
        self.in_features, self.out_features = self.meta["shape"][::-1]

    

    def compress(
        self,
        W: Tensor,
        weight_quant_params: dict,
        scale_quant_params: dict,
        zero_quant_params: dict,
        compensator_params: dict,
    ) -> None:
        # print(compensator_params)
        quant_scale = scale_quant_params is not None
        quant_zero = zero_quant_params is not None

        self.in_features, self.out_features = W.t().shape
         
        U = None
        V = None
        W_unquant = W.to(self.device)
        W_q = None
        self.UV_quantized = None

        iter = compensator_params["iter"]
        rank = compensator_params["ranks"].get(self.name, None)
        
        if rank is None:
            rank = next((value for key, value in compensator_params["ranks"].items() if self.name in key), None)
            if rank is None:
                rank = next((value for key, value in compensator_params["ranks"].items() if key in self.name), None)
                
        if rank is None:
            print(f"[WARNING] Could not find rank for {self.name}, will use None which may cause an error.")
        compensator_dtype = compensator_params["compensator_dtype"]
        # print(iter)
        # print(rank)
        # print(compensator_dtype)
        # print(weight_quant_params)
        for i in range(0, iter + 1):
            if i > 0:
                W = W_unquant.to(self.device) - (U @ V)

            # Quantize
            print(f"quantize {self.name} to {weight_quant_params['nbits']} bits, iter = {i}, rank = {rank}")
            W_q, meta = Quantizer.quantize(
                W,
                device=self.device,
                compute_dtype=self.compute_dtype,
                **weight_quant_params,
            )
            meta.update({"quant_scale": quant_scale, "quant_zero": quant_zero})

            if rank == 0:
                break
            W_q_dequant = Quantizer.dequantize(W_q, meta).to(self.device)
            U_svd, S, V_svd = torch.svd_lowrank(W_unquant.float() - W_q_dequant.float(), q=rank)
            S = torch.diag(S)
            U = (U_svd @ torch.sqrt(S)).to(self.device)
            V = (torch.sqrt(S) @ V_svd.T).to(self.device)

            Error = W_unquant - W_q_dequant - U@V
            print(f"iter {i}, norm={torch.norm(Error,p='fro')}")
            
        if rank > 0:
            if compensator_dtype == 'int3':
                UV_quantized = ((quantize_full_to_int3(U, compensator_params["compensator_quant_gs"])), 
                                (quantize_full_to_int3(V, compensator_params["compensator_quant_gs"])))
                self.UV_quantized = UV_quantized
            elif compensator_dtype == 'int4':
                U_quantized = quantize_full_to_int4(U, compensator_params["compensator_quant_gs"])
                V_quantized = quantize_full_to_int4(V, compensator_params["compensator_quant_gs"])
                UV_quantized = (U_quantized, V_quantized)
                self.UV_quantized = UV_quantized
                
                print(f"Original U norm: {torch.norm(U, p='fro'):.6f}")
                print(f"Original V norm: {torch.norm(V, p='fro'):.6f}")
                print(f"U min/max: {torch.min(U):.6f}/{torch.max(U):.6f}")
            # elif compensator_dtype == 'int8':
            #     UV_quantized = (quantize_full_to_int8(U), quantize_full_to_int8(V))
            #     self.UV_quantized = UV_quantized
            else:
                raise NotImplementedError

        if meta["quant_zero"]:
            meta["zero_q"], meta["meta_zero"] = Quantizer.quantize(
                meta["zero"],
                device=self.device,
                view_as_float=False,
                **zero_quant_params,
            )
            del meta["zero"]
            meta["meta_zero"]["compute_dtype"] = self.compute_dtype

        if meta["quant_scale"]:
            meta["scale_q"], meta["meta_scale"] = Quantizer.quantize(
                meta["scale"],
                device=self.device,
                view_as_float=False,
                **scale_quant_params,
            )
            del meta["scale"]
            meta["meta_scale"]["compute_dtype"] = self.compute_dtype

        self.W_q = W_q
        self.meta = meta
        self.cuda(self.device)
        self.ready = True
        if self.UV_quantized is not None:
            self.U,self.V = compensator_dequantize(self.UV_quantized, self.meta["shape"], rank, compensator_params["compensator_quant_gs"], compensator_dtype)

    def unpack(self, reshape=False, dtype=None):
        if self.ready is False:
            return None
        if self.meta["packing"]:
            W_r = Quantizer.unpack[self.meta["packing"]](
                self.W_q, dtype=dtype if (dtype is not None) else self.compute_dtype
            )
            return W_r.view(self.meta["shape"]) if (reshape) else W_r

    def dequantize(self):
        assert self.ready, "model was not quantized"
        W_q, meta = self.W_q, self.meta
        device = W_q.device
        del_keys = set()

        # Zero/Scale packed together
        if "zero_scale" in meta:
            zero_scale = meta["zero_scale"].to(device=device)

            if zero_scale.dtype == uint8:
                meta["zero_q"], meta["scale_q"] = zero_scale[0], zero_scale[1]
                del_keys.update({"zero_q", "scale_q"})
            else:
                meta["zero"], meta["scale"] = zero_scale[0], zero_scale[1]
                del_keys.update({"zero", "scale"})

        if meta["quant_zero"]:
            meta["zero"] = Quantizer.dequantize(
                meta["zero_q"].to(device=device), meta["meta_zero"]
            )
            del_keys.add("zero")

        if meta["quant_scale"]:
            meta["scale"] = Quantizer.dequantize(
                meta["scale_q"].to(device=device), meta["meta_scale"]
            )
            del_keys.add("scale")

        W_est = Quantizer.dequantize(W_q, meta)

        # Cleanup
        for key in del_keys:
            del meta[key]
        return W_est


    def matmul(self, x: Tensor, transpose: bool = True) -> Tensor:
        weight = self.dequantize()
        if (self.U is not None) and (self.V is not None):
            weight = weight + self.U @ self.V # recover E_hat and add to the weight
        return torch.matmul(x, weight.t() if (transpose) else weight)

    @torch.compile()
    def matmul_compile(self, *args, **kwargs):
        return self.matmul(*args, **kwargs)

    def forward_pytorch_backprop(self, x: Tensor) -> Tensor:
        
        return BTMoEMatmulNoCacheMul.apply(x, self.matmul, self.bias)

    def forward_pytorch_backprop_compile(self, x: Tensor) -> Tensor:
        
        return BTMoEMatmulNoCacheMul.apply(x, self.matmul_compile, self.bias)

    def forward_pytorch(self, x: Tensor) -> Tensor:
        print("HQQLinear, forward_pytorch")
        # Dequantize base weight
        weight = self.dequantize()
        # Add compensator if available
        if (self.U is not None) and (self.V is not None):
            weight = weight + self.U @ self.V
        out = torch.matmul(x, weight.t())
        if self.bias is not None:
            out += self.bias
        return out

    @torch.compile()
    def forward_pytorch_compile(self, x: Tensor) -> Tensor:
        print("HQQLinear, forward_pytorch_compile")
        return self.forward_pytorch(x)

    



def btmoe_base_compress_config(
    #quantization config
    nbits: int = 3,
    group_size: int = 64,
    quant_zero: bool = True,
    quant_scale: bool = False,
    offload_meta: bool = False,  # meta-data should be quantized with the same settings to use offload_meta
    view_as_float: bool = False,
    axis: int = 0,
    #compensator config
    iter: int = 10,
    sparse_rank: int = 0,
    dense_rank: int = 0,
    rank_strategy: str = None,
    compensator_dtype: str = "int3",  # Options: "int3", "int4"(need Test)
    compensator_quant_gs: int = 64,
    # skip switches
    skip_attention: bool = False,
    skip_shared_expert: bool = False,
):
    assert (
        nbits in Quantizer.SUPPORTED_BITS
    ), "nbits value not supported. Check Quantizer.SUPPORTED_BITS."

    if group_size is not None:
        assert is_divisible(
            group_size, 8
        ), "Invalid group_size param: the value should be a multiple of 8."
    weight_quant_params = {
        "nbits": nbits,
        "channel_wise": True,
        "group_size": group_size,
        "optimize": True,
        "round_zero": True if nbits == 4 else False,
        "axis": axis,
        "view_as_float": view_as_float,
    }

    if offload_meta:
        if quant_scale != quant_zero:
            # print(colored("quant_zero and quant_scale must be the same when offload_meta is set to True. Setting quant_scale=quant_zero." , 'yellow'))
            quant_scale = quant_zero

        scale_quant_params = (
            {"nbits": 8, "channel_wise": True, "group_size": 128, "optimize": False}
            if (quant_scale)
            else None
        )
        zero_quant_params = (
            {"nbits": 8, "channel_wise": True, "group_size": 128, "optimize": False}
            if (quant_zero)
            else None
        )

    else:
        scale_quant_params = (
            {"nbits": 8, "channel_wise": True, "group_size": 128, "optimize": False}
            if (quant_scale)
            else None
        )
        zero_quant_params = (
            {"nbits": 8, "channel_wise": False, "group_size": None, "optimize": False}
            if (quant_zero)
            else None
        )

    compensator_params = {
        "iter": iter,
        "sparse_rank": sparse_rank,
        "dense_rank": dense_rank,
        "rank_strategy": rank_strategy,
        "compensator_dtype": compensator_dtype,
        "compensator_quant_gs": compensator_quant_gs
    }

    return {
        # quantization configs
        "weight_quant_params": weight_quant_params,
        "scale_quant_params": scale_quant_params,
        "zero_quant_params": zero_quant_params,
        "offload_meta": offload_meta,
        # skip flags
        "skip_attention": skip_attention,
        "skip_shared_expert": skip_shared_expert,
        # compensator configs
        "compensator_params": compensator_params,
    }



# Alias: follow similar Auto-GPTQ naming
BaseCompressConfig = btmoe_base_compress_config
