import torch
import torch.nn as nn
from typing import Dict, Optional, Tuple, List
import math
from transformers.models.llama.modeling_llama import (
    LlamaRMSNorm, apply_rotary_pos_emb
)
from transformers.models.mistral.modeling_mistral import MistralRMSNorm
from transformers.models.qwen2.modeling_qwen2 import Qwen2RMSNorm

from model import CustomFFN
from model import CustomLayerNorm, CustomRMSNorm
from config import MoLConfig


class ModalitySpecificLoRA(nn.Module):
    def __init__(
        self,
        in_features: int,
        out_features: int,
        rank: int = 16,
        alpha: float = 32.0,
        dropout: float = 0.1,
        modality: str = "text",
        dtype: torch.dtype = torch.float32,
    ):
        super().__init__()
        self.rank = rank
        self.alpha = alpha
        self.modality = modality
        self.scaling = alpha / rank
        self.lora_A = nn.Linear(
            in_features,
            rank,
            bias=False,
            dtype=dtype
        )
        self.lora_B = nn.Linear(
            rank,
            out_features,
            bias=False,
            dtype=dtype
        )
        self.dropout = nn.Dropout(dropout)
        nn.init.kaiming_uniform_(self.lora_A.weight, a=math.sqrt(5))
        nn.init.zeros_(self.lora_B.weight)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.lora_B(self.dropout(self.lora_A(x))) * self.scaling


class MixtureOfLoRAProjection(nn.Module):
    """
    A projection layer (Q, K, V, or O) with modality-specific LoRA adapters.
    Now with controls for trainable modalities.
    """
    def __init__(self, base_layer: nn.Linear, config: MoLConfig, projection_type: str):
        super().__init__()
        self.config = config
        self.projection_type = projection_type
        self.base_layer = base_layer

        base_dtype = self.base_layer.weight.dtype

        for param in self.base_layer.parameters():
            param.requires_grad = False

        if self.config.baseline_lora:
            # Standard LoRA mode: Create a single, shared adapter
            self.lora_adapter = ModalitySpecificLoRA(
                in_features=base_layer.in_features,
                out_features=base_layer.out_features,
                rank=config.lora_rank,
                alpha=config.lora_alpha,
                dropout=config.lora_dropout,
                modality="shared",
                dtype=base_dtype,
            )
            self.lora_adapters = None
        else:
            # MoL mode: Create a dictionary of adapters
            self.lora_adapters = nn.ModuleDict()
            for modality in config.modalities:
                self.lora_adapters[modality] = ModalitySpecificLoRA(
                    in_features=base_layer.in_features,
                    out_features=base_layer.out_features,
                    rank=config.lora_rank,
                    alpha=config.lora_alpha,
                    dropout=config.lora_dropout,
                    modality=modality,
                    dtype=base_dtype,
                )
            self.lora_adapter = None # Ensure the new attribute is None
        
        self.set_trainable_modalities(config.trainable_modalities)

    def set_trainable_modalities(self, trainable_modalities: List[str]):
        """
        Freezes or unfreezes LoRA adapters based on the provided list.
        """
        if self.config.baseline_lora:
            # In standard LoRA mode, the single adapter is trainable if any
            # modality is set to be trainable.
            is_trainable = len(trainable_modalities) > 0
            if self.lora_adapter:
                for param in self.lora_adapter.parameters():
                    param.requires_grad = is_trainable
        else:
            # Original MoL logic
            if self.lora_adapters:
                for modality, adapter in self.lora_adapters.items():
                    is_trainable = modality in trainable_modalities
                    for param in adapter.parameters():
                        param.requires_grad = is_trainable

    def forward(
        self,
        hidden_states: torch.Tensor,
        modality_mask: Optional[Dict[str, torch.Tensor]] = None
    ) -> torch.Tensor:
        """
        Forward pass. This implementation is now robust to receiving a
        subset of modalities.
        """

        base_output = self.base_layer(hidden_states)

        if self.config.baseline_lora:
            # Standard LoRA mode: Apply the single adapter to all hidden states
            if self.lora_adapter:
                lora_output = self.lora_adapter(hidden_states)
                return base_output + lora_output
            else:
                return base_output # Should not happen if configured correctly
        else:
            # Original MoL mode
            if modality_mask is None:
                raise ValueError("Modality mask must be provided for Mixture-of-LoRA layers.")
            
            lora_output = torch.zeros_like(base_output)
            for modality, mask in modality_mask.items():
                if modality == 'text' and not self.config.text_lora_enabled:
                    continue
                
                if modality in self.lora_adapters and mask.any():
                    modality_tokens = hidden_states[mask]
                    if modality_tokens.numel() > 0:
                        lora_result = self.lora_adapters[modality](modality_tokens)
                        lora_output[mask] = lora_result
            
        return base_output + lora_output


