import torch
import torch.nn.functional as F


def sequence_loss(logits: torch.Tensor, targets: torch.Tensor, mask: torch.Tensor) -> torch.Tensor:
    """
    Compute cross-entropy loss for per-residue predictions with padding mask.

    Args:
        logits: [B, num_classes, L] raw scores per residue
        targets: [B, L] long tensor of labels
        mask: [B, L] boolean, True for valid positions (non-padding)

    Returns:
        scalar loss averaged over valid positions
    """

    # Flatten everything
    B, L, C = logits.shape
    logits_flat = logits.reshape(B*L, C)
    targets_flat = targets.reshape(B*L)
    mask_flat = mask.reshape(B*L)

    # Select only valid positions
    logits_valid = logits_flat[mask_flat]
    targets_valid = targets_flat[mask_flat]

    return F.cross_entropy(logits_valid, targets_valid)