"""Utility functions for the interpreter model.

This module provides utility functions for masking, normalization,
and perturbation operations used in the interpreter model.
"""

import copy
from typing import List, Tuple

import torch


def gumbel_softmax(
    logits: torch.Tensor,
    temperature: float = 1.0,
    hard: bool = False,
    dim: int = -1,
    k: int = 4,
    attention_mask: torch.Tensor = None,
) -> torch.Tensor:
    """Sample from Gumbel Softmax distribution for differentiable discrete sampling.

    This function performs k-times sampling to better identify top-k important features.
    It samples k different Gumbel noise distributions, computes k softmax distributions,
    sums them up, and normalizes to get a more robust importance distribution.

    Args:
        logits: [..., num_classes] unnormalized log probabilities
        temperature: Temperature parameter for Gumbel-Softmax
        hard: If True, returned samples will be one-hot, otherwise soft
        dim: Dimension along which to apply Gumbel-Softmax
        k: Number of times to sample Gumbel noise (default: 4)
        attention_mask: Boolean mask to zero out padded positions (default: None)

    Returns:
        Aggregated importance distribution from k-times sampling with proper masking

    """
    # Perform k-times sampling
    aggregated_distributions = []
    eps = 1e-8  # Epsilon for numerical stability

    for _ in range(k):
        # Sample Gumbel noise for each iteration
        gumbels = (
            -torch.empty_like(logits, memory_format=torch.legacy_contiguous_format)
            .exponential_()
            .log()
        )  # ~Gumbel(0,1)

        # Apply temperature scaling
        gumbels = (logits + gumbels) / temperature  # ~Gumbel(logits, temperature)
        y_soft = gumbels.softmax(dim)

        aggregated_distributions.append(y_soft)

    # Sum all k distributions and normalize
    summed_distribution = torch.stack(aggregated_distributions, dim=0).sum(dim=0)
    normalized_distribution = summed_distribution / k  # Average the distributions

    # Apply attention mask to ensure padded positions get exactly zero probability
    if attention_mask is not None:
        # Convert boolean mask to float and expand to match distribution dimensions
        mask_float = attention_mask.float()
        if mask_float.dim() < normalized_distribution.dim():
            # Expand mask to match distribution dimensions
            for _ in range(normalized_distribution.dim() - mask_float.dim()):
                mask_float = mask_float.unsqueeze(-1)

        # Zero out masked positions
        masked_distribution = normalized_distribution * mask_float

        # Renormalize to ensure valid probabilities (sum to 1 over valid positions)
        distribution_sum = masked_distribution.sum(dim=dim, keepdim=True)

        # Handle edge case where all positions are masked
        valid_positions = mask_float.sum(dim=dim, keepdim=True)
        uniform_fallback = mask_float / torch.clamp(valid_positions, min=1.0)

        # Use renormalized distribution where possible,
        # fallback to uniform for all-masked cases
        is_valid = distribution_sum > eps
        normalized_distribution = torch.where(
            is_valid,
            masked_distribution / torch.clamp(distribution_sum, min=eps),
            uniform_fallback,
        )

    if hard:
        # Straight through estimator on the final masked distribution
        index = normalized_distribution.max(dim, keepdim=True)[1]
        y_hard = torch.zeros_like(
            logits, memory_format=torch.legacy_contiguous_format
        ).scatter_(dim, index, 1.0)
        ret = y_hard - normalized_distribution.detach() + normalized_distribution
    else:
        ret = normalized_distribution

    return ret

def masked_softmax(logits: torch.Tensor, dim: int = 1, attention_mask: torch.Tensor = None) -> torch.Tensor:
    """
    Apply softmax to logits while masking padded (empty) positions using attention_mask.

    Args:
        logits: [..., num_classes] unnormalized log probabilities
        dim: Dimension along which to apply softmax
        attention_mask: Boolean mask (1 for valid positions, 0 for padding)

    Returns:
        Softmax probabilities with padded positions zeroed out and remaining probabilities renormalized
    """
    # Standard softmax
    probs = torch.softmax(logits, dim=dim)
    
    if attention_mask is not None:
        mask_float = attention_mask.float()
        if mask_float.dim() < probs.dim():
            for _ in range(probs.dim() - mask_float.dim()):
                mask_float = mask_float.unsqueeze(-1)
        
        # Zero out padded positions
        masked_probs = probs * mask_float
        
        # Renormalize over valid positions
        distribution_sum = masked_probs.sum(dim=dim, keepdim=True)
        eps = 1e-8
        normalized_probs = masked_probs / torch.clamp(distribution_sum, min=eps)
        
        probs = normalized_probs

    return probs


def sparsemax(input_tensor: torch.Tensor, dim: int = -1) -> torch.Tensor:
    """Sparsemax activation function for sparse attention.

    Args:
        input_tensor: Input tensor to apply sparsemax to
        dim: Dimension along which to apply sparsemax

    Returns:
        Sparse probability distribution

    """
    # Sort input in descending order
    sorted_input, _ = torch.sort(input_tensor, dim=dim, descending=True)

    # Compute cumulative sums
    cumsum = torch.cumsum(sorted_input, dim=dim)

    # Create range tensor
    range_tensor = torch.arange(
        1, input_tensor.size(dim) + 1, device=input_tensor.device, dtype=input_tensor.dtype
    )

    # Expand range tensor to match input dimensions
    shape = [1] * input_tensor.dim()
    shape[dim] = -1
    range_tensor = range_tensor.view(shape).expand_as(sorted_input)

    # Compute support condition
    support = sorted_input - (cumsum - 1) / range_tensor > 0

    # Find last true element in support (support size)
    support_size = support.sum(dim=dim, keepdim=True)

    # Compute tau (threshold)
    tau_indices = (support_size - 1).long()

    # Gather cumulative sums at tau positions
    tau_cumsum = torch.gather(cumsum, dim, tau_indices)
    tau = (tau_cumsum - 1) / support_size.float()

    # Apply sparsemax transformation
    return torch.clamp(input_tensor - tau, min=0)


