import torch
import torch.distributions
from torch.utils.data import Dataset
from torchvision import datasets, transforms
import numpy as np
import os

from utils.datasets.cifar_augmentation import get_cifar10_augmentation
from utils.datasets.paths import get_tiny_images_files
from utils.datasets.tinyImages import _load_cifar_exclusion_idcs, TINY_LENGTH

DEFAULT_TRAIN_BATCHSIZE = 128
DEFAULT_TEST_BATCHSIZE = 128

def get_80MTinyImages_subset(num_samples, batch_size=100, augm_type='default', shuffle=True, cutout_window=16, num_workers=1,
                      size=32, exclude_cifar=False, exclude_cifar10_1=False, config_dict=None):

    augm_config = {}
    transform = get_cifar10_augmentation(augm_type, cutout_window=cutout_window, out_size=size, config_dict=augm_config)

    dataset_out = TinyImagesSubset(transform, num_samples,
                                    exclude_cifar=exclude_cifar, exclude_cifar10_1=exclude_cifar10_1)

    loader = torch.utils.data.DataLoader(dataset_out, batch_size=batch_size,
                                         shuffle=shuffle, num_workers=num_workers)

    if config_dict is not None:
        if config_dict is not None:
            config_dict['Dataset'] = '80M Tiny Images Subset'
            config_dict['Samples'] = num_samples
            config_dict['Shuffle'] = shuffle
            config_dict['Batch out_size'] = batch_size
            config_dict['Exclude CIFAR'] = exclude_cifar
            config_dict['Exclude CIFAR10.1'] = exclude_cifar10_1
            config_dict['Augmentation'] = augm_config

    return loader

def _generate_split(num_samples, exclude_cifar, exclude_cifar10_1, additional_excluded_idcs=None):
    print(f'Generating 80M split - Exclude Cifar {exclude_cifar} - Exclude Cifar10_1 {exclude_cifar10_1}')
    exclusion_idcs = _load_cifar_exclusion_idcs(exclude_cifar, exclude_cifar10_1)
    available_samples = torch.ones(TINY_LENGTH, dtype=torch.long)
    available_samples[exclusion_idcs] = 0

    if additional_excluded_idcs is not None:
        available_samples[additional_excluded_idcs] = 0

    available_samples = torch.nonzero(available_samples, as_tuple=False).squeeze()

    idcs = available_samples[torch.randperm(len(available_samples))[:num_samples]]
    return idcs

# Code from https://github.com/hendrycks/outlier-exposure
class TinyImagesSubset(Dataset):
    def __init__(self, transform_base, num_samples, exclude_cifar=False, exclude_cifar10_1=False, generate_split=False):
        self.data_location = get_tiny_images_files(False)
        self.memap = np.memmap(self.data_location, mode='r', dtype='uint8', order='C').reshape(TINY_LENGTH, -1)

        filename = f'tiny_images_split_{num_samples}_cifar_{exclude_cifar}_cifar10_1_{exclude_cifar10_1}.pt'
        if os.path.isfile(filename):
           subset_idcs = torch.load(filename)
        elif generate_split:
            subset_idcs = _generate_split(num_samples, exclude_cifar, exclude_cifar10_1)
            torch.save(subset_idcs, filename)
        else:
            raise ValueError(f'Indices file {filename} could not be found and generate_split is set to False')

        if transform_base is not None:
            transform = transforms.Compose([
                transforms.ToPILImage(),
                transform_base])
        else:
            transform = transforms.Compose([
                transforms.ToPILImage(),
                transforms.ToTensor()])

        self.transform = transform
        self.exclude_cifar = exclude_cifar


        self.included_indices = subset_idcs
        self.length = len(subset_idcs)
        print(f'80M Tiny Images Subset- Length {self.length}')

    def __getitem__(self, ii):
        index = self.included_indices[ii]
        img = self.memap[index].reshape(32, 32, 3, order="F")

        if self.transform is not None:
            img = self.transform(img)

        return img, 0  # 0 is the class

    def __len__(self):
        return self.length



