import os
import numpy as np
from tqdm import tqdm
from pathlib import Path
from plot.datasets.utils import read_kitti_calib
from plot.utils.io import load_frames
from plot.utils.geometry import meshgrid2d, depthmap_to_pts3d, rigid_points_registration_numpy
from plot.datasets.utils import merge_object_masks, ransac_registration, ransac_registration_gpu
from plot.utils.processing import erode_mask
from plot.utils.misc import find_target_frame_idx, limit_frames

sky_height = 200
depth_threshold = 30


def filter_visibility(points1, points2, tracks_vis, H, W):
    mask = np.zeros(len(points1)).astype(bool)
    vis_mask = (points1[:, 0] < W) & (points1[:, 1] < H) & (points2[:, 0] < W) & (points2[:, 1] < H)
    mask[vis_mask == 1] = 1
    mask[vis_mask == 1] = tracks_vis[vis_mask] > 0.9
    return mask


def estimate_cam_pose(root: Path, scene_name: str, alltracker_dir: Path):
    intrinsic = read_kitti_calib(root / "calib" / f"{scene_name}.txt")
    frames, frame_lists_orig = load_frames(root / "frames" / scene_name)
    tgt_frame_idx_orig = find_target_frame_idx(frame_lists_orig, 20)
    frame_lists = limit_frames(frame_lists_orig, tgt_frame_idx_orig, 10)
    tgt_frame_idx = find_target_frame_idx(frame_lists, 20)
    H, W, _ = frames[0].shape
    grid_xy = meshgrid2d(W, H).reshape(H, W, 2).numpy()

    tgt_frame_path = frame_lists[tgt_frame_idx]
    tgt_frame_obj_masks = np.load(root / "gsam_frames" / scene_name / f"{tgt_frame_path.stem}.npy", allow_pickle=True).item()['masks']
    if len(tgt_frame_obj_masks) > 1:
        tgt_obj_mask = merge_object_masks(tgt_frame_obj_masks)
        tgt_bg_mask = erode_mask(~tgt_obj_mask, 15, 15)
    else:
        tgt_bg_mask = np.ones((H, W), dtype=bool)
    tgt_bg_mask[:sky_height, :] = 0

    tgt_depth = np.load(root / "unidepthv1" / scene_name / f"{tgt_frame_path.stem}.npz")['depth']
    tgt_points3d = depthmap_to_pts3d(tgt_depth, intrinsic[0, 0], intrinsic[0, 2], intrinsic[1, 2])

    dense_results = np.load(alltracker_dir / scene_name / f"{scene_name}_dense.npz")
    dense_tracks, dense_tracks_vis = dense_results['tracks'], dense_results['visibility']
    dense_flows = dense_tracks - grid_xy[None, ...]

    tgt_bg_tracks = (dense_flows[:, tgt_bg_mask] + grid_xy[None, tgt_bg_mask]).astype(int)
    tgt_bg_tracks_vis = dense_tracks_vis[:, tgt_bg_mask]
    tgt_pts2d = tgt_bg_tracks[tgt_frame_idx]

    cam_poses = {}
    for frame_idx, frame_path in enumerate(frame_lists):
        if frame_idx == tgt_frame_idx:
            cam_poses[frame_lists[tgt_frame_idx].stem] = np.eye(4)
            continue

        depth = np.load(root / "unidepthv1" / scene_name / f"{frame_path.stem}.npz")['depth']
        points3d = depthmap_to_pts3d(depth, intrinsic[0, 0], intrinsic[0, 2], intrinsic[1, 2])

        src_pts2d = tgt_bg_tracks[frame_idx]
        
        vis_mask = filter_visibility(src_pts2d, tgt_pts2d, tgt_bg_tracks_vis[frame_idx], H, W)
        src_pts3d = points3d[src_pts2d[vis_mask, 1], src_pts2d[vis_mask, 0], :]
        tgt_pts3d = tgt_points3d[tgt_pts2d[vis_mask, 1], tgt_pts2d[vis_mask, 0], :]
        depth_mask = (src_pts3d[:, 2] < depth_threshold) & (tgt_pts3d[:, 2] < depth_threshold)

        # RANSAC registration for better pose estimation
        f2f_mat, _, cost = ransac_registration_gpu(src_pts3d[depth_mask], tgt_pts3d[depth_mask], 300, 0.5, 0.5, n_samples=10)
        cam_poses[frame_path.stem] = f2f_mat
    return cam_poses



if __name__ == '__main__':
    root = Path("PATH_TO_DATA")
    tracks_dir = root / "alltracker"

    with open(root / "ImageSets" / "val.txt") as f:
        scene_names = f.read().splitlines()

    for scene_name in tqdm(scene_names):
        pose_dir = root / "poses" / scene_name
        pose_dir.mkdir(parents=True, exist_ok=True)
        cam_poses = estimate_cam_pose(root, scene_name, tracks_dir)
        for frame_idx, pose in cam_poses.items():
            np.save(pose_dir / f"{frame_idx}.npy", pose)