"""Loss functions for EBM training.

This module provides contrastive and InfoNCE loss functions along with
accuracy computation utilities for energy-based models. The functions are
designed to handle batches of positive and negative energy scores, including
support for named negative sampling methods and optional weighting.
"""

from __future__ import annotations

from typing import Dict, List, Optional

import torch
from torch import nn
from torch.nn import functional as f


def contrastive_loss(
    pos_energy: torch.Tensor,
    neg_energy: torch.Tensor,
    margin: float,
) -> torch.Tensor:
    """Calculates the margin-ranking contrastive loss.

    This loss function encourages the energy of positive samples to be lower
    than the energy of negative samples by at least a specified margin.
    The goal is to minimize `max(0, margin + pos_energy - neg_energy)`.

    Args:
        pos_energy (torch.Tensor): A tensor of energy scores for positive samples.
            Shape: (B,).
        neg_energy (torch.Tensor): A tensor of energy scores for negative samples.
            Shape: (B,).
        margin (float): The desired margin between positive and negative energies.

    Returns:
        torch.Tensor: The mean scalar loss value for the batch.
    """
    # Target `y=1` indicates that the first input (neg_energy) should have a
    # higher rank (value) than the second input (pos_energy).
    y = torch.ones_like(pos_energy)
    return nn.MarginRankingLoss(margin=margin)(
        neg_energy, pos_energy, y
    )  # MarginRankingLoss by default does a mean-reduction over the B examples


def individual_contrastive_losses(
    pos_energy: torch.Tensor, 
    neg_energies_dict: dict[str, torch.Tensor], 
    margin: float
) -> dict[str, torch.Tensor]:
    """Computes individual contrastive losses for multiple negative sampling methods.

    This function handles samples with non-finite energies (e.g., infinity)
    by excluding them from the loss calculation for that specific method.

    Args:
        pos_energy (torch.Tensor): A tensor of energy scores for positive samples.
            Shape: (B,).
        neg_energies_dict (Dict[str, torch.Tensor]): A dictionary mapping the
            name of a negative sampling method to its batch of energy scores.
        margin (float): The desired margin for the contrastive loss.

    Returns:
        Dict[str, torch.Tensor]: A dictionary mapping each method name to its
        scalar loss value. If a method has no valid samples, its loss is NaN.
    """
    losses = {}
    for method_name, neg_energy in neg_energies_dict.items():
        valid_indices = torch.isfinite(neg_energy)
        if not valid_indices.any():
            losses[method_name] = torch.tensor(float("nan"), device=pos_energy.device)
            continue

        valid_pos_energy = pos_energy[valid_indices]
        valid_neg_energy = neg_energy[valid_indices]
        losses[method_name] = contrastive_loss(valid_pos_energy, valid_neg_energy, margin)
    return losses


def summed_contrastive_loss(
    pos_energy: torch.Tensor,
    neg_energies_dict: Dict[str, torch.Tensor],
    margin: float,
    weights: Optional[Dict[str, float]] = None,
) -> torch.Tensor:
    """Calculates the (optionally weighted) sum of multiple contrastive losses.

    Args:
        pos_energy (torch.Tensor): A tensor of energy scores for positive samples.
            Shape: (B,).
        neg_energies_dict (Dict[str, torch.Tensor]): A dictionary mapping method
            names to their negative energy scores.
        margin (float): The desired margin for the contrastive loss.
        weights (Optional[Dict[str, float]]): A dictionary mapping method names
            to a specific weight for their loss. If None, all weights default to 1.0.

    Returns:
        torch.Tensor: The total scalar loss value for the batch.
    """
    if weights is None:
        weights = dict.fromkeys(neg_energies_dict, 1.0)

    total_loss = torch.tensor(0.0, device=pos_energy.device)
    for method_name, neg_energy in neg_energies_dict.items():
        valid_indices = torch.isfinite(neg_energy)
        if not valid_indices.any():
            continue

        valid_pos_energy = pos_energy[valid_indices]
        valid_neg_energy = neg_energy[valid_indices]

        loss = contrastive_loss(valid_pos_energy, valid_neg_energy, margin)
        total_loss += weights.get(method_name, 1.0) * loss

    return total_loss


