import numpy as np
import tonic
from torch.utils.data import Dataset, Subset, ConcatDataset
from datasets.dataset_windows_tonic import DatasetWindows
from datasets.ncaltech101_tonic import NCaltech101
from datasets.ncars_tonic import NCars
from datasets.gen1_tonic import Gen1

def get_dataset(dataset_name, val_split=None, train_transform=None, val_transform=None, test_transform=None):
    """
    Returns instances of train, validation and test datasets of the supported datasets, specified by name.
    """

    presplit_datasets = ['DVSGesture100ms', 'DVSGesture100ms_FULL', 'NCaltech101', 'NCaltech101_100ms',
                        'NCaltech101_FULL', 'NCaltech101_100ms_FULL', 'NCars_FULL', 'Gen1', 'Gen1_FULL']
    if val_split and dataset_name in presplit_datasets:
        print("val_split was provided for a dataset with pre-defined splits. Split value will be overridden by dataset specific one.")
    elif not val_split and dataset_name not in presplit_datasets:
        print("val_split was not provided for a dataset with no pre-defined splits. val_split = 0.2 will be used.")
        val_split = 0.2

    if dataset_name == 'DVSGesture100ms':
        dvsg_train = tonic.datasets.DVSGesture(save_to='/data/datasets', train=True)
        dvsg_test = tonic.datasets.DVSGesture(save_to='/data/datasets', train=False)

        train_split_path = "/data/datasets/dvsgesture-split-train-w01-s005-subset08.json"
        val_split_path = "/data/datasets/dvsgesture-split-train-w01-s005-subset02.json"
        test_split_path = "/data/datasets/dvsgesture-split-test-w01-s005.json"

        train_dataset = DatasetWindows(dvsg_train, split_path=train_split_path, transform=train_transform)
        val_dataset = DatasetWindows(dvsg_train, split_path=val_split_path, transform=val_transform)
        test_dataset = DatasetWindows(dvsg_test, split_path=test_split_path, transform=test_transform)
        # <xypt>
    elif dataset_name == 'DVSGesture100ms_FULL':
        dvsg_train = tonic.datasets.DVSGesture(save_to='/data/datasets', train=True)
        dvsg_test = tonic.datasets.DVSGesture(save_to='/data/datasets', train=False)

        train_split_path = "/data/datasets/dvsgesture-split-train-w01-s005.json"
        test_split_path = "/data/datasets/dvsgesture-split-test-w01-s005.json"

        train_dataset = DatasetWindows(dvsg_train, split_path=train_split_path, transform=train_transform)
        val_dataset = train_dataset
        test_dataset = DatasetWindows(dvsg_test, split_path=test_split_path, transform=test_transform)
        # <xypt>
    elif dataset_name == 'NCaltech101':
        train_dataset = NCaltech101(root='/data/datasets/N-Caltech101/training', transform=train_transform)
        val_dataset = NCaltech101(root='/data/datasets/N-Caltech101/validation', transform=val_transform)
        test_dataset = NCaltech101(root='/data/datasets/N-Caltech101/testing', transform=test_transform)
        # <xytp>
    elif dataset_name == 'NCaltech101_FULL':
        # NCaltech101 with no validation split
        ncaltech_train = NCaltech101(root='/data/datasets/N-Caltech101/training', transform=train_transform)
        ncaltech_val = NCaltech101(root='/data/datasets/N-Caltech101/validation', transform=train_transform) # train_transform!
        train_dataset = ConcatDataset([ncaltech_train,ncaltech_val])
        val_dataset = train_dataset # Still have to return a val_dataset because PyTorch Lightning requires it
        test_dataset = NCaltech101(root='/data/datasets/N-Caltech101/testing', transform=test_transform)
        # <xytp>
    elif dataset_name == 'NCaltech101_100ms':
        train_dataset = DatasetWindows(NCaltech101(root='/data/datasets/N-Caltech101/training'),
                                            split_path='/data/datasets/ncaltech101-split-train-w01.json',
                                            transform=train_transform)
        val_dataset = DatasetWindows(NCaltech101(root='/data/datasets/N-Caltech101/validation'),
                                            split_path='/data/datasets/ncaltech101-split-val-w01.json',
                                            transform=val_transform)
        test_dataset = DatasetWindows(NCaltech101(root='/data/datasets/N-Caltech101/testing'),
                                            split_path='/data/datasets/ncaltech101-split-test-w01.json',
                                            transform=test_transform)
        # <xytp>
    elif dataset_name == 'NCaltech101_100ms_FULL':
        # NCaltech101_100ms with no validation split
        ncaltech_train = DatasetWindows(NCaltech101(root='/data/datasets/N-Caltech101/training'),
                                            split_path='/data/datasets/ncaltech101-split-train-w01.json',
                                            transform=train_transform)
        ncaltech_val = DatasetWindows(NCaltech101(root='/data/datasets/N-Caltech101/validation'),
                                            split_path='/data/datasets/ncaltech101-split-val-w01.json',
                                            transform=train_transform) # train_transform!
        train_dataset = ConcatDataset([ncaltech_train, ncaltech_val])
        val_dataset = train_dataset  # Still have to return a val_dataset because PyTorch Lightning requires it
        test_dataset = DatasetWindows(NCaltech101(root='/data/datasets/N-Caltech101/testing'),
                                            split_path='/data/datasets/ncaltech101-split-test-w01.json',
                                            transform=test_transform)
        # <xytp>
    elif dataset_name == 'NCars':
        ncars_train = NCars(root='/data/datasets/ncars_original/train', transform=train_transform)
        ncars_val = NCars(root='/data/datasets/ncars_original/train', transform=val_transform)
        train_dataset, val_dataset = random_split_disjoint(ncars_train, ncars_val, (1 - val_split))

        test_dataset = NCars(root='/data/datasets/ncars_original/test', transform=test_transform)
        # <txyp>
    elif dataset_name == 'NCars_FULL':
        # NCars with no validation split
        train_dataset = NCars(root='/data/datasets/ncars_original/train', transform=train_transform)
        # Still have to return a val_dataset because PyTorch Lightning requires it
        val_dataset = NCars(root='/data/datasets/ncars_original/train', transform=val_transform)
        test_dataset = NCars(root='/data/datasets/ncars_original/test', transform=test_transform)
        # <txyp>
    elif dataset_name == 'Gen1':
        train_dataset = Gen1('/data/datasets/gen1/detection_dataset_duration_60s_ratio_1.0/train', transform=train_transform,
                             window_size=0.1, valid_idx_path='/data/datasets/gen1/idx_train_01.json')
        val_dataset = Gen1('/data/datasets/gen1/detection_dataset_duration_60s_ratio_1.0/val', transform=val_transform,
                           window_size=0.1, valid_idx_path='/data/datasets/gen1/idx_val_01.json')
        test_dataset = Gen1('/data/datasets/gen1/detection_dataset_duration_60s_ratio_1.0/test', transform=test_transform,
                            window_size=0.1, valid_idx_path='/data/datasets/gen1/idx_test_01.json')
        # <txyp>
    elif dataset_name == 'Gen1_FULL':
        train_dataset = Gen1('/data/datasets/gen1/detection_dataset_duration_60s_ratio_1.0/train', transform=train_transform,
                             window_size=0.1, valid_idx_path='/data/datasets/gen1/idx_train_01.json')
        val_dataset = Gen1('/data/datasets/gen1/detection_dataset_duration_60s_ratio_1.0/val', transform=val_transform,
                           window_size=0.1, valid_idx_path='/data/datasets/gen1/idx_val_01.json')
        train_dataset = ConcatDataset([train_dataset, val_dataset])
        test_dataset = Gen1('/data/datasets/gen1/detection_dataset_duration_60s_ratio_1.0/test', transform=test_transform,
                            window_size=0.1, valid_idx_path='/data/datasets/gen1/idx_test_01.json')
        # <txyp>
    else:
        raise ValueError('Dataset '+dataset_name+' is not supported.')

    return train_dataset, val_dataset, test_dataset


def random_split_disjoint(d1: Dataset, d2: Dataset, split_size):
    """
    Splits the two datasets into disjoint indices subsets.
    :param d1: first dataset
    :param d2: second dataset
    :param split_size: size of the split for the first dataset, as a percentage between 0 and 1
    """
    if split_size < 0 or split_size > 1:
        raise ValueError()

    split_len = int(split_size * len(d1))
    idx = np.arange(len(d1))
    idx_sub1 = np.random.choice(idx, split_len, replace=False)
    idx_sub2 = np.setdiff1d(idx, idx_sub1)

    return Subset(d1, idx_sub1), Subset(d2, idx_sub2)
