import math

from typing import Tuple
from typing import Optional

import nsn_tools
import torch
from torch import nn
from fast_hadamard_transform import hadamard_transform
from src.models.cache import DynamicCache
from src.models.cache import NSNCache
from src.utils import apply_rotary_pos_emb_single
from src.utils import repeat_kv
from src.utils import get_restore_function
from src.utils import get_scale_adjustment_function
from src.utils import get_weighted_sum_function
from src.utils import get_weighted_sum_residual_function
from src.utils import get_dot_product_fused_function
from src.utils import get_dot_product_fused_residual_function
from src.utils import pack_idx_signs
from src.utils import pseudo_quantize_tensor_4bit
from src.utils import quantize_tensor_4bit
from src.utils import dequantize_tensor_4bit

from src.quantizers.base_quantizer import BaseQuantizer
from transformers.models.llama.modeling_llama import apply_rotary_pos_emb

from src.recorder import recorder


class NSNQuantizer(BaseQuantizer):
    post_rope: bool = False
    def __init__(self,
                 n_bits: int,
                 codebook_path: str,
                 residual_size: int,
                 window_size: int,
                 hadamard: bool) -> None:
        super().__init__()
        self.n_bits = n_bits
        codebook_idx, codebook_scale, codebook_offset = quantize_tensor_4bit(torch.load(codebook_path, weights_only=True), 8)
        self.register_buffer("codebook", pseudo_quantize_tensor_4bit(torch.load(codebook_path, weights_only=True), 8))
        self.register_buffer("codebook_idx", codebook_idx)
        self.register_buffer("codebook_scale", codebook_scale)
        self.register_buffer("codebook_offset", codebook_offset)
        self.residual_size = residual_size
        self.window_size = window_size
        self.hadamard = hadamard
        assert residual_size % window_size == 0, "Residual size must be divisible by group size."
    
    def nsn(self, x, encode_type: str):
        B, H, L, D = x.shape
        x_norm = x.norm(p=2, dim=-1, keepdim=True) / math.sqrt(x.shape[-1])
        x_norm_reshaped = x_norm.reshape(B, H, L // self.window_size, self.window_size)
        x_norm_idx, x_norm_scale, x_norm_offset = quantize_tensor_4bit(x_norm_reshaped, self.window_size)
        x_norm = (dequantize_tensor_4bit(x_norm_idx, x_norm_scale, x_norm_offset)).reshape(B, H, L, 1)
        x = x / x_norm

        x_mean_idx, x_mean_scale, x_mean_offset = quantize_tensor_4bit(x.reshape(B, H, L // self.window_size, self.window_size, D).mean(dim=-2, keepdim=True), 32)
        x_mean = dequantize_tensor_4bit(x_mean_idx, x_mean_scale, x_mean_offset)
        x = x.reshape(B, H, L // self.window_size, self.window_size, D)
        x = x - x_mean
        x = x.reshape(B, H, L, D)

        x_norm2 = x.norm(p=2, dim=-1, keepdim=True) / math.sqrt(x.shape[-1])
        x = x / x_norm2
        return x, x_norm, (x_norm_idx, x_norm_scale, x_norm_offset), x_mean, (x_mean_idx, x_mean_scale, x_mean_offset), x_norm2

    @torch.no_grad()
    def quantize(self, x, encode_type: str):
        x, x_norm, (x_norm_idx, x_norm_scale, x_norm_offset), x_mean, (x_mean_idx, x_mean_scale, x_mean_offset), x_norm2 = self.nsn(x, encode_type)
        nsn_res = x
        packed = self.vq(nsn_res)
        get_scale_adjustment_function(x.shape[-1], self.n_bits)(packed, self.codebook_idx, self.codebook_scale, self.codebook_offset, x_norm2.contiguous(), nsn_res.contiguous())
        return packed, x_norm, (x_norm_idx, x_norm_scale, x_norm_offset), x_mean, (x_mean_idx, x_mean_scale, x_mean_offset), x_norm2

    def vq(self, x: torch.FloatTensor) -> torch.IntTensor:
        original_shape = x.shape
        x = x.reshape(-1, 8)
        if self.n_bits==2:
            quantized = nsn_tools.dist_argmin_half_packed_dq(x, self.codebook_idx, self.codebook_scale, self.codebook_offset)
            # quantized = nsn_tools.dist_argmin_half_packed(x, self.codebook)
            return quantized.reshape(*original_shape[:-1], original_shape[-1]//16)
        elif self.n_bits==1:
            quantized = nsn_tools.dist_argmin_half_packed_dq_1bit(x, self.codebook_idx, self.codebook_scale, self.codebook_offset)
            return quantized.reshape(*original_shape[:-1], original_shape[-1]//32)
    
    def vq_torch(self, x: torch.FloatTensor) -> torch.IntTensor:
        """
        if n_bits == 2, return sign mask
        if n_bits == 1, skip abs, return sign masks filled with 1
        """
        ori_x = x
        if self.n_bits == 2:
            x = abs(x)
        original_shape = x.shape
        x = x.reshape(-1, 8)
        quantized = nsn_tools.dist_argmin_half(x, self.codebook)
        return quantized.reshape(*original_shape[:-1], original_shape[-1]//8), \
            torch.ones_like((ori_x > 0)) if self.n_bits == 1 else ori_x > 0

    def vdq_torch(self, x: torch.IntTensor, sign_mask: torch.BoolTensor) -> torch.FloatTensor:
        x = self.codebook[x.int()]
        x = x.view(*x.shape[:-2], -1)
        x = torch.where(sign_mask, x, -x)
        return x

    def dequantize(self,
                   packed: torch.IntTensor,
                   x_norm: torch.FloatTensor,
                   x_mean: torch.FloatTensor,
                   x_norm2: torch.FloatTensor) -> torch.FloatTensor:
        x = get_restore_function(self.window_size, x_mean.shape[-1], self.n_bits)(packed, x_norm, x_mean, x_norm2, self.codebook_idx, self.codebook_scale, self.codebook_offset)
        return x


    @torch.no_grad()
    def forward(self, x: torch.FloatTensor, encode_type: str) -> torch.FloatTensor:
        """
        Simulates vector quantization to measure errors.

        x(torch.Tensor): a tensor to be quantized
        """
        B, H, L, D = x.shape
        packed, x_norm, x_mean, x_norm2 = self.quantize(x, encode_type)
        x = self.dequantize(packed, x_norm, x_mean, x_norm2)
        return x
    
    def forward_quant(self,
                      query_states: torch.FloatTensor,
                      key_states: torch.FloatTensor,
                      value_states: torch.FloatTensor,
                      position_embeddings: Tuple[torch.FloatTensor, torch.FloatTensor],
                      past_key_value: Optional[DynamicCache],
                      cache_position: torch.LongTensor):
        key_states, key_norm, _, key_mean, _, key_norm2 = self.nsn(key_states, "k")
        value_states, value_norm, _, value_mean, _, value_norm2 = self.nsn(value_states, "v")

        cos, sin, _, _ = position_embeddings
        query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)

        query_states = hadamard_transform(query_states, 1/math.sqrt(query_states.shape[-1]))
        key_states = hadamard_transform(key_states, 1/math.sqrt(key_states.shape[-1]))
  
        key_packed = self.vq(key_states)
        value_packed = self.vq(value_states)

        get_scale_adjustment_function(key_states.shape[-1], self.n_bits)(key_packed, self.codebook_idx, self.codebook_scale, self.codebook_offset, key_norm2, key_states)
        get_scale_adjustment_function(value_states.shape[-1], self.n_bits)(value_packed, self.codebook_idx, self.codebook_scale, self.codebook_offset, value_norm2, value_states)
    
        key_states = self.dequantize(key_packed, key_norm, torch.zeros_like(key_mean), key_norm2)
        value_states = self.dequantize(value_packed, value_norm, value_mean, value_norm2)

        key_mean = key_mean
        key_mean = key_mean.repeat(1, 1, 1, self.window_size, 1)
        key_mean = key_mean.reshape(*key_states.shape)

        key_mean = apply_rotary_pos_emb_single(key_mean, cos, sin)
        key_mean = hadamard_transform(key_mean, 1/math.sqrt(key_mean.shape[-1]))

        key_states += key_mean * key_norm

        query_states = hadamard_transform(query_states, 1/math.sqrt(query_states.shape[-1]))
        key_states = hadamard_transform(key_states, 1/math.sqrt(key_states.shape[-1]))
  
        # key_states[:, :, 0, :] = key_sink
        # value_states[:, :, 0, :] = value_sink
        return query_states, key_states, value_states, True

    @torch.no_grad()
    def self_attn(self, query_states: torch.Tensor,
                        past_key_value: NSNCache,
                        layer_idx: int,
                        attention_mask: torch.Tensor,
                        scaling: float,
                        num_key_value_groups: int,
                        **kwargs):
        had_query_states = hadamard_transform(query_states, 1/math.sqrt(query_states.shape[-1]))
        curr_cache = past_key_value[layer_idx]

        # full precision cache
        full_key_cache = curr_cache["full_key_cache"]
        full_value_cache = curr_cache["full_value_cache"]
        
        # quantized cache
        quantized_key_cache = curr_cache["quantized_key_cache"]
        key_norm_idx = curr_cache["key_norm_idx"]
        key_norm_scale = curr_cache["key_norm_scale"]
        key_norm_offset = curr_cache["key_norm_offset"]
        key_mean_idx = curr_cache["key_mean_idx"]
        key_mean_scale = curr_cache["key_mean_scale"]
        key_mean_offset = curr_cache["key_mean_offset"]
        key_norm2 = curr_cache["key_norm2"]

        quantized_value_cache = curr_cache["quantized_value_cache"]
        value_norm_idx = curr_cache["value_norm_idx"]
        value_norm_scale = curr_cache["value_norm_scale"]
        value_norm_offset = curr_cache["value_norm_offset"]
        value_mean_idx = curr_cache["value_mean_idx"]
        value_mean_scale = curr_cache["value_mean_scale"]
        value_mean_offset = curr_cache["value_mean_offset"]
        value_norm2 = curr_cache["value_norm2"]

        sin = curr_cache["sin"]
        cos = curr_cache["cos"]

        if len(quantized_key_cache) == 0:
            full_key_cache = apply_rotary_pos_emb_single(full_key_cache, cos, sin)
            full_key_cache = repeat_kv(full_key_cache, num_key_value_groups)


            attn_weights = torch.matmul(query_states, full_key_cache.transpose(2, 3))
        
        else:
            inv_freq = kwargs["inv_freq"]
            offset = kwargs["offset"]
            if(len(full_key_cache)) > 0:
                dot_product_fused_residual = get_dot_product_fused_residual_function(self.window_size, query_states.shape[-1], self.n_bits)
                attn_weights = dot_product_fused_residual(quantized_key_cache,
                                                          key_norm_idx, key_norm_scale, key_norm_offset,
                                                          key_mean_idx, key_mean_scale, key_mean_offset,
                                                          key_norm2,
                                                          self.codebook_idx,
                                                          self.codebook_scale,
                                                          self.codebook_offset,
                                                          inv_freq,
                                                          offset,
                                                          cos,
                                                          sin,
                                                          full_key_cache,
                                                          had_query_states,
                                                          query_states,
                                                          num_key_value_groups)
            else:
                dot_product_fused = get_dot_product_fused_function(self.window_size, query_states.shape[-1], self.n_bits)
                attn_weights = dot_product_fused(quantized_key_cache,
                                                 key_norm_idx, key_norm_scale, key_norm_offset,
                                                 key_mean_idx, key_mean_scale, key_mean_offset,
                                                 key_norm2,
                                                 self.codebook_idx,
                                                 self.codebook_scale,
                                                 self.codebook_offset,
                                                 inv_freq,
                                                 offset,
                                                 had_query_states,
                                                 query_states,
                                                 num_key_value_groups)

        attn_weights *= scaling

        if attention_mask is not None:
            causal_mask = attention_mask[:, :, :, :attn_weights.shape[-1]]
            attn_weights = attn_weights + causal_mask

        attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)

        full_length = full_value_cache.shape[-2] if len(full_value_cache) > 0 else 0
        quantized_length = quantized_key_cache.shape[-2] if len(quantized_value_cache) > 0 else 0

        if quantized_length == 0:
            full_value_cache = repeat_kv(full_value_cache, num_key_value_groups)
            attn_output = torch.matmul(attn_weights, full_value_cache) 
        else:
            if full_length == 0:
                attn_output = get_weighted_sum_function(self.window_size, query_states.shape[-1], self.n_bits)(quantized_value_cache,
                                                                                                               value_norm_idx, value_norm_scale, value_norm_offset,
                                                                                                               value_mean_idx, value_mean_scale, value_mean_offset,
                                                                                                               value_norm2,
                                                                                                               self.codebook_idx,
                                                                                                               self.codebook_scale,
                                                                                                               self.codebook_offset,
                                                                                                               attn_weights,
                                                                                                               num_key_value_groups)
            else:
                attn_output =  get_weighted_sum_residual_function(self.window_size, query_states.shape[-1], self.n_bits)(quantized_value_cache,
                                                                                                                         value_norm_idx, value_norm_scale, value_norm_offset,
                                                                                                                         value_mean_idx, value_mean_scale, value_mean_offset,
                                                                                                                         value_norm2,
                                                                                                                         full_value_cache,
                                                                                                                         self.codebook_idx,
                                                                                                                         self.codebook_scale,
                                                                                                                         self.codebook_offset,
                                                                                                                         attn_weights,
                                                                                                                         num_key_value_groups)

        attn_output = attn_output.transpose(1, 2).contiguous()
        return attn_output, attn_weights

    def update_cache(self, past_key_value: NSNCache, layer_idx: int, is_prefill: bool):
        residual_size = self.residual_size
        curr_cache = past_key_value[layer_idx]
        # full precision cache
        full_key_cache = curr_cache["full_key_cache"]
        full_value_cache = curr_cache["full_value_cache"]
        sin = curr_cache["sin"]
        cos = curr_cache["cos"]
        
        assert full_key_cache.shape[-2] == sin.shape[-2] == cos.shape[-2], "The length of sin and cos should be the same as the key cache"
        if full_key_cache.shape[-2] < residual_size:
            return

        if is_prefill:
            remainder = full_key_cache.shape[-2] % residual_size
            if remainder == 0:
                target_slice = (slice(None))
                remaining_slice = None
            else:
                target_slice = (..., slice(None, -remainder), slice(None))
                remaining_slice = (..., slice(-remainder, None), slice(None))
        else:
            target_slice = (slice(None))
            remaining_slice = None

        if (full_key_cache.shape[-2] % residual_size == 0) or is_prefill:
            remaining_key_cache = full_key_cache[remaining_slice].contiguous() if remaining_slice is not None else []
            remaining_value_cache = full_value_cache[remaining_slice].contiguous() if remaining_slice is not None else []
            remaining_sin = sin[remaining_slice].contiguous() if remaining_slice is not None else []
            remaining_cos = cos[remaining_slice].contiguous() if remaining_slice is not None else []

            target_key_cache = full_key_cache[target_slice].contiguous()
            target_value_cache = full_value_cache[target_slice].contiguous()
            
            target_sin = sin[target_slice].contiguous()
            target_cos = cos[target_slice].contiguous()

            # Quantize key tensors
            key_cache, key_norm, (key_norm_idx, key_norm_scale, key_norm_offset), key_mean, (key_mean_idx, key_mean_scale, key_mean_offset), key_norm2 = self.nsn(target_key_cache, "k")
            key_cache = apply_rotary_pos_emb_single(key_cache, target_cos, target_sin)
            key_cache = hadamard_transform(key_cache, 1/math.sqrt(key_cache.shape[-1]))

            quantized_key_cache = self.vq(key_cache)
            get_scale_adjustment_function(key_cache.shape[-1], self.n_bits)(quantized_key_cache, self.codebook_idx, self.codebook_scale, self.codebook_offset, key_norm2, key_cache)

            # Quantize value tensors
            quantized_value_cache, value_norm, (value_norm_idx, value_norm_scale, value_norm_offset), value_mean, (value_mean_idx, value_mean_scale, value_mean_offset), value_norm2 = self.quantize(target_value_cache, "v")
            
            past_key_value.concat_by_dim("quantized_key_cache", layer_idx, quantized_key_cache, -2)
            past_key_value.concat_by_dim("key_norm_idx", layer_idx, key_norm_idx, -2)
            past_key_value.concat_by_dim("key_norm_scale", layer_idx, key_norm_scale, -2)
            past_key_value.concat_by_dim("key_norm_offset", layer_idx, key_norm_offset, -2)
            past_key_value.concat_by_dim("key_mean_idx", layer_idx, key_mean_idx, -3)
            past_key_value.concat_by_dim("key_mean_scale", layer_idx, key_mean_scale, -3)
            past_key_value.concat_by_dim("key_mean_offset", layer_idx, key_mean_offset, -3)
            past_key_value.concat_by_dim("key_norm2", layer_idx, key_norm2, -2)

            past_key_value.concat_by_dim("quantized_value_cache", layer_idx, quantized_value_cache, -2)
            past_key_value.concat_by_dim("value_norm_idx", layer_idx, value_norm_idx, -2)
            past_key_value.concat_by_dim("value_norm_scale", layer_idx, value_norm_scale, -2)
            past_key_value.concat_by_dim("value_norm_offset", layer_idx, value_norm_offset, -2)
            past_key_value.concat_by_dim("value_mean_idx", layer_idx, value_mean_idx, -3)
            past_key_value.concat_by_dim("value_mean_scale", layer_idx, value_mean_scale, -3)
            past_key_value.concat_by_dim("value_mean_offset", layer_idx, value_mean_offset, -3)
            past_key_value.concat_by_dim("value_norm2", layer_idx, value_norm2, -2)

            past_key_value.direct_update("full_key_cache", layer_idx, remaining_key_cache)
            past_key_value.direct_update("full_value_cache", layer_idx, remaining_value_cache)
            past_key_value.direct_update("sin", layer_idx, remaining_sin)
            past_key_value.direct_update("cos", layer_idx, remaining_cos)

