import torch
import random
import torchvision.transforms as transforms
from torch.utils.data import Dataset, DataLoader
from torchvision.datasets import CIFAR10, CIFAR100, ImageFolder, MNIST
import warnings
import os
import torchvision
from os import listdir
import numpy as np
import time
from os.path import isfile, join

warnings.filterwarnings('ignore')


def build_cifar(cutout=False, use_cifar10=True, download=False):
    aug = [transforms.RandomCrop(32, padding=4), transforms.RandomHorizontalFlip()]
    aug.append(transforms.ToTensor())

    if cutout:
        aug.append(cutout(n_holes=1, length=16))

    if use_cifar10:
        aug.append(
            transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), )
        transform_train = transforms.Compose(aug)
        transform_test = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize(
                (0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
        ])
        train_dataset = CIFAR10(root='./raw/',
                                train=True, download=download, transform=transform_train)
        val_dataset = CIFAR10(root='./raw/',
                              train=False, download=download, transform=transform_test)

    else:
        aug.append(
            transforms.Normalize(
                (0.5071, 0.4867, 0.4408), (0.2675, 0.2565, 0.2761)),
        )
        transform_train = transforms.Compose(aug)
        transform_test = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize(
                (0.5071, 0.4867, 0.4408), (0.2675, 0.2565, 0.2761)),
        ])
        train_dataset = CIFAR100(root='./raw/',
                                 train=True, download=download, transform=transform_train)
        val_dataset = CIFAR100(root='./raw/',
                               train=False, download=download, transform=transform_test)

    return train_dataset, val_dataset

def build_mnist(download=False):
    train_dataset = MNIST(root='./raw/',
                             train=True, download=download, transform=transforms.ToTensor())
    val_dataset = MNIST(root='./raw/',
                           train=False, download=download, transform=transforms.ToTensor())
    return train_dataset, val_dataset


class DVSCifar10(Dataset):
    def __init__(self, root, train=True, transform=None, target_transform=None):
        self.root = os.path.expanduser(root)
        self.transform = transform
        self.target_transform = target_transform
        self.train = train
        #if self.train == True:
        self.test_transform = transforms.Compose([transforms.Resize(size=(48, 48))])

    def __getitem__(self, index):
        """
        Args:
            index (int): Index
        Returns:
            tuple: (image, target) where target is index of the target class.
        """
        data, target = torch.load(self.root + '{}.pt'.format(index))

        data = self.test_transform(data)
        if self.transform is not None:
            flip = random.random() > 0.5
            if flip:
                data = torch.flip(data, dims=(3,))
            off1 = random.randint(-5, 5)
            off2 = random.randint(-5, 5)
            data = torch.roll(data, shifts=(off1, off2), dims=(2, 3))
            #data = self.transform(data)
        #if self.transform is not None:
        

        if self.target_transform is not None:
            target = self.target_transform(target)

        return data, target.long()

    def __len__(self):
        return len(os.listdir(self.root))


# class DVSCifar10(Dataset):
#     def __init__(self, root, train=True, transform=None, target_transform=None):
#         self.root = os.path.expanduser(root)
#         self.transform = transform
#         self.target_transform = target_transform
#         self.train = train
#         self.resize = transforms.Resize(size=(48, 48))  # 48 48
#         self.tensorx = transforms.ToTensor()
#         self.imgx = transforms.ToPILImage()

#     def __getitem__(self, index):
#         """
#         Args:
#             index (int): Index
#         Returns:
#             tuple: (image, target) where target is index of the target class.
#         """
#         data, target = torch.load(self.root + '/{}.pt'.format(index))
#         # print(data.shape)
#         # if self.train:
#         new_data = []
#         for t in range(data.size(-1)):
#             new_data.append(self.tensorx(self.resize(self.imgx(data[..., t]))))
#         data = torch.stack(new_data, dim=0)
#         if self.transform is not None:
#             flip = random.random() > 0.5
#             if flip:
#                 data = torch.flip(data, dims=(3,))
#             off1 = random.randint(-5, 5)
#             off2 = random.randint(-5, 5)
#             data = torch.roll(data, shifts=(off1, off2), dims=(2, 3))
#             #print(data.shape)

#         if self.target_transform is not None:
#             target = self.target_transform(target)
#         return data, target.long().squeeze(-1)

#     def __len__(self):
#         return len(os.listdir(self.root))


def build_dvscifar(path):
    train_path = path + '/train_4/'
    val_path = path + '/test_4/'
    train_dataset = DVSCifar10(root=train_path, transform=False)
    val_dataset = DVSCifar10(root=val_path)

    return train_dataset, val_dataset

def build_imagenet():
    normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                     std=[0.229, 0.224, 0.225])
    root = '/data_smr/dataset/ImageNet'
    train_root = os.path.join(root,'train')
    val_root = os.path.join(root,'val')
    train_dataset = ImageFolder(
        train_root,
        transforms.Compose([
            transforms.RandomResizedCrop(224),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            normalize,
        ])
    )
    val_dataset = ImageFolder(
        val_root,
        transforms.Compose([
            transforms.Resize(256),
            transforms.CenterCrop(224),
            transforms.ToTensor(),
            normalize,
        ])
    )
    return train_dataset, val_dataset

if __name__ == '__main__':
    train_set, test_set = build_dvscifar('preprocess/dvs-cifar10')