def apply_hard_mask(
    texts: List[str],
    importance_scores: torch.Tensor,
    method: str,
    top_k: int = 8,
    threshold: float = 0.5,
    placeholder_token: str = ""
) -> Tuple[List[str], torch.Tensor]:
    """Apply hard masking to input texts based on importance scores.

    Args:
        texts: List of input text strings
        importance_scores: Tensor of shape [batch_size, n_sentences]
        method: Type of masking ("top_k" or "threshold")
        top_k: Number of top sentences to keep (for top_k method)
        threshold: Threshold value for importance (for threshold method)
        placeholder_token: Token to use as placeholder for masked sentences

    Returns:
        Tuple of (masked_texts, mask_tensor)

    """
    batch_size, n_sentences = importance_scores.shape
    masked_texts = []

    # Create binary mask
    if method == "top_k":
        # Ensure top_k doesn't exceed available sentences
        effective_k = min(top_k, n_sentences)
        _, top_indices = torch.topk(importance_scores, k=effective_k, dim=1)
        mask = torch.zeros_like(importance_scores, dtype=torch.bool)
        mask.scatter_(1, top_indices, 1)
    elif method == "threshold":
        mask = importance_scores >= threshold
    else:
        error_msg = f"Unknown masking method: {method}. Must be 'top_k' or 'threshold'."
        raise ValueError(error_msg)    # Apply mask to texts
    for i in range(batch_size):
        original_text = texts[i]

        # Use semantic sentence splitting to match importance scores calculation
        from src.energy_model.utils.energy_network import (
            normalize_sentences,
            semantic_sentence_split,
        )

        sentences = semantic_sentence_split(original_text)
        sentences = normalize_sentences(sentences, n_sentences)

        # Apply mask
        masked_sentences = []
        for j, sentence in enumerate(sentences):
            if j < mask.shape[1] and mask[i, j]:
                masked_sentences.append(sentence)
            elif placeholder_token:
                masked_sentences.append(placeholder_token)

        # Join back to text
        masked_text = ". ".join([s for s in masked_sentences if s != placeholder_token])
        masked_texts.append(masked_text)

    return masked_texts, mask


def apply_soft_mask(
    sentence_embeddings: torch.Tensor,
    importance_scores: torch.Tensor,
    mask_method: str = "multiply",
) -> torch.Tensor:
    """Apply soft masking to sentence embeddings.

    Args:
        sentence_embeddings: Tensor of shape [batch_size, n_sentences, d_model]
        importance_scores: Tensor of shape [batch_size, n_sentences]
            (already normalized soft mask)
        mask_method: Method for applying soft mask

    Returns:
        Soft-masked sentence embeddings

    """
    if mask_method == "multiply":
        return sentence_embeddings * importance_scores.unsqueeze(-1)

    if mask_method == "interpolate":
        # interpolate method
        zero_embeddings = torch.zeros_like(sentence_embeddings)
        return (
            importance_scores.unsqueeze(-1) * sentence_embeddings
            + (1 - importance_scores.unsqueeze(-1)) * zero_embeddings
        )

    available_methods = "'multiply', 'interpolate'"
    msg = f"Unknown soft mask method: {mask_method}. Available methods: {available_methods}"
    raise ValueError(msg)


def extract_target_sentence(output_text: str, target_index: int, n_sentences: int) -> str:
    """Extract target sentence from output text.

    Args:
        output_text: Full output text
        target_index: Index of target sentence to extract
        n_sentences: Expected number of sentences

    Returns:
        Target sentence string

    """
    from src.energy_model.utils.energy_network import (
        normalize_sentences,
        semantic_sentence_split,
    )

    sentences = semantic_sentence_split(output_text)
    sentences = normalize_sentences(sentences, n_sentences)

    if target_index < len(sentences):
        return sentences[target_index]

    # Return last sentence if index out of bounds
    return sentences[-1] if sentences else ""


def clone_encoder_from_ebm(ebm_encoder: torch.nn.Module) -> torch.nn.Module:
    """Clone encoder from EBM for independent training.

    Args:
        ebm_encoder: EBM encoder module to clone

    Returns:
        Cloned encoder with copied weights

    """
    return copy.deepcopy(ebm_encoder)


def freeze_encoder_parameters(encoder: torch.nn.Module) -> None:
    """Freeze all parameters in an encoder.

    Args:
        encoder: Encoder module to freeze

    """
    for param in encoder.parameters():
        param.requires_grad = False


def unfreeze_encoder_parameters(encoder: torch.nn.Module) -> None:
    """Unfreeze all parameters in an encoder.

    Args:
        encoder: Encoder module to unfreeze

    """
    for param in encoder.parameters():
        param.requires_grad = True
