"""Loss functions for training the interpreter model.

This module implements INTERP-5: Interpreter Loss and Training Objective.
Provides contrastive and regression loss functions for training the interpreter.
"""

from typing import Dict, Tuple

import torch
from torch import nn


class InterpreterContrastiveLoss(nn.Module):
    """InfoNCE loss for interpreter training.

    Trains the interpreter using InfoNCE loss to distinguish between original
    and masked inputs. The original input is treated as positive (lower energy)
    and masked input as negative (higher energy).
    """

    def __init__(self, margin: float = 1.0, temperature: float = 0.1) -> None:
        super().__init__()
        self.margin = margin  # Keep for backward compatibility, not used in InfoNCE
        self.temperature = temperature

    def forward(
        self,
        original_energies: torch.Tensor,
        masked_energies: torch.Tensor,
    ) -> Tuple[torch.Tensor, Dict[str, float]]:
        """Compute InfoNCE loss.

        Args:
            original_energies: [batch_size] energies for original inputs (positive)
            masked_energies: [batch_size] energies for masked inputs (negative)

        Returns:
            loss: Scalar loss value
            metrics: Dictionary of metrics for logging

        """
        # Convert energies to logits (lower energy = higher probability)
        pos_logits = -original_energies / self.temperature
        neg_logits = -masked_energies / self.temperature

        # Stack logits: [batch_size, 2] where dim 0 is positive, dim 1 is negative
        all_logits = torch.stack([pos_logits, neg_logits], dim=1)

        # Targets are always 0 (positive sample is the first one)
        targets = torch.zeros(original_energies.size(0), dtype=torch.long,
                            device=original_energies.device)

        # Add numerical stability
        max_logits = torch.max(all_logits, dim=1, keepdim=True)[0]
        stable_logits = all_logits - max_logits

        # InfoNCE loss using cross-entropy
        loss = torch.nn.functional.cross_entropy(stable_logits, targets)

        # Calculate metrics
        energy_diff = masked_energies - original_energies
        predicted_probs = torch.softmax(stable_logits, dim=1)
        classification_threshold = 0.5
        accuracy = (predicted_probs[:, 0] > classification_threshold).float().mean()

        metrics = {
            "infonce_loss": loss.item(),
            "energy_diff_mean": energy_diff.mean().item(),
            "original_energy_mean": original_energies.mean().item(),
            "masked_energy_mean": masked_energies.mean().item(),
            "correct_direction": (energy_diff > 0).float().mean().item(),
            "classification_accuracy": accuracy.item(),
            "temperature": self.temperature,
        }

        return loss, metrics


class InterpreterRegressionLoss(nn.Module):
    """Regression loss for interpreter training.

    Trains the interpreter to predict the energy increase when sentences
    are masked based on importance scores.
    """

    def __init__(self, loss_type: str = "mse") -> None:
        super().__init__()
        self.loss_type = loss_type

        if loss_type == "mse":
            self.criterion = nn.MSELoss()
        elif loss_type == "mae":
            self.criterion = nn.L1Loss()
        elif loss_type == "huber":
            self.criterion = nn.SmoothL1Loss()

    def forward(
        self,
        original_energies: torch.Tensor,
        masked_energies: torch.Tensor,
        target_increase: torch.Tensor = None,
    ) -> Tuple[torch.Tensor, Dict[str, float]]:
        """Compute regression loss.

        Args:
            original_energies: [batch_size] energies for original inputs
            masked_energies: [batch_size] energies for masked inputs
            target_increase: [batch_size] target energy increase (optional)

        Returns:
            loss: Scalar loss value
            metrics: Dictionary of metrics for logging

        """
        actual_increase = masked_energies - original_energies

        if target_increase is None:
            # Default target: encourage moderate positive increase
            target_increase = torch.ones_like(actual_increase)

        loss = self.criterion(actual_increase, target_increase)

        metrics = {
            "regression_loss": loss.item(),
            "actual_increase_mean": actual_increase.mean().item(),
            "target_increase_mean": target_increase.mean().item(),
            "mae": torch.abs(actual_increase - target_increase).mean().item(),
            "correct_direction": (actual_increase > 0).float().mean().item(),
        }

        return loss, metrics


class RegularizedInterpreterLoss(nn.Module):
    """Regularized loss function with contrastive and sparsity regularization terms."""

    def __init__(
        self,
        contrastive_weight: float = 1.0,
        sparsity_weight: float = 0.1,
        margin: float = 1.0,
        temperature: float = 0.1,
    ) -> None:
        super().__init__()
        self.contrastive_weight = contrastive_weight
        self.sparsity_weight = sparsity_weight
        self.contrastive_loss = InterpreterContrastiveLoss(
            margin=margin, temperature=temperature
        )

    def forward(
        self,
        original_energies: torch.Tensor,
        masked_energies: torch.Tensor,
        importance_scores: torch.Tensor,
    ) -> Tuple[torch.Tensor, Dict[str, float]]:
        """Compute regularized loss.

        Args:
            original_energies: [batch_size] energies for original inputs
            masked_energies: [batch_size] energies for masked inputs
            importance_scores: [batch_size, n_sentences] importance scores

        Returns:
            loss: Scalar loss value
            metrics: Dictionary of metrics for logging

        """
        # Contrastive loss
        contrastive_loss, contrastive_metrics = self.contrastive_loss(
            original_energies, masked_energies
        )

        # Sparsity regularization (encourage sparse explanations)
        sparsity_loss = torch.norm(importance_scores, p=1, dim=1).mean()

        # Regularized loss
        total_loss = (
            self.contrastive_weight * contrastive_loss
            + self.sparsity_weight * sparsity_loss
        )

        metrics = {
            **contrastive_metrics,
            "sparsity_loss": sparsity_loss.item(),
            "total_loss": total_loss.item(),
            "importance_mean": importance_scores.mean().item(),
            "importance_std": importance_scores.std().item(),
            "max_importance": importance_scores.max().item(),
            "min_importance": importance_scores.min().item(),
        }

        return total_loss, metrics


def create_interpreter_loss(loss_type: str, **kwargs) -> nn.Module:
    """Create an interpreter loss function.

    Args:
        loss_type: Type of loss ("contrastive", "regression", "regularized")
        **kwargs: Additional arguments for loss function

    Returns:
        Loss function module

    """
    if loss_type == "contrastive":
        return InterpreterContrastiveLoss(
            margin=kwargs.get("margin", 1.0),
            temperature=kwargs.get("temperature", 0.1)
        )
    if loss_type == "regression":
        return InterpreterRegressionLoss(
            loss_type=kwargs.get("regression_loss_type", "mse")
        )
    if loss_type == "regularized":
        return RegularizedInterpreterLoss(
            contrastive_weight=kwargs.get("contrastive_weight", 1.0),
            sparsity_weight=kwargs.get("sparsity_weight", 0.1),
            margin=kwargs.get("margin", 1.0),
            temperature=kwargs.get("temperature", 0.1),
        )
    error_msg = f"Unknown loss_type: {loss_type}"
    raise ValueError(error_msg)
