from torch.utils.data import DataLoader
import torch.nn.functional as F

import torchvision.transforms as transforms
import torchvision.datasets as datasets

from .data_path import global_data_path


def create_train_and_test_data_loaders(perform_data_augment, dataset_name, batch_size, test_batch_size, num_worker=1):
    train_set, val_set = get_torchvision_dataset(perform_data_augment, dataset_name)
    train_loader = create_one_data_loader(train_set, batch_size, True, True, num_worker, True)
    val_loader = create_one_data_loader(val_set, test_batch_size, False, False, num_worker, True)
    return train_loader, val_loader


def create_train_and_val_data_loaders(perform_data_augment, dataset_name, batch_size, test_batch_size,
                                      split_train_val=0.8, num_worker=1):
    train_set, val_set = get_torchvision_dataset(perform_data_augment, dataset_name)

    train_set_length = int(len(train_set) * split_train_val)

    val_set.data = train_set.data[train_set_length:]
    train_set.data = train_set.data[:train_set_length]
    if dataset_name == 'svhn':
        val_set.labels = train_set.labels[train_set_length:]
        train_set.labels = train_set.labels[:train_set_length]
    else:
        val_set.targets = train_set.targets[train_set_length:]
        train_set.targets = train_set.targets[:train_set_length]

    print(f'DEBUG: train set length: {len(train_set)}; val set length: {len(val_set)}')

    # print(f'DEBUG: {train_set.data.dtype}, {val_set.data.dtype}, {train_set.data[0]}, {val_set.data[0]}')
    # originally random.
    # if dataset_name == 'svhn':
    #     print(f'DEBUG: labels: {val_set.labels[:100]}')
    # else:
    #     print(f'DEBUG: labels: {val_set.targets[:100]}')

    train_loader = create_one_data_loader(train_set, batch_size, True, True, num_worker, True)
    val_loader = create_one_data_loader(val_set, test_batch_size, False, False, num_worker, True)
    return train_loader, val_loader


def create_one_data_loader(dataset, batch_size, shuffle=True, drop_last=False, num_worker=1, pin_memory=True):
    additional_kwargs = {'num_workers': num_worker, 'pin_memory': pin_memory}
    data_loader = DataLoader(dataset, batch_size=batch_size, shuffle=shuffle, drop_last=drop_last,
                             **additional_kwargs)
    return data_loader


def get_torchvision_dataset(perform_data_augment, dataset_name):
    assert dataset_name in ['cifar10', 'cifar100', 'svhn']

    if 'cifar' in dataset_name:
        mean = [125.3, 123.0, 113.9]
        std = [63.0, 62.1, 66.7]
    elif 'svhn' in dataset_name:
        mean = [129.3, 124.1, 112.4]
        std = [68.2, 65.4, 70.4]
    else:
        mean = [128.0, 128.0, 128.0]
        std = [64.0, 64.0, 64.0]

    normalize = transforms.Normalize(mean=[x / 255.0 for x in mean],
                                     std=[x / 255.0 for x in std])
    if perform_data_augment:
        transform_train = transforms.Compose([
            transforms.ToTensor(),
            transforms.Lambda(lambda x: F.pad(x.unsqueeze(0), (4,4,4,4), mode='reflect').squeeze()),
            transforms.ToPILImage(),
            transforms.RandomCrop(32),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            normalize,
        ])
    else:
        transform_train = transforms.Compose([
            transforms.ToTensor(),
            normalize,
        ])
    transform_test = transforms.Compose([
        transforms.ToTensor(),
        normalize
    ])
    try:
        train_set = datasets.__dict__[dataset_name.upper()](global_data_path,
                                                            train=True,
                                                            download=True,
                                                            transform=transform_train)
        val_set = datasets.__dict__[dataset_name.upper()](global_data_path,
                                                          train=False,
                                                          download=True,
                                                          transform=transform_test)
    except TypeError as e:
        train_set = datasets.__dict__[dataset_name.upper()](global_data_path,
                                                            split='train',
                                                            download=True,
                                                            transform=transform_train)
        val_set = datasets.__dict__[dataset_name.upper()](global_data_path,
                                                          split='test',
                                                          download=True,
                                                          transform=transform_test)

    return train_set, val_set

