import os
import pdb
# import csv
import numpy as np
import cv2
import torch
import pandas as pd
from tqdm import tqdm
import scipy.io
from spconv.pytorch.utils import PointToVoxel
from dv import AedatFile
import numpy as np


def transform_points_to_voxels(data_dict={}, voxel_generator=None, device=torch.device("cuda:0")):
    """
    将点云转换为voxel,调用spconv的VoxelGeneratorV2
    """
    points = data_dict['points']
    # 将points打乱
    shuffle_idx = np.random.permutation(points.shape[0])
    points = points[shuffle_idx]
    data_dict['points'] = points

    # 使用spconv生成voxel输出
    points = torch.as_tensor(data_dict['points']).to(device)
    voxel_output = voxel_generator(points)

    # 假设一份点云数据是N*4，那么经过pillar生成后会得到三份数据
    # voxels代表了每个生成的voxel数据，维度是[M, 5, 4]
    # coordinates代表了每个生成的voxel所在的zyx轴坐标，维度是[M,3]
    # num_points代表了每个生成的voxel中有多少个有效的点维度是[m,]，因为不满5会被0填充
    voxels, coordinates, num_points = voxel_output
    voxels = voxels.to(device)
    coordinates = coordinates.to(device)
    num_points = num_points.to(device)
    # 选event数量在前5000的voxel  8000 from(4k+,6k+)
    # print(torch.where(num_points>=16)[0].shape)
    if num_points.shape[0] < save_voxel:
        features = voxels[:, :, 3]
        coor = coordinates[:, :]
    else:
        _, voxels_idx = torch.topk(num_points, save_voxel)
        # 将每个voxel的1024个p拼接作为voxel初始特征   16
        features = voxels[voxels_idx][:, :, 3]
        # 前5000个voxel的三维坐标
        coor = coordinates[voxels_idx]
    # 将y.x.t改为t,x,y
    coor[:, [0, 1, 2]] = coor[:, [2, 1, 0]]

    return coor, features


