import torch
import torch.nn.functional as F


def generate_true_labels(labels: torch.Tensor, num_classes: int) -> torch.Tensor:
    """
    Generate one-hot encoded vectors for true labels.

    Args:
        labels (Tensor): Integer labels of shape (B,)
        num_classes (int): Total number of classes

    Returns:
        Tensor: One-hot labels of shape (B, C)
    """
    return F.one_hot(labels, num_classes=num_classes).float()

def generate_hard_labels(
    preds: torch.Tensor,
    labels: torch.Tensor,
    num_classes: int,
    alpha: float = 1.0,
) -> torch.Tensor:
    """    Generate hard labels for training by sampling from the wrong class probabilities.
    Args:
        preds (Tensor): Model predictions of shape (B, C)
        labels (Tensor): True labels of shape (B,)
        num_classes (int): Total number of classes
        alpha (float): Scaling factor for log probabilities
    Returns:
        Tensor: One-hot encoded hard labels of shape (B, C)
    """
    B = preds.shape[0]

    # 1. Stable masking
    masked_preds = preds.clone()
    masked_preds.scatter_(1, labels.unsqueeze(1), -torch.finfo(preds.dtype).max)

    # 2. Log-space operations
    log_probs = F.log_softmax(masked_preds, dim=1)
    scaled_log_probs = alpha * log_probs
    probs = scaled_log_probs.exp().clamp(min=1e-10, max=1e10)

    # 3. Explicitly mask true label and normalize
    probs[torch.arange(B), labels] = 0  # Critical fix
    probs_sum = probs.sum(dim=1, keepdim=True)
    probs = torch.where(probs_sum > 0, probs / probs_sum, 1.0 / (num_classes - 1))

    # 4. Sample wrong labels
    wrong_labels = torch.multinomial(probs, num_samples=1).squeeze(1)

    return F.one_hot(wrong_labels, num_classes=num_classes).float()
