import torch
from torch.utils.data import Dataset

import numpy as np

import os

class NCaltech101(Dataset):
    def __init__(self, root: str, train=True, transform=None, target_transform=None, download=False):

        self.filepath = f'{root}/NCALTECH101/frames_number_10_split_by_number'
        self.clslist = os.listdir(self.filepath)
        self.clslist.sort()

        self.dvs_filelist = []
        self.targets = []

        for i, cls in enumerate(self.clslist):
            # print (i, cls)
            file_list = os.listdir(os.path.join(self.filepath, cls))
            num_file = len(file_list)

            cut_idx = int(num_file * 0.9)
            train_file_list = file_list[:cut_idx]
            test_split_list = file_list[cut_idx:]
            for file in file_list:
                if train:
                    if file in train_file_list:
                        self.dvs_filelist.append(os.path.join(self.filepath, cls, file))
                        self.targets.append(i)
                else:
                    if file in test_split_list:
                        self.dvs_filelist.append(os.path.join(self.filepath, cls, file))
                        self.targets.append(i)

        self.data_num = len(self.dvs_filelist)

        self.train = train
        self.transform = transform
        self.target_transform = target_transform

    def __getitem__(self, index):
        file_pth = self.dvs_filelist[index]
        label = self.targets[index]
        data = torch.from_numpy(np.load(file_pth)['frames']).float()
        
        if self.transform:
            data = self.transform(data)
        
        if self.target_transform:
            label = self.target_transform(label)

        return data, label

    def __len__(self):
        return self.data_num