from typing import List, Tuple, Optional, Callable
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
import copy

from transformers.cache_utils import Cache, QuantizedCache, QuantizedCacheProcessor
from typing import Any, Callable, Optional, Union

from .quant_utils import fake_quant_token_wise

from .utils import repeat_kv, apply_rotary_pos_emb

from transformers.integrations.flash_attention import flash_attention_forward

# ---------------------
# Attention layer class
# ---------------------

class QuarotAttention(nn.Module):
    """
    Quarot attention implementation : https://github.com/spcl/QuaRot/tree/main
    """
    def __init__(self,
                 base_attn: nn.Module,  # like LlamaAttention
                 mask_value: int = 0,
                 eps: float = 1e-12,
                 **kwargs: any) -> None:
        super().__init__()
        self.base_config = getattr(base_attn, 'config', None)
        self.attention_type = "quarot_attention"
        self.config = self.base_config
        if self.base_config is not None:
            self.base_config = self.base_config.to_dict()
        self.mask_value = mask_value
        self.eps = eps
        self.layer_idx = base_attn.layer_idx
        self.base_inference = False

        self.init_weights_(base_attn)
        
        self.scaling = base_attn.scaling
        self.is_causal=base_attn.is_causal



    def init_weights_(self, base_attn: nn.Module):
        """
        Initialize module layers, weights, positional dependencies, etc. 
        from original softmax attention layer (base_attn)
        """
        # Make other attributes accessible
        self.attention_dropout = 0  # We don't use dropout
        self.hidden_size = base_attn.config.hidden_size
        self.num_heads = base_attn.config.num_attention_heads
        self.head_dim = base_attn.head_dim
        self.num_key_value_heads = base_attn.config.num_key_value_heads
        self.num_key_value_groups = base_attn.num_key_value_groups

        self.q_shape = [self.num_heads, self.head_dim]
        self.k_shape = [self.num_key_value_heads, self.head_dim]
        self.v_shape = [self.num_key_value_heads, self.head_dim]

        # Copy original model projection layers
        self.q_proj = base_attn.q_proj
        self.k_proj = base_attn.k_proj
        self.v_proj = base_attn.v_proj
        self.o_proj = base_attn.o_proj
        if hasattr(base_attn, "k_norm"):
            # used in qwen3 models
            self.k_norm = base_attn.k_norm
            self.q_norm = base_attn.q_norm
        else:
            self.k_norm = None
            self.q_norm = None

        self.q_rotation = None
        self.k_rotation = None
        try:  # If wanting to use FA2 for ground-truth inference
            self._flash_attn_uses_top_left_mask = base_attn._flash_attn_uses_top_left_mask
        except AttributeError: 
            pass

        del base_attn  # We don't need to keep these around

    def repeat_tensors(self, *tensors: torch.Tensor) -> Tuple[torch.Tensor, ...]:
        # repeat each tensor independently
        return tuple(repeat_kv(t, self.num_key_value_groups) for t in tensors)

    def forward(self,
                hidden_states: torch.Tensor,
                position_embeddings: tuple[torch.Tensor, torch.Tensor],
                attention_mask: Optional[torch.Tensor] = None,
                cache_position: Optional[torch.LongTensor] = None,
                past_key_value: Optional[Cache] = None,
                **kwargs,
               ) -> tuple[torch.Tensor, torch.Tensor]:
        """
        Forward pass with the option to compute attention weights multiple ways
        if self.train_attention is True
        -> Consistent with HuggingFace Transformers for easy use with their pretrained models
        """
        input_shape = hidden_states.shape[:-1]
        hidden_shape = (*input_shape, -1, self.head_dim)

        if self.k_norm is not None:
            # qwen3 model
            query_states = self.q_norm(self.q_proj(hidden_states).view(hidden_shape)).transpose(1, 2)
            key_states = self.k_norm(self.k_proj(hidden_states).view(hidden_shape)).transpose(1, 2)
        else:
            query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2)
            key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2)
        
        value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)

        cos, sin = position_embeddings
        query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
        
        query_states = torch.matmul(query_states, self.q_rotation)
        key_states = torch.matmul(key_states, self.k_rotation)

        if past_key_value is not None:
            # sin and cos are specific to RoPE models; cache_position needed for the static cache
            cache_kwargs = {
                    "sin": sin, "cos": cos, "cache_position": cache_position,
            }
            key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)

        attn_output, attn_weights = flash_attention_forward(self, query_states, key_states, value_states, attention_mask = attention_mask, dropout=0.0, scaling = 1.0/math.sqrt(self.head_dim), **kwargs)
        y_true = attn_output.reshape(*input_shape, -1).contiguous()
        y_true = self.o_proj(y_true)

        return y_true, attn_weights

class QuarotAttentionCache(QuantizedCache):
    """
    Class for `past_key_values`
    -> Stores quantized kv cache with full precision recent tokens.
    -> Modified from transformers.cache_utils.QuantizedCache
    """
    def __init__(self, **kwargs) -> None:
         # Call the grandparent class's __init__ directly
        super(QuantizedCache, self).__init__(cache_processor=QuarotCacheProcessor, **kwargs)


