import torch


def mixup_data(x, y, alpha=0.5):
    if alpha > 0:
        lam = torch.distributions.Beta(alpha, alpha).sample().item()
    else:
        lam = 1.0

    batch_size = x.size()[0]
    index = torch.randperm(batch_size)

    mixed_x = lam * x + (1 - lam) * x[index, :]
    y_a, y_b = y, y[index]

    return mixed_x, y_a, y_b, lam


def mixup_criterion(criterion, pred, y_a, y_b, lam):
    return lam * criterion(pred, y_a) + (1 - lam) * criterion(pred, y_b)


if __name__ == "__main__":
    # test mixup_data
    x = torch.tensor([[1, 2], [3, 4], [5, 6], [2, 3]], dtype=torch.float32)
    y = torch.tensor([0, 1, 1, 0], dtype=torch.long)
    alpha = 0.5
    print(mixup_data(x, y, alpha))
