import torch
import torchvision
from spikingjelly.datasets.n_mnist import NMNIST
from spikingjelly.datasets.cifar10_dvs import CIFAR10DVS
# from spikingjelly.datasets.cifar10_dvs_split import CIFAR10DVS_SPLIT
from spikingjelly.datasets import split_to_train_test_set

DATASET_DIR = {
    "mnist": 'D:/datasets/mnist',
    'nmnist': 'D:/datasets/nmnist',
    'cifar10': 'D:/datasets/cifar10',
    'cifar10dvs': 'D:/datasets/cifar10dvs'
}


def load_dataset(dataset, dataset_dir, batch_size, T=20):
    if dataset == 'mnist':
        train_data_loader = torch.utils.data.DataLoader(
            dataset=torchvision.datasets.MNIST(
                root=dataset_dir,
                train=True,
                transform=torchvision.transforms.ToTensor(),
                download=True),
            batch_size=batch_size,
            shuffle=True,
            drop_last=True)
        test_data_loader = torch.utils.data.DataLoader(
            dataset=torchvision.datasets.MNIST(
                root=dataset_dir,
                train=False,
                transform=torchvision.transforms.ToTensor(),
                download=False),
            batch_size=batch_size,
            shuffle=True,
            drop_last=False)
    elif dataset == 'nmnist':
        train_data_loader = torch.utils.data.DataLoader(
            dataset=NMNIST(dataset_dir,
                           train=True,
                           data_type='frame',
                           frames_number=T,
                           split_by='number'),
            batch_size=batch_size,
            shuffle=True,
            drop_last=True)

        test_data_loader = torch.utils.data.DataLoader(
            dataset=NMNIST(dataset_dir,
                           train=False,
                           data_type='frame',
                           frames_number=T,
                           split_by='number'),
            batch_size=batch_size,
            shuffle=True,
            drop_last=False)
    elif dataset == 'cifar10':
        train_data_loader = torch.utils.data.DataLoader(
            dataset=torchvision.datasets.CIFAR10(
                root=dataset_dir,
                train=True,
                transform=torchvision.transforms.Compose([torchvision.transforms.RandomCrop(32, padding=4),
                                                          torchvision.transforms.RandomHorizontalFlip(),
                                                          torchvision.transforms.ToTensor(),
                                                          # torchvision.transforms.Normalize([0.4948052, 0.48568845, 0.446829745], [0.24580306, 0.24236229, 0.2603115]),
                                                          ]),
                download=True),
            batch_size=batch_size,
            shuffle=True,
            drop_last=True)
        test_data_loader = torch.utils.data.DataLoader(
            dataset=torchvision.datasets.CIFAR10(
                root=dataset_dir,
                train=False,
                transform=torchvision.transforms.Compose([
                    torchvision.transforms.ToTensor(),
                    torchvision.transforms.Normalize([0.4948052, 0.48568845, 0.446829745], [0.24580306, 0.24236229, 0.2603115])
                ]),
                download=False),
            batch_size=batch_size,
            shuffle=True,
            drop_last=False)
    elif dataset == 'cifar10dvs':
        cifar10dataset = CIFAR10DVS(dataset_dir,
                                    data_type='frame',
                                    frames_number=T,
                                    split_by='number')
        train_test_split = split_to_train_test_set(0.9, cifar10dataset, 10, True)
        train_data_loader = torch.utils.data.DataLoader(
            dataset=train_test_split[0],
            batch_size=batch_size,
            shuffle=True,
            drop_last=True)

        test_data_loader = torch.utils.data.DataLoader(
            dataset=train_test_split[1],
            batch_size=batch_size,
            shuffle=True,
            drop_last=False)

    else:
        raise ValueError()
    return train_data_loader, test_data_loader


if __name__ == '__main__':
    load_dataset('cifar10dvs', 'C:/Users/jyzhang/Work/research/datasets/cifar10dvs', 64, T=20)
