import os
import pdb
import shutil

# import csv
import numpy as np
import cv2
import torch
import pandas as pd
from tqdm import tqdm
import scipy.io
from spconv.pytorch.utils import PointToVoxel
from dv import AedatFile
import numpy as np


def transform_points_to_voxels(data_dict={}, voxel_generator=None, device=torch.device("cuda:0")):
    """
    将点云转换为voxel,调用spconv的VoxelGeneratorV2
    """
    points = data_dict['points']
    # 将points打乱
    shuffle_idx = np.random.permutation(points.shape[0])
    points = points[shuffle_idx]
    data_dict['points'] = points

    # 使用spconv生成voxel输出
    points = torch.as_tensor(data_dict['points']).to(device)
    voxel_output = voxel_generator(points)

    # 假设一份点云数据是N*4，那么经过pillar生成后会得到三份数据
    # voxels代表了每个生成的voxel数据，维度是[M, 5, 4]
    # coordinates代表了每个生成的voxel所在的zyx轴坐标，维度是[M,3]
    # num_points代表了每个生成的voxel中有多少个有效的点维度是[m,]，因为不满5会被0填充
    voxels, coordinates, num_points = voxel_output
    voxels = voxels.to(device)
    coordinates = coordinates.to(device)
    num_points = num_points.to(device)
    # 选event数量在前5000的voxel  8000 from(4k+,6k+)
    # print(torch.where(num_points>=16)[0].shape)
    if num_points.shape[0] < save_voxel:
        features = voxels[:, :, 3]
        coor = coordinates[:, :]
    else:
        _, voxels_idx = torch.topk(num_points, save_voxel)
        # 将每个voxel的1024个p拼接作为voxel初始特征   16
        features = voxels[voxels_idx][:, :, 3]
        # 前5000个voxel的三维坐标
        coor = coordinates[voxels_idx]
    # 将y.x.t改为t,x,y
    coor[:, [0, 1, 2]] = coor[:, [2, 1, 0]]
    return coor, features


if __name__ == '__main__':
    os.environ['CUDA_VISIBLE_DEVICES'] = "0"
    save_voxel = 10000
    device = torch.device("cuda:0")
    data_path = r"/media/user/a3a55be1-b252-4624-aa4f-fbb9404c4a7b/Mamba_FETrack/Mamba_FETrack-main/Mamba_FETrack/data/FE108/test"
    save_path = r"/media/user/a3a55be1-b252-4624-aa4f-fbb9404c4a7b/Mamba_FETrack/Mamba_FETrack-main/Mamba_FETrack/data/FE108/test"
    # data_path = r"/home/ioe/xxxxx/EventTracking/FE108/train"
    # save_path = r"/home/ioe/xxxxx/FE108/train"
    video_files = os.listdir(data_path)
    dvs_img_interval = 1
    voxel_generator = PointToVoxel(
        # 给定每个voxel的长宽高  [0.05, 0.05, 0.1]
        vsize_xyz=[50, 10, 10],  # [0.2, 0.25, 0.16]  # [50, 10, 10]  [50, 35, 26] 因此坐标范围（20,20,20）  (20, 34/35, 26)
        # 给定点云的范围 [  0.  -40.   -3.   70.4  40.    1. ]
        coors_range_xyz=[0, 0, 0, 1000, 345, 259],
        # 给定每个点云的特征维度，这里是x，y，z，r 其中r是激光雷达反射强度       # 346x260  t,x,y
        num_point_features=4,
        # 最多选取多少个voxel，训练16000，推理40000
        max_num_voxels=40000,  # 16000
        # 给定每个pillar中有采样多少个点，不够则补0  因此我将neg voxel改为-1;
        max_num_points_per_voxel=16,  # 1024
        device=device
    )

    for videoID in range(len(video_files)):
        foldName = video_files[videoID]
        print("==>> foldName: ", foldName)
        fileLIST = os.listdir(os.path.join(data_path, foldName))
        if not os.path.exists(os.path.join(save_path, foldName)):
            os.mkdir(os.path.join(save_path, foldName))
        mat_save = os.path.join(save_path, foldName, 'voxel/')
        if not os.path.exists(mat_save):
            # shutil.rmtree(mat_save)
            os.mkdir(mat_save)
        # else:
        #     continue
        save_path_aps = os.path.join(save_path, foldName, 'aps/')
        if not os.path.exists(save_path_aps):
            # shutil.rmtree(save_path_aps)
            os.mkdir(save_path_aps)
        # else:
        #     continue

        save_path_dvs = os.path.join(save_path, foldName, 'dvs/')
        if not os.path.exists(save_path_dvs):
            # shutil.rmtree(save_path_dvs)
            os.mkdir(save_path_dvs)

        

        aedat4_file = 'events.aedat4'
        # aedat4_file = f"{foldName}.aedat4"
        print('filename:', aedat4_file)
        read_path = os.path.join(data_path, foldName, aedat4_file)

        # read aeda4;
        frame_all = []
        frame_exposure_time = []
        frame_interval_time = []
        match_file = 'pair.txt'
        pair = {}
        with open(match_file, 'r') as f:
            for line in f.readlines():
                file, start_frame = line.split()
                pair[file] = int(start_frame) + 1
        start_frame = pair[foldName]
        img_path = os.path.join(data_path, foldName, 'img')
        frame_end = len(os.listdir(img_path))
        with AedatFile(read_path) as f:
            # print(f.names)
            for frame in f['frames']:
                frame_all.append(frame.image)
                frame_interval_time.append([frame.timestamp_start_of_frame,
                                            frame.timestamp_end_of_frame])  ## [1607928583387944, 1607928583410285]
        frame_timestamp = frame_interval_time
        frame_num = len(frame_timestamp)

        events = np.hstack([packet for packet in f['events'].numpy()])
        height, width = f['events'].size  # 获取图像高度和宽度

        t_all = torch.tensor(events['timestamp']).unsqueeze(1).to(device)
        x_all = torch.tensor(events['x']).unsqueeze(1).to(device)
        y_all = torch.tensor(events['y']).unsqueeze(1).to(device)
        p_all = torch.tensor(events['polarity']).unsqueeze(1).to(device)

        # for frame_no in range(start_frame, int(frame_num / dvs_img_interval) - 1):
        begin_frame = start_frame-1
        end_frame = frame_end + start_frame-1
        frame_densities = [] # 初始化用于存储密度的列表
        for frame_no in range(begin_frame, end_frame):
            start_idx_list = np.where(events['timestamp'] >= frame_timestamp[frame_no][0])[0]
            end_idx_list = np.where(events['timestamp'] >= frame_timestamp[frame_no][1])[0]

            # 处理边界情况，如果找不到时间戳对应的事件
            if len(start_idx_list) == 0 or len(end_idx_list) == 0:
                idx_length = 0
            else:
                start_idx = start_idx_list[0]
                end_idx = end_idx_list[0]
                idx_length = end_idx - start_idx
            
            # 计算事件密度
            density = idx_length / (height * width) if (height * width) > 0 else 0
            frame_densities.append(density)

            sub_event = events[start_idx:end_idx]

            
        # 保存事件密度到文件
        density_save_path = os.path.join(save_path, foldName, f'{foldName}.txt')
        with open(density_save_path, 'w') as density_file:
            for i, density in enumerate(frame_densities):
                density_file.write(f"{density:.6f}\n")
        print(f"Saved event densities to {density_save_path}")
