import math
from typing import Tuple
from typing import Optional

import torch
from torch import nn
from fast_hadamard_transform import hadamard_transform
from src.models.cache import DynamicCache
from src.utils import repeat_kv
from transformers.models.llama.modeling_llama import apply_rotary_pos_emb


class BaseQuantizer(nn.Module):
    residual_size: int = 1
    post_rope: bool = True
    def quantize(self, x: torch.Tensor, encode_type: str):
        return x
    
    def dequantize(self, x: torch.Tensor, *args):
        return x

    def forward(self, x, encode_type: str):
        return self.dequantize(self.quantize(x, encode_type))

    def forward_quant(self,
                      query_states: torch.FloatTensor,
                      key_states: torch.FloatTensor,
                      value_states: torch.FloatTensor,
                      position_embeddings: Tuple[torch.FloatTensor, torch.FloatTensor],
                      past_key_value: Optional[DynamicCache],
                      cache_position: torch.LongTensor):
        cos, sin, _, _ = position_embeddings
        if self.post_rope:
            query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
        if hasattr(self, "hadamard") and self.hadamard:
            query_states = hadamard_transform(query_states, 1/math.sqrt(query_states.shape[-1]))
            key_states = hadamard_transform(key_states, 1/math.sqrt(key_states.shape[-1]))

        key_states = self.forward(key_states, "k")
        value_states = self.forward(value_states, "v")
        if hasattr(self, "hadamard") and self.hadamard:
            query_states = hadamard_transform(query_states, 1/math.sqrt(query_states.shape[-1]))
            key_states = hadamard_transform(key_states, 1/math.sqrt(key_states.shape[-1]))

        if not self.post_rope:
            query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)

        return query_states, key_states, value_states, True

    def self_attn(self,
                  query_states: torch.Tensor,
                  past_key_value: DynamicCache,
                  layer_idx: int,
                  attention_mask: torch.Tensor,
                  scaling: float,
                  num_key_value_groups: int,
                  **kwargs):
        key_states, value_states = past_key_value[layer_idx]

        key_states = repeat_kv(key_states, num_key_value_groups)
        value_states = repeat_kv(value_states, num_key_value_groups)

        attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) * scaling
        if attention_mask is not None:
            causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
            attn_weights = attn_weights + causal_mask

        attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
        attn_output = torch.matmul(attn_weights, value_states)
        attn_output = attn_output.transpose(1, 2).contiguous()
        return attn_output, attn_weights
    
    def update_cache(self, past_key_value: DynamicCache, layer_idx: int, is_prefill: bool):
        if hasattr(self, "residual_size"):
            residual_size = self.residual_size
        else:
            residual_size = 1

        curr_key_cache, curr_value_cache = past_key_value[layer_idx]
        if curr_key_cache.shape[-2] < residual_size:
            return

        if is_prefill:
            remainder = curr_key_cache.shape[-2] % residual_size
            if remainder == 0:
                target_slice = (slice(None))
            else:
                target_slice = (..., slice(None, -remainder), slice(None))
        else:
            target_slice = (..., slice(-residual_size, None), slice(None))

        if (curr_key_cache.shape[-2] % residual_size == 0) or is_prefill:
            assert curr_key_cache.shape[-2] == curr_value_cache.shape[-2], "the length of key and value cache should be identical"
            target_key_states = curr_key_cache[target_slice]
            target_value_states = curr_value_cache[target_slice]

            if hasattr(self, "hadamard") and self.hadamard:
                target_key_states = hadamard_transform(target_key_states, 1/math.sqrt(target_key_states.shape[-1]))
            quantized_key_states = self.forward(target_key_states, "k")
            quantized_value_states = self.forward(target_value_states, "v")

            if hasattr(self, "hadamard") and self.hadamard:
                quantized_key_states = hadamard_transform(quantized_key_states, 1/math.sqrt(quantized_key_states.shape[-1]))

            curr_key_cache[target_slice] = quantized_key_states
            curr_value_cache[target_slice] = quantized_value_states
