import torch
from torch import nn
import pickle
from src.models.cache import DynamicCache
from src.models.cache import NSNCache
from src.utils import apply_rotary_pos_emb_single
from src.utils import repeat_kv
from src.quantizers.base_quantizer import BaseQuantizer


class KVQuantQuantizer(BaseQuantizer):
    residual_size: int = 64
    sparse_threshold = 0.99
    post_rope: bool = False
    def __init__(self,
                 layer_idx: int,
                 quantizer_path: str,
                 norm: bool) -> None:
        super().__init__()
        self.norm = norm
        with open(quantizer_path, "rb") as f:
            quantizer = pickle.load(f)
        for key in quantizer.keys():
            if (f".{layer_idx}." in key) and ("k_proj" in key):
                self.register_buffer("key_outlier_upper", quantizer[key][0])
                self.register_buffer("key_outlier_lower", quantizer[key][1])
                self.register_buffer("key_centroids", quantizer[key][2][0][:, 0])
                if norm:
                    self.register_buffer("key_normscale", quantizer[key][3])
                    self.register_buffer("key_normoffset", quantizer[key][4])
            elif (f".{layer_idx}." in key) and ("v_proj" in key):
                self.register_buffer("value_centroids", quantizer[key][2][0][:, 0])
                if norm:
                    self.register_buffer("value_normscale", quantizer[key][3])
                    self.register_buffer("value_normoffset", quantizer[key][4])

    def forward(self, x, encode_type: str):
        B, H, L, D = x.shape
        x = x.transpose(1, 2)
        x = x.reshape(-1, x.shape[-1]*x.shape[-2])
        if encode_type == "k":
            outlier_upper = self.key_outlier_upper
            outlier_lower = self.key_outlier_lower
            centroids = self.key_centroids
            normscale = self.key_normscale
            normoffset = self.key_normoffset
            channel = 0
        elif encode_type == "v":
            t = self.sparse_threshold + (1-self.sparse_threshold) * 0.5
            outlier_upper = torch.quantile(x.float(), t, dim=-1, keepdim=True).half()
            outlier_lower = torch.quantile(x.float(), 1-t, dim=-1, keepdim=True).half()
            centroids = self.value_centroids
            normscale = self.value_normscale
            normoffset = self.value_normoffset
            channel = -1

        outlier_mask = ((x >= outlier_upper) | (x <= outlier_lower))
        outliers = x[outlier_mask]

        med = torch.median(x, dim=channel, keepdim=True).values

        x = outlier_mask.float() * med + (1 - outlier_mask.float()) * x
        offset = (outlier_upper + outlier_lower) / 2
        range = (outlier_upper - outlier_lower) / 2
        x = (x - offset) / range
        x = x.clamp_(-1, 1)
        
        x = x[..., None]
        centroid_idx = (x - centroids).abs().argmin(dim=-1)
        x = centroids[centroid_idx]
        x *= normscale
        x += normoffset
        x *= range
        x += offset

        x[outlier_mask] = outliers
        x = x.reshape(B, L, H, D)
        x = x.transpose(1, 2)
        return x

    def self_attn(self,
                  query_states: torch.Tensor,
                  past_key_value: DynamicCache,
                  layer_idx: int,
                  attention_mask: torch.Tensor,
                  scaling: float,
                  num_key_value_groups: int,
                  **kwargs):
        curr_cache = past_key_value[layer_idx]
        # full precision cache
        key_states = curr_cache["full_key_cache"].clone()
        value_states = curr_cache["full_value_cache"]
        sin = curr_cache["sin"]
        cos = curr_cache["cos"]
        residual_length = sin.shape[-2]

        key_states[..., -residual_length:, :] = apply_rotary_pos_emb_single(key_states[..., -residual_length:, :], cos, sin)
        key_states = repeat_kv(key_states, num_key_value_groups)
        value_states = repeat_kv(value_states, num_key_value_groups)

        attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) * scaling
        if attention_mask is not None:
            causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
            attn_weights = attn_weights + causal_mask

        attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
        attn_output = torch.matmul(attn_weights, value_states)
        attn_output = attn_output.transpose(1, 2).contiguous()
        return attn_output, attn_weights
    
    def update_cache(self, past_key_value: NSNCache, layer_idx: int, is_prefill: bool):
        residual_size = self.residual_size
        curr_cache = past_key_value[layer_idx]
        # full precision cache
        curr_key_cache = curr_cache["full_key_cache"]
        curr_value_cache = curr_cache["full_value_cache"]
        sin = curr_cache["sin"]
        cos = curr_cache["cos"]
        if curr_key_cache.shape[-2] < residual_size:
            return

        if is_prefill:
            remainder = curr_key_cache.shape[-2] % residual_size
            if remainder == 0:
                target_slice = (slice(None))
                remaining_slice = None
            else:
                target_slice = (..., slice(None, -remainder), slice(None))
                remaining_slice = (..., slice(-remainder, None), slice(None))
        else:
            target_slice = (..., slice(-residual_size, None), slice(None))
            remaining_slice = None

        if (curr_key_cache.shape[-2] % residual_size == 0) or is_prefill:
            assert curr_key_cache.shape[-2] == curr_value_cache.shape[-2], "the length of key and value cache should be identical"
            target_key_states = curr_key_cache[target_slice]
            target_value_states = curr_value_cache[target_slice]
            remaining_sin = sin[remaining_slice].contiguous() if remaining_slice is not None else []
            remaining_cos = cos[remaining_slice].contiguous() if remaining_slice is not None else []

            target_sin = sin[target_slice].contiguous()
            target_cos = cos[target_slice].contiguous()

            quantized_key_states = self.forward(target_key_states, "k")
            quantized_value_states = self.forward(target_value_states, "v")

            quantized_key_states = apply_rotary_pos_emb_single(quantized_key_states, target_cos, target_sin)
            curr_key_cache[target_slice] = quantized_key_states
            curr_value_cache[target_slice] = quantized_value_states

            past_key_value.direct_update("full_key_cache", layer_idx, curr_key_cache)
            past_key_value.direct_update("full_value_cache", layer_idx, curr_value_cache)
            past_key_value.direct_update("sin", layer_idx, remaining_sin)
            past_key_value.direct_update("cos", layer_idx, remaining_cos)


