from typing import Optional, Callable, Dict, Any

import torch
import torch.nn as nn

from transformers.models.mistral.configuration_mistral import MistralConfig
from transformers.cache_utils import Cache
from transformers.processing_utils import Unpack
from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS
from transformers.modeling_flash_attention_utils import FlashAttentionKwargs
from transformers.models.mistral.modeling_mistral import apply_rotary_pos_emb, eager_attention_forward
from transformers.activations import ACT2FN

from quantization.qlinear import QLinear
from quantization.quantizer import Quantizer
from quantization.transforms.transforms import BaseTransform, IdentityTransform


class QuantizedMistralMLP(nn.Module):
    """Quantized Mistral MLP (SwiGLU: gate_proj, up_proj, down_proj) with optional input transformation."""

    def __init__(
        self,
        config: MistralConfig,
        layer_idx: int,
        weight_quantizer_kwargs: Dict[str, Any] | None = None,
        act_quantizer_kwargs: Dict[str, Any] | None = None,
        gate_up_in_transform: BaseTransform = IdentityTransform(),
        down_in_transform: BaseTransform = IdentityTransform(),
        norm_gamma: torch.Tensor = None,
    ):
        super().__init__()
        # gate & up share the same input activation quantizer (mirrors your LLaMA/Qwen utils)
        gate_up_act_quantizer = Quantizer(**act_quantizer_kwargs) if act_quantizer_kwargs else None

        mlp_bias = getattr(config, "mlp_bias", False)

        self.up_proj = QLinear(
            config.hidden_size,
            config.intermediate_size,
            bias=mlp_bias,
            weight_quantizer=Quantizer(**weight_quantizer_kwargs) if weight_quantizer_kwargs else None,
            act_quantizer=gate_up_act_quantizer, norm_gamma=norm_gamma
        )
        self.gate_proj = QLinear(
            config.hidden_size,
            config.intermediate_size,
            bias=mlp_bias,
            weight_quantizer=Quantizer(**weight_quantizer_kwargs) if weight_quantizer_kwargs else None,
            act_quantizer=gate_up_act_quantizer, norm_gamma=norm_gamma
        )
        self.down_proj = QLinear(
            config.intermediate_size,
            config.hidden_size,
            bias=mlp_bias,
            weight_quantizer=Quantizer(**weight_quantizer_kwargs) if weight_quantizer_kwargs else None,
            act_quantizer=Quantizer(**act_quantizer_kwargs) if act_quantizer_kwargs else None,
        )

        self.act_fn = ACT2FN[config.hidden_act]
        self.gate_up_in_transform = gate_up_in_transform
        self.down_in_transform = down_in_transform

        self._train_mode = True

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        up = self.up_proj(x, self.gate_up_in_transform)
        gate = self.gate_proj(x, self.gate_up_in_transform)

        x = self.act_fn(gate) * up

        # Transform intermediate before down projection
        x = self.down_in_transform(x)
        x = self.down_proj(x, self.down_in_transform)
        return x

    def fix_parametrization(self):
        self.up_proj.fix_parametrization(self.gate_up_in_transform)
        self.gate_proj.fix_parametrization(self.gate_up_in_transform)
        self.down_proj.fix_parametrization(self.down_in_transform)
        self._train_mode = False


