import torch
import numpy as np

# source: https://github.com/hysts/pytorch_mixup/blob/master/utils.py


def onehot(label, n_classes):
    return torch.zeros(label.size(0), n_classes,
                       device=label.device).scatter_(1, label.view(-1, 1), 1)


def mixup(data, targets, alpha, n_classes):
    indices = torch.randperm(data.size(0))
    data2 = data[indices]
    targets2 = targets[indices]

    targets = onehot(targets, n_classes)
    targets2 = onehot(targets2, n_classes)

    lam = torch.FloatTensor([np.random.beta(alpha, alpha)]).to(data.device)
    data = data * lam + data2 * (1 - lam)
    targets = targets * lam + targets2 * (1 - lam)
    return data, targets
