from torch.utils.data import Dataset


from sklearn.model_selection import train_test_split

from spikingjelly.datasets import cifar10_dvs

import os

# def split_dataset(dataset: Dataset, train: bool, train_size: float = 0.9) -> tuple:
#     train_size = int(train_size * len(dataset))
#     test_size = len(dataset) - train_size

#     torch.manual_seed(42)
#     train_set, val_set = torch.utils.data.random_split(dataset, [train_size, test_size])
#     return train_set if train else val_set

def split_dataset(dataset: Dataset, train: bool, train_size: float = 0.9) -> Dataset:
    train_size = int(train_size * len(dataset))
    test_size = len(dataset) - train_size

    # split the dataset into train and test sets, ensure that the split is stratified
    train_dataset, test_dataset = train_test_split(
        dataset, 
        train_size=train_size, 
        test_size=test_size, 
        shuffle=True, 
        stratify=dataset.targets
    )

    if train:
        return train_dataset
    else:
        return test_dataset

def CIFAR10DVS(root, train=True, transform=None, target_transform=None, download=False) -> Dataset:
    if not os.path.exists(f'{root}/cifar10_dvs'):
        os.makedirs(f'{root}/cifar10_dvs')
        
    dataset = cifar10_dvs.CIFAR10DVS(
        f'{root}/cifar10_dvs', 
        data_type='frame',
        frames_number=10,
        split_by='number',
        transform=transform
    )
    return split_dataset(dataset, train=train)