import torch
import random
import torchvision.transforms as transforms
from torch.utils.data import Dataset
import os
import numpy as np

class CIFAR10DVS(Dataset):
    def __init__(self, root, train=True, transform=None, target_transform=None, download=None):
        self.root = f'{root}/dvscifar10'
        if train:
            self.root = os.path.join(self.root, 'train')
        else:
            self.root = os.path.join(self.root, 'test')

        self.train = train
        self.target_transform = target_transform

        if self.train:
            self.transform = transforms.Compose([
                transforms.Resize((48, 48), transforms.InterpolationMode.BILINEAR, antialias=True),
                transforms.RandomCrop(size=(48, 48), padding=4),
                transforms.RandomHorizontalFlip(),
            ])
        else:
            self.transform = transforms.Compose([
                transforms.Resize((48, 48), transforms.InterpolationMode.BILINEAR, antialias=True),
            ])

    def __getitem__(self, index):
        """
        Args:
            index (int): Index
        Returns:
            tuple: (image, target) where target is index of the target class.
        """
        data, target = torch.load(self.root + '/{}.pt'.format(index))
        # print(data.shape)
        data = data.permute([3, 0, 1, 2])
        data = self.transform(data)
            
        return data, target.long()
    
    # def __init__(self, root, train=True, transform=None, target_transform=None):
    #     self.root = f'{root}/dvscifar10'
    #     if train:
    #         self.root = os.path.join(self.root, 'train')
    #     else:
    #         self.root = os.path.join(self.root, 'test')

    #     self.train = train
    #     self.target_transform = target_transform

    #     self.resize = transforms.Resize(size=(48, 48), interpolation=torchvision.transforms.InterpolationMode.NEAREST)
    #     self.rotate = transforms.RandomRotation(degrees=30)
    #     self.shearx = transforms.RandomAffine(degrees=0, shear=(-30, 30))

    # def __getitem__(self, index):
    #     """
    #     Args:
    #         index (int): Index
    #     Returns:
    #         tuple: (image, target) where target is index of the target class.
    #     """
    #     data, target = torch.load(self.root + '/{}.pt'.format(index))
    #     data = self.resize(data.permute([3, 0, 1, 2]))

    #     if self.train:

    #         choices = ['roll', 'rotate', 'shear']
    #         aug = np.random.choice(choices)
    #         if aug == 'roll':
    #             off1 = random.randint(-5, 5)
    #             off2 = random.randint(-5, 5)
    #             data = torch.roll(data, shifts=(off1, off2), dims=(2, 3))
    #         if aug == 'rotate':
    #             data = self.rotate(data)
    #         if aug == 'shear':
    #             data = self.shearx(data)

    #     return data, target.long()
    
    # def __init__(self, root, train=True, transform=None, target_transform=None, download=None):
    #     self.root = f'{root}/dvscifar10'
    #     if train:
    #         self.root = os.path.join(self.root, 'train')
    #     else:
    #         self.root = os.path.join(self.root, 'test')
    #     self.target_transform = target_transform
        
    #     self.train = train
    #     self.resize = transforms.Resize(size=(48, 48), interpolation=torchvision.transforms.InterpolationMode.BILINEAR)

    # def __getitem__(self, index):
    #     """
    #     Args:
    #         index (int): Index
    #     Returns:
    #         tuple: (image, target) where target is index of the target class.
    #     """
    #     data, target = torch.load(self.root + '/{}.pt'.format(index))
    #     data = self.resize(data.permute([3, 0, 1, 2]))
    #     if self.train:
    #         flip = random.random() > 0.5
    #         if flip:
    #             data = torch.flip(data, dims=(3,))
    #         off1 = random.randint(-5, 5)
    #         off2 = random.randint(-5, 5)
    #         data = torch.roll(data, shifts=(off1, off2), dims=(2, 3))

    #     if self.target_transform is not None:
    #         target = self.target_transform(target)

    #     return data, target.long()

    def __len__(self):
        return len(os.listdir(self.root))
