"""Interpreter Model for explaining energy-based model decisions.

This module implements the interpreter model that explains which input sentences
are most important for generating specific target output sentences in the EBM.
"""

from typing import List, Optional, Tuple

import torch
from torch import nn

from src.energy_model.models import CrossAttentionEncoder

from .config import InterpreterConfig
from .utils.interpreter_utils import (
    apply_hard_mask,
    apply_soft_mask,
    clone_encoder_from_ebm,
    extract_target_sentence,
    freeze_encoder_parameters,
    unfreeze_encoder_parameters,
    gumbel_softmax,
    masked_softmax,
)

# Constants
SOFT_MASK_THRESHOLD = 0.5  # Threshold for converting soft masks to boolean


class ImportanceScorer(nn.Module):
    """Module for computing importance scores from cross-attended embeddings.

    Implements INTERP-3: Pooling or Mapping of Sentence Representations.
    Uses direct MLP mapping from sentence embeddings to importance scores.
    """

    def __init__(
        self, d_model: int, config: InterpreterConfig, n_sentences: int = 16
    ) -> None:
        super().__init__()
        self.config = config

        # Build MLP layers based on config.mlp_layers
        layers = []
        if config.mlp_layers == 1:
            # Single linear layer: direct mapping from d_model to scalar score
            layers.append(nn.Linear(d_model, 1))
        else:
            # Multi-layer MLP: d_model -> hidden -> ... -> scalar score
            layers.extend([
                nn.Linear(d_model, config.mlp_hidden_dim),
                nn.ReLU() if config.activation_fn.lower() == "relu" else nn.GELU(),
                nn.Linear(config.mlp_hidden_dim, 1),
            ])

        self.scorer = nn.Sequential(*layers)

        # Initialize with small weights to avoid extreme logits initially
        for layer in self.scorer:
            if isinstance(layer, nn.Linear):
                nn.init.xavier_uniform_(layer.weight, gain=0.1)
                nn.init.zeros_(layer.bias)

    def forward(
        self,
        sentence_embeddings: torch.Tensor,
        attention_mask: Optional[torch.Tensor] = None,
    ) -> torch.Tensor:
        """Compute importance scores for input sentences.

        Args:
            sentence_embeddings: [batch_size, n_sentences, d_model]
            attention_mask: [batch_size, n_sentences] boolean mask

        Returns:
            importance_scores: [batch_size, n_sentences]

        """
        batch_size, n_sentences, d_model = sentence_embeddings.shape

        # Step 1: Calculate raw scores (logits) for each sentence
        logits = self.scorer(sentence_embeddings).squeeze(-1)

        # Step 2: Convert logits to importance scores using Gumbel Softmax
        # This provides differentiable discrete sampling with k-times sampling
        # for robust top-k important feature identification
        # The attention mask is now properly applied after aggregation
        
        if self.config.softmax_type == "gumbel":
            return gumbel_softmax(
                logits,
                temperature=self.config.gumbel_temperature,
                hard=self.config.gumbel_hard,
                dim=1,
                k=self.config.gumbel_k,
                attention_mask=attention_mask,
            )
        else:
            return masked_softmax(logits , dim=1)



