import numpy as np
import open3d as o3d
import matplotlib.pyplot as plt
import copy
import math
import os
import time


eps = 0.25
min_points = 50

save_dir_preframe = "XXXXXXXXXX/OSTTA-GOO/OOD_result/Ablation_kitti/kitti_dbscan_preframe3" + "_" + str(eps) + "_" + str(min_points)
os.makedirs(save_dir_preframe, exist_ok=True)
save_dir_singe = "XXXXXXXXXX/OSTTA-GOO/OOD_result/Ablation_kitti/kitti_dbscan" + "_" + str(eps) + "_" + str(min_points)
os.makedirs(save_dir_singe, exist_ok=True)
path = "XXXXXXXXXX/OSTTA-GOO/OOD_result/kitti/"
files= os.listdir(path)

#2016
# for i in range(2017, len(files)):
for i in range(len(files)):
    time_start=time.time()
    file_num = i
    if file_num < 3:
        continue
    range_num = 3

    global_point_list = []
    local_point_list = []
    ood_label_list = []
    seg_label_list = []
    pre_label_list = []

    self_path =  "XXXXXXXXXX/OSTTA-GOO/OOD_result/kitti_selfcar/" + str(file_num) + ".npy"
    self_car = np.load(self_path, allow_pickle=True).tolist()['self_car']
    self_car = self_car.reshape(-1)

    for j in range(range_num):
        path = "XXXXXXXXXX/OSTTA-GOO/OOD_result/kitti/" + str(file_num - j) + ".npy"
        data = np.load(path, allow_pickle=True).tolist()
        data['global_points'] = data['global_points'] / 0.05
        global_point_list.append(data['global_points'])
        ood_label_list.append(data['OOD_labels'])
        seg_label_list.append(data['labels'])
        pre_label_list.append(data['out'])
        local_point_list.append(data['coordinates'])

        if j == 0:
            self_mask = data['global_points'].shape[0]

    global_point_list = np.concatenate(global_point_list, axis=0)
    ood_label_list = np.concatenate(ood_label_list, axis=0)
    seg_label_list = np.concatenate(seg_label_list, axis=0)
    pre_label_list = np.concatenate(pre_label_list, axis=0)
    local_point_list = np.concatenate(local_point_list, axis=0)
    self_mask = np.arange(self_mask)

    pcd_global = o3d.geometry.PointCloud()
    pcd_global.points = o3d.utility.Vector3dVector(global_point_list)
    pcd_global.paint_uniform_color([0.5,0.5,0.5])
    pcd_gt = pcd_global

    ########RANSAC########
    # crop point by xy plane 20*20
    all_points = np.asarray(pcd_gt.points)
    x = all_points[:,0]
    y = all_points[:,1]
    z = all_points[:,2]
    x_max = x.max()
    x_min = x.min()
    y_max = y.max()
    y_min = y.min()
    range_split = 10
    range_x = (x_max - x_min) / range_split
    range_y = (y_max - y_min) / range_split

    crop_points_list = []
    index_list = []
    for i in range(range_split):
        for j in range(range_split):
            x_crop_min = x_min + range_x * i
            x_crop_max = x_crop_min + range_x
            y_crop_min = y_min + range_y * j
            y_crop_max = y_crop_min + range_y
            x_filter = np.logical_and(x >= x_crop_min, x < x_crop_max)
            y_filter = np.logical_and(y >= y_crop_min, y < y_crop_max)
            filter = np.logical_and(x_filter, y_filter)
            crop_points = all_points[filter]
            if crop_points.shape[0] == 0:
                continue
            index = np.where(filter)[0]
            index_list.append(index)
            crop_points_list.append(crop_points)


    index_list_new = []
    for i, crop_points in enumerate(crop_points_list):
        index = index_list[i]

        z_mask = crop_points[:,2] < ((self_car[-1] - 1.0) / 0.05)
        if z_mask.sum() < 10:
            continue
        
        pcd_crop = o3d.geometry.PointCloud()
        pcd_crop.points = o3d.utility.Vector3dVector(crop_points)
        ground_model_crop, ground_index_crop = pcd_crop.segment_plane(distance_threshold=(0.1/0.05), ransac_n=3, num_iterations=1000)

        normal = ground_model_crop[0:3]
        if abs(normal[0]) > abs(normal[-1]) or abs(normal[1]) > abs(normal[-1]):
            continue

        index = index[ground_index_crop]
        index_list_new.append(index)
    try:
        index_list_new = np.concatenate(index_list_new)
    except:
        print(file_num)
        print(file_num)
        print(file_num)
        continue
    index_list_new = index_list_new.tolist()
    ground_model, ground_index = pcd_gt.segment_plane(distance_threshold=(0.1/0.05), ransac_n=3, num_iterations=1000)

    index_range = np.zeros(all_points.shape[0])
    index_range[index_list_new] = 1
    index_range[ground_index] = 1
    index_new = np.where(index_range)[0]

    ground = pcd_gt.select_by_index(index_new)
    rest = pcd_gt.select_by_index(index_new, invert=True)

    ########RANSAC########


    ########DBSACN########
    with o3d.utility.VerbosityContextManager(o3d.utility.VerbosityLevel.Debug) as cm:
        dbscan_labels = np.array(rest.cluster_dbscan(eps=(eps / 0.05), min_points=min_points, print_progress=True))

    label_number = 0
    for i in range(max(dbscan_labels)+1):
        if (dbscan_labels == i).sum() < 100:
            dbscan_labels[dbscan_labels == i] = -1
        else:
            dbscan_labels[dbscan_labels == i] = label_number
            label_number += 1
    max_label = max(dbscan_labels)  # 最大的类别值
    print(f"point cloud has {max_label + 1} clusters")
    ########DBSACN########

    dbscan_labels_all_points_ground_mask = np.zeros(all_points.shape[0])
    dbscan_labels_all_points_ground_mask[index_new] = 1

    dbscan_labels_all_points = -np.ones(all_points.shape[0])
    dbscan_labels_all_points[~dbscan_labels_all_points_ground_mask.astype(bool)] = dbscan_labels

    dbscan_labels_all_points_self = dbscan_labels_all_points[self_mask]

    np.save(save_dir_singe + '/' + str(file_num) + ".npy", dbscan_labels_all_points_self)
    np.save(save_dir_preframe + '/' + str(file_num) + ".npy", dbscan_labels_all_points)

    time_end=time.time()
    print('time cost',time_end-time_start,'s')



    

