from PIL import Image

import numpy as np

import torch

class Cutout(object):
    """Randomly mask out one or more patches from an image.

    Args:
        n_holes (int): Number of patches to cut out of each image.
        length (int): The length (in pixels) of each square patch.
    """
    def __init__(self, n_holes, length):
        self.n_holes = n_holes
        self.length = length

    def __call__(self, img: Image.Image) -> Image.Image:
        """
        Args:
            img (PIL Image): PIL Image of size (H, W, C).
        Returns:
            PIL Image: Image with n_holes of dimension length x length cut out of it.
        """
        img = np.array(img)
        h, w, _ = img.shape

        mask = np.ones((h, w), np.float32)

        for _ in range(self.n_holes):
            y = np.random.randint(h)
            x = np.random.randint(w)

            y1 = np.clip(y - self.length // 2, 0, h)
            y2 = np.clip(y + self.length // 2, 0, h)
            x1 = np.clip(x - self.length // 2, 0, w)
            x2 = np.clip(x + self.length // 2, 0, w)

            mask[y1: y2, x1: x2] = 0.

        mask = np.expand_dims(mask, axis=-1)
        img = img * mask

        img = Image.fromarray(img.astype(np.uint8))

        return img
    
def mixup(x: torch.Tensor, y: torch.Tensor, alpha: float = 1.0) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, float]:
    '''
    Code adapted from https://github.com/facebookresearch/mixup-cifar10/blob/main/train.py
    
    Returns mixed inputs, pairs of targets, and lambda
    
    
    Args:
        x: input batch
        y: target batch
        is_snn: whether the model is an SNN
        alpha: parameter for beta distribution

    Returns:
        mixed_x: mixed input batch
        y_a: target batch
        y_b: target batch
        lam: lambda
    '''
    if alpha > 0:
        lam = np.random.beta(alpha, alpha)
    else:
        lam = 1

    batch_size = x.size()[0]
    print(batch_size)
    index = torch.randperm(batch_size).to(x.device)
    print(index)

    mixed_x = lam * x + (1 - lam) * x[index, :]
    y_a, y_b = y, y[index]
    return mixed_x, y_a, y_b, lam

def rand_bbox(size: int, lam: float) -> tuple[int, int, int, int]:
    """
    Code adapted from https://github.com/Intelligent-Computing-Lab-Yale/NDA_SNN/blob/main/functions/data_loaders.py
    
    Returns a random square bbox based on lambda
    
    Args:
        size: size of the image
        lam: lambda
        
    Returns:
        bbx1: x1 coordinate
        bby1: y1 coordinate
        bbx2: x2 coordinate
        bby2: y2 coordinate
    """
    W = size[-1]
    H = size[-2]
    cut_rat = np.sqrt(1. - lam)
    cut_w = np.int32(W * cut_rat)
    cut_h = np.int32(H * cut_rat)

    # uniform
    cx = np.random.randint(W)
    cy = np.random.randint(H)

    bbx1 = np.clip(cx - cut_w // 2, 0, W)
    bby1 = np.clip(cy - cut_h // 2, 0, H)
    bbx2 = np.clip(cx + cut_w // 2, 0, W)
    bby2 = np.clip(cy + cut_h // 2, 0, H)

    return bbx1, bby1, bbx2, bby2


def cutmix(x: torch.Tensor, y: torch.Tensor, alpha: float = 1.0) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, float]:
    """
    Code adapted from https://github.com/Intelligent-Computing-Lab-Yale/NDA_SNN/blob/main/functions/data_loaders.py
    
    Returns mixed inputs, pairs of targets, and lambda
    
    Args:
        x: input batch
            - Shape: (T, B, C, H, W) or (B, C, H, W)
        y: target batch
        alpha: parameter for beta distribution
        
    Returns:
        input: mixed input batch
        target_a: target batch
        target_b: target batch
        lam: lambda
    """
    lam = np.random.beta(alpha, alpha)
    if x.ndim == 5: # Temporal Data
        rand_index = torch.randperm(x.size()[1]).to(x.device)
    elif x.ndim == 4: # Spatial Data
        rand_index = torch.randperm(x.size()[0]).to(x.device)

    target_a = y
    target_b = y[rand_index]

    # generate mixed sample
    bbx1, bby1, bbx2, bby2 = rand_bbox(x.size(), lam)
    if x.ndim == 5: # Temporal Data
        x[:, :, :, bbx1:bbx2, bby1:bby2] = x[:, rand_index, :, bbx1:bbx2, bby1:bby2]
    elif x.ndim == 4: # Spatial Data
        x[:, :, bbx1:bbx2, bby1:bby2] = x[rand_index, :, bbx1:bbx2, bby1:bby2]
    # adjust lambda to exactly match pixel ratio
    lam = 1 - ((bbx2 - bbx1) * (bby2 - bby1) / (x.size()[-1] * x.size()[-2]))
    return x, target_a, target_b, lam