from functools import partial

import torch
from src.quantizers.base_quantizer import BaseQuantizer


class KIVIQuantizer(BaseQuantizer):
    def __init__(self,
            n_bits: int,
            residual_size: int,
            key_group_size: int,
            value_group_size: int,
            hadamard: bool,
            sym: bool,
            clip_ratio: float     
        ) -> None:
        super().__init__()
        self.n_bits = n_bits
        self.residual_size = residual_size
        self.key_group_size = key_group_size
        self.value_group_size = value_group_size

        assert residual_size % key_group_size == 0, "Residual size must be divisible by key group size."
        self.hadamard = hadamard
        self.sym = sym
        self.clip_ratio = clip_ratio

    @classmethod
    @torch.no_grad()
    def quantize_tensor(self, w: torch.tensor, n_bits, group_size, sym, clip_ratio=1.0) -> torch.tensor:
        # Asymmetric RTN (Round-To-Nearest)
        savedShape = w.shape
        assert w.dim() == 2 

        if group_size > 0:
            assert w.shape[-1] % group_size == 0
            w = w.reshape(-1, group_size) # row-major order

        w_max = w.amax(dim=-1, keepdim=True)
        w_min = w.amin(dim=-1, keepdim=True)
        scale = (w_max - w_min) / (2 ** n_bits - 1)
        w = w - w_min
        w.div_(scale)
        w = w.clamp_(0, 2 ** n_bits - 1).round_()
        w *= scale
        w += w_min

        return w.reshape(savedShape)
    
    def quantize(self, x, encode_type: str):
        if self.n_bits >= 16:
            return x 
        
        if encode_type == "k":
            group_size = self.key_group_size
            x = x.transpose(-1, -2)
        elif encode_type == "v":
            group_size = self.value_group_size


        qFunction = partial(
            self.quantize_tensor, 
            n_bits=self.n_bits,
            group_size=group_size,
            sym=self.sym,
            clip_ratio=self.clip_ratio
        )

        savedShape = x.shape
        x = x.reshape(-1, savedShape[-1])
        assert group_size == 0 or (savedShape[-1]) % group_size == 0, "Group size should be divisible by (dim)."

        x = qFunction(x)
        x = x.reshape(savedShape)

        if encode_type == "k":
            x = x.transpose(-1, -2)
        return x
    
    def dequantize(self, x):
        return x

    @torch.no_grad()
    def forward(self, x, encode_type: str = None):
        return self.quantize(x, encode_type)
