# Copyright (c) 2022-2023, NVIDIA Corporation & Affiliates. All rights reserved.
#
# This work is made available under the Nvidia Source Code License-NC.
# To view a copy of this license, visit
# https://github.com/NVlabs/VoxFormer/blob/main/LICENSE

#!/usr/bin/env python3
# This file is covered by the LICENSE file in the root of this project.

import argparse
import os

import numpy as np
from tqdm.contrib.concurrent import process_map

from collections import deque
import shutil
from numpy.linalg import inv
import time
import mapping
import torch
#F
import torch.nn.functional as F

import sys 
# import atomic_max_custom
# from scipy.stats import norm


# ray_voxel_intersection_root="/path/ray_voxel_intersection"
# # for i in range(0, 466615):
# ray_voxel_intersection_list=[]
# from tqdm import tqdm
# for i in tqdm(range(0, 466616)):
#   ray_voxel_intersection_file = os.path.join(ray_voxel_intersection_root, str(i).zfill(8) + ".npy")
#   ray_voxel_intersection = np.load(ray_voxel_intersection_file)
#   ray_voxel_intersection_list.append(ray_voxel_intersection)

global history
history = deque()


ray_voxel_intersection_root="/path/VoxFormer_UQ/ray_voxel_intersection/semantickitti/"
import pickle
global ray_voxel_intersection_list
global ray_voxel_intersection_list_ori

ray_voxel_intersection_list_ori=None
ray_voxel_intersection_list=None


a=0
  # print(ray_voxel_intersection)
  # print(np.max(ray_voxel_intersection))
  # print(np.min(ray_voxel_intersection))
  # print(np.mean(ray_voxel_intersection))
  # print(np.std(ray_voxel_intersection))
  # print(caculate_cdf(0, np.mean(ray_voxel_intersection), np.std(ray_voxel_intersection)))
 

def pack(array):
  """ convert a boolean array into a bitwise array. """
  array = array.reshape((-1))

  #compressing bit flags.
  # yapf: disable
  compressed = array[::8] << 7 | array[1::8] << 6  | array[2::8] << 5 | array[3::8] << 4 | array[4::8] << 3 | array[5::8] << 2 | array[6::8] << 1 | array[7::8]
  # yapf: enable

  return np.array(compressed, dtype=np.uint8)

def parse_calibration(filename):
  """ read calibration file with given filename

      Returns
      -------
      dict
          Calibration matrices as 4x4 numpy arrays.
  """
  calib = {}

  calib_file = open(filename)
  for line in calib_file:
    key, content = line.strip().split(":")
    values = [float(v) for v in content.strip().split()]

    pose = np.zeros((4, 4))
    pose[0, 0:4] = values[0:4]
    pose[1, 0:4] = values[4:8]
    pose[2, 0:4] = values[8:12]
    pose[3, 3] = 1.0

    calib[key] = pose

  calib_file.close()

  return calib


def parse_poses(filename, calibration):
  """ read poses file with per-scan poses from given filename

      Returns
      -------
      list
          list of poses as 4x4 numpy arrays.
  """
  file = open(filename)

  poses = []

  Tr = calibration["Tr"]
  Tr_inv = inv(Tr)

  for line in file:
    values = [float(v) for v in line.strip().split()]

    pose = np.zeros((4, 4))
    pose[0, 0:4] = values[0:4]
    pose[1, 0:4] = values[4:8]
    pose[2, 0:4] = values[8:12]
    pose[3, 3] = 1.0

    poses.append(np.matmul(Tr_inv, np.matmul(pose, Tr)))

  return poses






