from typing import Dict, Optional, List
import warnings

import torch.nn as nn
import torch
import torch.nn.functional as F

from transformers.modeling_outputs import CausalLMOutputWithPast
from transformers.modeling_attn_mask_utils import _prepare_4d_causal_attention_mask

from config import MoLConfig
from model import MixtureOfLoRALayer, MixtureOfLoRAProjection


class MixtureOfLoRAModel(nn.Module):
    def __init__(
        self,
        base_model,
        config: MoLConfig,
        llm_config
    ):
        super().__init__()
        self.base_model = base_model
        self.config = config
        self.llm_config = llm_config

        self.model_type = llm_config.model_type

        self._replace_layers_with_mol()

    def _replace_layers_with_mol(self):
        """
            Replaces standard transformer layers with our flexible MixtureOfLoRALayer.
            NOTE: The path to the layers may need adjustment for different base models.
        """
        try:
            # This path is common for models like Llama, Mistral, etc.
            target_layers = self.base_model.model.layers
            rotary_emb = None
            if self.model_type in ["llama", "mistral", "qwen2", "qwen3"]:
                # Get the single rotary_emb instance from the base model
                rotary_emb = self.base_model.model.rotary_emb
        except AttributeError:
            raise ValueError("Could not find transformer layers in `base_model.model.layers`. "
                             "Please adjust the path for your model architecture.")

        for i, layer in enumerate(target_layers):
            target_layers[i] = MixtureOfLoRALayer(
                base_layer=layer,
                rotary_emb=rotary_emb,
                config=self.config,
                model_type=self.model_type
            )

    def _set_embedding_trainability(self, trainable: bool):
        """
        Helper method to set the trainability of the embedding layer.
        """
        if hasattr(self.base_model.model, 'embed_tokens'):
            for param in self.base_model.model.embed_tokens.parameters():
                param.requires_grad = trainable
            self.trainable_embedding = trainable
        else:
            warnings.warn("Embedding layer not found at `self.base_model.model.embed_tokens`. "
                          "Cannot set embedding trainability.")


    def set_trainable_modalities(self, trainable_modalities: List[str]):
        """
            Public method to dynamically set which modalities are trainable.
            Example: `model.set_trainable_modalities(['image'])`
        """
        self.config.trainable_modalities = trainable_modalities
        for module in self.base_model.modules():
            if isinstance(module, MixtureOfLoRAProjection):
                module.set_trainable_modalities(trainable_modalities)

        # Set trainability for the embedding layer
        if (
            'embedding' in trainable_modalities or
            'embeddings' in trainable_modalities
        ):
            self._set_embedding_trainability(True)
            print("Embeddings in trainable parameters.")
        else:
            self._set_embedding_trainability(False) # Ensure it's off if not specified

        print(f"Set trainable modalities to: {trainable_modalities}")
        self.print_trainable_parameters()


    def switch_component_specificity(self, use_ln: bool, use_ffn: bool):
        """
            Switches between shared and modality-specific LN/FFN.
        """
        self.config.use_modality_specific_ln = use_ln
        self.config.use_modality_specific_ffn = use_ffn
        print(f"Modality-specific LN: {use_ln}, Modality-specific FFN: {use_ffn}")


    def forward(
        self,
        input_ids: Optional[torch.LongTensor] = None,
        attention_mask: Optional[torch.Tensor] = None,
        position_ids: Optional[torch.LongTensor] = None,
        past_key_values: Optional[List[torch.FloatTensor]] = None,
        inputs_embeds: Optional[torch.FloatTensor] = None,
        labels: Optional[torch.LongTensor] = None,
        use_cache: Optional[bool] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
        modality_mask: Optional[Dict[str, torch.Tensor]] = None,
        output_per_modality_loss: bool = False,
        condition_on_first_modality: bool = False,
        **kwargs
    ):
        """Forward pass with modality information"""

       # Warn if a mask is for an unsupported modality
        if modality_mask:
            for m in modality_mask:
                if m not in self.config.modalities:
                    warnings.warn(
                        f"Input mask for modality '{m}' was provided, but this modality "
                        f"is not in the model config: {self.config.modalities}. Ignoring."
                    )
        kwargs['modality_mask'] = modality_mask

        # Determine output settings from config if not provided
        output_attentions = (
            output_attentions if output_attentions is not None
            else self.base_model.config.output_attentions
            )
        output_hidden_states = (
            output_hidden_states if output_hidden_states is not None
            else self.base_model.config.output_hidden_states
        )
        use_cache = (
            use_cache if use_cache is not None
            else self.base_model.config.use_cache
            )
        return_dict = (
            return_dict if return_dict is not None
            else self.base_model.config.use_return_dict
        )

        # 1. Get input embeddings
        # The base_model is LlamaForCausalLM, its main components are in `.model`
        if inputs_embeds is None:
            inputs_embeds = self.base_model.model.embed_tokens(input_ids)

        # 2. Prepare attention mask for causal attention
        # We borrow the private method from the base model to create the causal mask
        if attention_mask is not None and attention_mask.ndim == 2:
            attention_mask = _prepare_4d_causal_attention_mask(
                attention_mask, inputs_embeds.shape[:2], inputs_embeds, 0
            )
        
        # 3. Manually iterate through the transformer layers
        hidden_states = inputs_embeds
        all_hidden_states = () if output_hidden_states else None
        all_self_attns = () if output_attentions else None
        next_decoder_cache = () if use_cache else None

        if position_ids is None:
            device = inputs_embeds.device
            past_key_values_length = 0
            if past_key_values is not None:
                past_key_values_length = past_key_values[0][0].shape[2]
            
            seq_length = inputs_embeds.shape[1]
            position_ids = torch.arange(
                past_key_values_length,
                seq_length + past_key_values_length,
                dtype=torch.long,
                device=device
            )
            position_ids = position_ids.unsqueeze(0)

        for idx, decoder_layer in enumerate(self.base_model.model.layers):
            if output_hidden_states:
                all_hidden_states += (hidden_states,)

            past_key_value = past_key_values[idx] if past_key_values is not None else None

            # Pass all necessary arguments, INCLUDING the modality_mask, to our custom layer.
            layer_outputs = decoder_layer(
                hidden_states,
                attention_mask=attention_mask,
                position_ids=position_ids,
                past_key_value=past_key_value,
                output_attentions=output_attentions,
                use_cache=use_cache,
                modality_mask=modality_mask,
            )
            hidden_states = layer_outputs[0]

            if use_cache:
                next_decoder_cache += (layer_outputs[2 if output_attentions else 1],)
            if output_attentions:
                all_self_attns += (layer_outputs[1],)

        # 4. Apply the final LayerNorm (not modality specific)
        hidden_states = self.base_model.model.norm(hidden_states)

        if output_hidden_states:
            all_hidden_states += (hidden_states,)

        next_cache = next_decoder_cache if use_cache else None

        # 5. Apply the language model head to get logits
        logits = self.base_model.lm_head(hidden_states)

        # 6. Calculate the loss if labels are provided (overall loss)
        loss = None
        per_modality_loss = {} # Initialize per_modality_loss here
        if labels is not None:
            # Shift so that tokens < n predict n
            shift_logits = logits[..., :-1, :].contiguous()
            shift_labels = labels[..., 1:].contiguous()

            final_shift_labels = shift_labels

            # we add the possibility to train on modality conditionned
            # prediction.
            if condition_on_first_modality and modality_mask is not None:
                # We'll modify the labels to ignore the first modality's tokens.
                masked_labels = shift_labels.clone()
                
                # Shift modality masks to align with shifted labels
                shifted_modality_masks = {
                    mod: mask[:, 1:].contiguous()
                    for mod, mask in modality_mask.items()
                }

                # Identify the first modality in each sequence and mask its labels
                for modality, s_mask in shifted_modality_masks.items():
                    # Check which sequences start with this modality (at the first token position)
                    starts_with_modality = s_mask[:, 0]
                    
                    if starts_with_modality.any():
                        # For the sequences that start with this modality, find all tokens
                        # belonging to it and set their labels to -100.
                        # `s_mask[starts_with_modality]` gives the full sequence masks for the relevant batch items.
                        tokens_to_ignore = s_mask[starts_with_modality]
                        
                        # Apply the mask to the corresponding rows in the labels tensor
                        masked_labels[starts_with_modality][tokens_to_ignore] = -100
                
                final_shift_labels = masked_labels
            
            # Flatten the tokens for overall loss calculation
            loss_fct = nn.CrossEntropyLoss()
            flat_shift_logits = shift_logits.view(-1, self.llm_config.vocab_size)
            flat_shift_labels = final_shift_labels.view(-1)
            
            # Enable model parallelism
            flat_shift_labels = flat_shift_labels.to(flat_shift_logits.device)
            loss = loss_fct(flat_shift_logits, flat_shift_labels)

            if output_per_modality_loss and modality_mask is not None:
                # Shift all modality masks to align with the shifted logits and labels
                shifted_modality_masks = {
                    mod: mask[:, 1:].contiguous()
                    for mod, mask in modality_mask.items()
                }

                # Iterate through each modality defined in your masks
                for modality, shifted_mask in shifted_modality_masks.items():
                    # Create a target mask on the shifted labels and shifted modality mask
                    # Ensure shift_labels is on the same device as shifted_mask
                    target_mask = (shift_labels != -100) & shifted_mask.to(shift_labels.device)

                    # Select the aligned logits and labels for the current modality
                    modality_logits = shift_logits[target_mask]
                    modality_labels = shift_labels[target_mask]

                    # Compute and store the loss, handle cases with no target tokens
                    if modality_labels.numel() > 0:
                        mod_loss = F.cross_entropy(modality_logits, modality_labels)
                        per_modality_loss[f'{modality}_loss'] = mod_loss.item()
                    else:
                        # No tokens of this modality to calculate loss on in this batch
                        per_modality_loss[f'{modality}_loss'] = 0.0

        if not return_dict:
            warnings.warn("When `output_per_modality_loss` is True, it is recommended "
                          "to set `return_dict=True` for cleaner output handling.")
            output = (logits,) + all_hidden_states + all_self_attns
            return ((loss,) + output + (per_modality_loss,)) \
                    if loss is not None else (output  + (per_modality_loss,))
        
        out = CausalLMOutputWithPast(
            loss=loss,
            logits=logits,
            past_key_values=next_cache,
            hidden_states=all_hidden_states,
            attentions=all_self_attns,
        )

        out.per_modality_loss = per_modality_loss

        return out


    def print_trainable_parameters(self):
        trainable_params = sum(p.numel() for p in self.parameters() if p.requires_grad)
        all_param = sum(p.numel() for p in self.parameters())
        print(
            f"Trainable params: {trainable_params:,} || "
            f"All params: {all_param:,} || "
            f"Trainable %: {100 * trainable_params / all_param:.4f}"
        )