import torch
import math
import numpy as np
import random
from torch.distributions.beta import Beta

def mixup(img1, img2, lam):
    return img1 * lam + img2 * (1-lam)

def cutMix(img1, img2, lam):
    batch = img1.reshape(1, img1.shape[0], img1.shape[1], img1.shape[2])
    batch2 = img2.reshape(1, img2.shape[0], img2.shape[1], img2.shape[2])

    length = torch.tensor(math.sqrt(1 - lam))
    cutter = BatchCutout(1, (length * batch.size(-1)).round().item(), (length * batch.size(-2)).round().item())
    mask = cutter(batch)
    erase_locations = mask == 0

    permutation = torch.randperm(batch.size(0))

    batch[erase_locations] = batch2[erase_locations]
    return batch[0]

def cutMix_original(img1, img2, lam):
    batch = img1.reshape(1, img1.shape[0], img1.shape[1], img1.shape[2])
    batch2 = img2.reshape(1, img2.shape[0], img2.shape[1], img2.shape[2])

    length = torch.tensor(math.sqrt(1 - lam))
    cutter = BatchCutoutOriginal(1, (length * batch.size(-1)).round().item(), (length * batch.size(-2)).round().item())
    mask = cutter(batch)
    erase_locations = mask == 0

    permutation = torch.randperm(batch.size(0))

    batch[erase_locations] = batch2[erase_locations]
    return batch[0]

def cutout(img1, img2, lam):
    batch = img1.reshape(1, img1.shape[0], img1.shape[1], img1.shape[2])

    length = torch.tensor(math.sqrt(1 - lam))
    cutter = BatchCutout(1, (length * batch.size(-1)).round().item(), (length * batch.size(-2)).round().item())
    mask = cutter(batch)
    erase_locations = mask == 0

    permutation = torch.randperm(batch.size(0))
    batch[erase_locations] = 0

    return batch[0]

def cutout_original(img1, img2, lam):
    batch = img1.reshape(1, img1.shape[0], img1.shape[1], img1.shape[2])

    length = torch.tensor(math.sqrt(1 - lam))
    cutter = BatchCutoutOriginal(1, (length * batch.size(-1)).round().item(), (length * batch.size(-2)).round().item())
    mask = cutter(batch)
    erase_locations = mask == 0

    permutation = torch.randperm(batch.size(0))
    batch[erase_locations] = 0

    return batch[0]


