import numpy as np
import open3d as o3d
import matplotlib.pyplot as plt
import copy
import math
import os

root_path = "XXXXXXXXXX/OSTTA-GOO/OOD_result/nuscene"
seq_files = os.listdir(root_path)

for seq_file in seq_files:

    save_dir_singe = "XXXXXXXXXX/OSTTA-GOO/OOD_result/nuscenes_dbscan/" + seq_file
    os.makedirs(save_dir_singe, exist_ok=True)
    save_dir_preframe = "XXXXXXXXXX/OSTTA-GOO/OOD_result/nuscenes_dbscan_preframe3/" + seq_file
    os.makedirs(save_dir_preframe, exist_ok=True)
    path = "XXXXXXXXXX/OSTTA-GOO/OOD_result/nuscene/" + seq_file
    files= os.listdir(path)

    #2016
    # for i in range(2017, len(files)):
    for i in range(len(files)):
        file_num = i
        if file_num < 3:
            continue
        range_num = 3

        global_point_list = []
        local_point_list = []
        ood_label_list = []
        seg_label_list = []
        pre_label_list = []

        self_path =  "XXXXXXXXXX/OSTTA-GOO/OOD_result/nuscene_selfcar/"  + seq_file + '/' + str(file_num) + ".npy"
        self_car = np.load(self_path, allow_pickle=True).tolist()['self_car']
        self_car = self_car.reshape(-1)

        for j in range(range_num):
            path = "XXXXXXXXXX/OSTTA-GOO/OOD_result/nuscene/" + seq_file + '/' + str(file_num - j) + ".npy"
            data = np.load(path, allow_pickle=True).tolist()
            data['global_points'] = data['global_points']

            global_point_list.append(data['global_points'])
            ood_label_list.append(data['OOD_labels'])
            seg_label_list.append(data['labels'])
            pre_label_list.append(data['out'])
            local_point_list.append(data['coordinates'])

            if j == 0:
                self_mask = data['global_points'].shape[0]

        global_point_list = np.concatenate(global_point_list, axis=0)
        ood_label_list = np.concatenate(ood_label_list, axis=0)
        seg_label_list = np.concatenate(seg_label_list, axis=0)
        pre_label_list = np.concatenate(pre_label_list, axis=0)
        local_point_list = np.concatenate(local_point_list, axis=0)
        self_mask = np.arange(self_mask)

        pcd_global = o3d.geometry.PointCloud()
        # pcd_global.points = o3d.utility.Vector3dVector(local_point_list[:, 1:])
        pcd_global.points = o3d.utility.Vector3dVector(global_point_list)
        pcd_global.paint_uniform_color([0.5,0.5,0.5])
        pcd_gt = pcd_global

        ########RANSAC########
        # crop point by xy plane 20*20
        all_points = np.asarray(pcd_gt.points)
        x = all_points[:,0]
        y = all_points[:,1]
        z = all_points[:,2]
        x_max = x.max()
        x_min = x.min()
        y_max = y.max()
        y_min = y.min()
        range_split = 10
        range_x = (x_max - x_min) / range_split
        range_y = (y_max - y_min) / range_split

        crop_points_list = []
        index_list = []
        for i in range(range_split):
            for j in range(range_split):
                x_crop_min = x_min + range_x * i
                x_crop_max = x_crop_min + range_x
                y_crop_min = y_min + range_y * j
                y_crop_max = y_crop_min + range_y
                x_filter = np.logical_and(x >= x_crop_min, x < x_crop_max)
                y_filter = np.logical_and(y >= y_crop_min, y < y_crop_max)
                filter = np.logical_and(x_filter, y_filter)
                crop_points = all_points[filter]
                if crop_points.shape[0] == 0:
                    continue
                index = np.where(filter)[0]
                index_list.append(index)
                crop_points_list.append(crop_points)


        index_list_new = []
        for i, crop_points in enumerate(crop_points_list):
            index = index_list[i]

            z_mask = crop_points[:,2] < ((self_car[-1] - 0.8) / 0.05)
            if z_mask.sum() < 10:
                continue
            
            pcd_crop = o3d.geometry.PointCloud()
            
            # pcd_crop.points = o3d.utility.Vector3dVector(crop_points)
            pcd_crop.points = o3d.utility.Vector3dVector(crop_points[z_mask])

            ground_model_crop, ground_index_crop = pcd_crop.segment_plane(distance_threshold=(0.2/0.05), ransac_n=3, num_iterations=1000)
            normal = ground_model_crop[0:3]
            if abs(normal[0]) > abs(normal[-1]) or abs(normal[1]) > abs(normal[-1]) or abs(normal[-1]) < 0.9:
                continue

            # index = index[ground_index_crop]
            z_mask_index = np.where(z_mask)[0]
            index = index[z_mask_index][ground_index_crop]

            index_list_new.append(index)
        try:
            index_list_new = np.concatenate(index_list_new)
        except:
            print(file_num)
            print(file_num)
            print(file_num)
            continue
        index_list_new = index_list_new.tolist()
        ground_model, ground_index = pcd_gt.segment_plane(distance_threshold=(0.2/0.05), ransac_n=3, num_iterations=1000)

        index_range = np.zeros(all_points.shape[0])
        index_range[index_list_new] = 1
        index_range[ground_index] = 1
        index_new = np.where(index_range)[0]

        ground = pcd_gt.select_by_index(index_new)
        rest = pcd_gt.select_by_index(index_new, invert=True)

        ########RANSAC########


        ########DBSACN########
        with o3d.utility.VerbosityContextManager(o3d.utility.VerbosityLevel.Debug) as cm:
            dbscan_labels = np.array(rest.cluster_dbscan(eps=(0.25 / 0.05), min_points=25, print_progress=True))

        label_number = 0
        for i in range(max(dbscan_labels)+1):
            if (dbscan_labels == i).sum() < 50:
                dbscan_labels[dbscan_labels == i] = -1
            else:
                dbscan_labels[dbscan_labels == i] = label_number
                label_number += 1
        max_label = max(dbscan_labels)  # 最大的类别值
        print(f"point cloud has {max_label + 1} clusters")
        ########DBSACN########

        dbscan_labels_all_points_ground_mask = np.zeros(all_points.shape[0])
        dbscan_labels_all_points_ground_mask[index_new] = 1

        dbscan_labels_all_points = -np.ones(all_points.shape[0])
        dbscan_labels_all_points[~dbscan_labels_all_points_ground_mask.astype(bool)] = dbscan_labels

        dbscan_labels_all_points_self = dbscan_labels_all_points[self_mask]

        np.save(save_dir_singe + '/' + str(file_num) + ".npy", dbscan_labels_all_points_self)
        np.save(save_dir_preframe + '/' + str(file_num) + ".npy", dbscan_labels_all_points)



        