class QuantizedMistralAttention(nn.Module):
    """Quantized Mistral attention with optional input/output transformation.

    Mirrors your QuantizedLlamaAttention/QuantizedQwen3Attention pattern:
      - apply qkv_in_transform on hidden_states before q/k/v projections
      - apply o_in_transform on attn_output before o_proj
      - fold transforms into QLinear via fix_parametrization()
    """

    def __init__(
        self,
        config: MistralConfig,
        layer_idx: int,
        weight_quantizer_kwargs: Dict[str, Any] | None = None,
        act_quantizer_kwargs: Dict[str, Any] | None = None,
        qkv_in_transform: BaseTransform = IdentityTransform(),
        o_in_transform: BaseTransform = IdentityTransform(),
        v_out_transform: BaseTransform = None,
        norm_gamma: torch.Tensor = None,

    ):
        super().__init__()
        self.config = config
        self.layer_idx = layer_idx

        self.hidden_size = config.hidden_size
        self.num_heads = config.num_attention_heads
        self.num_key_value_heads = config.num_key_value_heads
        head_dim = getattr(config, "head_dim", None)
        self.head_dim = head_dim if head_dim is not None else (config.hidden_size // config.num_attention_heads)
        self.num_key_value_groups = self.num_heads // self.num_key_value_heads

        self.attention_dropout = config.attention_dropout
        self.scaling = self.head_dim ** -0.5
        self.is_causal = True

        qkv_act_quantizer = Quantizer(**act_quantizer_kwargs) if act_quantizer_kwargs else None
        attn_bias = getattr(config, "attention_bias", False)

        self.q_proj = QLinear(
            self.hidden_size,
            self.num_heads * self.head_dim,
            bias=attn_bias,
            weight_quantizer=Quantizer(**weight_quantizer_kwargs) if weight_quantizer_kwargs else None,
            act_quantizer=qkv_act_quantizer, norm_gamma=norm_gamma
        )
        self.k_proj = QLinear(
            self.hidden_size,
            self.num_key_value_heads * self.head_dim,
            bias=attn_bias,
            weight_quantizer=Quantizer(**weight_quantizer_kwargs) if weight_quantizer_kwargs else None,
            act_quantizer=qkv_act_quantizer, norm_gamma=norm_gamma
        )
        self.v_proj = QLinear(
            self.hidden_size,
            self.num_key_value_heads * self.head_dim,
            bias=attn_bias,
            weight_quantizer=Quantizer(**weight_quantizer_kwargs) if weight_quantizer_kwargs else None,
            act_quantizer=qkv_act_quantizer, norm_gamma=norm_gamma
        )
        self.o_proj = QLinear(
            self.num_heads * self.head_dim,
            self.hidden_size,
            bias=attn_bias,
            weight_quantizer=Quantizer(**weight_quantizer_kwargs) if weight_quantizer_kwargs else None,
            act_quantizer=Quantizer(**act_quantizer_kwargs) if act_quantizer_kwargs else None,
        )

        self.qkv_in_transform = qkv_in_transform
        self.o_in_transform = o_in_transform
        self.v_out_transform = v_out_transform

        self._train_mode = True

    def forward(
        self,
        hidden_states: torch.Tensor,
        position_embeddings: tuple[torch.Tensor, torch.Tensor],
        attention_mask: Optional[torch.Tensor] = None,
        past_key_value: Optional[Cache] = None,
        cache_position: Optional[torch.LongTensor] = None,
        **kwargs: Unpack[FlashAttentionKwargs],
    ):
        input_shape = hidden_states.shape[:-1]
        hidden_shape = (*input_shape, -1, self.head_dim)

        query_states = self.q_proj(hidden_states, self.qkv_in_transform).view(hidden_shape).transpose(1, 2)
        key_states = self.k_proj(hidden_states, self.qkv_in_transform).view(hidden_shape).transpose(1, 2)
        value_states = self.v_proj(hidden_states, self.qkv_in_transform, out_transform=self.v_out_transform).view(hidden_shape).transpose(1, 2)

        cos, sin = position_embeddings
        query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)

        if past_key_value is not None:
            # sin/cos needed for RoPE cache update in HF
            cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
            key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)

        attention_interface: Callable = eager_attention_forward
        if self.config._attn_implementation != "eager":
            if self.config._attn_implementation == "sdpa" and kwargs.get("output_attentions", False):
                # Match HF behavior: SDPA doesn't return attn weights with output_attentions=True
                ValueError(
                    "`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. "
                    "Falling back to eager attention."
                )
            else:
                attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]

        attn_output, attn_weights = attention_interface(
            self,
            query_states,
            key_states,
            value_states,
            attention_mask,
            dropout=0.0 if not self.training else self.attention_dropout,
            scaling=self.scaling,
            **kwargs,
        )

        attn_output = attn_output.reshape(*input_shape, -1).contiguous()

        # Transform before output projection
        if self.v_out_transform is None:
            attn_output = self.o_in_transform(attn_output)
            attn_output = self.o_proj(attn_output, self.o_in_transform)
        else:
            attn_output = self.o_proj(attn_output, self.o_in_transform,
                                      reverse_r2_transform_dim=self.v_out_transform.block_size)
        return attn_output, attn_weights

    def fix_parametrization(self):
        # Fix layer parametrizations
        self.q_proj.fix_parametrization(self.qkv_in_transform)
        self.k_proj.fix_parametrization(self.qkv_in_transform)
        self.v_proj.fix_parametrization(self.qkv_in_transform, self.v_out_transform)
        if self.v_out_transform:
            self.o_proj.fix_parametrization(self.o_in_transform,
                                            reverse_r2_transform_dim=self.v_out_transform.block_size)
        else:
            self.o_proj.fix_parametrization(self.o_in_transform)

        self._train_mode = False