if __name__ == '__main__':
    save_voxel = 5000
    use_mode = 'frame_exposure_time'
    device = torch.device("cuda:0")
    # data_path = r"/home/ioe/xxxxx/EventTracking/COESOT/train"
    data_path = r"/media/user/a3a55be1-b252-4624-aa4f-fbb9404c4a7b/Mamba_FETrack/Mamba_FETrack-main/Mamba_FETrack/data/FELT/test"
    save_path = r"/media/user/a3a55be1-b252-4624-aa4f-fbb9404c4a7b/Mamba_FETrack/Mamba_FETrack-main/Mamba_FETrack/data/FELT/test"
    # save_path = r"/home/ioe/xxxxx/COESOT/train"
    video_files = os.listdir(data_path)
    dvs_img_interval = 1
    voxel_generator = PointToVoxel(
        # 给定每个voxel的长宽高  [0.05, 0.05, 0.1]
        vsize_xyz=[50, 10, 10],  # [0.2, 0.25, 0.16]  # [50, 10, 10]  [50, 35, 26] 因此坐标范围（20,20,20）  (20, 34/35, 26)
        # 给定点云的范围 [  0.  -40.   -3.   70.4  40.    1. ]
        coors_range_xyz=[0, 0, 0, 1000, 345, 259],
        # 给定每个点云的特征维度，这里是x，y，z，r 其中r是激光雷达反射强度       # 346x260  t,x,y
        num_point_features=4,
        # 最多选取多少个voxel，训练16000，推理40000
        max_num_voxels=16000,  # 16000
        # 给定每个pillar中有采样多少个点，不够则补0  因此我将neg voxel改为-1;
        max_num_points_per_voxel=16,  # 1024
        device=device
    )
      
    for videoID in range(len(video_files)):
        foldName = video_files[videoID]
        print(f'[{videoID+1}/{len(video_files)}]  start process the foldName :{foldName}')
        fileLIST = os.listdir(os.path.join(data_path, foldName))
        
        if not os.path.exists(os.path.join(save_path, foldName, foldName+'_voxel')):
            os.mkdir(os.path.join(save_path, foldName, foldName+'_voxel'))
        # else:
        #     print(f'[{videoID+1}/{len(video_files)}] ==> skip the foldName :{foldName}')
        #     continue
        mat_save = os.path.join(save_path, foldName, foldName+'_voxel/')
        
        for FileID in range(len(fileLIST)):
            videoName = fileLIST[FileID]
            if 'aps' in videoName: 
                aps_path = os.path.join(save_path, foldName, videoName)  
                frame_num = len(os.listdir(aps_path))  
            if 'voxel' in videoName: 
                voxel_path = os.path.join(save_path, foldName, videoName) 
                voxel_num = len(os.listdir(voxel_path)) 
      
        # if voxel_num == frame_num:
        #     print(foldName)
        #     continue
        # else:
        #     continue
        
        frame_densities = [] # 初始化用于存储密度的列表
        aedat4_file = foldName + '.aedat4'
        read_path = os.path.join(data_path, foldName, aedat4_file)

        # read aeda4;
        frame_all = []
        frame_exposure_time = []
        frame_interval_time = []
        with AedatFile(read_path) as f:
            # print(f.names)
            for frame in f['frames']:
                frame_all.append(frame.image)
                frame_exposure_time.append([frame.timestamp_start_of_exposure,
                                            frame.timestamp_end_of_exposure])  ## [1607928583397102, 1607928583401102]
                frame_interval_time.append([frame.timestamp_start_of_frame,
                                            frame.timestamp_end_of_frame])  ## [1607928583387944, 1607928583410285]
        if use_mode == 'frame_exposure_time':
            frame_timestamp = frame_exposure_time
        elif use_mode == 'frame_interval_time':
            frame_timestamp = frame_interval_time
        frame_num = len(frame_timestamp)

        events = np.hstack([packet for packet in f['events'].numpy()])
        height, width = f['events'].size # 获取图像高度和宽度

        t_all = torch.tensor(events['timestamp']).unsqueeze(1).to(device)
        x_all = torch.tensor(events['x']).unsqueeze(1).to(device)
        y_all = torch.tensor(events['y']).unsqueeze(1).to(device)
        p_all = torch.tensor(events['polarity']).unsqueeze(1).to(device)

        for frame_no in range(0, int(frame_num / dvs_img_interval) - 1):
            start_idx_list = np.where(events['timestamp'] >= frame_timestamp[frame_no][0])[0]
            end_idx_list = np.where(events['timestamp'] >= frame_timestamp[frame_no][1])[0]

            # 处理边界情况，如果找不到时间戳对应的事件
            if len(start_idx_list) == 0 or len(end_idx_list) == 0:
                start_idx = 0
                end_idx = 0
                idx_length = 0
                sub_event = np.array([]) # 确保 sub_event 是空数组
            else:
                start_idx = start_idx_list[0]
                end_idx = end_idx_list[0]
                sub_event = events[start_idx:end_idx]
                idx_length = end_idx - start_idx

            # 计算事件密度
            density = idx_length / (height * width) if (height * width) > 0 else 0
            frame_densities.append(density)

            # t = t_all[start_idx: end_idx]
            # if start_idx == end_idx:
            #     time_length = 0
            #     t = ((t-t).float() / time_length) * 1000
            #     scipy.io.savemat(mat_save + 'frame{:0>4d}.mat'.format(frame_no), mdict={'coor': np.zeros([100, 3]),
            #                      'features': np.zeros([100, 16])})  # coor: numpy.nan;   features: numpy.nan
            #     print('empty event frame ', frame_no)
            #     continue
            # else:
            #     time_length = t[-1] - t[0]
            #     # rescale the timestampes to start from 0 up to 1000
            #     t = ((t-t[0]).float() / time_length) * 1000
            # all_idx = np.where(sub_event['polarity'] != -1)      # all event
            # neg_idx = np.where(sub_event['polarity'] == 0)      # neg event
            # t = t[all_idx]
            # x = x_all[all_idx]
            # y = y_all[all_idx]
            # p = p_all[all_idx]
            # p[neg_idx] = -1     # negtive voxel change from 0 to -1. because after append 0 operation.
            # current_events = torch.cat((t, x, y, p), dim=1)
            # # if current_events.shape[0] < 10:   # remove it
            # #     continue
            # data_dict = {'points': current_events}

            # coor, features = transform_points_to_voxels(data_dict=data_dict, voxel_generator=voxel_generator,
            #                                             device=device)
            # # pdb.set_trace()
            # coor = coor.cpu().numpy()
            # features = features.cpu().numpy()
            # # print('coor', coor)
            # scipy.io.savemat(mat_save + 'frame{:0>4d}.mat'.format(frame_no), mdict={'coor': coor, 'features': features})   # coor: Nx(t,x,y);   features:Nx32 or Nx10024
        
        # 保存事件密度到文件
        density_save_path = os.path.join(save_path, foldName, f'{foldName}.txt')
        with open(density_save_path, 'w') as density_file:
            for i, density in enumerate(frame_densities):
                density_file.write(f"{density:.6f}\n")
        print(f"Saved event densities to {density_save_path}")
