import torch
import torchbearer
from torch.distributions import Beta
import torch.nn.functional as F

class RCutMix(torchbearer.Callback):
    """ Cutmix callback which replaces a random patch of image data with the corresponding patch from another image.
    This callback also converts labels to one hot before combining them according to the lambda parameters, sampled from
    a beta distribution as is done in the paper.

    Example: ::

        >>> from torchbearer import Trial
        >>> from torchbearer.callbacks import CutMix

        # Example Trial which does CutMix regularisation
        >>> cutmix = CutMix(1, classes=10)
        >>> trial = Trial(None, callbacks=[cutmix], metrics=['acc'])

    Args:
        alpha (float): The alpha value for the beta distribution.
        classes (int): The number of classes for conversion to one hot.

    State Requirements:
        - :attr:`torchbearer.state.X`: State should have the current data stored
        - :attr:`torchbearer.state.Y_TRUE`: State should have the current data stored
    """
    def __init__(self, other_generator=None, alpha=1., classes=-1, mixup_loss=False, reformulate=False):
        super(RCutMix, self).__init__()
        self.classes = classes
        self.reformulate = reformulate
        self.dist = Beta(torch.tensor([float(alpha)]), torch.tensor([float(alpha)])) if not reformulate else Beta(torch.tensor([float(alpha) + 1]), torch.tensor([float(alpha)]))
        self.mixup_loss = mixup_loss
        self.other_generator = other_generator
        self.iterator = None

    def on_start_epoch(self, state):
        self.iterator = iter(self.other_generator)

    def _to_one_hot(self, target):
        if target.dim() == 1:
            target = target.unsqueeze(1)
            one_hot = torch.zeros_like(target).repeat(1, self.classes)
            one_hot.scatter_(1, target, 1)
            return one_hot
        return target

    def on_sample(self, state):
        super(RCutMix, self).on_sample(state)

        lam = self.dist.sample().to(state[torchbearer.DEVICE])
        length = (1 - lam).sqrt()
        cutter = BatchCutout(1, (length * state[torchbearer.X].size(-1)).round().item(), (length * state[torchbearer.X].size(-2)).round().item())
        mask = cutter(state[torchbearer.X])
        erase_locations = mask == 0

        permutation = torch.randperm(state[torchbearer.X].size(0))
        if self.mixup_loss:
            state[torchbearer.MIXUP_PERMUTATION] = permutation
            state[torchbearer.MIXUP_LAMBDA] = lam

        if self.reformulate:
            other, _ = next(self.iterator)
            other = other[:(state[torchbearer.X].shape[0])].to(state[torchbearer.X].device)
            if other.shape != state[torchbearer.X].shape:
                other = F.pad(other, (2, 2, 2, 2)).repeat(1,3,1,1)
            state[torchbearer.X][erase_locations] = other[erase_locations]
            lam = 1.
            state[torchbearer.MIXUP_LAMBDA] = 1.
        else:
            state[torchbearer.X][erase_locations] = state[torchbearer.X][permutation][erase_locations]


        if self.mixup_loss:
            state[torchbearer.TARGET] = (state[torchbearer.TARGET], state[torchbearer.TARGET][state[torchbearer.MIXUP_PERMUTATION]])
        else:
            target = self._to_one_hot(state[torchbearer.TARGET]).float()
            state[torchbearer.TARGET] = lam * target + (1 - lam) * target[permutation]

    def on_sample_validation(self, state):
        super(RCutMix, self).on_sample_validation(state)
        if not self.mixup_loss:
            state[torchbearer.TARGET] = self._to_one_hot(state[torchbearer.TARGET]).float()

class BatchCutout(object):
    """Randomly mask out one or more patches from a batch of images.

    Args:
        n_holes (int): Number of patches to cut out of each image.
        width (int): The width (in pixels) of each square patch.
        height (int): The height (in pixels) of each square patch.
    """
    def __init__(self, n_holes, width, height):
        self.n_holes = n_holes
        self.width = width
        self.height = height

    def __call__(self, img):
        """

        Args:
            img (Tensor): Tensor image of size (B, C, H, W).
        Returns:
            Tensor: Image with n_holes of dimension length x length cut out of it.
        """
        b = img.size(0)
        c = img.size(1)
        h = img.size(-2)
        w = img.size(-1)

        mask = torch.ones((b, h, w), device=img.device)

        for n in range(self.n_holes):
            y = torch.randint(high=h, size=(b,)).long()
            x = torch.randint(high=w, size=(b,)).long() 

            y1 = (y - self.height // 2).clamp(0, h).int()
            y2 = (y + self.height // 2).clamp(0, h).int()
            x1 = (x - self.width // 2).clamp(0, w).int()
            x2 = (x + self.width // 2).clamp(0, w).int()

            for batch in range(b):
                mask[batch, y1[batch]: y2[batch], x1[batch]: x2[batch]] = 0

        mask = mask.unsqueeze(1).repeat(1, c, 1, 1)

        return mask

class Cutout(torchbearer.Callback):
    """ Cutout callback which randomly masks out patches of image data. Implementation a modified version of the code
    found `here <https://github.com/uoguelph-mlrg/Cutout>`_.

    Example: ::

        >>> from torchbearer import Trial
        >>> from torchbearer.callbacks import Cutout

        # Example Trial which does Cutout regularisation
        >>> cutout = Cutout(1, 10)
        >>> trial = Trial(None, callbacks=[cutout], metrics=['acc'])

    Args:
        n_holes (int): Number of patches to cut out of each image.
        length (int): The length (in pixels) of each square patch.
        constant (float): Constant value for each square patch

    State Requirements:
        - :attr:`torchbearer.state.X`: State should have the current data stored
    """
    def __init__(self, n_holes, length, constant=0.):
        super(Cutout, self).__init__()
        self.constant = constant
        self.cutter = BatchCutout(n_holes, length, length)

    def on_sample(self, state):
        super(Cutout, self).on_sample(state)
        mask = self.cutter(state[torchbearer.X])
        erase_locations = mask == 0
        constant = torch.ones_like(state[torchbearer.X]) * self.constant
        state[torchbearer.X][erase_locations] = constant[erase_locations]
