import torch
import torch.nn as nn
import torch.nn.functional as F


class MixUp(nn.Module):
    """Similar to the Pytorch implementation, each input is mixed with the subsequent
    input in the mini-batch. Thus, the mini-batch should be shuffled.
    """

    def __init__(self, alpha=1.0, num_classes=10):
        super().__init__()
        self.alpha = float(alpha)
        self.num_classes = num_classes
        self._dist = torch.distributions.beta.Beta(
            torch.tensor([alpha]), torch.tensor([alpha])
        )

    def _one_hot(self, labels):
        return F.one_hot(labels, num_classes=self.num_classes)

    def forward(self, embeddings, labels):
        lamda = self._dist.sample((labels.shape[0],))
        labels = lamda * self._one_hot(labels) + (1 - lamda) * self._one_hot(
            labels.roll(1, 0)
        )
        embeddings = lamda[:, None, :] * embeddings + (
            1 - lamda[:, None, :]
        ) * embeddings.roll(1, 0)
        return embeddings, labels


class CutMix(nn.Module):
    def __init__(self, alpha=1.0, num_classes=10):
        super().__init__()
        self.alpha = float(alpha)
        self.num_classes = num_classes
        self._dist = torch.distributions.beta.Beta(
            torch.tensor([alpha]), torch.tensor([alpha])
        )
        self._dist = torch.distributions.Bernoulli(alpha)

    def _one_hot(self, labels):
        return F.one_hot(labels, num_classes=self.num_classes)

    def forward(self, embeddings, labels):
        lamda = self._dist.sample(embeddings.shape).bool()
        label_lambda = lamda.float().mean((1, 2))
        labels = label_lambda[:, None] * self._one_hot(labels) + (
            1 - label_lambda[:, None]
        ) * self._one_hot(labels.roll(1, 0))
        embeddings = lamda * embeddings + (~lamda) * embeddings.roll(1, 0)
        return embeddings, labels