if __name__ == '__main__':
  # global ray_voxel_intersection_list
  # global ray_voxel_intersection_list_ori
  # global ray_voxel_intersection_list_ori
  start_time = time.time()

  parser = argparse.ArgumentParser("./lidar2voxel.py")
  parser.add_argument(
      '--dataset',
      '-d',
      type=str,
      required=True,
      help='dataset folder containing all sequences in a folder called "sequences".',
  )

  parser.add_argument(
      '--output',
      '-o',
      type=str,
      required=True,
      help='output folder for generated sequence scans.',
  )
  parser.add_argument(
      '--num_seq',
      '-n',
      type=str,
      required=True,
      help='number of sequence',
  )

  parser.add_argument(
      '--model',
      '-m',
      type=str,
      default='msnet3dmax1std',
      help='depth model',
  )

  parser.add_argument(
      '--sequence_length',
      '-s',
      type=int,
      default=1,
      help='length of sequence, i.e., how many scans are concatenated.',
  )

  
  FLAGS, unparsed = parser.parse_known_args()
  dataset = FLAGS.dataset
  output = FLAGS.output
  num_seq = FLAGS.num_seq

  print("*" * 80)
  print(" dataset folder: ", FLAGS.dataset)
  print("  output folder: ", FLAGS.output)
  print("sequence length: ", FLAGS.sequence_length)
  print("*" * 80)


  with open(os.path.join(ray_voxel_intersection_root,str(num_seq).zfill(2)+'_array_list.pkl') , 'rb') as f:

  # with open(os.path.join(ray_voxel_intersection_root,'array_list.pkl') , 'rb') as f:

  
    ray_voxel_intersection_list = pickle.load(f)
    ray_voxel_intersection_list = [item + [item[-1]] + [100] for item in ray_voxel_intersection_list]
    ray_voxel_intersection_list_np = np.array(ray_voxel_intersection_list, dtype=object)
    max_len = max(len(sublist) for sublist in ray_voxel_intersection_list_np)+2
    ray_voxel_intersection_list = np.array([np.pad(sublist, (0, max_len - len(sublist)), mode='constant', constant_values=1) for sublist in ray_voxel_intersection_list_np])
    ray_voxel_intersection_list_ori = torch.tensor(ray_voxel_intersection_list, device='cuda')
    ray_voxel_intersection_list=None
    a=0


    ops_type=None
    thr=None
    if "max" in FLAGS.model:
      ops_type="max"
    elif "sum" in FLAGS.model:
      ops_type="sum"
    else:
      raise NotImplementedError(FLAGS.model)
    thr=int(FLAGS.model[-4])


  if FLAGS.sequence_length!=1:

    sequences_dir = os.path.join(dataset, "sequences")
    input_folder = os.path.join(sequences_dir, num_seq)

    pseudo_lidar_files = [
          f for f in sorted(os.listdir(input_folder))
          if f.endswith(".bin")
      ]
    # output_folder = os.path.join(output, "sequences", num_seq, "voxels")
    # output_folder = os.path.join(output, "sequences_msnet3d_sweep1", num_seq, "voxels")
    # output_folder = os.path.join(output, "sequences_msnet3d_sweep1", num_seq, "voxels")
    output_folder = os.path.join(output, "sequences_" + FLAGS.model + "_sweep"+str(FLAGS.sequence_length), num_seq, "voxels")

    # mesh_temp_folder = os.path.join(output_folder, "temp")
    # mesh_folder = os.path.join(output_folder, "mesh")
    # scan_folder = os.path.join(output, "sequences", str(num_scan+1) + num_seq, "voxels")

    if not os.path.exists(output_folder):
      os.makedirs(output_folder)

    # process_map(parallel_work_sequence, pseudo_lidar_files, max_workers=18)

    calibration = parse_calibration(os.path.join(input_folder, "calib.txt"))
    poses = parse_poses(os.path.join(input_folder, "poses.txt"), calibration)


    from tqdm import tqdm
    # for f in pseudo_lidar_files:
    # for f in tqdm(pseudo_lidar_files):
    # enumerate
    for i, f in tqdm(enumerate(pseudo_lidar_files)):
  
      ray_voxel_intersection_list=ray_voxel_intersection_list_ori.clone()

      # read scan and labels, get pose
      scan_filename = os.path.join(input_folder, f)
      scan = np.fromfile(scan_filename, dtype=np.float32)

      scan_uq=np.load(scan_filename.replace("lidar","depth").replace(".bin","_uq.npy"))

      scan = scan.reshape((-1, 4))

      # convert points to homogenous coordinates (x, y, z, 1)
      points = np.ones((scan.shape))
      points[:, 0:3] = scan[:, 0:3]
      remissions = scan[:, 3]

      # prepare single numpy array for all points that can be written at once.
      num_concat_points = points.shape[0]

      # num_concat_points += sum([past["points"].shape[0] for past in history])
      concated_points = np.zeros((num_concat_points * 4), dtype = np.float32)
      # concated_labels = np.zeros((num_concat_points), dtype = np.uint32)

      start = 0
      concated_points[4 * start:4 * (start + points.shape[0])] = scan.reshape((-1))
      # concated_labels[start:start + points.shape[0]] = labels
      start += points.shape[0]

      pose = poses[i]

      # if float(os.path.splitext(f)[0])%5==0:
      if True:
        voxel_size = (0.2, 0.2, 0.2)
        area_extents = np.array([[0, 51.2], [-25.6, 25.6], [-2., 4.4]])
        pts = concated_points.reshape(-1,4)
        pts=scan.reshape(-1,4)
        filter_idx = np.where((area_extents[0, 0] < pts[:, 0]) & (pts[:, 0] < area_extents[0, 1]) & (area_extents[1, 0] < pts[:, 1]) & (pts[:, 1] < area_extents[1, 1]) & (area_extents[2, 0] < pts[:, 2]) & (pts[:, 2] < area_extents[2, 1]))[0]

        pts = pts[filter_idx]
        scan_uq=scan_uq.reshape(-1)[filter_idx]


        # # import numpy as np
        # import open3d as o3d
        # pcd = o3d.geometry.PointCloud()
        # # 将 numpy 数组转换为 Open3D 理解的点云格式
        # pcd.points = o3d.utility.Vector3dVector(pts[:,:3])
        # # 可视化点云
        # o3d.visualization.draw_geometries([pcd])


        output_tensor=np.zeros((256,256,32))
        ##########################################################################################################
        output_tensor = torch.zeros((256, 256, 32), device='cuda')
        #to double
        output_tensor=output_tensor.double()
        
        # print(ray_voxel_intersection_list)

        ray_info = ray_voxel_intersection_list[filter_idx]
        pts = torch.tensor(pts, device='cuda')
        scan_uq = torch.tensor(scan_uq, device='cuda')

        # myst=1000
        # myed=1100
        # ray_info=ray_info[myst:myed]
        # pts=pts[myst:myed]
        # scan_uq=scan_uq[myst:myed]


        
        st = ray_info[:, :3]
        direction = ray_info[:, 3:6]
        # pts_i = pts[torch.arange(len(filter_idx)), :]
        pts_i = pts

        t1 = (pts_i[:, 0] - st[:, 0]) / direction[:, 0]
        # t2 = (pts_i[:, 1] - st[:, 1]) / direction[:, 1]
        # t3 = (pts_i[:, 2] - st[:, 2]) / direction[:, 2]

        # mean_t = torch.stack([t1, t2, t3], dim=1).mean(dim=1)
        mean_t = torch.stack([t1, t1, t1], dim=1).mean(dim=1)
        std_t = torch.exp(scan_uq / 2) / direction[:, 0]

        gaussian_dist = torch.distributions.Normal(loc=mean_t.unsqueeze(1), scale=std_t.unsqueeze(1))
        cdf_values = gaussian_dist.cdf(ray_info[:, 6:])

        # thr=1
        upper_bound = mean_t + thr * std_t
        lower_bound = mean_t - thr * std_t

        # upper_bound最大1
        upper_bound = torch.where(upper_bound > 1, torch.ones_like(upper_bound), upper_bound)
        # lower_bound最小0
        lower_bound = torch.where(lower_bound < 0, torch.zeros_like(lower_bound), lower_bound)

        mask = (ray_info[:, 6:-1] >= lower_bound.unsqueeze(1)) & (ray_info[:, 6:-1] <= upper_bound.unsqueeze(1)) &\
                (ray_info[:, 7:] >= lower_bound.unsqueeze(1)) & (ray_info[:, 7:] <= upper_bound.unsqueeze(1))
        # mask[:, 0] = False

        test_t = torch.where(mask, (ray_info[:, 6:-1] + ray_info[:, 7:]) / 2, torch.zeros_like(ray_info[:, 6:-1]))

        x = st[:, 0:1] + direction[:, 0:1] * test_t
        y = st[:, 1:2] + direction[:, 1:2] * test_t
        z = st[:, 2:3] + direction[:, 2:3] * test_t

        x_index = ((x - area_extents[0, 0]) / voxel_size[0]).long().cuda()
        y_index = ((y - area_extents[1, 0]) / voxel_size[1]).long().cuda()
        z_index = ((z - area_extents[2, 0]) / voxel_size[2]).long().cuda()

        valid_mask = (x_index >= 0) & (x_index < 256) & (y_index >= 0) & (y_index < 256) & (z_index >= 0) & (z_index < 32)

        cdf_diff = cdf_values[:, 1:] - cdf_values[:, :-1] 
        cdf_diff = torch.where(mask[:, :], cdf_diff, torch.zeros_like(cdf_diff))

    #     # print(1)
        # output_tensor.index_put_((x_index[valid_mask], y_index[valid_mask], z_index[valid_mask]), cdf_diff[valid_mask], accumulate=True)
        # torch.cuda.synchronize()

        # atomic_max_custom.atomic_max(
        #     output_tensor,
        #     x_index[valid_mask],
        #     y_index[valid_mask],
        #     z_index[valid_mask],
        #     cdf_diff[valid_mask]
        # )

        # torch.cuda.synchronize()

        if ops_type=="max":
          atomic_max_custom.atomic_max(
              output_tensor,
              x_index[valid_mask],
              y_index[valid_mask],
              z_index[valid_mask],
              cdf_diff[valid_mask]
          )
          torch.cuda.synchronize()
        elif ops_type=="sum":
          # output_tensor.index_put_((x_index[valid_mask], y_index[valid_mask], z_index[valid_mask]), cdf_diff[valid_mask], accumulate=True)
          # torch.cuda.synchronize()
          valid_x_index = x_index[valid_mask]
          valid_y_index = y_index[valid_mask]
          valid_z_index = z_index[valid_mask]
          valid_cdf_diff = cdf_diff[valid_mask]
          output_tensor = output_tensor.view(-1)  # 展平 output_tensor
          indices_1D = valid_z_index + valid_y_index * 32 + valid_x_index * 32 * 256
          indices_1D = indices_1D.view(-1)
          output_tensor.scatter_add_(0, indices_1D, valid_cdf_diff)
          output_tensor = output_tensor.view(256, 256, 32)
          torch.cuda.synchronize()

  #     output_tensor[x_index[valid_mask], y_index[valid_mask], z_index[valid_mask]] = torch.max(
  #     output_tensor[x_index[valid_mask], y_index[valid_mask], z_index[valid_mask]],
  #     cdf_diff[valid_mask]
  # )
        else:
          raise NotImplementedError(ops_type)

        if float(os.path.splitext(f)[0])%5!=0:
          # output_tensor=output_tensor.cpu()
          # output_tensor=output_tensor.numpy()
          # # np.save(os.path.join(output_folder, os.path.splitext(f)[0] + ".npy"), output_tensor)

          # print(float(os.path.splitext(f)[0]))
          # print("Finished processing:",float(os.path.splitext(f)[0]))
          pass
        else:

          for past in history:
            prev_bev = past["output_tensor"]
            pose_tmp = past["pose"]

            prev_bev=torch.tensor(prev_bev, device='cuda')


            prev_bev=prev_bev.reshape(256,256,32,1)
            prev_bev=prev_bev.permute(3,0,1,2)
            prev_bev=prev_bev.unsqueeze(0)

            tensor=prev_bev
            tensor=tensor.permute(0,1,4,2,3)#H W Z->Z H W
            theta=np.identity(4)
            #
            angle_tmp=np.pi/2
            tmp=np.identity(4)
            tmp[0,0]=np.cos(angle_tmp)
            tmp[0,1]=-np.sin(angle_tmp)
            tmp[1,0]=np.sin(angle_tmp)
            tmp[1,1]=np.cos(angle_tmp)
            theta=np.matmul(tmp,theta)
            #
            angle_tmp=-np.pi/2
            tmp=np.identity(4)
            tmp[0,3]=-1
            theta=np.matmul(tmp,theta)
            #
            tmp=np.identity(4)*25.6
            tmp[2,2]/=8
            tmp[3,3]=1
            theta=np.matmul(tmp,theta)

            pre_mat=theta.copy()
            aft_mat=np.linalg.inv(pre_mat)
            # assert len(img_metas)==1
            # mat_tmp=img_metas[0]['prev2curr']
            mat_tmp=np.linalg.inv(pose)@pose_tmp

            theta=np.matmul(mat_tmp,pre_mat)
            theta=np.matmul(aft_mat,theta)

            # theta[:3,3]=theta[:3,3]*0.1
            theta=torch.from_numpy(theta[:3,:])
            theta=theta.unsqueeze(0).repeat(1,1,1)


            align_corners=True
            # theta=theta.to(tensor.dtype)
            # tensor=tensor.to(device=theta.device)
            theta=theta.to(tensor.dtype)
            theta=theta.to(device=tensor.device)
            grid = F.affine_grid(theta, tensor.size(), align_corners=align_corners)
            # tensor=tensor.to(grid.dtype)
            output = F.grid_sample(tensor, grid, align_corners=align_corners)
            output=output.permute(0,1,3,4,2)#Z H W->H W Z

            # output=output.to(device=bev_queries.device)
            output=output.permute(0,2,3,4,1)
            prev_bev=output.reshape(output_tensor.shape)
            output_tensor=torch.max(output_tensor,prev_bev)

#test


        a=0
        output_tensor=output_tensor.cpu()
        output_tensor=output_tensor.numpy()
        history.appendleft({
            "output_tensor": output_tensor.copy(),
            "pose": pose.copy()
        })
        if int(os.path.splitext(f)[0])%5==0:
          np.save(os.path.join(output_folder, os.path.splitext(f)[0] + ".npy"), output_tensor)

        # # print(4)
        # np.save(os.path.join(output_folder, os.path.splitext(f)[0] + ".npy"), output_tensor)
        # # print(5)
        # print("Finished processing:",float(os.path.splitext(f)[0]))
        if len(history) >= FLAGS.sequence_length:
          history.pop()
        torch.cuda.empty_cache()
    #max











    print("finished.")
    print("execution time: {}".format(time.time() - start_time))
  
  else:
    assert False, 'sequence length should be 10'
  