import numpy as np
import torch
import torchvision
import torchvision.transforms as transforms
from tqdm import tqdm
from torch.utils.data import Dataset


def get_torchvision_dataset(dataset_name, 
                            train_transforms, 
                            test_transforms=None,
                            ):
    if test_transforms is None:
        test_transforms = train_transforms

    if dataset_name == "mnist": 
        data_init = torchvision.datasets.MNIST
    if dataset_name == "fashion_mnist": 
        data_init = torchvision.datasets.FashionMNIST
    if dataset_name == "svhn": 
        data_init = torchvision.datasets.SVHN
    if dataset_name == "cifar10": 
        data_init = torchvision.datasets.CIFAR10

    if dataset_name == "svhn":
        train_set = data_init(root="../data", 
                              split="train", 
                              download=True, 
                              transform=transforms.Compose(train_transforms))
        test_set = data_init(root="../data", 
                             split="test", 
                             download=True, 
                             transform=transforms.Compose(test_transforms))
    else:
        train_set = data_init(root="../data", 
                              train=True, 
                              download=True, 
                              transform=transforms.Compose(train_transforms))
        test_set = data_init(root="../data", 
                             train=False, 
                             download=True, 
                             transform=transforms.Compose(test_transforms))
    return train_set, test_set


class PoisonedDataset(torch.utils.data.Dataset):
    def __init__(self, original_dataset, poison_ids, num_class, poisoning_strategy="shift_right"):
        self.data = original_dataset
        self.poison_ids = poison_ids
        self.num_class = num_class
        self.poisoning_strategy = poisoning_strategy
    
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        sample = self.data[idx]
        if idx in self.poison_ids:
            true_label = sample[1]
            if self.poisoning_strategy == "random_flip":
                possible_poisoned_labels = [i for i in range(self.num_class) if i != true_label]
                poison_label = np.random.choice(possible_poisoned_labels)
            elif self.poisoning_strategy == "shift_right":
                poison_label = true_label + 1
                if poison_label >= self.num_class:
                    poison_label = 0
            else:
                raise NotImplementedError
            sample = (sample[0], poison_label)
        
        return sample



class FashionMNISTSubset(Dataset):
    def __init__(self, train: bool, root: str = '../data'):
        self.transforms = [
            transforms.ToTensor(), 
            transforms.Resize(8),
            # transforms.Normalize((0.5,), (0.5,)),
        ]
        self.data = torchvision.datasets.FashionMNIST(
            root=root, train=train, download=True, 
            transform=transforms.Compose(self.transforms),
        )
        self.n_class = 10 
        self.data, self.n_class = self._subset(subset_class=[0,1,3,4,5,7,8,9])
        self.shape = self.data[0][0].shape
    
    def __getitem__(self, index):
        return self.data[index]
    
    def __len__(self):
        return len(self.data)
    
    def _subset(self, subset_class):
        data = []
        subset_id = {c: i for i, c in enumerate(subset_class)} 
        subset_class = set(subset_class)
        for z in tqdm(self.data, desc='Slicing'):
            if z[1] in subset_class:
                new_lbl = subset_id[z[1]]
                data.append((z[0], new_lbl))
        return data, len(subset_class)