from PIL import Image

import numpy as np

import torch

class Cutout(object):
    """Randomly mask out one or more patches from an image.

    Args:
        n_holes (int): Number of patches to cut out of each image.
        length (int): The length (in pixels) of each square patch.
    """
    def __init__(self, n_holes, length):
        self.n_holes = n_holes
        self.length = length

    def __call__(self, img: Image.Image) -> Image.Image:
        """
        Args:
            img (PIL Image): PIL Image of size (H, W, C).
        Returns:
            PIL Image: Image with n_holes of dimension length x length cut out of it.
        """
        img = np.array(img)
        h, w, _ = img.shape

        mask = np.ones((h, w), np.float32)

        for _ in range(self.n_holes):
            y = np.random.randint(h)
            x = np.random.randint(w)

            y1 = np.clip(y - self.length // 2, 0, h)
            y2 = np.clip(y + self.length // 2, 0, h)
            x1 = np.clip(x - self.length // 2, 0, w)
            x2 = np.clip(x + self.length // 2, 0, w)

            mask[y1: y2, x1: x2] = 0.

        mask = np.expand_dims(mask, axis=-1)
        img = img * mask

        img = Image.fromarray(img.astype(np.uint8))

        return img