import torch
from PIL import Image
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, datasets
import os
import numpy as np
from config import data_raw_dir, data_dir
from glob import glob
import tarfile
import io
from tqdm import tqdm

class CustomDataset(Dataset):
    def __init__(self, root, split, transform=None):
        '''Custom dataset for loading images from a list of paths
        Args:
            root (str): root directory of the images
            split (str): path to a text file containing the list of image paths
            transform (torchvision.transforms): image transformations
        '''
        self.root = root
        self.transform = transform
        with open(f'{split}', 'r') as f:
            self.data = f.readlines()
        #remove newline characters
        self.data = [x.strip() for x in self.data]
        self.data = [os.path.join(root, x) for x in self.data]
        
    
    def __len__(self):
        '''Returns the number of images in the dataset'''
        return len(self.data)
    
    def __getitem__(self, idx):
        '''Returns the image at the given index
        Args:
            idx (int): index of the image
        Returns:
            image (PIL.Image): the image at the given index
            label (int): the label of the image
        '''
        image = Image.open(self.data[idx]).convert('RGB')
        if self.transform:
            image = self.transform(image)
        return image, 0
    
def inaturalist_dataloader(batch_size, img_size, id_dataset='cifar10'):
    transform = transforms.Compose([
        transforms.Resize((img_size, img_size)),
        transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
    ])
    if id_dataset == 'imagenet':
        dataset = CustomDataset(data_raw_dir, os.path.join(data_dir,'splits','test_inaturalist_2.txt'), transform)
    else:
        dataset = CustomDataset(data_raw_dir, os.path.join(data_dir,'splits','test_inaturalist.txt'), transform)
    dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=0)
    return dataloader

def ninco_dataloader(batch_size, img_size, id_dataset='cifar10'):
    transform = transforms.Compose([
        transforms.Resize((img_size, img_size)),
        transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
    ])
    if id_dataset == 'imagenet':
        dataset = CustomDataset(data_raw_dir, os.path.join(data_dir,'splits','test_ninco_2.txt'), transform)
    else:
        dataset = CustomDataset(data_raw_dir, os.path.join(data_dir,'splits','test_ninco.txt'), transform)
    dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=0)
    return dataloader

def ssbhard_dataloader(batch_size, img_size, id_dataset='cifar10'):
    transform = transforms.Compose([
        transforms.Resize((img_size, img_size)),
        transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
    ])
    if id_dataset == 'imagenet':
        dataset = CustomDataset(data_raw_dir, os.path.join(data_dir,'splits','test_ssb_hard.txt'), transform)
    else:
        dataset = CustomDataset(data_raw_dir, os.path.join(data_dir,'splits','test_ssb_hard.txt'), transform)
    dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=0)
    return dataloader

def openimageo_dataloader(batch_size, img_size, id_dataset='cifar10'):
    transform = transforms.Compose([
        transforms.Resize((img_size, img_size)),
        transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
    ])
    if id_dataset == 'imagenet':
        dataset = CustomDataset(data_raw_dir, os.path.join(data_dir,'splits','test_openimage_o_2.txt'), transform)
    else:
        dataset = CustomDataset(data_raw_dir, os.path.join(data_dir,'splits','test_openimage_o.txt'), transform)
    dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=0)
    return dataloader

def dtd_dataloader(batch_size, img_size, id_dataset='cifar10'):

    if id_dataset == 'cifar10' or id_dataset == 'tinyimagenet' or id_dataset == 'imagenet':
        transform = transforms.Compose([
            transforms.Resize((img_size, img_size)),
            transforms.ToTensor(),
            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
        ])

    else:
        transform = transforms.Compose([
            transforms.Grayscale(num_output_channels=1),
            transforms.Resize((32, 32)),
            transforms.ToTensor(),
            transforms.Normalize((0.5), (0.5))
        ])
    
    if id_dataset == 'tinyimagenet':
        dataset = CustomDataset(data_raw_dir, os.path.join(data_dir,'splits','test_dtd_2.txt'), transform)
    elif id_dataset == 'imagenet':
        dataset = CustomDataset(data_raw_dir, os.path.join(data_dir,'splits','test_dtd_3.txt'), transform)
    else:
        dataset = CustomDataset(data_raw_dir, os.path.join(data_dir,'splits','test_dtd.txt'), transform)
    dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=0)
    return dataloader

