import os
import PIL
import torch
import pickle
import random
import numpy as np
import torchvision
import matplotlib.pyplot as plt

from PIL import Image
from tqdm import tqdm
from torchvision import transforms, datasets
from torch.utils.data.dataset import Dataset

def seed_torch(seed=0):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed) # if you are using multi-GPU.
    torch.backends.cudnn.benchmark = False
    torch.backends.cudnn.deterministic = True

class CIFAR10(Dataset):
    def __init__(self, train):
        super(CIFAR10, self).__init__()
        if train:
            seed_torch(0)
        else:
            seed_torch(1)
        self.train = train
        self.num_attributes = 128 #128,8,8
        self.num_features = 128 #128,8,8
        self.num_classes = 10
        if self.train:
            self.data_size = 50000
            # cifar = torch.load('../18_cifar/PyTorch_CIFAR10/cifar_without_pca_l4.pth')
            cifar = torch.load('./PyTorch_CIFAR10/cifar_without_pca_l4.pth')
            # self.data = torch.Tensor(cifar['final_train_features'])
            self.data = torch.Tensor(cifar['resnet18_train_features'])
            self.labels = torch.Tensor(cifar['train_labels'])
            self.labels = self.labels.type(torch.LongTensor)
        else:
            self.data_size = 10000
            # cifar = torch.load('../18_cifar/PyTorch_CIFAR10/cifar_without_pca_l4.pth')
            cifar = torch.load('./PyTorch_CIFAR10/cifar_without_pca_l4.pth')
            # self.data = torch.Tensor(cifar['final_test_features'])
            self.data = torch.Tensor(cifar['resnet18_test_features'])
            self.labels = torch.Tensor(cifar['test_labels'])
            self.labels = self.labels.type(torch.LongTensor)
        
    def __getitem__(self, i):
        return i, self.data[i], self.labels[i]
        
    def __len__(self):
        return self.data.shape[0]

    # def nearest(self,x):
    #     idx = torch.searchsorted(self.sortedX, x)
    #     return self.indices[idx]

    def plot(self):
        pass
        # plt.show()
        # plt.savefig('data3.png')




class PathMNIST(Dataset):
    def __init__(self, train):
        super(PathMNIST, self).__init__()
        if train:
            seed_torch(0)
        else:
            seed_torch(1)
        self.train = train
        self.num_attributes = 64 # 64,7,7
        self.num_features = 64 # 64,7,7
        self.num_classes = 9
        if self.train:
            self.data_size = 89996
            # cifar = torch.load('../21_pathmnist/PyTorch_PathMNIST/pathmnist_without_pca_l5.pth')
            cifar = torch.load('./PyTorch_PathMNIST/pathmnist_without_pca_l5.pth')
            # self.data = torch.Tensor(cifar['final_train_features'])
            self.data = torch.Tensor(cifar['resnet18_train_features'])
            self.labels = torch.Tensor(cifar['train_labels'])
            self.labels = self.labels.type(torch.LongTensor)
        else:
            self.data_size = 7180
            # cifar = torch.load('../21_pathmnist/PyTorch_PathMNIST/pathmnist_without_pca_l5.pth')
            cifar = torch.load('./PyTorch_PathMNIST/pathmnist_without_pca_l5.pth')
            # self.data = torch.Tensor(cifar['final_test_features'])
            self.data = torch.Tensor(cifar['resnet18_test_features'])
            self.labels = torch.Tensor(cifar['test_labels'])
            self.labels = self.labels.type(torch.LongTensor)

            idx = random.sample(range(0, self.data_size), 5)
            self.test_sample_data = self.data[idx]
            self.test_sample_labels = self.labels[idx]
        
    def __getitem__(self, i):
        return i, self.data[i], self.labels[i]
        
    def __len__(self):
        return self.data.shape[0]

class DermaMNIST(Dataset):
    def __init__(self, train):
        super(DermaMNIST, self).__init__()
        if train:
            seed_torch(0)
        else:
            seed_torch(1)
        self.train = train
        self.num_attributes = 64 # 64,7,7
        self.num_features = 64 # 64,7,7
        self.num_classes = 7
        if self.train:
            self.data_size = 7007
            # cifar = torch.load('../22_dermamnist/PyTorch_DermaMNIST/dermamnist_without_pca_l5.pth')
            cifar = torch.load('./PyTorch_DermaMNIST/dermamnist_without_pca_l5.pth')
            # self.data = torch.Tensor(cifar['final_train_features'])
            self.data = torch.Tensor(cifar['resnet18_train_features'])
            self.labels = torch.Tensor(cifar['train_labels'])
            self.labels = self.labels.type(torch.LongTensor)
        else:
            self.data_size = 2005
            # cifar = torch.load('../22_dermamnist/PyTorch_DermaMNIST/dermamnist_without_pca_l5.pth')
            cifar = torch.load('./PyTorch_DermaMNIST/dermamnist_without_pca_l5.pth')
            # self.data = torch.Tensor(cifar['final_test_features'])
            self.data = torch.Tensor(cifar['resnet18_test_features'])
            self.labels = torch.Tensor(cifar['test_labels'])
            self.labels = self.labels.type(torch.LongTensor)

            idx = random.sample(range(0, self.data_size), 5)
            self.test_sample_data = self.data[idx]
            self.test_sample_labels = self.labels[idx]
        
    def __getitem__(self, i):
        return i, self.data[i], self.labels[i]
        
    def __len__(self):
        return self.data.shape[0]

class SVHN(Dataset):
    def __init__(self, train):
        super(SVHN, self).__init__()
        if train:
            seed_torch(0)
        else:
            seed_torch(1)
        self.train = train
        self.num_attributes = 64 # 64,8,8
        self.num_features = 64 # 64,8,8
        self.num_classes = 10
        if self.train:
            self.data_size = 73257
            # cifar = torch.load('../24_svhn/PyTorch_SVHN/svhn_without_pca_l5.pth')
            cifar = torch.load('./PyTorch_SVHN/svhn_without_pca_l5.pth')
            # self.data = torch.Tensor(cifar['final_train_features'])
            self.data = torch.Tensor(cifar['resnet18_train_features'])
            self.labels = torch.Tensor(cifar['train_labels'])
            self.labels = self.labels.type(torch.LongTensor)
        else:
            self.data_size = 26032
            # cifar = torch.load('../24_svhn/PyTorch_SVHN/svhn_without_pca_l5.pth')
            cifar = torch.load('./PyTorch_SVHN/svhn_without_pca_l5.pth')
            # self.data = torch.Tensor(cifar['final_test_features'])
            self.data = torch.Tensor(cifar['resnet18_test_features'])
            self.labels = torch.Tensor(cifar['test_labels'])
            self.labels = self.labels.type(torch.LongTensor)

            idx = random.sample(range(0, self.data_size), 5)
            self.test_sample_data = self.data[idx]
            self.test_sample_labels = self.labels[idx]
        
    def __getitem__(self, i):
        return i, self.data[i], self.labels[i]
        
    def __len__(self):
        return self.data.shape[0]

    def class_wise(self):
        count = [0]*10
        for i in range(self.data_size):
            count[self.labels[i]] = count[self.labels[i]] + 1
        print(count)