import torch

def mixup_criterion(criterion: torch.nn.Module, pred: torch.Tensor, y_a: torch.Tensor, y_b: torch.Tensor, lam: float) -> torch.Tensor:
    """
    Code adapted from https://github.com/facebookresearch/mixup-cifar10/blob/main/train.py
    """
    return lam * criterion(pred, y_a) + (1 - lam) * criterion(pred, y_b)