class QuarotCacheProcessor(QuantizedCacheProcessor):
    """
    Quantized cache processor that uses kvlinc quantization backend
    """
    def __init__(
            self,
            cache: "Cache",
            backend: str = "kivi",
            nbits: int = 2,
            axis_key: int = 0,
            axis_value: int = 0,
            q_group_size: int = 64,
            residual_length: int = 128,
            compute_dtype: torch.dtype = torch.float16,
            device: str = "cpu",
    )-> None:
        """Initialize the kivi quantization processor.
        self.keys, self.values => full precision recent keys and values
        self._quantized_keys, self._quantized_values => quantized keys and values
        """
        super().__init__(
            cache, backend, nbits, axis_key, axis_value, q_group_size, residual_length, compute_dtype, device
        )
        if backend != "kivi":
            raise ValueError(f"KIVICacheProcessor only supports `kivi` backend, but got {backend}")

    def post_update(
            self, 
            cache: "Cache",
            k: torch.Tensor,
            v: torch.Tensor,
            layer_idx: int,
            cache_kwargs: Optional[dict[str, Any]] = None,
        ) -> tuple[torch.Tensor, torch.Tensor]:
            "Apply quantization after cache update"
            if len(cache) < layer_idx:
                raise ValueError("QuantizedCache does not support model usage where layers are skipped. Use DynamicCache.")

            # `k` is the content of the residual cache, after having been updated by DynamicLayer
            # quantization process is adopted from KIVI : https://arxiv.org/abs/2402.02750
            # keep recent tokens upto residual length in full precision and quantize the tokens once window is full.
            if self._is_quantized_length_zero(layer_idx):
                # prefill
                keys_to_return, values_to_return = k, v   

                # get tokens within residual_length window and outside the window
                if k.shape[-2] % self.residual_length != 0:
                    if k.shape[-2] < self.residual_length:
                        # nothing to quantize. residual window not full yet.
                        self._quantized_keys.append(None)
                        self._quantized_values.append(None)
                    else:
                        # keys within window are in fp cache
                        cache.layers[layer_idx].keys = k[:, :, -(k.shape[-2] % self.residual_length):, :].contiguous()
                        cache.layers[layer_idx].values = v[:, :, -(v.shape[-2] % self.residual_length):, :].contiguous()

                        # quantize outside window
                        k_outside_window =  k[:, :, :-(k.shape[-2] % self.residual_length), :].contiguous()
                        v_outside_window = v[:, :, :-(v.shape[-2] % self.residual_length), :].contiguous()
                        
                        self.erased_length = k_outside_window.shape[-2]

                        k_q = fake_quant_token_wise(k_outside_window,self.q_group_size,self.nbits)
                        v_q = fake_quant_token_wise(v_outside_window,self.q_group_size,self.nbits)
                        
                        self._quantized_keys.append(k_q)
                        self._quantized_values.append(v_q)

                else:
                    k_outside_window = k
                    v_outside_window = v

                    self.erased_length = k_outside_window.shape[-2]

                    # empty fp cache
                    cache.layers[layer_idx].keys = torch.zeros(
                        0,
                        dtype=k.dtype,
                        device=k.device,
                    )
                    cache.layers[layer_idx].values = torch.zeros(
                        0,
                        dtype=v.dtype,
                        device=v.device,
                    )

                    k_q = fake_quant_token_wise(k_outside_window,self.q_group_size,self.nbits)
                    v_q = fake_quant_token_wise(v_outside_window,self.q_group_size,self.nbits)
                    
                    self._quantized_keys.append(k_q)
                    self._quantized_values.append(v_q)
                
            else:
                #generating
                k_q = self._quantized_keys[layer_idx]
                v_q = self._quantized_values[layer_idx]
                if k_q is not None:
                    keys_to_return =  torch.cat([k_q, k], dim=-2)  
                    values_to_return = torch.cat([v_q, v], dim=-2)
                else:
                    keys_to_return = k
                    values_to_return = v

                if k.shape[-2] >= self.residual_length:
                    # quantize the window
                    k_q_new = fake_quant_token_wise(k, self.q_group_size, self.nbits) 
                    v_q_new = fake_quant_token_wise(v, self.q_group_size, self.nbits)

                    if k_q is not None:
                        self._quantized_keys[layer_idx] = torch.cat([k_q, k_q_new], dim=-2)
                        self._quantized_values[layer_idx] = torch.cat([v_q, v_q_new], dim = -2)
                    else:
                        self._quantized_keys[layer_idx] = k_q_new
                        self._quantized_values[layer_idx] = v_q_new

                    cache.layers[layer_idx].keys =  torch.zeros(
                        0,
                        dtype=k.dtype,
                        device=k.device,
                    )
                    cache.layers[layer_idx].values =  torch.zeros(
                        0,
                        dtype=v.dtype,
                        device=v.device,
                    )
                    
                    self.erased_length += k.shape[-2]
            
            return keys_to_return, values_to_return