class MixtureOfLoRA_FFN(nn.Module):
    def __init__(
            self,
            base_ffn: nn.Module,
            config: MoLConfig
        ):
        super().__init__()
        self.config = config
        self.base_ffn = base_ffn
        
        # Replace the original linear layers with our LoRA-adapted versions
        # We assume the FFN has `gate_proj`, `up_proj`, and `down_proj`
        self.gate_proj = MixtureOfLoRAProjection(base_ffn.gate_proj, config, 'gate_proj')
        self.up_proj = MixtureOfLoRAProjection(base_ffn.up_proj, config, 'up_proj')
        self.down_proj = MixtureOfLoRAProjection(base_ffn.down_proj, config, 'down_proj')
        
        self.activation_fn = base_ffn.act_fn # e.g. SiLU or GELU
        self.hidden_size = base_ffn.gate_proj.in_features
        self.intermediate_size = base_ffn.gate_proj.out_features

    def forward(
        self,
        hidden_states: torch.Tensor,
        modality_mask: Optional[Dict[str, torch.Tensor]] = None
    ) -> torch.Tensor:
        if modality_mask is None:
            raise ValueError("Modality mask must be provided for the MixtureOfLoRA_FFN module.")
        
        # Apply LoRA-adapted linear layers
        gate_proj_output = self.gate_proj(hidden_states, modality_mask)
        up_proj_output = self.up_proj(hidden_states, modality_mask)
        
        # Apply the activation function
        hidden_states = self.activation_fn(gate_proj_output) * up_proj_output
        
        # Apply the final LoRA-adapted linear layer
        return self.down_proj(hidden_states, modality_mask)


