import torch
import torch.distributions
from torchvision import datasets, transforms
from torchvision.datasets.vision import VisionDataset
from torch.utils.data import DataLoader, Dataset
from utils.datasets.preproc import Gray
from utils.datasets.semisupervised_dataset import SemiSupervisedDataset, SemiSupervisedSampler
import numpy as np
#from .auto_augmen_old import AutoAugment
from utils.datasets.autoaugment import CIFAR10Policy
from utils.datasets.cutout import Cutout
from utils.datasets.utils import GaussianSmoothing
from PIL import Image
import pickle

from utils.datasets.paths import get_CIFAR10_path, get_CIFAR100_path, get_base_data_dir
from utils.datasets.cifar_augmentation import get_cifar10_augmentation

VAL_RATIO = 0.2
TRAIN_RATIO = 0.8

DEFAULT_TRAIN_BATCHSIZE = 128
DEFAULT_TEST_BATCHSIZE = 128

def _generate_split(labels, val_ratio, num_classes):
    labels_tensor = torch.LongTensor(labels)

    samples_per_class = int(len(labels) / num_classes)
    assert len(labels) % num_classes == 0
    val_per_class = int(samples_per_class * val_ratio)
    train_per_class = int(samples_per_class - val_per_class)

    print(f'Samples per class {samples_per_class} - Train {train_per_class} - Validation {val_per_class}')

    train_idcs = torch.zeros(train_per_class * num_classes, dtype=torch.long)
    val_idcs = torch.zeros(val_per_class * num_classes, dtype=torch.long)

    for class_idx in range(num_classes):
        class_idcs = torch.nonzero(labels_tensor == class_idx, as_tuple=False).squeeze()
        assert class_idcs.shape[0] == samples_per_class

        shuffled_idcs = class_idcs[torch.randperm(samples_per_class)]
        train_idcs[class_idx*train_per_class:(class_idx+1)*train_per_class] = shuffled_idcs[:train_per_class]
        val_idcs[class_idx*val_per_class:(class_idx+1)*val_per_class] = shuffled_idcs[train_per_class:]

    #Validate:
    train_labels = labels_tensor[train_idcs]
    validation_labels = labels_tensor[val_idcs]

    for class_idx in range(num_classes):
        assert torch.sum(train_labels == class_idx) == train_per_class
        assert torch.sum(validation_labels == class_idx) == val_per_class

    print('Split generation completed')

    return train_idcs, val_idcs

class CIFAR10TrainValidationSplit(torch.utils.data.Dataset):
    def __init__(self, path, train, transform):
        self.cifar = datasets.CIFAR10(path, train=True, transform=transform, download=True)
        if train:
            self.idcs = torch.load('cifar10_train_split.pth')
            print(f'Cifar10 Train split - Length {len(self.idcs)}')
        else:
            self.idcs = torch.load('cifar10_validation_split.pth')
            print(f'Cifar10 Validation split - Length {len(self.idcs)}')

        self.targets = []
        for idx in self.idcs:
            self.targets.append( self.cifar.targets[idx])

        self.length = len(self.idcs)

    def __getitem__(self, index):
        cifar_idx = self.idcs[index]
        return self.cifar[cifar_idx]

    def __len__(self):
        return self.length

def get_CIFAR10TrainValidation(train=True, batch_size=None, shuffle=None, augm_type='none', cutout_window=16, num_workers=2, size=32, config_dict=None):
    if batch_size == None:
        if train:
            batch_size = DEFAULT_TRAIN_BATCHSIZE
        else:
            batch_size = DEFAULT_TEST_BATCHSIZE

    augm_config = {}
    transform = get_cifar10_augmentation(type=augm_type, cutout_window=cutout_window, out_size=size, config_dict=augm_config)
    if not train and augm_type != 'none':
        print('Warning: CIFAR10 test set with ref_data augmentation')

    if shuffle is None:
        shuffle = train

    path = get_CIFAR10_path()
    dataset = CIFAR10TrainValidationSplit(path, train=train, transform=transform)
    loader = torch.utils.data.DataLoader(dataset, batch_size=batch_size,
                                         shuffle=shuffle, num_workers=num_workers)

    if config_dict is not None:
        config_dict['Dataset'] = 'Cifar10TrainValidationSplit'
        config_dict['Batch out_size'] = batch_size
        config_dict['Augmentation'] = augm_config

    return loader



class CIFAR100TrainValidationSplit(torch.utils.data.Dataset):
    def __init__(self, path, train, transform):
        self.cifar = datasets.CIFAR100(path, train=True, transform=transform)
        if train:
            self.idcs = torch.load('cifar100_train_split.pth')
            print(f'Cifar100 Train split - Length {len(self.idcs)}')
        else:
            self.idcs = torch.load('cifar100_validation_split.pth')
            print(f'Cifar100 Validation split - Length {len(self.idcs)}')

        self.targets = []
        for idx in self.idcs:
            self.targets.append( self.cifar.targets[idx])

        self.length = len(self.idcs)

    def __getitem__(self, index):
        cifar_idx = self.idcs[index]
        return self.cifar[cifar_idx]

    def __len__(self):
        return self.length


def get_CIFAR100TrainValidation(train=True, batch_size=None, shuffle=None, augm_type='none', cutout_window=16, num_workers=2, size=32, config_dict=None):
    if batch_size == None:
        if train:
            batch_size = DEFAULT_TRAIN_BATCHSIZE
        else:
            batch_size = DEFAULT_TEST_BATCHSIZE

    augm_config = {}
    transform = get_cifar10_augmentation(type=augm_type, cutout_window=cutout_window, out_size=size, config_dict=augm_config)

    if not train and augm_type != 'none':
        print('Warning: CIFAR100 test set with ref_data augmentation')

    if shuffle is None:
        shuffle = train

    path = get_CIFAR100_path()
    dataset = CIFAR100TrainValidationSplit(path, train=train, transform=transform)
    loader = torch.utils.data.DataLoader(dataset, batch_size=batch_size,
                                         shuffle=shuffle, num_workers=num_workers)

    if config_dict is not None:
        config_dict['Dataset'] = 'Cifar100TrainValidationSplit'
        config_dict['Batch out_size'] = batch_size
        config_dict['Augmentation'] = augm_config

    return loader