import os
import tqdm
import json
import torch
import numpy as np
from PIL import Image
import scipy.io as sio
import torchvision.datasets as datasets
import torchvision.transforms as transforms
from torch.utils.data import DataLoader, Dataset


def get_transform():
    mean = [0.485, 0.456, 0.406]
    std = [0.229, 0.224, 0.225]

    # Using ImageNet pretrained models
    train_transform = transforms.Compose([
        transforms.Resize((224, 224)),
        # transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize(mean=mean, std=std)
    ])

    val_transform = transforms.Compose([
        transforms.Resize((224, 224)),
        # transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize(mean=mean, std=std)
    ])
    return train_transform, val_transform


def get_train_val_dataset():
    train_transform, val_transform = get_transform()
    root_path = '/mnt/bn/benchmark-dataset/TIP_SSOD/data/ID/cifar-10/'
    train_set = datasets.CIFAR10(root=root_path, download=True, train=True, transform=train_transform)
    val_set = datasets.CIFAR10(root=root_path, download=True, train=False, transform=val_transform)
    return train_set, val_set
    # all_set = torch.utils.data.ConcatDataset([train_set, val_set])
    # return all_set, val_set

def get_svhn_dataset():
    data_path = './data/OOD/As_CIFAR10_OOD/classicOOD/SVHN/selected_test_svhn_32x32.mat'
    loaded_mat = sio.loadmat(data_path)
    data = loaded_mat['X']
    targets = loaded_mat['y']
    data = np.transpose(data, (3, 0, 1, 2))
    svhn_set = SVHNDataset(data)
    return svhn_set


class SVHNDataset(Dataset):
    def __init__(self, data):
        self.data = data
        _, self.transform = get_transform()

    def __len__(self):
        return len(self.data)

    def __getitem__(self, index):
        img = Image.fromarray(self.data[index]).convert('RGB')
        return self.transform(img)


class MyDataset(Dataset):
    def __init__(self, names, transform, id=True):
        self.names = names
        self.transform = transform
        self.label = id

    def __len__(self):
        return len(self.names)

    def __getitem__(self, index):
        if self.label:
            img_name, label = self.names[index]
            img = Image.open(img_name).convert('RGB')
            img = self.transform(img)
            return img, int(label)
        else:
            img_name = self.names[index]
            img = Image.open(img_name).convert('RGB')
            img = self.transform(img)
            return img


def get_cifar_ood_dataset(ood_type):
    if ood_type == 'iSUN':
        ood_path = './data/OOD/As_CIFAR10_OOD/classicOOD/iSUN/'
        ood_names = [ood_path + img_name for img_name in os.listdir(ood_path)]
    elif ood_type == 'LSUN':
        ood_path = './data/OOD/As_CIFAR10_OOD/classicOOD/LSUN_crop/'
        ood_names = [ood_path + img_name for img_name in os.listdir(ood_path)]
    elif ood_type == 'Places':
        ood_path = './data/OOD/As_CIFAR10_OOD/classicOOD/Places365/'
        ood_names = [ood_path + img_name for img_name in os.listdir(ood_path)]
    elif ood_type == 'SVHN':
        return get_svhn_dataset()
    elif ood_type == 'Texture':
        ood_path = './data/OOD/As_CIFAR10_OOD/classicOOD/Texture/images/'
        ood_names = list()
        for folder in os.listdir(ood_path):
            abs_folder = ood_path + folder
            for img_name in os.listdir(abs_folder):
                ood_names.append(abs_folder + '/' + img_name)
    
    _, val_transform = get_transform()
    ood_dataset = MyDataset(ood_names, val_transform, id=False)
    return ood_dataset
    

if __name__ == '__main__':
    batch_size = 64

    # check ood dataset
    # ood_set = get_cifar_ood_dataset(ood_type='iSUN')
    # ood_set = get_cifar_ood_dataset(ood_type='LSUN')
    # ood_set = get_cifar_ood_dataset(ood_type='SVHN')
    # ood_set = get_cifar_ood_dataset(ood_type='Places')
    ood_set = get_cifar_ood_dataset(ood_type='Texture')
    ood_loader = DataLoader(ood_set, batch_size=batch_size, shuffle=False, num_workers=3)
    for x in tqdm.tqdm(ood_loader):
        print(x.shape, torch.min(x), torch.max(x))

    # train_set, val_set = get_train_val_dataset()
    # print(len(train_set), len(val_set))

    # train_loader = DataLoader(train_set, batch_size=batch_size, shuffle=False, num_workers=3)
    # for x, y in tqdm.tqdm(train_loader):
    #     print(x.shape, y.shape, torch.min(x), torch.max(x))
    
    # val_loader = DataLoader(val_set, batch_size=batch_size, shuffle=False, num_workers=3)
    # for x, y in tqdm.tqdm(val_loader):
    #     print(x.shape, y.shape, torch.min(x), torch.max(x))