import torch
import torch.distributions
from torchvision import datasets, transforms
from torch.utils.data import Dataset
from utils.datasets.tinyImages import _load_tiny_image, _preload_tiny_images
from utils.datasets.paths import get_CIFAR10_path, get_CIFAR100_path
from utils.datasets.cifar_augmentation import get_cifar10_augmentation
from .tiny_images_subset import get_80MTinyImages_subset
import os
import utils.datasets as dl
from utils.datasets.combo_dataset import ComboDataset
from .lsun_subset import get_LSUN_scenes_subset

from .cifar_subsets import CIFARSubset

def get_CIFAR10_subset_plus_OD(split, samples_per_class, od_dataset, od_samples,
                               batch_size=128,  augm_type='default', shuffle=True,
                       cutout_window=16, size=32, num_workers=8, config_file=None):

    augm_config = {}
    transform = get_cifar10_augmentation(augm_type, cutout_window, out_size=size, config_dict=augm_config)

    if samples_per_class == 400:
        samples_per_class_per_split = [400, 500, 4100]
    elif samples_per_class == 100:
        samples_per_class_per_split = [100, 500, 4400]
    else:
        raise NotImplementedError()

    dataset = CIFARSubset(split, 'CIFAR10', samples_per_class_per_split, transform, generate_idcs=False)

    if od_dataset == 'lsun_subset':
        od_loader = get_LSUN_scenes_subset('train', samples_per_class=od_samples // 10,
                                           batch_size=batch_size, shuffle=True, augm_type=augm_type,
                                           augm_class='imagenet', size=size, config_dict=config_file)
    elif od_dataset == 'tinyImageNet':
        od_loader = dl.get_TinyImageNet('test', batch_size=batch_size, shuffle=True, augm_type=augm_type,
                                        size=size, config_dict=config_file)
    elif od_dataset == 'normalNoise':
        od_loader = dl.get_noise_dataset(od_samples, type='normal', batch_size=batch_size, augm_type=augm_type,
                                         size=size, config_dict=config_file)
    elif od_dataset == 'uniformNoise':
        od_loader = dl.get_noise_dataset(od_samples, type='uniform', batch_size=batch_size, augm_type=augm_type,
                                         size=size, config_dict=config_file)
    elif od_dataset == 'tinyImages_subset':
        od_loader = get_80MTinyImages_subset(od_samples,batch_size=batch_size, augm_type=augm_type, shuffle=shuffle, size=size,
                                         exclude_cifar=True, exclude_cifar10_1=True, config_dict=config_file)
    else:
        raise NotImplementedError()

    combo_dataset = ComboDataset((dataset, od_loader.dataset))
    loader = torch.utils.data.DataLoader(combo_dataset, batch_size=batch_size, shuffle=shuffle, num_workers=num_workers)

    if config_file is not None:
        config_file['Dataset'] = f'CIFAR10-Subset + {od_dataset}'
        config_file['Batch size'] = batch_size
        config_file['Samples per class'] = samples_per_class
        config_file['OD samples'] = od_samples
        config_file['Augmentation'] = augm_config

    return loader