class InterpreterModel(nn.Module):
    """Main Interpreter Model for explaining EBM decisions.

    This model implements the complete interpreter pipeline:
    - INTERP-1: Input and Target Embedding with Shared Pipeline
    - INTERP-2: Cross-Attention to Target Sentence
    - INTERP-3: Pooling or Mapping of Sentence Representations
    - INTERP-4: Masking and Perturbation of Input
    """

    def __init__(
        self, energy_model: nn.Module, config: InterpreterConfig, n_sentences: int = 16
    ) -> None:
        super().__init__()
        self.config = config
        self.n_sentences = n_sentences
        self.energy_model = energy_model

        # Get embedding dimension from energy model
        self.d_model = energy_model.text_encoder.embedder.get_embedding_dim()

        # INTERP-1: Setup encoders based on sharing strategy
        if config.encoder_sharing_strategy == "shared":
            # Share encoder weights with EBM
            self.input_encoder = energy_model.x_sa
            self.output_encoder = energy_model.y_sa
            self.text_encoder = energy_model.text_encoder
        elif config.encoder_sharing_strategy == "cloned":
            # Clone encoders for independent training
            self.input_encoder = clone_encoder_from_ebm(energy_model.x_sa)
            self.output_encoder = clone_encoder_from_ebm(energy_model.y_sa)
            self.text_encoder = clone_encoder_from_ebm(energy_model.text_encoder)

            # Unfreeze cloned encoders (they inherit frozen state from EBM)
            # Only unfreeze input/output encoders, keep text_encoder (BERT) frozen
            unfreeze_encoder_parameters(self.input_encoder)
            unfreeze_encoder_parameters(self.output_encoder)

        # Always freeze text encoder (BERT should remain frozen)
        for p in self.text_encoder.parameters():
            p.requires_grad = False
        self.text_encoder.eval()  # optional: keeps it in inference mode

        # Optionally freeze encoder parameters based on config
        if config.freeze_encoder_in_interp:
            freeze_encoder_parameters(self.input_encoder)
            freeze_encoder_parameters(self.output_encoder)

        # INTERP-2: Cross-attention from input to target
        self.cross_attention = CrossAttentionEncoder(
            d_model=self.d_model,
            n_layers=config.cross_attention_layers,
            n_heads=config.attention_heads,
            dropout=config.dropout_rate,
        )

        # INTERP-3: Importance scoring
        self.importance_scorer = ImportanceScorer(
            d_model=self.d_model, config=config, n_sentences=n_sentences
        )

    def get_target_embedding(
        self, output_texts: List[str], target_indices: Optional[List[int]] = None
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        """Get target sentence embeddings.

        Args:
            output_texts: List of full output texts
            target_indices: Indices of target sentences (if None, use last sentence)

        Returns:
            target_embeddings: [batch_size, d_model]
            target_masks: [batch_size, 1] boolean mask

        """
        if self.config.target_embedding == "isolated":
            # Extract and embed target sentences in isolation
            target_texts = []
            for i, output_text in enumerate(output_texts):
                target_idx = target_indices[i] if target_indices else -1
                if target_idx == -1:
                    target_idx = self.n_sentences - 1  # Last sentence
                target_sentence = extract_target_sentence(
                    output_text, target_idx, self.n_sentences
                )
                target_texts.append(target_sentence)

            # Encode target sentences
            target_embs, target_masks = self.text_encoder(target_texts)
            # Take the first (and only meaningful) sentence embedding
            return target_embs[:, 0, :], target_masks[:, 0:1]

        if self.config.target_embedding == "contextualized":
            # Embed target within full output context
            output_embs, output_masks = self.text_encoder(output_texts)
            output_proc = self.output_encoder(output_embs, output_masks)

            # Extract target sentence embeddings
            batch_size = len(output_texts)
            target_embeddings = []
            target_masks = []

            for i in range(batch_size):
                target_idx = target_indices[i] if target_indices else -1
                if target_idx == -1:
                    target_idx = self.n_sentences - 1

                # Ensure target_idx is within bounds
                target_idx = min(target_idx, output_proc.shape[1] - 1)

                target_embeddings.append(output_proc[i, target_idx, :])
                target_masks.append(output_masks[i, target_idx : target_idx + 1])

            return torch.stack(target_embeddings), torch.stack(target_masks)

        msg = f"Invalid target_embedding configuration: {self.config.target_embedding}"
        raise ValueError(msg)

    def forward(
        self,
        input_texts: List[str],
        output_texts: List[str],
        target_indices: Optional[List[int]] = None,
    ) -> Tuple[torch.Tensor, List[str], torch.Tensor]:
        """Forward pass of interpreter model - generates complete explanations.

        This implements INTERP-1 through INTERP-4: Complete explanation pipeline
        including masking and perturbation of input.

        Args:
            input_texts: List of input texts to explain
            output_texts: List of corresponding output texts
            target_indices: Indices of target sentences to explain

        Returns:
            importance_scores: [batch_size, n_sentences] importance scores
            masked_texts: List of masked input texts based on importance
            mask: [batch_size, n_sentences] boolean mask used for masking

        """
        batch_size = len(input_texts)

        # INTERP-1: Encode inputs and targets
        input_embs, input_masks = self.text_encoder(input_texts)
        input_proc = self.input_encoder(input_embs, input_masks)

        target_embs, target_masks = self.get_target_embedding(output_texts, target_indices)

        # INTERP-2: Cross-attention from input to target
        # Expand target embeddings to match input sequence length
        target_expanded = target_embs.unsqueeze(1).expand(
            batch_size, self.n_sentences, self.d_model
        )
        target_mask_expanded = target_masks.expand(batch_size, self.n_sentences)

        # Apply cross-attention (output queries attend to input)
        cross_attended = self.cross_attention(
            query=input_proc,
            key_value=target_expanded,
            query_mask=input_masks,
            kv_mask=target_mask_expanded,
        )

        # INTERP-3: Compute importance scores
        importance_scores = self.importance_scorer(cross_attended, input_masks)

        # INTERP-4: Generate concept-space mask using hard masking as fallback
        # (regardless of masking_type, we generate text representation and binary mask)
        masked_texts, concept_mask = apply_hard_mask(
            input_texts,
            importance_scores,
            method=self.config.hard_mask_method,
            top_k=self.config.top_k,
            threshold=self.config.threshold,
        )

        return importance_scores, masked_texts, concept_mask

    def _compute_energy_with_masked_embeddings(
        self,
        input_texts: List[str],
        output_texts: List[str],
        selected_mask: torch.Tensor,
        unselected_mask: torch.Tensor,
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        """Compute energies using soft-masked embeddings.

        This method applies soft masking directly to embeddings and then
        computes energy using the EBM's forward pass components.

        Args:
            input_texts: List of input texts
            output_texts: List of output texts
            selected_mask: Soft mask for selected sentences [batch_size, n_sentences]
            unselected_mask: Soft mask for unselected sentences [batch_size, n_sentences]

        Returns:
            Tuple of (selected_energies, unselected_energies)

        """
        # Step 1: Encode all texts
        batch_size = len(input_texts)
        all_texts = input_texts + output_texts
        all_emb, all_masks = self.energy_model.text_encoder(all_texts)

        # Split back into input and output components
        x_emb, x_mask = all_emb[:batch_size], all_masks[:batch_size]
        y_emb, y_mask = all_emb[batch_size:], all_masks[batch_size:]

        # Step 2: Apply self-attention to get contextualized representations
        x_proc = self.energy_model.x_sa(x_emb, x_mask)
        y_proc = self.energy_model.y_sa(y_emb, y_mask)

        # Step 3: Apply soft masking to input embeddings
        x_proc_selected = apply_soft_mask(
            x_proc, selected_mask, mask_method="multiply"
        )
        x_proc_unselected = apply_soft_mask(
            x_proc, unselected_mask, mask_method="multiply"
        )

        # Step 4: Apply cross-attention and compute energies
        y_cross_selected = self.energy_model.cross(
            query=y_proc,   # input side = queries
            key_value=x_proc_selected,        # output side = keys/values
            query_mask=y_mask,
            kv_mask=x_mask,
        )

        y_cross_unselected = self.energy_model.cross(
            query=y_proc,
            key_value=x_proc_unselected,
            query_mask=y_mask,
            kv_mask=x_mask,
        )

        # Step 5: Compute final energy scores
        selected_energies = self.energy_model.head(y_cross_selected, y_mask)
        unselected_energies = self.energy_model.head(y_cross_unselected, y_mask)

        return selected_energies, unselected_energies

    def evaluate(
        self,
        input_texts: List[str],
        output_texts: List[str],
        target_indices: Optional[List[int]] = None,
    ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
        """
        NOTE: This path re-encodes texts. For speed-critical training, prefer
        `evaluate_from_encoded(...)` with a dataloader-level encoding cache.
        """
            
        """Evaluate energy changes after explanation-based masking.

        This method generates explanations using forward() and then evaluates
        the energy changes by comparing selected vs unselected sentence energies.

        Returns:
            importance_scores: [batch_size, n_sentences] (soft masks)
            selected_energies: [batch_size] energies with selected sentences
            unselected_energies: [batch_size] energies with unselected sentences

        """
        # Get importance scores and concept-space mask from forward pass
        importance_scores, masked_texts, concept_mask = self.forward(
            input_texts, output_texts, target_indices
        )

        # Step 1: Create soft masks
        selected_mask = importance_scores  # Already soft normalized from Gumbel Softmax
        unselected_mask = 1.0 - selected_mask

        # Step 2: Determine target texts based on comparison_type
        if self.config.comparison_type == "targeted":
            # Extract target sentences only
            target_texts = []
            for i, output_text in enumerate(output_texts):
                target_idx = target_indices[i] if target_indices else -1
                target_sentence = extract_target_sentence(
                    output_text, target_idx, self.n_sentences
                )
                target_texts.append(target_sentence)
            comparison_outputs = target_texts
        else:  # "full"
            comparison_outputs = output_texts

        # Step 3: Compute energies using soft-masked embeddings
        if self.config.masking_type == "soft":
            (
                selected_energies,
                unselected_energies,
            ) = self._compute_energy_with_masked_embeddings(
                input_texts, comparison_outputs, selected_mask, unselected_mask
            )
        else:
            # Hard masking: use traditional approach with text masking
            # Build unselected texts using inverse of concept mask
            from src.energy_model.utils.energy_network import (
                normalize_sentences,
                semantic_sentence_split,
            )

            # Ensure boolean dtype
            concept_mask = concept_mask.bool()
            inverse_mask = ~concept_mask

            unselected_texts = []
            for i in range(len(input_texts)):
                original_text = input_texts[i]
                sentences = semantic_sentence_split(original_text)
                sentences = normalize_sentences(sentences, self.n_sentences)

                unselected_sentences = []
                for j, sentence in enumerate(sentences):
                    if j < inverse_mask.shape[1] and inverse_mask[i, j]:
                        unselected_sentences.append(sentence)

                unselected_text = ". ".join([s for s in unselected_sentences if s.strip()])
                if not unselected_text.strip():
                    unselected_text = "[NO_UNSELECTED_SENTENCES]"
                unselected_texts.append(unselected_text)

            # Compute energies using text-based masking
            selected_energies = self.energy_model(masked_texts, comparison_outputs)
            unselected_energies = self.energy_model(unselected_texts, comparison_outputs)

        return importance_scores, selected_energies, unselected_energies
        
    def evaluate_from_encoded(
        self,
        x_encoded: Tuple[torch.Tensor, torch.Tensor],   # (x_emb, x_mask)  [B, S, D], [B, S]
        y_encoded: Tuple[torch.Tensor, torch.Tensor],   # (y_emb, y_mask)  [B, S, D], [B, S]
        target_indices: Optional[List[int]] = None,
        input_texts: Optional[List[str]] = None,        # only needed if masking_type == "hard"
        output_texts: List[str] = None,       # only needed if masking_type == "hard"
    ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
        """
        Fast evaluation that skips tokenization/encoding by using pre-encoded tensors.
        Returns:
            importance_scores: [B, S]
            selected_energies: [B]
            unselected_energies: [B]
        """
        x_emb, x_mask = x_encoded
        y_emb, y_mask = y_encoded
        B, S, D = x_emb.shape

        # --- Build target embeddings from encoded outputs ---
        if self.config.target_embedding == "contextualized":
            target_embs, target_masks = self._get_target_from_encoded_contextualized(
                y_emb, y_mask, target_indices
            )
        else:
            # With "isolated", we need raw sentences to re-encode targets in isolation.
            raise NotImplementedError(
                "evaluate_from_encoded currently supports target_embedding='contextualized'. "
                "Use evaluate() or switch to contextualized for full speed-up."
            )

        # --- Cross-attention: input (queries) -> target (key/value) ---
        input_proc = self.input_encoder(x_emb, x_mask)

        target_expanded = target_embs.unsqueeze(1).expand(B, self.n_sentences, self.d_model)
        target_mask_expanded = target_masks.expand(B, self.n_sentences)

        cross_attended = self.cross_attention(
            query=input_proc,
            key_value=target_expanded,
            query_mask=x_mask,
            kv_mask=target_mask_expanded,
        )

        # --- Importance scores (Gumbel/soft) ---
        importance_scores = self.importance_scorer(cross_attended, x_mask)
        selected_mask = importance_scores
        unselected_mask = 1.0 - selected_mask

        # --- Energy evaluation ---
        if self.config.masking_type == "soft":
            # Fully encoded + fast path
            selected_energies, unselected_energies = self._compute_energy_with_masked_from_encoded(
                x_emb, x_mask, y_emb, y_mask, selected_mask, unselected_mask, output_texts
            )
        else:
            # Hard masking still needs text reconstruction to build masked strings
            if input_texts is None or (self.config.comparison_type == "targeted" and output_texts is None):
                raise ValueError(
                    "Hard masking in evaluate_from_encoded requires input_texts (and output_texts if targeted). "
                    "Pass them in, or switch masking_type='soft' for full speed-up."
                )
            # Fallback: use existing (string-based) path for hard masking
            importance_scores_text, selected_energies, unselected_energies = self.evaluate(
                input_texts=input_texts,
                output_texts=output_texts if output_texts is not None else input_texts,
                target_indices=target_indices,
            )
            # Keep the fast-path importance scores (computed from encoded), but energies from fallback
            importance_scores = importance_scores  # keep
            # selected_energies / unselected_energies are from fallback

        return importance_scores, selected_energies, unselected_energies

    def forward_from_encoded(
        self,
        x_encoded: Tuple[torch.Tensor, torch.Tensor],   # (x_emb, x_mask)  [B, S, D], [B, S]
        y_encoded: Tuple[torch.Tensor, torch.Tensor],   # (y_emb, y_mask)  [B, S, D], [B, S]
        target_indices: Optional[List[int]] = None,
        input_texts: Optional[List[str]] = None,        # only needed if masking_type == "hard"
        output_texts: List[str] = None,       # only needed if masking_type == "hard"
    ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
        """
        Fast evaluation that skips tokenization/encoding by using pre-encoded tensors.
        Returns:
            importance_scores: [B, S]
            selected_energies: [B]
            unselected_energies: [B]
        """
        x_emb, x_mask = x_encoded
        y_emb, y_mask = y_encoded
        B, S, D = x_emb.shape

        # --- Build target embeddings from encoded outputs ---
        if self.config.target_embedding == "contextualized":
            target_embs, target_masks = self._get_target_from_encoded_contextualized(
                y_emb, y_mask, target_indices
            )
        else:
            # With "isolated", we need raw sentences to re-encode targets in isolation.
            raise NotImplementedError(
                "evaluate_from_encoded currently supports target_embedding='contextualized'. "
                "Use evaluate() or switch to contextualized for full speed-up."
            )

        # --- Cross-attention: input (queries) -> target (key/value) ---
        input_proc = self.input_encoder(x_emb, x_mask)

        target_expanded = target_embs.unsqueeze(1).expand(B, self.n_sentences, self.d_model)
        target_mask_expanded = target_masks.expand(B, self.n_sentences)

        cross_attended = self.cross_attention(
            query=input_proc,
            key_value=target_expanded,
            query_mask=x_mask,
            kv_mask=target_mask_expanded,
        )

        # --- Importance scores (Gumbel/soft) ---
        importance_scores = self.importance_scorer(cross_attended, x_mask)
        return importance_scores

    def explain(
        self,
        input_texts: List[str],
        output_texts: List[str],
        target_indices: Optional[List[int]] = None,
    ) -> Tuple[torch.Tensor, List[str], torch.Tensor]:
        """Generate explanations without energy evaluation.

        This is a convenience method that calls forward() for inference-only usage
        where energy evaluation is not needed.

        Args:
            input_texts: List of input texts to explain
            output_texts: List of corresponding output texts
            target_indices: Indices of target sentences to explain

        Returns:
            importance_scores: [batch_size, n_sentences] importance scores
            masked_texts: List of masked input texts based on importance
            concept_mask: [batch_size, n_sentences] concept-space mask used for masking

        """
        return self.forward(input_texts, output_texts, target_indices)
    
    def _get_target_from_encoded_contextualized(
        self,
        output_embs: torch.Tensor,   # [B, S, D]
        output_masks: torch.Tensor,  # [B, S]
        target_indices: Optional[List[int]] = None,
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        """
        Contextualized target extraction when output texts are ALREADY encoded.
        Returns:
            target_embeddings: [B, D]
            target_masks:      [B, 1] (boolean)
        """
        # Apply the EBM's output self-attention first (same as get_target_embedding)
        output_proc = self.output_encoder(output_embs, output_masks)

        batch_size, seq_len, _ = output_proc.shape
        target_embeddings = []
        target_masks = []

        for i in range(batch_size):
            t_idx = (target_indices[i] if target_indices else -1)
            if t_idx == -1:
                t_idx = min(self.n_sentences - 1, seq_len - 1)
            else:
                t_idx = min(t_idx, seq_len - 1)

            target_embeddings.append(output_proc[i, t_idx, :])
            # shape [1] -> keep as [1] then stack to [B, 1]
            target_masks.append(output_masks[i, t_idx:t_idx+1])

        return torch.stack(target_embeddings, dim=0), torch.stack(target_masks, dim=0)


    def _compute_energy_with_masked_from_encoded(
        self,
        x_emb: torch.Tensor, x_mask: torch.Tensor,   # [B, S, D], [B, S]
        y_emb: torch.Tensor, y_mask: torch.Tensor,   # [B, S, D], [B, S]
        selected_mask: torch.Tensor,                 # [B, S] (soft)
        unselected_mask: torch.Tensor,               # [B, S] (soft)
        output_texts: List[str], 
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        """
        Soft-masking energy computation using ALREADY-ENCODED embeddings.
        Mirrors _compute_energy_with_masked_embeddings but avoids re-encoding.
        """
        # Self-attend to get contextualized representations
        x_proc = self.energy_model.x_sa(x_emb, x_mask)
        y_proc = self.energy_model.y_sa(y_emb, y_mask)

        # Apply soft masks to the input side
        x_proc_selected = apply_soft_mask(x_proc, selected_mask, mask_method="multiply")
        x_proc_unselected = apply_soft_mask(x_proc, unselected_mask, mask_method="multiply")

        # Cross-attend into the (masked) inputs
        y_cross_selected = self.energy_model.cross(
            query=y_proc, key_value=x_proc_selected, query_mask=y_mask, kv_mask=x_mask
        )
        y_cross_unselected = self.energy_model.cross(
            query=y_proc, key_value=x_proc_unselected, query_mask=y_mask, kv_mask=x_mask
        )

        # Head to scalar energies
        selected_energies = self.energy_model.head(y_cross_selected, y_mask)
        unselected_energies = self.energy_model.head(y_cross_unselected, y_mask)
        return selected_energies, unselected_energies