class MixtureOfLoRAAttention(nn.Module):
    """
    Mixture-of-LoRA Attention layer that implements the global attention mechanism
    with modality-specific LoRA adapters for Q, K, V, and O projections
    """
    
    def __init__(
        self,
        base_attention_layer: nn.Module,
        rotary_emb: nn.Module,
        config: MoLConfig,
    ):
        super().__init__()
        self.config = config
        self.q_proj = MixtureOfLoRAProjection(
            base_attention_layer.q_proj, config, 'q_proj'
        )
        self.k_proj = MixtureOfLoRAProjection(
            base_attention_layer.k_proj, config, 'k_proj'
        )
        self.v_proj = MixtureOfLoRAProjection(
            base_attention_layer.v_proj, config, 'v_proj'
        )
        self.o_proj = MixtureOfLoRAProjection(
            base_attention_layer.o_proj, config, 'o_proj'
        )
        self.rotary_emb = rotary_emb
        
        self.hidden_size = base_attention_layer.config.hidden_size
        self.num_heads = base_attention_layer.config.num_attention_heads
        self.num_key_value_heads = base_attention_layer.config.num_key_value_heads
        try:
            self.head_dim = base_attention_layer.config.head_dim
        except:
            self.head_dim = self.hidden_size // self.num_heads
        self.num_key_value_groups = self.num_heads // self.num_key_value_heads

        assert self.head_dim * self.num_heads == self.hidden_size, (
            f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}"
            f" and `num_heads`: {self.num_heads})."
        )

        # Copy any other necessary attributes
        for attr_name in dir(base_attention_layer):
            if not attr_name.startswith('_') and not hasattr(self, attr_name):
                attr_value = getattr(base_attention_layer, attr_name)
                if not callable(attr_value) and not isinstance(attr_value, nn.Module):
                    setattr(self, attr_name, attr_value)

    def _repeat_kv(self, hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
        """
        Repeats the key and value heads to match the query heads in GQA.
        From (batch, num_key_value_heads, seq_len, head_dim)
        To   (batch, num_attention_heads, seq_len, head_dim)
        """
        batch, num_key_value_heads, slen, head_dim = hidden_states.shape
        if n_rep == 1:
            return hidden_states
        hidden_states = hidden_states[:, :, None, :, :].expand(
            batch,
            num_key_value_heads,
            n_rep,
            slen,
            head_dim
        )
        return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)


    def forward(
        self,
        hidden_states: torch.Tensor,
        attention_mask: Optional[torch.Tensor] = None,
        position_ids: Optional[torch.LongTensor] = None,
        past_key_value: Optional[Tuple[torch.Tensor]] = None,
        output_attentions: bool = False,
        use_cache: bool = False,
        modality_mask: Optional[Dict[str, torch.Tensor]] = None,
        **kwargs
    ):
        if modality_mask is None:
            raise ValueError("Modality mask was not provided to the MixtureOfLoRAAttention module.")

        batch_size, seq_len, _ = hidden_states.size()

        query_states = self.q_proj(hidden_states, modality_mask)
        key_states = self.k_proj(hidden_states, modality_mask)
        value_states = self.v_proj(hidden_states, modality_mask)

        # --- Reshape using the correct number of heads for Q, K, and V ---
        query_states = query_states.view(
            batch_size, seq_len, self.num_heads, self.head_dim
            ).transpose(1, 2)
        key_states = key_states.view(
            batch_size, seq_len, self.num_key_value_heads, self.head_dim
            ).transpose(1, 2)
        value_states = value_states.view(
            batch_size, seq_len, self.num_key_value_heads, self.head_dim
            ).transpose(1, 2)
        
        # --- Apply Rotary Embeddings using the passed tensors ---
        if 'cos' in kwargs and 'sin' in kwargs:
            cos, sin = kwargs['cos'], kwargs['sin']
            query_states, key_states = apply_rotary_pos_emb(
                query_states,
                key_states,
                cos,
                sin,
                position_ids)

        if past_key_value is not None:
            if len(past_key_value) > 0:
                key_states = torch.cat([past_key_value[0], key_states], dim=2)
                value_states = torch.cat([past_key_value[1], value_states], dim=2)

        past_key_value = (key_states, value_states) if use_cache else None

        key_states = self._repeat_kv(key_states, self.num_key_value_groups)
        value_states = self._repeat_kv(value_states, self.num_key_value_groups)

        attn_weights = (
            torch.matmul(query_states, key_states.transpose(2, 3))
            / math.sqrt(self.head_dim)
        )

        # --- Add position bias if it exists (for T5-like models) ---
        if 'position_bias' in kwargs and kwargs['position_bias'] is not None:
            attn_weights += kwargs['position_bias']

        if attention_mask is not None:
             attn_weights = attn_weights + attention_mask

        attn_weights = nn.functional.softmax(
            attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
        attn_output = torch.matmul(attn_weights, value_states)

        attn_output = attn_output.transpose(1, 2).contiguous()
        attn_output = attn_output.reshape(batch_size, seq_len, self.hidden_size)
        attn_output = self.o_proj(attn_output, modality_mask)

        if not output_attentions:
            attn_weights = None

        return attn_output, attn_weights, past_key_value


class MixtureOfLoRALayer(nn.Module):
    """
    Complete MoL layer with attention, LayerNorm, and FFN components
    """
    
    def __init__(
        self,
        base_layer: nn.Module,
        rotary_emb: nn.Module,
        config: MoLConfig,
        model_type: str,
    ):
        super().__init__()
        self.config = config
        self.rotary_emb = rotary_emb
        self.model_type = model_type

        # Attention is always modality-specific via MoL Projections
        self.attention = MixtureOfLoRAAttention(
            base_attention_layer=base_layer.self_attn,
            rotary_emb=rotary_emb,
            config=config
        )

        # Store original components to fall back on
        self.original_input_layernorm = base_layer.input_layernorm
        self.original_post_attention_layernorm = base_layer.post_attention_layernorm
        self.original_mlp = base_layer.mlp

        hidden_size = self.original_input_layernorm.weight.shape[0]

        if self.config.use_modality_specific_ln:
            # Create modality-specific components
            if isinstance(self.original_input_layernorm, nn.LayerNorm):
                self.custom_input_layernorm = CustomLayerNorm(
                    normalized_shape=hidden_size,
                    config=config,
                    eps=self.original_input_layernorm.eps
                )
                self.custom_post_attention_layernorm = CustomLayerNorm(
                    normalized_shape=hidden_size,
                    config=config,
                    eps=self.original_post_attention_layernorm.eps
                )
            elif isinstance(self.original_input_layernorm, 
                           (LlamaRMSNorm, MistralRMSNorm, Qwen2RMSNorm)
                           ):
                self.custom_input_layernorm = CustomRMSNorm(
                    normalized_shape=hidden_size,
                    config=config,
                )
                self.custom_post_attention_layernorm = CustomRMSNorm(
                    normalized_shape=hidden_size,
                    config=config,
                )

        # Assuming FFN intermediate size can be found this way
        intermediate_size = base_layer.mlp.gate_proj.out_features

        if self.config.use_modality_specific_ffn:
            self.custom_mlp = CustomFFN(
                hidden_size=hidden_size,
                intermediate_size=intermediate_size,
                config=config
            )
            self.original_mlp = None
        elif self.config.use_lora_ffn:
            self.custom_mlp = MixtureOfLoRA_FFN(
                base_ffn=base_layer.mlp,
                config=config
            )
            self.original_mlp = None

    def forward(
        self,
        hidden_states: torch.Tensor,
        attention_mask: Optional[torch.Tensor] = None,
        position_ids: Optional[torch.LongTensor] = None,
        past_key_value: Optional[Tuple[torch.Tensor]] = None,
        output_attentions: bool = False,
        use_cache: bool = False,
        modality_mask: Optional[Dict[str, torch.Tensor]] = None,
        **kwargs
    ):
        if modality_mask is None:
            raise ValueError("Modality mask was not provided to the MixtureOfLoRALayer module.")

        residual = hidden_states
        
        if self.config.use_modality_specific_ln:
            ln_input = self.custom_input_layernorm(hidden_states, modality_mask)
        else:
            ln_input = self.original_input_layernorm(hidden_states)

        attention_kwargs = {}
        if self.model_type in ["llama", "mistral", "qwen2", "qwen3"]:
            cos, sin = self.rotary_emb(
                ln_input,
                position_ids=position_ids
            )
            attention_kwargs['cos'] = cos
            attention_kwargs['sin'] = sin

        else:
            raise NotImplementedError(
        f"Support for model type '{self.model_type}' has not been implemented. "
        f"Currently supported types are: ['llama', 'mistral', 'qwen2', 'qwen3']."
        f"Check whether chosen model is compatible with implementation."
    )

        attn_outputs = self.attention(
            ln_input,
            attention_mask=attention_mask,
            position_ids=position_ids,
            past_key_value=past_key_value,
            output_attentions=output_attentions,
            use_cache=use_cache,
            modality_mask=modality_mask,
            **attention_kwargs, # Pass the architecture-specific args
            **kwargs
        )
        attn_output = attn_outputs[0]
        # The rest of the outputs are attention_weights and past_key_value
        outputs = attn_outputs[1:]

        hidden_states = residual + attn_output

        residual = hidden_states
        if self.config.use_modality_specific_ln:
            ln_output = self.custom_post_attention_layernorm(hidden_states, modality_mask)
        else:
            ln_output = self.original_post_attention_layernorm(hidden_states)

        if self.config.use_modality_specific_ffn or self.config.use_lora_ffn:
            ffn_output = self.custom_mlp(ln_output, modality_mask)
        else:
            ffn_output = self.original_mlp(ln_output)

        hidden_states = residual + ffn_output

        # Structure the final output correctly
        if use_cache:
            outputs = (hidden_states,) + outputs
        else:
            outputs = (hidden_states,) + outputs[1:]

        return outputs
