import numpy as np
import torch
import torch.nn as nn
from torch.nn.modules.conv import _ConvNd
from transformers import AutoConfig

from quantization.utils.common_utils import to
from .llama_utils import QuantizedLlamaMLP, QuantizedLlamaAttention
from .mistral_utils import QuantizedMistralAttention, QuantizedMistralMLP
from .qwen2_utils import QuantizedQwen2MLP, QuantizedQwen2Attention
from .qwen3_utils import QuantizedQwen3MLP, QuantizedQwen3Attention
from quantization.transforms.transforms import BaseTransform

### Calibration utils and modules

LINEAR_LAYERS = (nn.Linear, _ConvNd)


class ForwardInterrupt(Exception):
    pass


class InputCollector(nn.Module):

    def __init__(self, module: nn.Module, cpu_offload: bool = False):
        super().__init__()
        self.module = module
        if hasattr(module, "attention_type"):
            self.attention_type = module.attention_type

        self.cpu_offload = cpu_offload
        self.input_args = []
        self.input_kwargs = []

    def forward(self, *input_args, **input_kwargs):
        """
        Assumes that the wrapped module has a single
        input that can reside in inputs or input_kwargs.
        """
        if self.cpu_offload:
            input_args = to(input_args, device="cpu")
            input_kwargs = to(input_kwargs, device="cpu")
        self.input_args.append(input_args)
        self.input_kwargs.append(input_kwargs)
        raise ForwardInterrupt
    
def get_number_of_rows_and_cols(layer):
    return layer.weight.shape[0], np.prod(layer.weight.shape[1:])

def get_mlp_layer(config: AutoConfig):
    if config.model_type == "llama":
        return QuantizedLlamaMLP
    elif config.model_type == "qwen3":
        return QuantizedQwen3MLP
    elif config.model_type == "qwen2":
        return QuantizedQwen2MLP
    elif config.model_type == "mistral":
        return QuantizedMistralMLP
    else:
        raise ValueError(f"Model type {config.model_type} not supported")

def get_attention_layer(config: AutoConfig):
    if config.model_type == "llama":
        return QuantizedLlamaAttention
    elif config.model_type == "qwen3":
        return QuantizedQwen3Attention
    elif config.model_type == "qwen2":
        return QuantizedQwen2Attention
    elif config.model_type == "mistral":
        return QuantizedMistralAttention
    else:
        raise ValueError(f"Model type {config.model_type} not supported")

def create_quantized_mlp(config: AutoConfig, layer_idx, weight_quantizer_kwargs=None, act_quantizer_kwargs=None,
                         transform1=None, transform2=None, norm_gamma=None):
    """
    Create a quantized MLP layer with correct transform parameter names based on model type.

    Args:
        config: Model configuration
        weight_quantizer_kwargs: Weight quantizer configuration
        act_quantizer_kwargs: Activation quantizer configuration
        transform1: First input transform (gate_up_in_transform for Llama/Qwen3)
        transform2: Second input transform (down_in_transform for Llama/Qwen3)

    Returns:
        Quantized MLP layer instance
    """
    mlp_class = get_mlp_layer(config)

    if config.model_type in ["llama", "qwen3", "qwen2"]:
        # Llama/Qwen3 use gate_up_in_transform and down_in_transform
        return mlp_class(
            config,
            layer_idx=layer_idx,
            weight_quantizer_kwargs=weight_quantizer_kwargs,
            act_quantizer_kwargs=act_quantizer_kwargs,
            gate_up_in_transform=transform1,
            down_in_transform=transform2,
            norm_gamma=norm_gamma
        )
    elif config.model_type == "mistral":
        # Mistral uses gate_up_in_transform and down_in_transform like Llama/Qwen3
        return mlp_class(
            config,
            layer_idx=layer_idx,
            weight_quantizer_kwargs=weight_quantizer_kwargs,
            act_quantizer_kwargs=act_quantizer_kwargs,
            gate_up_in_transform=transform1,
            down_in_transform=transform2
        )
    else:
        raise ValueError(f"Model type {config.model_type} not supported")


def load_quantized_modules_state_dict(
        block,
        quantized_attn,
        quantized_mlp,
        transformed_input_attn_norm,
        transformed_input_mlp_norm,
        model_config
):
    """
    Load state dict into quantized attention and MLP modules based on model architecture.

    Args:
        block: Transformer decoder block
        quantized_attn: Quantized attention module
        quantized_mlp: Quantized MLP module
        model_config: Model configuration
    """

    # Load attention state dict (same for all models)
    quantized_attn.load_state_dict(block.self_attn.state_dict(), strict=False)

    transformed_input_attn_norm.load_state_dict(block.input_layernorm.state_dict(), strict=False)
    transformed_input_mlp_norm.load_state_dict(block.post_attention_layernorm.state_dict(), strict=False)

    # Llama/Qwen3: mlp is a submodule
    assert hasattr(block, 'mlp'), "Expecting model blocks to have attribute 'mlp', model is unsupported"
    quantized_mlp.load_state_dict(block.mlp.state_dict(), strict=False)
    # Replace in block
    block.mlp = quantized_mlp

    # Replace attention
    block.self_attn = quantized_attn
    block.input_layernorm = transformed_input_attn_norm
    block.post_attention_layernorm = transformed_input_mlp_norm


class PreNormTransformWrapper(nn.Module):
    """
    Wrap an existing norm so that we apply an input transform BEFORE the norm.
    """
    def __init__(self, norm: nn.Module, in_transform: BaseTransform):
        super().__init__()
        self.norm = norm
        self.in_transform = in_transform

        self.norm.weight.data = torch.ones_like(self.norm.weight)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        if hasattr(self.in_transform, "bias"):
            x = self.in_transform(x) +  self.in_transform.bias.to(x.dtype)
        else:
            x = self.in_transform(x)
        return self.norm(x)