import os
import pandas as pd
from torchvision.io import read_image
import torch
import torch.nn.functional as F
from torch.utils.data import Dataset
import numpy as np
import csv
import matplotlib.pyplot as plt

class N_Caltech_Dataset(Dataset):
    def __init__(self, annotations_file, img_dir, transform=None, target_transform=None, sample_window=20, num_of_frame=10):
        self.img_labels = pd.read_csv(annotations_file)
        self.img_dir = img_dir
        self.transform = transform
        self.target_transform = target_transform
        self.sample_window = sample_window
        self.num_of_frame = num_of_frame

        self.label_mapping = {} #when using a subset, map the original label to new labels

        self.remap_labels()
        remove_sub_ten = len(self.img_labels)//10*10
        self.img_labels = self.img_labels[0:remove_sub_ten]
        print("label amount:", len(self.img_labels))

    def __len__(self):
        return len(self.img_labels)

    def read_csv_event_data(self, file_name):
        max_x = 0
        max_y = 0
        total_events = 0
        events_x = []
        events_y = []
        events_sign = []
        events_time = []

        

        with open(file_name, newline='') as csvfile:
            reader = csv.reader(csvfile, delimiter=',')
            row_cnt = 0
            for row in reader:
                if row_cnt==0:
                    max_x = int(row[0])
                    max_x = int(row[1])
                    total_events = int(row[2])
                elif row_cnt==1:
                    events_x = row.copy()
                    events_x = list(map(int, events_x))
                elif row_cnt==2:
                    events_y = row.copy()
                    events_y = list(map(int, events_y))
                elif row_cnt==3:
                    events_sign = row.copy()
                    events_sign = list(map(int, events_sign))
                elif row_cnt==4:
                    events_time = row.copy()      
                    events_time = list(map(int, events_time))
                row_cnt += 1
        data_dict = {'event_x': events_x, 'event_y': events_y, "event_sign": events_sign, "event_time": events_time, "total_events": total_events, "max_x": max_x, "max_y": max_y}
        return data_dict

    def remap_labels(self):
        total_class_no = 0
        for idx in range(len(self.img_labels)):
            if not self.img_labels.iloc[idx, 1] in self.label_mapping:
                
                self.label_mapping[self.img_labels.iloc[idx, 1]] = total_class_no
                total_class_no += 1


    def get_time_slices(self, data_dict, sample_window):
        sequence_end_time = data_dict['event_time'][-1]
    
        window_first_index = []
        window_last_index = []

        current_window_end = data_dict['event_time'][0] + sample_window*1000 # to uS
        last_index = 0

        window_ending = []

        while current_window_end<sequence_end_time:
            window_ending.append(current_window_end)
            current_window_end = current_window_end + sample_window*1000 # to uS

        window_last_index = np.searchsorted(data_dict['event_time'], window_ending)
        window_first_index = window_last_index.copy() + 1

        window_first_index = np.insert(window_first_index, 0, 0, axis=0)
        window_first_index = window_first_index[0:-1]

        return window_first_index, window_last_index

    def to_dense_tensor(self, data_dict, sample_window=20, output_xy_shape=240, per_seq_len=10):
        window_start, window_end = self.get_time_slices(data_dict, sample_window)
        sequence_tensor = []
        if len(window_start)>per_seq_len: 
            window_start = window_start[0:per_seq_len]
            window_end = window_end[0:per_seq_len]
        else:
            print("The number of frame per input datapoint is too big")

        for i in range(len(window_start)):
            ind = np.stack((data_dict['event_x'][window_start[i]:window_end[i]], data_dict['event_y'][window_start[i]:window_end[i]], data_dict['event_sign'][window_start[i]:window_end[i]]))
            if(window_end[i]-window_start[i]<=0): print("negative dim:",window_end[i], window_start[i])
            v = np.ones(((window_end[i]-window_start[i]),), dtype=int)
            s = torch.sparse_coo_tensor(ind, v)
            one_frame = s.to_dense()
            seq_padding_size = 0
            padding = (0, seq_padding_size, 0, 240-one_frame.shape[1], 0, 240-one_frame.shape[0])
            one_frame = F.pad(one_frame, padding, 'constant', 0)
            if i==0: sequence_tensor = one_frame.detach().clone().unsqueeze(0)
            else: sequence_tensor = torch.cat((sequence_tensor, one_frame.detach().clone().unsqueeze(0)), dim=0)
        sequence_tensor=sequence_tensor.transpose(3, 1)
        return sequence_tensor



    def __getitem__(self, idx):
        img_path = os.path.join(self.img_dir, self.img_labels.iloc[idx, 0])
        image = self.read_csv_event_data(img_path)
        image = self.to_dense_tensor(image, sample_window=self.sample_window, per_seq_len=self.num_of_frame)
        label = self.label_mapping[self.img_labels.iloc[idx, 1]]
        if self.target_transform:
            label = self.target_transform(label)
        return image, label



