import torch
from torch.utils.data import Dataset
import torchvision
import torchvision.transforms as transforms
from torchvision.transforms import InterpolationMode
import random
import numpy as np

class OccludedCifar(Dataset):
    def __init__(self, cifar_path, mnist_path, train, label, aug):
        self.label = label
        self.aug = aug
        cifar_transform = transforms.Compose(
        [transforms.ToTensor()])
        self.cifar_set = torchvision.datasets.CIFAR10(root=cifar_path, train=train,
                                        download=False, transform=cifar_transform)
        mnist_transform = transforms.Compose(
            [transforms.Resize(20),
             transforms.RandomCrop(32, padding=(12,12,0,0), pad_if_needed=True, fill=0, padding_mode='constant'),
             transforms.ToTensor()])
        self.mnist_set = torchvision.datasets.MNIST(root=mnist_path, train=train,
                                        download=False, transform=mnist_transform)
        random.seed(314)
        self.shuffle_index = random.sample(range(len(self.mnist_set)), len(self.cifar_set))
        self.transform = transforms.Compose(
            [transforms.RandomCrop(32, padding=4)])
        
    def __getitem__(self, i):
        cifar_img, cifar_label = self.cifar_set[i]
        mnist_img, mnist_label = self.mnist_set[self.shuffle_index[i]]
        img = torch.maximum(cifar_img, mnist_img>0.5)
        img = (img-0.5)/0.5
        
        if self.aug:
            img = self.transform(img)
            
        if self.label == "cifar":
            label = cifar_label
        elif self.label == "mnist":
            label = mnist_label
        return img, label

    def __len__(self):
        return len(self.cifar_set)
    
class Cifar10h(Dataset):
    def __init__(self, cifar10_path, transform=None):
        self.cifar_set = torchvision.datasets.CIFAR10(root=cifar10_path, train=False, download=False)
        self.prob_labels = np.load("cifar-10h/data/cifar10h-probs.npy")
        self.transform = transform
        
    def __getitem__(self, i):
        img, _ = self.cifar_set[i]
        label = torch.tensor(self.prob_labels[i], dtype=torch.float32)
        
        if self.transform:
            img = self.transform(img)

        return img, label

    def __len__(self):
        return len(self.cifar_set)