import math
import torch
import random
import torchvision.transforms as transforms
from torch.utils.data import Dataset, DataLoader
from torchvision.datasets import CIFAR10, CIFAR100, ImageFolder, MNIST
import warnings
import os
import torchvision
from os import listdir
import numpy as np
import time
from os.path import isfile, join
from utils.autoaugment import Cutout,CIFAR10Policy
from spikingjelly.datasets.dvs128_gesture import DVS128Gesture
from spikingjelly.datasets import cifar10_dvs
from torchvision.datasets import DatasetFolder
from spikingjelly.datasets.cifar10_dvs import CIFAR10DVS
import spikingjelly.datasets
import PIL

warnings.filterwarnings('ignore')


def build_cifar(cutout=False, use_cifar10=True, download=False,data_dir="/dataset"):
    aug = [transforms.RandomCrop(32, padding=4), transforms.RandomHorizontalFlip(), CIFAR10Policy()]
    aug.append(transforms.ToTensor())

    if cutout:
        aug.append(Cutout(n_holes=1, length=16))

    if use_cifar10:
        aug.append(
            transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), )
        transform_train = transforms.Compose(aug)
        transform_test = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize(
                (0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
        ])
        train_dataset = CIFAR10(root=data_dir,
                                train=True, download=download, transform=transform_train)
        val_dataset = CIFAR10(root=data_dir,
                              train=False, download=download, transform=transform_test)

    else:
        aug.append(
            transforms.Normalize(
                (0.5071, 0.4867, 0.4408), (0.2675, 0.2565, 0.2761)),
        )
        transform_train = transforms.Compose(aug)
        transform_test = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize(
                (0.5071, 0.4867, 0.4408), (0.2675, 0.2565, 0.2761)),
        ])
        train_dataset = CIFAR100(root=data_dir,
                                 train=True, download=download, transform=transform_train)
        val_dataset = CIFAR100(root=data_dir,
                               train=False, download=download, transform=transform_test)

    return train_dataset, val_dataset

def build_mnist(download=False):
    train_dataset = MNIST(root='./raw/',
                             train=True, download=download, transform=transforms.ToTensor())
    val_dataset = MNIST(root='./raw/',
                           train=False, download=download, transform=transforms.ToTensor())
    return train_dataset, val_dataset

from datasets import load_dataset
def build_tiny_imagenet():
    aug = [transforms.RandomCrop(64, padding=8), transforms.RandomHorizontalFlip(),transforms.ToTensor(),
           transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])]

    transform_train = transforms.Compose(aug)
    transform_test = transforms.Compose([transforms.ToTensor(),transforms.Normalize(
            mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])])

    # ds_train = load_dataset("zh-plus/tiny-imagenet", split='train')
    # ds_val = load_dataset("zh-plus/tiny-imagenet", split='valid')


    root = '/data/dataset/tiny-imagenet-200/'
    train_dataset = ImageFolder(os.path.join(root,'train'), transform_train)
    val_dataset = ImageFolder(os.path.join(root,'valid'), transform_test)

    return train_dataset, val_dataset


class DVSCifar10(Dataset):
    def __init__(self, root, train=True, transform=None, target_transform=None):
        self.root = os.path.expanduser(root)
        self.transform = transform
        self.target_transform = target_transform
        self.train = train
        self.resize = transforms.Resize(size=(48, 48))  # 48 48
        self.tensorx = transforms.ToTensor()
        self.imgx = transforms.ToPILImage()

    def __getitem__(self, index):
        """
        Args:
            index (int): Index
        Returns:
            tuple: (image, target) where target is index of the target class.
        """
        data, target = torch.load(self.root + '/{}.pt'.format(index))
        # print(data.shape)
        # if self.train:
        new_data = []
        for t in range(data.size(0)):
            new_data.append(self.tensorx(self.resize(self.imgx(data[t,...]))))
        data = torch.stack(new_data, dim=0)
        if self.transform is not None:
            flip = random.random() > 0.5
            if flip:
                data = torch.flip(data, dims=(3,))
            off1 = random.randint(-5, 5)
            off2 = random.randint(-5, 5)
            data = torch.roll(data, shifts=(off1, off2), dims=(2, 3))

        if self.target_transform is not None:
            target = self.target_transform(target)
        return data, target.long().squeeze(-1)

    def __len__(self):
        return len(os.listdir(self.root))


def build_dvscifar(path):
    train_path = path + '/train'
    val_path = path + '/test'
    train_dataset = DVSCifar10(root=train_path, transform=True)
    val_dataset = DVSCifar10(root=val_path)

    return train_dataset, val_dataset

def build_imagenet(path):
    normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                     std=[0.229, 0.224, 0.225])
    # root = '/xiaoyc/datasets/ImageNet2012/'
    root = path
    train_root = os.path.join(root,'train')
    val_root = os.path.join(root,'val')
    train_dataset = ImageFolder(
        train_root,
        transforms.Compose([
            transforms.RandomResizedCrop(224),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            normalize,
        ])
    )
    val_dataset = ImageFolder(
        val_root,
        transforms.Compose([
            transforms.Resize(256),
            transforms.CenterCrop(224),
            transforms.ToTensor(),
            normalize,
        ])
    )
    return train_dataset, val_dataset

def build_dvs128(path='../../dataset/dvs128/', T=10):
    transform_train = transforms.Compose([
        transforms.RandomResizedCrop(128, scale=(0.6, 1.0), interpolation=PIL.Image.NEAREST),
        transforms.Resize(size=(48, 48)),
        # transforms.RandomHorizontalFlip(),
        transforms.RandomRotation(degrees=30),
        spikingjelly.datasets.RandomTemporalDelete(T_remain=T, batch_first=False),
    ])
    transform_test = transforms.Compose([
        transforms.Resize(size=(48, 48)),
    ])
    train_set = DVS128Gesture(root=path, train=True, data_type='frame', frames_number=T, split_by='number')
    test_set = DVS128Gesture(root=path, train=False, data_type='frame', frames_number=T, split_by='number')

    trainset, testset = packaging_class(train_set, transform_train), packaging_class(test_set, transform_test)
    return trainset, testset
if __name__ == '__main__':
    train_set, test_set = build_mnist(download=True)
