import numpy as np
import torch
import random
import matplotlib.pyplot as plt
import torchvision.datasets as dset
import torchvision.transforms as transforms

RAW_CIFAR_MEAN = torch.tensor([0.49139968, 0.48215827, 0.44653124])
RAW_CIFAR_STD = torch.tensor([0.24703233, 0.24348505, 0.26158768])

CIFAR_MEAN = RAW_CIFAR_MEAN.clone()[None, :, None, None].squeeze(0)
CIFAR_STD = RAW_CIFAR_STD.clone()[None, :, None, None].squeeze(0)

# can use to verify poisons
def imshow(tensor, name, denormalize=False):

    if not type(tensor) == torch.Tensor:
        tensor = transforms.ToTensor()(tensor)
    
    if denormalize:
        tensor = tensor * CIFAR_STD + CIFAR_MEAN

    img = tensor.numpy().transpose((1, 2, 0))
    plt.imshow(img)
    plt.axis("off")
    plt.savefig(f"./{name}.png")

# Class to do Cutout
# Credit: https://github.com/chenxin061/pdarts
class Cutout(object):
    def __init__(self, length):
        self.length = length

    def __call__(self, img):
        h, w = img.size(1), img.size(2)
        mask = np.ones((h, w), np.float32)
        y = np.random.randint(h)
        x = np.random.randint(w)

        y1 = np.clip(y - self.length // 2, 0, h)
        y2 = np.clip(y + self.length // 2, 0, h)
        x1 = np.clip(x - self.length // 2, 0, w)
        x2 = np.clip(x + self.length // 2, 0, w)

        mask[y1: y2, x1: x2] = 0.
        mask = torch.from_numpy(mask)
        mask = mask.expand_as(img)
        img *= mask
        return img

# common transforms for CIFAR10
# adapted from https://github.com/chenxin061/pdarts
def data_transforms_cifar10(cutout=False, cutout_length=16):
  CIFAR_MEAN = [0.49139968, 0.48215827, 0.44653124]
  CIFAR_STD = [0.24703233, 0.24348505, 0.26158768]

  train_transform = transforms.Compose([
    transforms.RandomCrop(32, padding=4),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize(CIFAR_MEAN, CIFAR_STD),
  ])
  if cutout:
    train_transform.transforms.append(Cutout(cutout_length))

  valid_transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(CIFAR_MEAN, CIFAR_STD),
    ])
  return train_transform, valid_transform