def places365_dataloader(batch_size, img_size, id_dataset='cifar10'):

    if id_dataset == 'cifar10' or id_dataset == 'tinyimagenet':
        transform = transforms.Compose([
            transforms.Resize((img_size, img_size)),
            transforms.ToTensor(),
            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
        ])

    else:
        transform = transforms.Compose([
            transforms.Grayscale(num_output_channels=1),
            transforms.Resize((32, 32)),
            transforms.ToTensor(),
            transforms.Normalize((0.5), (0.5))
        ])
    if id_dataset == 'imagenet':
        dataset = CustomDataset(data_raw_dir, os.path.join(data_dir,'splits','test_places365_2.txt'), transform)
    else:
        dataset = CustomDataset(data_raw_dir, os.path.join(data_dir,'splits','test_places365.txt'), transform)
    dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=0)
    return dataloader

class MNISTC(Dataset):
    def __init__(self, root, corruption, transform=None):
        self.root = root
        self.transform = transform
        self.data = np.load(os.path.join(root, 'mnist_c', f'{corruption}', 'test_images.npy')).squeeze()
    
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        image = Image.fromarray(self.data[idx])
        if self.transform:
            image = self.transform(image)
        return image, 0

def mnistc_dataloader(batch_size, corruption):

    transform = transforms.Compose([
        transforms.Resize((32, 32)),
        transforms.ToTensor(),
        transforms.Normalize((0.5), (0.5))
    ])

    dataset = MNISTC(data_raw_dir, corruption, transform)
    dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=0)
    return dataloader

class ImageNetC(Dataset):
    def __init__(self, root, corruption, intensity, transform=None):
        self.root = root
        self.transform = transform
        self.data = glob(os.path.join(root, 'ImageNet-C', f'{corruption}', f'{intensity}','*', '*.JPEG'))
        self.data = [Image.open(i).convert('RGB') for i in self.data]
    
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        image = self.data[idx]
        if self.transform:
            image = self.transform(image)
        return image, 0
    
def imagenetc_dataloader(batch_size, corruption, intensity):
    
        transform = transforms.Compose([
            transforms.Resize((128, 128)),
            transforms.ToTensor(),
            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
        ])
    
        dataset = ImageNetC(data_raw_dir, corruption, intensity, transform)
        dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=0)
        return dataloader

class TinyImageNetC(Dataset):
    def __init__(self, root, corruption, intensity, transform=None):
        self.root = root
        self.transform = transform
        self.data = glob(os.path.join(root, 'Tiny-ImageNet-C', f'{corruption}', f'{intensity}','*', '*.JPEG'))
    
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        image = Image.open(self.data[idx]).convert('RGB')
        if self.transform:
            image = self.transform(image)
        return image, 0
    
def tinyimagenetc_dataloader(batch_size, corruption, intensity):
    
        transform = transforms.Compose([
            transforms.Resize((64, 64)),
            transforms.ToTensor(),
            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
        ])
    
        dataset = TinyImageNetC(data_raw_dir, corruption, intensity, transform)
        dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=0)
        return dataloader

class CIFAR10C(Dataset):
    def __init__(self, root, corruption, intensity, transform=None):
        self.root = root
        self.transform = transform
        self.data = np.load(os.path.join(root, 'CIFAR-10-C', f'{corruption}.npy'))[(intensity-1)*10000:intensity*10000]
    
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        image = Image.fromarray(self.data[idx])
        if self.transform:
            image = self.transform(image)
        return image, 0

def cifar10c_dataloader(batch_size, corruption, intensity):
    
        transform = transforms.Compose([
            transforms.Resize((32, 32)),
            transforms.ToTensor(),
            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
        ])
    
        dataset = CIFAR10C(data_raw_dir, corruption, intensity, transform)
        dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=0)
        return dataloader

def mnist_dataloader_train(batch_size, img_size):
    
    transform = transforms.Compose([
        transforms.Resize((32, 32)),
        transforms.ToTensor(),
        transforms.Normalize((0.5), (0.5))
    ])

    dataset = datasets.MNIST(data_raw_dir, train=True, transform=transform, download=True)
    dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=0)
    return dataloader

def mnist_dataloader_test(batch_size, img_size, in_dataset='mnist'):

    if in_dataset == 'mnist':
        transform = transforms.Compose([
            transforms.Resize((32, 32)),
            transforms.ToTensor(),
            transforms.Normalize((0.5), (0.5))
        ])

    else:
        transform = transforms.Compose([
            transforms.Resize((32, 32)),
            transforms.ToTensor(),
            transforms.Normalize((0.5), (0.5)),
            transforms.Lambda(lambda x: x.repeat(3, 1, 1))
        ])

    dataset = datasets.MNIST(data_raw_dir, train=False, transform=transform, download=True)
    dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=0)
    return dataloader

