import csv
import os
import torch
from torch.utils.data import Dataset
import torchvision.transforms as transforms
import numpy as np
from tqdm import tqdm
from os import listdir
from os.path import isfile, join


class DVS_Gesture(Dataset):
    def __init__(self, dataset_path, n_steps):
        self.path = dataset_path
        self.n_steps = n_steps
        self.samples = []
        self.labels = []
        for l in range(11):
            p = self.path + "/" + str(l)
            for f in listdir(p):
                filename = join(p, f)
                if isfile(filename):
                    self.samples.append(filename)
                    self.labels.append(l)

    def __getitem__(self, index):
        data_path = self.samples[index]
        label = self.labels[index]
        d = np.loadtxt(data_path)
        data = np.zeros((2, 128, 128, self.n_steps))
        d = d.T

        xs = d[0, :]
        ys = d[1, :]
        ps = d[2, :]
        ts = d[3, :]
        
        for i in range(len(xs)):
            t = int(ts[i])
            if  0 <= t < self.n_steps:
                p = int(ps[i])
                x = int(xs[i])
                y = int(ys[i])
                data[p, x, y, t] = 1 
            elif t < 0:
                print(data_path)
        data = torch.FloatTensor(data)

        return data, label

    def __len__(self):
        return len(self.samples)


def get_dvs_gesture(data_path, network_config):
    print("loading DVS Gestures")
    if not os.path.exists(data_path):
        os.mkdir(data_path)
    batch_size = network_config['batch_size']
    n_steps = network_config['n_steps']
    train_path = data_path + '/train'
    test_path = data_path + '/test'
    all_train_set = DVS_Gesture(train_path, n_steps) 
    test_set = DVS_Gesture(test_path, n_steps)

    train_set, val_set = torch.utils.data.random_split(all_train_set, [int(len(all_train_set)/2), len(all_train_set) - int(len(all_train_set)/2)])
    print("training samples loaded: %d" % len(all_train_set))
    print("testing samples loaded: %d" % len(test_set))
    print("arch training samples loaded: %d" % len(train_set))
    print("arch validating samples loaded: %d" % len(val_set))
    return all_train_set, train_set, val_set, test_set
