import torch
import torch.distributions
from torch.utils.data import Dataset
from torchvision import datasets
from utils.datasets.paths import get_svhn_path
from utils.datasets.svhn_augmentation import get_SVHN_augmentation
from utils.datasets.tinyImages import TinyImagesDataset
from utils.datasets.combo_dataset import ComboDataset
from ssl_utils.svhn_validation_extra_split import SVHNValidationExtraSplit

DEFAULT_TRAIN_BATCHSIZE = 128
DEFAULT_TEST_BATCHSIZE = 128

def get_SVHNTinyCombo(split='extra-split', shuffle=None, batch_size=None, augm_type='none', num_workers=1,
                      config_dict=None):
    if num_workers > 1:
        raise ValueError('Bug in the current multithreaded tinyimages implementation')

    if batch_size==None:
        if split in ['train', 'extra', 'extra-split']:
            batch_size=DEFAULT_TRAIN_BATCHSIZE
        else:
            batch_size=DEFAULT_TEST_BATCHSIZE

    svhn_path = get_svhn_path()

    augm_config = {}
    transform = get_SVHN_augmentation(augm_type, config_dict=augm_config)

    if split == 'extra-split':
        svhn = SVHNValidationExtraSplit(svhn_path, 'extra-split', transform)
    else:
        svhn = datasets.SVHN(svhn_path, split=split, transform=transform)

    if shuffle is None:
        if split in ['train', 'extra', 'extra-split']:
            shuffle = True
        else:
            shuffle = False

    tiny_images = TinyImagesDataset(transform)
    dataset = ComboDataset([svhn, tiny_images])
    loader = torch.utils.data.DataLoader(dataset, batch_size=batch_size,
                                               shuffle=shuffle, num_workers=num_workers)

    if config_dict is not None:
        if config_dict is not None:
            config_dict['Dataset'] = 'SVHN + 80M Tiny Images'
            config_dict['SVHN Split'] = split
            config_dict['Shuffle'] = shuffle
            config_dict['Batch out_size'] = batch_size
            config_dict['Augmentation'] = augm_config

    return loader