def fashionmnist_dataloader(batch_size, img_size):

    transform = transforms.Compose([
        transforms.Resize((32, 32)),
        transforms.ToTensor(),
        transforms.Normalize((0.5), (0.5))
    ])

    dataset = datasets.FashionMNIST(data_raw_dir, train=False, transform=transform, download=True)
    dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=0)
    return dataloader

def cifar10_dataloader_train(batch_size, img_size):

    transform = transforms.Compose([
        transforms.Resize((32, 32)),
        transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
    ])

    dataset = datasets.CIFAR10(data_raw_dir, train=True, transform=transform, download=True)
    dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=0)
    return dataloader

def cifar10_dataloader_test(batch_size, img_size, in_dataset='mnist'):

    if in_dataset == 'mnist':
        transform = transforms.Compose([
            transforms.Grayscale(num_output_channels=1),
            transforms.Resize((32, 32)),
            transforms.ToTensor(),
            transforms.Normalize((0.5), (0.5))
        ])
    
    else:
        transform = transforms.Compose([
            transforms.Resize((32, 32)),
            transforms.ToTensor(),
            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
        ])

    dataset = datasets.CIFAR10(data_raw_dir, train=False, transform=transform, download=True)
    dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=0)
    return dataloader

def cifar100_dataloader(batch_size, img_size):

    transform = transforms.Compose([
        transforms.Resize((32, 32)),
        transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
    ])

    dataset = datasets.CIFAR100(data_raw_dir, train=False, transform=transform, download=True)
    dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=0)
    return dataloader

def svhn_dataloader(batch_size, img_size):
    
        transform = transforms.Compose([
            transforms.Resize((32, 32)),
            transforms.ToTensor(),
            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
        ])
    
        dataset = datasets.SVHN(data_raw_dir, split='test', transform=transform, download=True)
        dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=0)
        return dataloader

class TinyImageNet(Dataset):
    def __init__(self, root, train=False, transform=None):
        self.root = root
        self.transform = transform
        if train:
            self.data = glob(os.path.join(root, 'tinyimagenet', 'train', '*', 'images', '*.JPEG'))
        else:
            self.data = glob(os.path.join(root, 'tinyimagenet', 'test', 'images', '*.JPEG'))
    
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        image = Image.open(self.data[idx]).convert('RGB')
        if self.transform:
            image = self.transform(image)
        return image, 0

def tinyimagenet_dataloader_train(batch_size, img_size):
    
        transform = transforms.Compose([
            transforms.Resize((64, 64)),
            transforms.ToTensor(),
            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
        ])
    
        dataset = TinyImageNet(data_raw_dir, train=True, transform=transform)
        dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=0)
        return dataloader

def tinyimagenet_dataloader_test(batch_size, img_size, id_dataset='cifar10'):

    if id_dataset == 'cifar10' or id_dataset == 'tinyimagenet':
        transform = transforms.Compose([
            transforms.Resize((img_size, img_size)),
            transforms.ToTensor(),
            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
        ])
    else:
        transform = transforms.Compose([
            transforms.Grayscale(num_output_channels=1),
            transforms.Resize((32, 32)),
            transforms.ToTensor(),
            transforms.Normalize((0.5), (0.5))
        ])

    dataset = TinyImageNet(data_raw_dir, train=False, transform=transform)
    dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=0)
    return dataloader