class BatchCutout(object):
    """Randomly mask out one or more patches from a batch of images.
    Args:
        n_holes (int): Number of patches to cut out of each image.
        width (int): The width (in pixels) of each square patch.
        height (int): The height (in pixels) of each square patch.
    """
    def __init__(self, n_holes, width, height):
        self.n_holes = n_holes
        self.width = width
        self.height = height

    def __call__(self, img):
        """
        Args:
            img (Tensor): Tensor image of size (B, C, H, W).
        Returns:
            Tensor: Image with n_holes of dimension length x length cut out of it.
        """
        b = img.size(0)
        c = img.size(1)
        h = img.size(-2)
        w = img.size(-1)

        mask = torch.ones((b, h, w), device=img.device)

        for n in range(self.n_holes):
            y = torch.randint(low=round(self.height/2), high=h - round(self.height/2), size=(b,)).long() if not round(self.height/2) == round(h/2) else  torch.tensor([int(self.height/2)])
            x = torch.randint(low=round(self.width/2), high=w - round(self.width/2), size=(b,)).long() if not round(self.width/2) == round(w/2) else  torch.tensor([int(self.width/2)])

            y1 = (y - self.height // 2).clamp(0, h).type(torch.uint8)
            y2 = (y + self.height // 2).clamp(0, h).type(torch.uint8)
            x1 = (x - self.width // 2).clamp(0, w).type(torch.uint8)
            x2 = (x + self.width // 2).clamp(0, w).type(torch.uint8)

            for batch in range(b):
                mask[batch, y1[batch]: y2[batch], x1[batch]: x2[batch]] = 0

        mask = mask.unsqueeze(1).repeat(1, c, 1, 1)

        return mask

def fftfreq2d(w, h):
    fy = np.fft.fftfreq(h)[:, None]
    if w % 2 == 1:
        fx = np.fft.fftfreq(w)[: w // 2 + 2]
    else:
        fx = np.fft.fftfreq(w)[: w // 2 + 1]
    return torch.from_numpy(np.sqrt(fx * fx + fy * fy)).float()

def fftfreqnd(h, w=None, z=None):
    """ Get bin values for discrete fourier transform of size (h, w, z)

    :param h: Required, first dimension size
    :param w: Optional, second dimension size
    :param z: Optional, third dimension size
    """
    fz = fx = 0
    fy = np.fft.fftfreq(h)

    if w is not None:
        fy = np.expand_dims(fy, -1)

        if w % 2 == 1:
            fx = np.fft.fftfreq(w)[: w // 2 + 2]
        else:
            fx = np.fft.fftfreq(w)[: w // 2 + 1]

    if z is not None:
        fy = np.expand_dims(fy, -1)
        if z % 2 == 1:
            fz = np.fft.fftfreq(z)[:, None]
        else:
            fz = np.fft.fftfreq(z)[:, None]

    return np.sqrt(fx * fx + fy * fy + fz * fz)


def get_spectrum(freqs, decay_power, ch, h, w=0, z=0):
    """ Samples a fourier image with given size and frequencies decayed by decay power

    :param freqs: Bin values for the discrete fourier transform
    :param decay_power: Decay power for frequency decay prop 1/f**d
    :param ch: Number of channels for the resulting mask
    :param h: Required, first dimension size
    :param w: Optional, second dimension size
    :param z: Optional, third dimension size
    """
    scale = np.ones(1) / (np.maximum(freqs, np.array([1. / max(w, h, z)])) ** decay_power)

    param_size = [ch] + list(freqs.shape) + [2]
    param = np.random.randn(*param_size)

    scale = np.expand_dims(scale, -1)[None, :]

    return scale * param


def make_low_freq_image(decay, shape, ch=1):
    """ Sample a low frequency image from fourier space

    :param decay_power: Decay power for frequency decay prop 1/f**d
    :param shape: Shape of desired mask, list up to 3 dims
    :param ch: Number of channels for desired mask
    """
    freqs = fftfreqnd(*shape)
    spectrum = get_spectrum(freqs, decay, ch, *shape)
    spectrum = spectrum[:, 0] + 1j * spectrum[:, 1]
    mask = np.real(np.fft.irfftn(spectrum, shape))

    if len(shape) == 1:
        mask = mask[:1, :shape[0]]
    if len(shape) == 2:
        mask = mask[:1, :shape[0], :shape[1]]
    if len(shape) == 3:
        mask = mask[:1, :shape[0], :shape[1], :shape[2]]

    mask = mask
    mask = (mask - mask.min())
    mask = mask / mask.max()
    return mask


def sample_lam(alpha, reformulate=False):
    """ Sample a lambda from symmetric beta distribution with given alpha

    :param alpha: Alpha value for beta distribution
    :param reformulate: If True, uses the reformulation of [1].
    """
    if reformulate:
        lam = beta.rvs(alpha+1, alpha)
    else:
        lam = beta.rvs(alpha, alpha)

    return lam


def binarise_mask(mask, lam, in_shape, max_soft=0.0):
    """ Binarises a given low frequency image such that it has mean lambda.

    :param mask: Low frequency image, usually the result of `make_low_freq_image`
    :param lam: Mean value of final mask
    :param in_shape: Shape of inputs
    :param max_soft: Softening value between 0 and 0.5 which smooths hard edges in the mask.
    :return:
    """
    idx = mask.reshape(-1).argsort()[::-1]
    mask = mask.reshape(-1)
    num = math.ceil(lam * mask.size) if random.random() > 0.5 else math.floor(lam * mask.size)

    eff_soft = max_soft
    if max_soft > lam or max_soft > (1-lam):
        eff_soft = min(lam, 1-lam)

    soft = int(mask.size * eff_soft)
    num_low = num - soft
    num_high = num + soft

    mask[idx[:num_high]] = 1
    mask[idx[num_low:]] = 0
    mask[idx[num_low:num_high]] = np.linspace(1, 0, (num_high - num_low))

    mask = mask.reshape((1, *in_shape))
    return mask


def sample_mask(lam, decay_power, shape, max_soft=0.0, reformulate=False):
    """ Samples a mean lambda from beta distribution parametrised by alpha, creates a low frequency image and binarises
    it based on this lambda

    :param alpha: Alpha value for beta distribution from which to sample mean of mask
    :param decay_power: Decay power for frequency decay prop 1/f**d
    :param shape: Shape of desired mask, list up to 3 dims
    :param max_soft: Softening value between 0 and 0.5 which smooths hard edges in the mask.
    :param reformulate: If True, uses the reformulation of [1].
    """
    if isinstance(shape, int):
        shape = (shape,)

    mask = make_low_freq_image(decay_power, shape)
    mask = binarise_mask(mask, lam, shape, max_soft)

    return mask


def fmix(img1, img2, lam, decay_power=3, shape=(64, 64), max_soft=0.0, reformulate=False):
    """

    :param x: Image batch on which to apply fmix of shape [b, c, shape*]
    :param alpha: Alpha value for beta distribution from which to sample mean of mask
    :param decay_power: Decay power for frequency decay prop 1/f**d
    :param shape: Shape of desired mask, list up to 3 dims
    :param max_soft: Softening value between 0 and 0.5 which smooths hard edges in the mask.
    :param reformulate: If True, uses the reformulation of [1].
    :return: mixed input, permutation indices, lambda value of mix,
    """
    mask = torch.Tensor(sample_mask(lam, decay_power, shape, max_soft, reformulate))

    x1, x2 = img1 * mask, img2 * (1-mask)
    return x1+x2

def fout(img1, img2, lam, decay_power=3, shape=(28, 28), max_soft=0.0, reformulate=False):
    """

    :param x: Image batch on which to apply fmix of shape [b, c, shape*]
    :param alpha: Alpha value for beta distribution from which to sample mean of mask
    :param decay_power: Decay power for frequency decay prop 1/f**d
    :param shape: Shape of desired mask, list up to 3 dims
    :param max_soft: Softening value between 0 and 0.5 which smooths hard edges in the mask.
    :param reformulate: If True, uses the reformulation of [1].
    :return: mixed input, permutation indices, lambda value of mix,
    """
    mask = torch.Tensor(sample_mask(lam, decay_power, shape, max_soft, reformulate))
    return img1 * mask

class FMixBase:
    r""" FMix augmentation

        Args:
            decay_power (float): Decay power for frequency decay prop 1/f**d
            alpha (float): Alpha value for beta distribution from which to sample mean of mask
            size ([int] | [int, int] | [int, int, int]): Shape of desired mask, list up to 3 dims
            max_soft (float): Softening value between 0 and 0.5 which smooths hard edges in the mask.
            reformulate (bool): If True, uses the reformulation of [1].
    """

    def __init__(self, decay_power=3, alpha=1, size=(32, 32), max_soft=0.0, reformulate=False):
        super().__init__()
        self.decay_power = decay_power
        self.reformulate = reformulate
        self.size = size
        self.alpha = alpha
        self.max_soft = max_soft
        self.index = None
        self.lam = None

    def __call__(self, x):
        raise NotImplementedError

    def loss(self, *args, **kwargs):
        raise NotImplementedError

class BatchCutoutOriginal(object):
    """Randomly mask out one or more patches from a batch of images.
    Args:
        n_holes (int): Number of patches to cut out of each image.
        width (int): The width (in pixels) of each square patch.
        height (int): The height (in pixels) of each square patch.
    """
    def __init__(self, n_holes, width, height):
        self.n_holes = n_holes
        self.width = width
        self.height = height

    def __call__(self, img):
        """
        Args:
            img (Tensor): Tensor image of size (B, C, H, W).
        Returns:
            Tensor: Image with n_holes of dimension length x length cut out of it.
        """
        b = img.size(0)
        c = img.size(1)
        h = img.size(-2)
        w = img.size(-1)

        mask = torch.ones((b, h, w), device=img.device)

        for n in range(self.n_holes):
            y = torch.randint(high=h, size=(b,)).long()
            x = torch.randint(high=w, size=(b,)).long()

            y1 = (y - self.height // 2).clamp(0, h).type(torch.uint8)
            y2 = (y + self.height // 2).clamp(0, h).type(torch.uint8)
            x1 = (x - self.width // 2).clamp(0, w).type(torch.uint8)
            x2 = (x + self.width // 2).clamp(0, w).type(torch.uint8)

            for batch in range(b):
                mask[batch, y1[batch]: y2[batch], x1[batch]: x2[batch]] = 0

        mask = mask.unsqueeze(1).repeat(1, c, 1, 1)

        return mask
