from torchvision.transforms import v2 as transforms

import numpy as np
import torch


"""
Cutout implementation from https://github.com/Intelligent-Computing-Lab-Yale/NDA_SNN/blob/main/functions/data_loaders.py
"""
class Cutout(transforms.Transform):
    """Randomly mask out one or more patches from an image.
    Args:
        length (int): The length (in pixels) of each square patch.
    """
    def __init__(self, length: int):
        super(Cutout, self).__init__()
        self.length = length

    def forward(self, img: torch.Tensor) -> torch.Tensor:
        h = img.size(2)
        w = img.size(3)
        mask = np.ones((h, w), np.float32)
        y = np.random.randint(h)
        x = np.random.randint(w)
        y1 = np.clip(y - self.length // 2, 0, h)
        y2 = np.clip(y + self.length // 2, 0, h)
        x1 = np.clip(x - self.length // 2, 0, w)
        x2 = np.clip(x + self.length // 2, 0, w)
        mask[y1: y2, x1: x2] = 0.
        mask = torch.from_numpy(mask)
        mask = mask.expand_as(img)
        img = img * mask
        return img