import torchvision.transforms as transforms
from torchvision.datasets import CIFAR10, CIFAR100, ImageFolder, MNIST
import warnings
from pack.autoaugment import Cutout
from spikingjelly.datasets.cifar10_dvs import CIFAR10DVS
import PIL
import os
import spikingjelly
# from transforms import *
import torch
warnings.filterwarnings('ignore')

def build_cifar(use_cifar10=True):
    aug = [transforms.RandomCrop(32, padding=4), transforms.RandomHorizontalFlip(), transforms.AutoAugment(policy=transforms.autoaugment.AutoAugmentPolicy.CIFAR10)]
    aug.append(transforms.ToTensor())
    aug.append(Cutout(n_holes=1, length=16))

    if use_cifar10:
        aug.append(
            transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), )
        transform_train = transforms.Compose(aug)
        transform_test = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize(
                (0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
        ])
        train_dataset = CIFAR10(root='/data/dataset/CIFAR10/', train=True, download=False, transform=transform_train)
        val_dataset = CIFAR10(root='/data/dataset/CIFAR10/', train=False, download=False, transform=transform_test)

    else:
        aug.append(
            transforms.Normalize(
                (0.5071, 0.4867, 0.4408), (0.2675, 0.2565, 0.2761)),
        )
        transform_train = transforms.Compose(aug)
        transform_test = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize(
                (0.5071, 0.4867, 0.4408), (0.2675, 0.2565, 0.2761)),
        ])
        train_dataset = CIFAR100(root='/data/dataset/CIFAR100/', train=True, download=False, transform=transform_train)
        val_dataset = CIFAR100(root='/data/dataset/CIFAR100/', train=False, download=False, transform=transform_test)

    return train_dataset, val_dataset

class packaging_class(torch.utils.data.Dataset):
    def __init__(self, dataset, transform=None):
        self.transform = transform
        self.dataset = dataset

    def __getitem__(self, index):
        data, label = self.dataset[index]
        data = torch.FloatTensor(data)
        if self.transform:
            data = self.transform(data)

        return data, label

    def __len__(self):
        return len(self.dataset)

def trans_t(data):
    # print(data.shape)
    # exit(0)
    data = transforms.RandomResizedCrop(128, scale=(0.7, 1.0), interpolation=PIL.Image.NEAREST)(data)
    resize = transforms.Resize(size=(48, 48))  # 48 48
    data = resize(data).float()
    flip = np.random.random() > 0.5
    if flip:
        data = torch.flip(data, dims=(3,))
    data = function_nda(data)
    return data.float()

def trans(data):
    resize = transforms.Resize(size=(48, 48))  # 48 48
    data = resize(data).float()
    return data.float()

def build_dvscifar10(path='/data/dataset/CIFAR10DVS', T=10):

    train_path = path + '/train/Train' + str(T)
    test_path = path + '/test/Test' + str(T)
    print(train_path)

    if os.path.exists(train_path) and os.path.exists(test_path):
        trainset = torch.load(train_path)
        testset = torch.load(test_path)
        print('Load DVSCIFAR10 success')

    else:
        dataset = CIFAR10DVS(root=path, data_type='frame', frames_number=T, split_by='number')
        trainset, testset = spikingjelly.datasets.split_to_train_test_set(train_ratio=0.9, origin_dataset=dataset,
                                                                 num_classes=10)
        trainset, testset = packaging_class(trainset, trans_t), packaging_class(testset, trans)

        torch.save(trainset, train_path)
        torch.save(testset, test_path)

    return trainset, testset