class ImageNetDataset(Dataset):  
    def __init__(self, root, train = False, transform=None):
        self.root = root
        self.transform = transform
        self.train = train
        self.imgs = []
        if train:
            with open(os.path.join(data_dir,'splits', 'train_images.txt'), 'r') as f:
                self.imgs = f.readlines()
            self.imgs = [i.strip() for i in self.imgs]
            self.imgs = self.imgs
            self.tar_files = os.listdir(os.path.join(root, 'imagenet', 'train'))
            self.tar_files = {f"{i[:-4]}": tarfile.open(os.path.join(root, 'imagenet', 'train', i)) for i in tqdm(self.tar_files, desc='Opening tar files')}
            self.prefix = self.tar_files[self.imgs[0].split('/')[-2]].getnames()[0].split('/')[:-1]
            self.prefix = '/'.join(self.prefix)
            # load all images to memory
            self.imgs = [Image.open(io.BytesIO(self.tar_files[img.split('/')[-2]].extractfile(f'{self.prefix}/{img}').read())).convert('RGB') for img in tqdm(self.imgs, desc='Loading images')]
            del self.tar_files
        else:
            with open(os.path.join(data_dir,'splits', 'test_images.txt'), 'r') as f:
                self.imgs = f.readlines()
            self.imgs = [i.strip() for i in self.imgs]
            self.tar_files = tarfile.open(os.path.join(root, 'imagenet', 'test.tar'))
            self.prefix = self.tar_files.getnames()[0]
            # load all images to memory
            self.imgs = [Image.open(io.BytesIO(self.tar_files.extractfile(f'{self.prefix}/{img}').read())).convert('RGB') for img in tqdm(self.imgs, desc='Loading images')]
            del self.tar_files
    def __len__(self):
        return len(self.imgs)
    
    def __getitem__(self, idx):
        if self.train:
            #tar_file = self.imgs[idx].split('/')[-2]
            #img = Image.open(io.BytesIO(self.tar_files[tar_file].extractfile(f'{self.prefix}/{self.imgs[idx]}').read())).convert('RGB')
            img = self.imgs[idx]
        else:
            #img = Image.open(io.BytesIO(self.tar_files.extractfile(f'{self.prefix}/{self.imgs[idx]}').read())).convert('RGB')
            img = self.imgs[idx]
        if self.transform:
            img = self.transform(img)
        return img, 0
    
def imagenet_train_loader(batch_size, input_shape = None):
        
    transform = transforms.Compose([
        transforms.Resize((input_shape,input_shape)) if input_shape is not None else transforms.Resize((128,128)),
        transforms.ToTensor(),
        transforms.Normalize((0.5,0.5,0.5), (0.5,0.5,0.5)),
    ])
        
    training_data = ImageNetDataset(root=data_raw_dir, train = True, transform=transform)

    training_loader = DataLoader(training_data, 
                                batch_size=batch_size, 
                                shuffle=True,
                                pin_memory=True)
    
    return training_loader
        
def imagenet_val_loader(batch_size, input_shape = None):
                    
    transform = transforms.Compose([
        transforms.Resize((input_shape,input_shape)) if input_shape is not None else transforms.Resize((128,128)),
        transforms.ToTensor(),
        transforms.Normalize((0.5,0.5,0.5), (0.5,0.5,0.5)),
    ])
    validation_data = ImageNetDataset(root=data_raw_dir, train = False, transform=transform)

    validation_loader = DataLoader(validation_data,
                                batch_size=batch_size,
                                shuffle=True,
                                pin_memory=True)
    
    return validation_loader 

def pick_dataset(name='cifar10', train=True, batch_size=64, img_size=32, id_dataset='cifar10'):
    '''Returns a dataloader for the given dataset
    Args:
        name (str): name of the dataset
        train (bool): whether to load the training or test set
        batch_size (int): batch size
        img_size (int): size of the images
        Returns:
        dataloader (torch.utils.data.DataLoader): dataloader for the given dataset
    '''
    if name == 'mnist':
        if train:
            return mnist_dataloader_train(batch_size, img_size)
        else:
            return mnist_dataloader_test(batch_size, img_size, in_dataset=id_dataset)
    elif name == 'fashionmnist':
        return fashionmnist_dataloader(batch_size, img_size)
    elif name == 'cifar10':
        if train:
            return cifar10_dataloader_train(batch_size, img_size)
        else:
            return cifar10_dataloader_test(batch_size, img_size, in_dataset=id_dataset)
    elif name == 'cifar100':
        return cifar100_dataloader(batch_size, img_size)
    elif name == 'tinyimagenet':
        if train:
            return tinyimagenet_dataloader_train(batch_size, img_size)
        else:
            return tinyimagenet_dataloader_test(batch_size, img_size, id_dataset=id_dataset)
    elif name == 'imagenet':
        if train:
            return imagenet_train_loader(batch_size, input_shape=img_size)
        else:
            return imagenet_val_loader(batch_size, input_shape=img_size)
    elif name == 'dtd':
        return dtd_dataloader(batch_size, img_size, id_dataset=id_dataset)
    elif name == 'places365':
        return places365_dataloader(batch_size, img_size, id_dataset=id_dataset)
    elif name == 'inaturalist':
        return inaturalist_dataloader(batch_size, img_size)
    elif name == 'ninco':
        return ninco_dataloader(batch_size, img_size)
    elif name == 'ssbhard' or name == 'ssb-hard':
        return ssbhard_dataloader(batch_size, img_size)
    elif name == 'openimageo':
        return openimageo_dataloader(batch_size, img_size)
    elif name == 'svhn':
        return svhn_dataloader(batch_size, img_size)
    else:
        raise ValueError(f'Unknown dataset: {name}')