def infonce_loss(
    pos_energy: torch.Tensor, 
    neg_energies: list[torch.Tensor], 
    temperature: float = 0.1
) -> torch.Tensor:
    """Calculates the InfoNCE (Noise Contrastive Estimation) loss.

    This loss treats the task as a classification problem where the goal is to
    identify the single positive sample from a set of negative samples. Lower
    energy scores are treated as higher probability logits.

    Args:
        pos_energy (torch.Tensor): A tensor of positive sample energies.
            Shape: (B,).
        neg_energies (List[torch.Tensor]): A list of tensors, each containing
            negative sample energies. Each tensor has shape (B,).
        temperature (float): The temperature scaling parameter to sharpen or
            soften the probability distribution.

    Returns:
        torch.Tensor: The mean scalar InfoNCE loss for the batch.
    """
    # Convert to logits (lower energy = higher probability)
    pos_logits = -pos_energy / temperature

    # Stack all negative energies: [B, num_negatives]
    if not neg_energies:
        msg = "At least one negative energy tensor is needed."
        raise ValueError(msg)

    all_neg_energies = torch.stack(neg_energies, dim=1)  # [B, num_negatives]
    neg_logits = -all_neg_energies / temperature

    # Concatenate positive and negative logits: [B, 1 + num_negatives]
    all_logits = torch.cat([pos_logits.unsqueeze(1), neg_logits], dim=1)

    # Targets are always 0 (positive sample is the first one)
    targets = torch.zeros(pos_energy.size(0), dtype=torch.long, device=pos_energy.device)

    # Add numerical stability
    max_logits = torch.max(all_logits, dim=1, keepdim=True)[0]
    stable_logits = all_logits - max_logits

    return f.cross_entropy(stable_logits, targets)


def individual_infonce_losses(
    pos_energy: torch.Tensor,
    neg_energies_dict: dict[str, torch.Tensor],
    temperature: float = 0.1,
) -> dict[str, torch.Tensor]:
    """Computes individual InfoNCE losses for each negative sampling method.

    Each loss is calculated as a 1-vs-1 classification problem between the
    positive sample and the negative sample from a single method.

    Args:
        pos_energy (torch.Tensor): A tensor of positive sample energies.
            Shape: (B,).
        neg_energies_dict (Dict[str, torch.Tensor]): A dictionary mapping method
            names to their negative energy scores.
        temperature (float): The temperature scaling parameter.

    Returns:
        Dict[str, torch.Tensor]: A dictionary mapping each method name to its
        scalar loss value. If a method has no valid samples, its loss is NaN.
    """
    losses = {}

    for method_name, neg_energy in neg_energies_dict.items():
        # Select only the valid pairs where negative energy is not infinity;
        # needed due to method isolation and the exception in sentence masking method
        # since the possible inf values could affect batch average for that method
        valid_indices = torch.isfinite(neg_energy)

        # If there are no valid samples for this method in this batch, skip it
        # else separate valid negative energies
        if not valid_indices.any():
            losses[method_name] = torch.tensor(float("nan"), device=pos_energy.device)
            continue

        # Filtering the tensors to remove pairs with invalid negative energy
        valid_pos_energy = pos_energy[valid_indices]
        valid_neg_energy = neg_energy[valid_indices]

        # Individual InfoNCE loss between positive and this negative
        pos_logits = -valid_pos_energy / temperature
        neg_logits = -valid_neg_energy / temperature

        # Stack: [B, 2] where first column is positive, second is negative
        logits = torch.stack([pos_logits, neg_logits], dim=1)
        targets = torch.zeros(
            valid_pos_energy.size(0), dtype=torch.long, device=pos_energy.device
        )

        # Add numerical stability
        max_logits = torch.max(logits, dim=1, keepdim=True)[0]
        stable_logits = logits - max_logits

        losses[method_name] = f.cross_entropy(stable_logits, targets)

    return losses


def compute_accuracies(
    pos_energy: torch.Tensor, neg_energies_dict: dict[str, torch.Tensor]
) -> dict[str, float]:
    """Computes accuracy metrics for an energy-based model.

    Accuracy is defined as the percentage of samples where the positive energy
    is lower than the negative energy.

    Args:
        pos_energy (torch.Tensor): A tensor of positive sample energies.
            Shape: (B,).
        neg_energies_dict (Dict[str, torch.Tensor]): A dictionary mapping method
            names to their negative energy scores.

    Returns:
        Dict[str, float]: A dictionary mapping each method name (plus an
        'overall' key) to its accuracy in percent.
    """
    batch_size = pos_energy.size(0)
    accuracies = {}

    # Overall accuracy: positive should have lower energy than ALL negatives
    # Does not need to be altered to remove the influence of invalid 'inf' samples,
    # since they already result in True and the hard '&' operation
    # will prioritize other methods
    is_correct_overall = torch.ones_like(pos_energy, dtype=torch.bool)
    for neg_energy in neg_energies_dict.values():
        is_correct_overall &= pos_energy < neg_energy

    if batch_size > 0:
        accuracies["overall"] = is_correct_overall.sum().item() / batch_size * 100.0
    else:
        accuracies["overall"] = float("nan")
        
    # Individual accuracies: positive should have lower energy than negative
    for method_name, neg_energy in neg_energies_dict.items():
        # Only consider pairs where the negative energy is finite
        valid_indices = torch.isfinite(neg_energy)
        num_valid_samples = valid_indices.sum().item()

        if num_valid_samples == 0:
            accuracies[method_name] = float("nan")
            continue

        # Compare only the valid positive and negative energies
        correct = (pos_energy[valid_indices] < neg_energy[valid_indices]).sum().item()
        accuracies[method_name] = correct / num_valid_samples * 100.0

    return accuracies
