import torch
import numpy as np
from torchvision.datasets import CIFAR10, CIFAR100  
from torchvision import transforms
from torch.utils.data import Dataset, DataLoader

class Cutout(object):
    """Randomly mask out one or more patches from an image.

    Args:
        n_holes (int): Number of patches to cut out of each image.
        length (int): The length (in pixels) of each square patch.
    """
    def __init__(self, n_holes=1, length=16): # cifar10:16, cifar100:8
        self.n_holes = n_holes
        self.length = length

    def __call__(self, img):
        """
        Args:
            img (Tensor): Tensor image of size (C, H, W).
        Returns:
            Tensor: Image with n_holes of dimension length x length cut out of it.
        """
        #print(img)
        h = img.size(1)
        #raise ValueError
        w = img.size(2)

        mask = np.ones((h, w), np.float32)

        for n in range(self.n_holes):
            y = np.random.randint(h)
            x = np.random.randint(w)

            y1 = np.clip(y - self.length // 2, 0, h)
            y2 = np.clip(y + self.length // 2, 0, h)
            x1 = np.clip(x - self.length // 2, 0, w)
            x2 = np.clip(x + self.length // 2, 0, w)

            mask[y1: y2, x1: x2] = 0.

        mask = torch.from_numpy(mask)
        mask = mask.expand_as(img)
        img = img * mask

        return img
    
# trans = transforms.Compose([transforms.RandomCrop(32, 4), 
#                             transforms.RandomHorizontalFlip(), 
#                             transforms.ToTensor(), 
#                             Cutout(1, 16)])
# train_set = CIFAR10(root='./data', train=True, download=True, transform=trans)
# train_loader = DataLoader(train_set, batch_size=50000, shuffle=False)
# count = 1
# for data, target in train_loader:
#     print(data.shape, data.max(), data.min())
#     count += 1
#     if count > 10:
#         break



# print(
#     'ntga' in 'ntga-cutout'
# )


"""


python evaluation.py \
    --experiment ntga-cutout \
    --dataset cifar10 \
    --data data/CIFAR10 \
    --backbone resnet18 \
    --cutout \
    --gpu-id 0 > ntga-cutout-sl.output 2>&1&


python evaluation.py \
    --experiment ent-cutout \
    --poison-path  ./baseline/ent-cifar10.pt \
    --dataset cifar10 \
    --data data/CIFAR10 \
    --backbone resnet18 \
    --cutout \
    --gpu-id 0 > ent-cutout-sl.output 2>&1&



"""