import os
import pickle
from tqdm import tqdm
import shutil
import numpy as np
import open3d as o3d
import torch
from pytorch3d.ops.knn import knn_points

data_root = "/mnt/e/Dataset/KITTI/"
path_output = "/mnt/e/Dataset/Image_KITTI/data/"
dict_path = "/mnt/e/Dataset/Image_KITTI/metadata/"

def prepare_intrin():
    if not os.path.exists(path_output):
        os.makedirs(path_output)
    if not os.path.exists(dict_path):
        os.makedirs(dict_path)

    for n in range(11):
        seq = str(n).zfill(2)

        if not os.path.exists(path_output + seq):
            os.makedirs(path_output + seq)
        rgb_path = path_output + seq + "/image_2/"
        pc_path = path_output + seq +  "/velodyne/"
        if not os.path.exists(rgb_path):
            os.makedirs(rgb_path)
        if not os.path.exists(pc_path):
            os.makedirs(pc_path)

        with open(data_root + seq + "/calib.txt", 'r') as file:
            lines = file.readlines()
        intrin = lines[2].split(' ')[1:]  # P2
        intrin = [float(i) for i in intrin]
        vel_to_cam = lines[4].split(' ')[1:]  # Tr

        intrin_mat = np.eye(4)[:3, :]
        intrin_mat[0, :] = intrin[0:4]
        intrin_mat[1, :] = intrin[4:8]
        intrin_mat[2, :] = intrin[8:12]

        vel_to_cam_mat = np.eye(4)
        vel_to_cam_mat[0, :] = vel_to_cam[0:4]
        vel_to_cam_mat[1, :] = vel_to_cam[4:8]
        vel_to_cam_mat[2, :] = vel_to_cam[8:12]

        np.savetxt(path_output + seq + "/calib.txt", intrin_mat, fmt = "%.12e")
        np.savetxt(path_output + seq + "/vel_to_cam.txt", vel_to_cam_mat, fmt = "%.12e")

def create_kitti_dict(split):
    with open("./GeoTransformer/" + split + ".pkl", 'rb') as f:  # Metadata from GeoTransformer
        metadata_list = pickle.load(f)

    data_dict = []
    for meta in tqdm(metadata_list):
        data = {}

        data["seq_id"] = meta["seq_id"]
        rgb_rel = str(meta["seq_id"]).zfill(2) + "/image_2/"
        pc_rel = str(meta["seq_id"]).zfill(2) +  "/velodyne/"

        data['frame0'] = meta['frame0']
        data['frame1'] = meta['frame1']

        data["img0"] = rgb_rel + str(meta["frame0"]).zfill(6) + ".png"
        data["img1"] = rgb_rel + str(meta["frame1"]).zfill(6) + ".png"
        data["pcd0"] = pc_rel + str(meta["frame0"]).zfill(6) + ".bin"
        data["pcd1"] = pc_rel + str(meta["frame1"]).zfill(6) + ".bin"
        data["transform"] = meta["transform"]

        # -- get camera matrix --
        int_path = path_output + str(meta["seq_id"]).zfill(2) + "/calib.txt"
        tr_path = path_output + str(meta["seq_id"]).zfill(2) + "/vel_to_cam.txt"
        int_mat = np.loadtxt(int_path)
        tr_mat = np.loadtxt(tr_path)
        cam_mat = int_mat @ tr_mat
        data["camera_matrix"] = cam_mat
        data["vel_to_cam"] = tr_mat

        data_dict.append(data)

    # save dictionary as pickle in output path
    with open(dict_path + split + ".pkl", "wb") as f:
        pickle.dump(data_dict, f)

def load_KITTI(pc_file):
    points = np.fromfile(pc_file, dtype = np.float32).reshape(-1, 4)
    points = points[:, :3]
    pcd = o3d.geometry.PointCloud(o3d.utility.Vector3dVector(points))

    return pcd

def process_point_cloud(pc, camera_matrix, vel_to_cam):
    # Remove the points behind the camera
    idx_img = []
    points = np.array(pc.points)
    points = np.concatenate((points, np.ones((points.shape[0], 1))), axis = 1)
    points_tmp = (vel_to_cam @ points.T).T
    for n in range(points_tmp.shape[0]):
        if points_tmp[n, 2] > 0:
            idx_img.append(n)
    points = points[idx_img, :]

    # Remove the points out of the image
    uv = camera_matrix @ points.T
    uv = (uv / uv[2, :]).T
    idx_img = []
    for n in range(uv.shape[0]):
        if uv[n, 0] >=0 and uv[n, 0] < 1226:
            if uv[n, 1] >=0 and uv[n, 1] < 370:
                idx_img.append(n)
    points = points[idx_img, :]
    pc_clip = o3d.geometry.PointCloud(o3d.utility.Vector3dVector(points[:, 0:3]))
    pcd = pc_clip.voxel_down_sample(0.3)

    return pcd

def get_overlap_ratio_cuda(source, target, threshold):
    source = torch.tensor(np.asarray(source.points), dtype = torch.float32, device = 'cuda')
    target = torch.tensor(np.asarray(target.points), dtype = torch.float32, device = 'cuda')
    
    diff, _, _ = knn_points(source[None, :, :], target[None, :, :], K = 1)
    diff = diff[0, :, 0]
    diff = (diff <= threshold ** 2).to(torch.int32)

    return diff.sum(-1).cpu().item() / source.shape[0]

def process_data(split):
    res = []

    with open(dict_path + split + ".pkl", 'rb') as f:
        metadata_list = pickle.load(f)
    
    for meta in tqdm(metadata_list):
        pc0 = load_KITTI(data_root + meta['pcd0'])
        pc1 = load_KITTI(data_root + meta['pcd1'])
        camera_matrix = meta['camera_matrix']
        vel_to_cam = meta['vel_to_cam']
        pose = meta['transform']

        pc0 = process_point_cloud(pc0, camera_matrix, vel_to_cam)
        pc1 = process_point_cloud(pc1, camera_matrix, vel_to_cam)

        pc1.transform(pose)
        c_overlap = get_overlap_ratio_cuda(pc0, pc1, 0.6)
        pc1.transform(np.linalg.inv(pose))

        if c_overlap >= 0.05:
            meta['overlap'] = c_overlap
            res.append(meta)
            torch.save(torch.tensor(np.asarray(pc0.points), dtype = torch.float32), path_output + meta['pcd0'])
            torch.save(torch.tensor(np.asarray(pc1.points), dtype = torch.float32), path_output + meta['pcd1'])
            if not os.path.exists(path_output + meta["img0"]):
                shutil.copy(data_root + meta["img0"], path_output + meta["img0"])
            if not os.path.exists(path_output + meta["img1"]):
                shutil.copy(data_root + meta["img1"], path_output + meta["img1"])

    with open(dict_path + split + '.pkl', 'wb') as f:
        pickle.dump(res, f)
    
    print(f"Total data num: {len(metadata_list)}")
    print(f"Valid data num: {len(res)}")
    avg_ov = sum([meta['overlap'] for meta in res]) / len(res)
    print(f"Average overlap: {avg_ov}")

if __name__ == "__main__":
    prepare_intrin()
    create_kitti_dict('train')
    process_data('train')
    create_kitti_dict('val')
    process_data('val')
    create_kitti_dict('test')
    process_data('test')