import os
from torch.utils.data import Dataset, DataLoader
import numpy as np

import cv2
import torch
from time import sleep, time
from torchvision.transforms import Resize
from multiprocessing import Pool
import pandas as pd


class Eventvot(Dataset):
    def __init__(self, data_path, split='train', args=None):
        self.data_path = data_path
        self.data_name = sorted(os.listdir(data_path))
        self.data_ls = sorted(list(map(lambda x:os.path.join(data_path,x),self.data_name)))
        self.fps = 60
        self.crop_size = [720,720]
        self.num_bins = 3
        self.width = 1280
        self.height = 720
        self.device = 'cpu'
        self.cond_length = 3
        self.seq_length = 25
        self.csv_data = pd.read_csv('train.csv')
    
    def __len__(self):
        return len(self.data_ls)
    
    def __getname__(self): 
        return 'vot' 

    def __getitem__(self, index):
        tic = time()
        event_pth = os.path.join(self.data_ls[index],self.data_name[index])+'_voxel.npy'
        event_data = np.load(event_pth)
        toc = time()
        max_idx = event_data.shape[0]
        event_seq_length = self.cond_length+self.seq_length
        event_start_idx = np.random.randint(0,max_idx-event_seq_length-1)



        random_crop_kernel = np.random.randint(512, self.crop_size[0])
        crop_x_start = int((self.width- self.crop_size[0])/2)
        crop_y_start = int((self.height-self.crop_size[1])/2)
        # print(random_crop_kernel)
        crop_x = np.random.randint(crop_x_start,crop_x_start+self.width-random_crop_kernel)
        crop_y = np.random.randint(crop_y_start,crop_y_start+self.height-random_crop_kernel)

        # print(crop_x,crop_y)
        events_voxel_cat = torch.from_numpy(event_data[event_start_idx:event_start_idx+event_seq_length])
        a = abs(events_voxel_cat.max())
        b = abs(events_voxel_cat.min())
        max_norm = a if a>b else b
        events_voxel_cat = events_voxel_cat/max_norm
        events_voxel_cat = events_voxel_cat[:,:,crop_y:crop_y+random_crop_kernel,crop_x:crop_x+random_crop_kernel]

        torch_resize = Resize([128,128],antialias=True)
        events_voxel_cat = torch_resize(events_voxel_cat)


        event0 = (events_voxel_cat[:3] +1)/2 
        
        print("Done:", f"Time: {toc - tic:.3f}s")

        return {"pixel_values": events_voxel_cat[3:], "image": event0, 'dataset': self.__getname__()}

if __name__ == '__main__':
    tic = time()

    trainset = Eventvot()
    
    batch = trainset[0]
    event_tensor = batch['pixel_values']
    image = batch['image']
    print(event_tensor.shape,image.shape)
    print(len(trainset))
    toc = time()
    print("Done:", f"Time: {toc - tic:.3f}s")