import numpy as np
import matplotlib.pyplot as plt
from dv import LegacyAedatFile
import os

path = '/home/zhangwenrui/datasets/DvsGesture/'

max_t = 0

actionName = [
    'hand_clapping',
    'right_hand_wave',
    'left_hand_wave',
    'right_arm_clockwise',
    'right_arm_counter_clockwise',
    'left_arm_clockwise', 
    'left_arm_counter_clockwise',
    'arm_roll',
    'air_drums',
    'air_guitar',
    'other_gestures',
]

def readAedatEvent(filename):
    xEvent = []
    yEvent = []
    pEvent = []
    tEvent = []
    with LegacyAedatFile(filename) as f:
        for event in f:
            xEvent.append(event.x)
            yEvent.append(event.y)
            pEvent.append(event.polarity)
            tEvent.append(event.timestamp/1000)

    return xEvent, yEvent, pEvent, tEvent

def splitData(filename, path, index):
    global max_t
    x, y, p, t = readAedatEvent(path + filename + '.aedat')
    # x, y, p, t = scipy.io.loadmat(path + filename + '.aedat')
    labels = np.loadtxt(path + filename + '_labels.csv', delimiter=',', skiprows=1)
    labels[:,0]  -= 1
    labels[:,1:]

    if not os.path.isdir('data/train/'):
        os.makedirs('data/train/')

    if not os.path.isdir('data/test/'):
        os.makedirs('data/test/')
    
    x = np.array(x)
    y = np.array(y)
    p = np.array(p)
    t = np.array(t)

    lastAction = 100
    for action, tst, ten in labels:
        if action == lastAction:    continue # This is to ignore second arm_roll samples
        tst = int(tst/1000)
        ten = int(ten/1000)
        print(actionName[int(action)])

        ind = (t >= tst) & (t < min(ten, tst+6000))
        assert (np.sum(ind) != 0)
        xs = x[ind]
        ys = y[ind]
        ps = p[ind]
        ts = (t[ind]-tst)/15
        print(ten - tst)

        x_t = []
        y_t = []
        p_t = []
        t_t = []
        tmp = np.zeros((2, 128, 128, 400))
        for i in range(len(xs)):
            assert ts[i]>=0
            tmp[int(ps[i]), int(xs[i]), int(ys[i]), int(ts[i])] = 1

        for tt in range(400):
            for xx in range(128):
                for yy in range(128):
                    for pp in range(2):
                        if tmp[pp, xx, yy, tt] == 1:
                            x_t.append(xx)
                            y_t.append(yy)
                            p_t.append(pp)
                            t_t.append(tt)
                            if tt > max_t:
                                max_t = tt
        lastAction = action
        if(index<24):
            dest = 'data/train/'+ filename + '{:g}.npy'.format(action)
        else:
            dest = 'data/test/'+ filename + '{:g}.npy'.format(action)
        sample = np.asarray([x_t, y_t, p_t, t_t])
        np.savetxt(dest, sample)


if __name__ == '__main__':
    user = np.arange(29) + 1
    lighting = [
        'fluorescent',
        'fluorescent_led',
        'lab',
        'led',
        'natural',
    ]

    count = 0
    for id in user:
        for light in lighting:
            filename = 'user{:02d}_{}'.format(id, light)

            if os.path.isfile(path + filename + '.aedat'):
                print(count, filename)
                splitData(filename, path, id)
                count += 1
    print(max_t)
