import torch
import torch.nn.functional as F
from typing import Optional, Tuple

def prepare_for_classification(
    student_logits: torch.Tensor,   # [B, V]
    labels: torch.Tensor,           # [B]
    teacher_logits: Optional[torch.Tensor] = None
) -> Tuple[torch.Tensor, ...]:
    """
    Prepares tensors for image classification losses.

    Returns:
      If teacher_logits is provided: (student_logits, teacher_logits, labels)
      Otherwise: (student_logits, labels)
    """
    if teacher_logits is not None:
        if teacher_logits.shape != student_logits.shape:
            raise ValueError(
                f"Shape mismatch: student {student_logits.shape}, teacher {teacher_logits.shape}"
            )
        return student_logits, teacher_logits, labels
    return student_logits, labels

def create_adjusted_ranking(
    logits: torch.Tensor,  # [N, V]
    labels: torch.Tensor   # [N]
) -> torch.Tensor:
    """
    For each sample, sorts logits ascending, removes the true label index,
    and appends the true label at the end.
    Returns adjusted indices of shape [N, V].
    """
    N, V = logits.shape
    sorted_idx = torch.argsort(logits, dim=-1,decending=False)  # ascending
    mask = sorted_idx != labels.unsqueeze(-1)
    excl = sorted_idx[mask].view(N, V - 1)
    return torch.cat([excl, labels.unsqueeze(-1)], dim=-1)

def plackett_luce_loss(
    student_logits: torch.Tensor,  # [B, V]
    teacher_logits: torch.Tensor,  # [B, V]
    labels: torch.Tensor,          # [B]
    temperature: float = 1.0
) -> torch.Tensor:
    """
    Computes the Plackett-Luce Distillation (PLD) loss.

    Steps per sample:
      1) Permute logits so the true label is last.
      2) Compute log-cumulative-sum-exp minus logits.
      3) Weight by teacher softmax probabilities.
      4) Sum over classes and average over batch.

    Returns:
      Scalar tensor: mean PLD loss over batch.
    """
    # Prepare and flatten
    flat_s, flat_t, flat_y = prepare_for_classification(
        student_logits, labels, teacher_logits
    )
    # Permutation indices
    idx = create_adjusted_ranking(flat_t, flat_y)
    # Permute logits
    s_p = torch.gather(flat_s, dim=1, index=idx)
    t_p = torch.gather(flat_t, dim=1, index=idx)
    # Compute terms
    cums = torch.logcumsumexp(s_p, dim=1)
    eps = torch.finfo(torch.float32).eps
    terms = (cums + eps) - (s_p + eps)
    # Teacher probabilities
    probs = F.softmax(t_p / temperature, dim=1)
    # Weighted sum and mean
    return (terms * probs).sum(dim=1).mean()
