import torch
import numpy as np
import cv2
import shutil
import matplotlib.pyplot as plt
from pathlib import Path
from PIL import Image
from scipy.optimize import linear_sum_assignment
from plot.datasets.utils import *
from plot.utils.io import load_frames, read_image, write_video, save_pcd
from plot.utils.misc import find_target_frame_idx, limit_frames
from plot.utils.geometry import yaw_to_rotation_matrix, rotation_matrix_to_yaw
from plot.utils.processing import *
from plot.utils.viz import visualize_pcd, visualize_bev_and_box3d, draw_point_tracks, get_2d_colors, draw_box2d
from plot.datasets.kitti.config import BaseConfig


KITTI_CLASSES = ['Car']

class LocalObject:
    idx             : int                                   # per frame object index
    ignore          : bool                                  # ignore this object or not            
    occluded        : bool                                  # if this object is occluded in the perspective image
    is_moving       : bool                                  # whether the object is moving or not
    pseudo_lidar    : np.ndarray                            # completed pseudo lidar in shape (N, 3)
    cleaned_lidar   : np.ndarray                            # processed pseudo lidar in shape (N, 3)
    label           : str                                   # object label name
    box             : List[int]                             # (x1, y1, x2, y2)
    area            : float                                 # box area 
    logit           : float                                 # confidence score of the object from gsam
    center          : List[float]                           # object center in the camera coordinate system
    size            : List[float]                           # object size in [L, H, W]
    yaw             : float                                 # object orientation in the camera coordinate system
    mask            : np.ndarray                            # object mask in shape [image_height, image_width]
    flow            : float                                 # mean flow for the object
    tracks          : np.ndarray                            # object tracks in other frames (num_frames, num_pixels, 2)
    tracks_vis      : np.ndarray                            # tracks visibility in other frames (num_frames, num_pixels)
    start_frame_idx : int                                   # object's first appearance frame
    end_frame_idx   : int                                   # object's last seen frame
    
    trfs            : Dict[int, np.ndarray]                 # object transformations {adjacent_frame_idx: array[4,4]}
    visible_frames  : list                                  # object's visible frames [start_frame_idx ... end_frame_idx]
    occluded_frames : list                                  # object's occluded frames during visible frames

    def __init__(self) -> None:
        self.trfs = {}
        self.visible_frames = []
        self.occluded_frames = []
        self.ignore = False
        self.occluded = False
        self.is_moving = False

class GlobalObject:
    idx             : int                                   # global object index
    label           : str                                   # object label name
    size            : List[float]                           # object size in [L, H, W] throught the video
    is_moving       : bool                                  # whether the object is moving or not
    start_frame_idx : int                                   # object's first appearance frame
    end_frame_idx   : int                                   # object's last seen frame
    visible_frames  : List[int]                             # object's visible frames [start_frame_idx ... end_frame_idx]
    occluded_frames : List[int]                             # object's occluded frames during visible frames

    pseudo_lidar    : np.ndarray                            # completed pseudo lidar in shape (N, 3)
    cleaned_lidar   : np.ndarray                            # processed pseudo lidar in shape (N, 3)
    
    logits          : Dict[int, float]                      # per frame gsam confidence score
    boxes           : Dict[int, list]                       # per frame boxes in {frame_idx: [x1, y1, x2, y2]}
    masks           : Dict[int, np.ndarray]                 # per frame masks 

    flows           : Dict[int, float]                      # average flow of the object counted from start appearance to last appearance
    tracks          : Dict[int, np.ndarray]                 # object's tracks (num_frames, num_pixels, 2) with num_frames < 30
    tracks_vis      : Dict[int, np.ndarray]                 # object's tracks visibility (num_frames, num_pixels) with num_frames < 30

    trfs            : Dict[int, np.ndarray]                 # object transformations to the first frame {adjacent_frame_idx: array[4,4]}
    adj_trfs        : Dict[int, np.ndarray]
    
    yaws            : Dict[int, float]                      # per frame object's orientation
    centers         : Dict[int, list]                       # per frame object's center in cam system {frame_idx: [x, y, z]}
    unmatched_times : int

    def __init__(self) -> None:
        self.visible_frames = []
        self.occluded_frames = []
        self.centers = {}
        self.yaws = {}
        self.trfs = {}
        self.adj_trfs = {}
        self.logits = {}
        self.boxes = {}
        self.masks = {}
        self.flows = {}
        self.tracks = {}
        self.tracks_vis = {}
        self.is_moving = False
        self.label = "Car"
        self.start_frame_idx = None
        self.end_frame_idx = None
        self.unmatched_times = 0

    

class Frame:
    idx             : int
    name            : str = "Name"                          # frame name
    image           : np.ndarray                            # rgb image
    points3d        : np.ndarray                            # point cloud in shape (H, W, 3)
    background_flow : float
    objects         : Dict[int, LocalObject]                # local objects
    matched_indices : Dict[int, int]                             # mapping between target frame object indices and current frame object indices 
    unmatched_dets  : list
    unmatched_tracks: list
    local_to_global : Dict[int, int]

    def __init__(self) -> None:
        self.objects = {}
        self.matched_indices = {}
        self.local_to_global = {}


class Video:
    name            : str                                   # scene name
    image_size      : list                                  # image size in (H, W)
    intrinsic       : np.ndarray                            # same intrinsic for the whole video
    num_frames      : int                                   # num of frames exist in the segment
    fps             : int
    meshgrid        : np.ndarray
    frame_lists     : List[Path]
    tgt_frame_idx   : int

    frames          : Dict[int, Frame]                      # dict of Frames [frame_idx, Frame]
    global_objects  : Dict[int, GlobalObject]               # dict of GlobalObjects [frame_idx, GlobalObject]
    poses           : Dict[int, np.ndarray]                 # first frame posts (1 to 0)
    rel_poses       : Dict[int, np.ndarray]
    moving_cam      : bool = False
    bg_flows        : np.ndarray
    background_flow : float                                 # average flow of the video

    def __init__(self) -> None:
        self.frames = {}
        self.global_objects = {}
        self.poses = {}
        self.rel_poses = {}


class PseudoLabeler:
    def __init__(self, config: BaseConfig) -> None:
        self.config = config
        self.dataset_root = Path(config.dataset_root)
        self.pseudo_root = Path(config.pseudo_root)

        # original video infos
        self.img_dir = self.dataset_root / "frames"
        self.calib_dir = self.dataset_root / "calib"

        # foundation model outputs
        self.depth_dir = self.pseudo_root / "unidepthv1"
        self.gsam_dir = self.pseudo_root / "gsam_frames"
        self.track_dir = self.pseudo_root / "alltracker"
        self.pose_dir = self.pseudo_root / "poses"

        # save directories
        self.pseudo_label_dir = self.pseudo_root / "pseudo_labels_nxn_adj"
        self.pseudo_results_dir = self.pseudo_root / "image_results_nxn_adj"
        self.pseudo_label_dir.mkdir(parents=True, exist_ok=True)
        self.pseudo_results_dir.mkdir(parents=True, exist_ok=True)


        # self.scenes = self.select_train_scenes()
        self.scenes = self.select_test_scenes()

        self.video: Video = None
        self.global_objects: Dict[int, GlobalObject] = {}
        self.global_obj_idx = 0

    
    def adjacent_frame_registration(self):
        # initialize from the first frame
        self.initialization()
        
        for src_frame_idx in range(self.video.num_frames - 1):
            tgt_frame_idx = src_frame_idx + 1

            matched_indices, unmatched_dets, unmatched_tracks = self.one_to_one_match(src_frame_idx, tgt_frame_idx)
            self.video.frames[tgt_frame_idx].matched_indices = matched_indices
            self.video.frames[tgt_frame_idx].unmatched_dets = unmatched_dets
            self.video.frames[tgt_frame_idx].unmatched_tracks = unmatched_tracks

            
            # for the new incoming detections, there are only two conditions
            # if it is matched to a previous track, add its observation
            # if not, initialize as a new object (WARNING: there might be some fragmented tracks)
            for new_det_idx, new_det in self.video.frames[tgt_frame_idx].objects.items():
                # if this object is unmatched with previous tracked object, add as a new global object
                if new_det_idx in unmatched_dets: 
                    self.init_global_object(new_det, tgt_frame_idx)
                # if this object is matched with previous frame objects, append the new observations
                else:
                    global_obj_idx = self.video.frames[src_frame_idx].local_to_global[matched_indices[new_det_idx]]
                    self.append_new_observations(new_det, global_obj_idx, src_frame_idx, tgt_frame_idx)

            # if there is unmatched tracks from the previous frame
            for track_idx in unmatched_tracks:
                # if the tracklet does not leave the frame but failed matching with new detections
                # might be because of occlusion in that frame or imperfect detections from gsam
                if not self.is_tracklet_exit(track_idx, src_frame_idx, tgt_frame_idx):
                    # print("append tracked", local_obj_idx, src_frame_idx, tgt_frame_idx)
                    self.append_tracked_observations(track_idx, src_frame_idx, tgt_frame_idx)
                # the tracked objects leaves the frame
                else:
                    # print("tracks exit", local_obj_idx, src_frame_idx, tgt_frame_idx)
                    self.tracklet_exit_handle(track_idx, src_frame_idx) 
        
        # if the tracklet does not exit throughout the video, use the last frame as the end frame idx
        for global_obj in self.global_objects.values():
            if global_obj.end_frame_idx is None:
                global_obj.end_frame_idx = self.video.num_frames - 1

        # if the track is less than the threshold, ignore it
        new_global_objects = {}
        self.global_obj_idx = 0
        for global_obj_idx, global_obj in self.global_objects.items():
            track_length = global_obj.end_frame_idx - global_obj.start_frame_idx
            if track_length > self.config.valid_obj_threshold:
                new_global_objects[self.global_obj_idx] = global_obj
                self.global_obj_idx += 1
        self.global_objects = new_global_objects


    def initialization(self) -> Tuple[int, Dict[int, GlobalObject]]:
        # reset global objects
        self.global_obj_idx, self.global_objects = 0, {}
        # init with first frame objects
        for local_obj in self.video.frames[0].objects.values():
            self.init_global_object(local_obj, 0)

    def init_global_object(self, local_object: LocalObject, frame_idx: int) -> GlobalObject:
        global_object = GlobalObject()
        global_object.idx = self.global_obj_idx
        global_object.start_frame_idx = frame_idx
        global_object.visible_frames.append(frame_idx)
        global_object.tracks[frame_idx] = local_object.tracks
        global_object.tracks_vis[frame_idx] = local_object.tracks_vis
        global_object.flows[frame_idx] = local_object.flow
        global_object.logits[frame_idx] = local_object.logit
        global_object.boxes[frame_idx] = local_object.box
        global_object.masks[frame_idx] = local_object.mask
        # first frame registration matrix and pseudo lidar
        global_object.trfs[frame_idx] = np.eye(4)
        global_object.pseudo_lidar = self.video.frames[frame_idx].points3d[local_object.mask]
        self.global_objects[self.global_obj_idx] = global_object
        self.video.frames[frame_idx].local_to_global[local_object.idx] = self.global_obj_idx
        self.global_obj_idx += 1


    def one_to_one_match(self, src_frame_idx: int, tgt_frame_idx: int):
        H, W = self.video.image_size
        matched_indices, unmatched_detections, unmatched_tracks = {}, [], []

        tracked_boxes = []
        for src_object in self.video.frames[src_frame_idx].objects.values():
            tbox = tracked_to_box(src_object.tracks[tgt_frame_idx], src_object.tracks_vis[tgt_frame_idx] > self.config.flow_vis_threshold, W, H)
            tracked_boxes.append(tbox)

        # scale boxes according to the tracked result (make smaller)
        original_boxes = []
        for tgt_object in self.video.frames[tgt_frame_idx].objects.values():
            x1, y1, x2, y2 = tgt_object.box
            offset_x = 5 if x2 - x1 >= 10 else 3
            offset_y = 5 if y2 - y1 >= 10 else 3
            original_boxes.append([int(x1+offset_x), int(y1+offset_y), int(x2-offset_x), int(y2-offset_y)])

        if len(original_boxes) < 1 or len(tracked_boxes) < 1: 
            if len(original_boxes) < 1:
                unmatched_tracks = [t_idx for t_idx in range(len(tracked_boxes))]
            if len(tracked_boxes) < 1:
                unmatched_detections = [o_idx for o_idx in range(len(original_boxes))]
            return matched_indices, unmatched_detections, unmatched_tracks

        iou_mat = get_iou_matrix(np.stack(original_boxes, axis=0), np.stack(tracked_boxes, axis=0), add1=True)
        original_indices, tracked_indices = linear_sum_assignment(iou_mat, maximize=True)        
        matched_indices = {o_idx: t_idx for o_idx, t_idx in zip(original_indices, tracked_indices) if iou_mat[o_idx, t_idx] >= self.config.match_iou_threshold}
        unmatched_detections = [o_idx for o_idx in range(len(original_boxes)) if o_idx not in list(matched_indices.keys())]
        unmatched_tracks = [t_idx for t_idx in range(len(tracked_boxes)) if t_idx not in list(matched_indices.values())]
        return matched_indices, unmatched_detections, unmatched_tracks

    
    def append_new_observations(self, local_object: LocalObject, global_obj_idx: int, src_frame_idx: int, tgt_frame_idx: int):
        """Aggreagate new observations from matched detections
        """
        self.video.frames[tgt_frame_idx].local_to_global[local_object.idx] = global_obj_idx
        self.global_objects[global_obj_idx].visible_frames.append(tgt_frame_idx)
        self.global_objects[global_obj_idx].tracks[tgt_frame_idx] = local_object.tracks
        self.global_objects[global_obj_idx].tracks_vis[tgt_frame_idx] = local_object.tracks_vis
        self.global_objects[global_obj_idx].flows[tgt_frame_idx] = local_object.flow
        self.global_objects[global_obj_idx].logits[tgt_frame_idx] = local_object.logit
        self.global_objects[global_obj_idx].boxes[tgt_frame_idx] = local_object.box
        self.global_objects[global_obj_idx].masks[tgt_frame_idx] = local_object.mask

        # object registration between current frame and previous frame
        obj_trf, pseudo_lidar = self.object_registration_two_frames(local_object, src_frame_idx, tgt_frame_idx)

        # in case of failure registration due to small object points and high registration errors
        # use ego pose as a object pose
        if obj_trf is None:
            obj_trf = self.video.rel_poses[tgt_frame_idx]
            self.global_objects[global_obj_idx].occluded_frames.append(tgt_frame_idx)

        # convert the transformation and pseudo lidar to the first frame
        trf = self.global_objects[global_obj_idx].trfs[src_frame_idx] @ obj_trf
        self.global_objects[global_obj_idx].trfs[tgt_frame_idx] = trf

        # aggreagete pseudo lidar
        # only if enough points
        if pseudo_lidar is not None:
            pseudo_lidar = self.transform_points(pseudo_lidar, trf)
            prev_pseudo_lidar = self.global_objects[global_obj_idx].pseudo_lidar
            self.global_objects[global_obj_idx].pseudo_lidar = np.concatenate([prev_pseudo_lidar, pseudo_lidar], axis=0)


    def append_tracked_observations(self, local_obj_idx: int, src_frame_idx: int, tgt_frame_idx: int):
        """Aggreagate tracked observations from unmatched tracklets
        """
        H, W = self.video.image_size
        global_obj_idx = self.video.frames[src_frame_idx].local_to_global[local_obj_idx]

        # create mask from tracked points of previous frame
        tracks = self.global_objects[global_obj_idx].tracks[src_frame_idx][tgt_frame_idx]
        tracks_vis = self.global_objects[global_obj_idx].tracks_vis[src_frame_idx][tgt_frame_idx] > self.config.flow_vis_threshold

        new_mask = tracked_to_mask(tracks, tracks_vis, W, H)
        new_box = tracked_to_box(tracks, tracks_vis, W, H)
        # enlarge the box a little bit
        x1, y1, x2, y2 = new_box
        x1, y1, x2, y2 = int(x1-5), int(y1-5), int(x2+5), int(y2+5)
        new_box = (x1, y1, x2, y2)
        new_logit = self.global_objects[global_obj_idx].logits[src_frame_idx]
        
        is_occluded = True if (np.sum(tracks_vis) / len(tracks)) < self.config.tracklet_occlusion_threshold else False

        if is_occluded:
            self.global_objects[global_obj_idx].occluded_frames.append(tgt_frame_idx)
        else:
            self.global_objects[global_obj_idx].visible_frames.append(tgt_frame_idx)

        self.global_objects[global_obj_idx].logits[tgt_frame_idx] = new_logit
        self.global_objects[global_obj_idx].boxes[tgt_frame_idx] = new_box
        self.global_objects[global_obj_idx].masks[tgt_frame_idx] = new_mask
        self.global_objects[global_obj_idx].unmatched_times += 1

        # TODO: for now, just propagate the previous track results
        new_flow = self.global_objects[global_obj_idx].flows[src_frame_idx]
        new_tracks = self.global_objects[global_obj_idx].tracks[src_frame_idx]
        new_tracks_vis = self.global_objects[global_obj_idx].tracks_vis[src_frame_idx]
        self.global_objects[global_obj_idx].flows[tgt_frame_idx] = new_flow
        self.global_objects[global_obj_idx].tracks[tgt_frame_idx] = new_tracks
        self.global_objects[global_obj_idx].tracks_vis[tgt_frame_idx] = new_tracks_vis    
        
        # skip the object registration since the object may be occluded 
        ego_trf = self.video.rel_poses[tgt_frame_idx]
        prev_trf = self.global_objects[global_obj_idx].trfs[src_frame_idx]
        trf = prev_trf @ ego_trf
        self.global_objects[global_obj_idx].trfs[tgt_frame_idx] = trf
        
        # add this object to the frames local objects
        new_local_idx = len(self.video.frames[tgt_frame_idx].objects)
        recovered_object = LocalObject()
        recovered_object.idx = new_local_idx
        recovered_object.box = new_box
        recovered_object.area = (x2 - x1) * (y2 - y1)
        recovered_object.logit = new_logit
        recovered_object.mask = new_mask
        recovered_object.flow = new_flow
        recovered_object.tracks = new_tracks
        recovered_object.tracks_vis = new_tracks_vis
        self.video.frames[tgt_frame_idx].objects[new_local_idx] = recovered_object
        self.video.frames[tgt_frame_idx].local_to_global[new_local_idx] = global_obj_idx


    def is_tracklet_exit(self, local_obj_idx: int, src_frame_idx: int, tgt_frame_idx: int):
        H, W = self.video.image_size
        global_obj_idx = self.video.frames[src_frame_idx].local_to_global[local_obj_idx]
        tracks = self.global_objects[global_obj_idx].tracks[src_frame_idx][tgt_frame_idx]
        tracks_vis = self.global_objects[global_obj_idx].tracks_vis[src_frame_idx][tgt_frame_idx] > self.config.flow_vis_threshold
        vis_mask = (tracks[:, 0] > 0) & (tracks[:, 1] > 0) & (tracks[:, 0] < W) & (tracks[:, 1] < H)
        vis_mask = vis_mask & tracks_vis
        if (np.sum(vis_mask) / len(tracks)) < self.config.tracklet_exit_threshold:
            return True
        if self.global_objects[global_obj_idx].unmatched_times > self.config.unmatched_times_threshold:
            return True
        return False

    def tracklet_exit_handle(self, local_obj_idx: int, src_frame_idx: int):
        global_obj_idx = self.video.frames[src_frame_idx].local_to_global[local_obj_idx]
        self.global_objects[global_obj_idx].end_frame_idx = src_frame_idx


    def object_registration_two_frames(self, object: LocalObject, src_frame_idx: int, tgt_frame_idx: int):
        """A single object registration between two frames
        Object is a Target Object (i.e. a new detection)
        """
        src_pts2d = object.tracks[src_frame_idx]
        tgt_pts2d = object.tracks[tgt_frame_idx]

        # visibility filter of tracked points
        vis_mask = self.filter_visibility_and_confidence(src_pts2d, tgt_pts2d, object.tracks_vis[src_frame_idx])
        src_pts3d = self.video.frames[src_frame_idx].points3d[src_pts2d[vis_mask, 1], src_pts2d[vis_mask, 0], :]
        tgt_pts3d = self.video.frames[tgt_frame_idx].points3d[tgt_pts2d[vis_mask, 1], tgt_pts2d[vis_mask, 0], :]
        # object is occluded in this frame
        if len(src_pts3d) < 10 or len(tgt_pts3d) < 10:
            return None, None
        
        # outliers removal (only used for faster registration)
        src_noise_mask = remove_outliers_by_percentile(src_pts3d, 10, 85)[1]
        tgt_noise_mask = remove_outliers_by_percentile(tgt_pts3d, 10, 85)[1]
        noise_mask = src_noise_mask & tgt_noise_mask
        if np.sum(noise_mask) > 20:
            src_pts3d, tgt_pts3d = src_pts3d[noise_mask], tgt_pts3d[noise_mask]
        
        trf, _, cost = ransac_registration(src_pts3d, tgt_pts3d, self.config.ransac_iters, self.config.ransac_inlier_threshold, self.config.ransac_stop_threshold)
        
        if cost > 2:
            # return None, None
            trf = np.linalg.inv(self.video.rel_poses[tgt_frame_idx])
        
        trf = np.linalg.inv(trf)    # target to source (i.e. to the previous frame)
        return trf, self.video.frames[tgt_frame_idx].points3d[object.mask, :]
    
    def save_pseudo_lidar(self):
        save_dir = self.pseudo_root / "pseudo_lidars" / f"{self.video.name}"
        save_dir.mkdir(parents=True, exist_ok=True)
        for global_obj_idx, global_object in self.global_objects.items():
            if self.video.tgt_frame_idx in list(global_object.trfs.keys()):
                tgt_trf = global_object.trfs[self.video.tgt_frame_idx]
                pseudo_lidar = self.transform_points(global_object.cleaned_lidar, np.linalg.inv(tgt_trf))
                save_pcd(pseudo_lidar, str(save_dir / f"{global_obj_idx}.ply"))

    def compute_attributes(self):
        self.clean_pseudo_lidar()
        self.save_pseudo_lidar()
        labels = []

        for global_obj_idx, global_object in self.global_objects.items():
            self.compute_object_direction(global_obj_idx)
            self.compute_object_size(global_obj_idx)

            visible_frames = list(range(global_object.start_frame_idx, global_object.end_frame_idx+1))

            if self.video.tgt_frame_idx in visible_frames:
                labels.append({
                    "label": global_object.label,
                    "size": global_object.size,
                    "box": global_object.boxes[self.video.tgt_frame_idx],
                    "yaw": global_object.yaws[self.video.tgt_frame_idx],
                    "center": global_object.centers[self.video.tgt_frame_idx]
                })

        image = self.video.frames[self.video.tgt_frame_idx].image
        self.save_kitti_labels(labels, self.pseudo_label_dir / f"{self.video.name}.txt")
        self.visualize_pseudo_lidar_with_box(image, labels, self.pseudo_results_dir / f"{self.video.name}.png")

    def search_valid_adjacent_frame_idx(self, tgt_frame_idx, global_obj_idx):
        tgt_object = self.global_objects[global_obj_idx]
        found = False
        # search forward
        for adj_frame_idx in range(tgt_frame_idx+1, self.video.num_frames):
            if tgt_object.trfs[adj_frame_idx] is None: continue
            if abs(adj_frame_idx - tgt_frame_idx) >= self.config.frame_gap:
                found = True
                break
        # search backward
        if not found:
            for adj_frame_idx in range(tgt_frame_idx-1, 0, -1):
                if tgt_object.trfs[adj_frame_idx] is None: continue
                if abs(adj_frame_idx - tgt_frame_idx) >= self.config.frame_gap:
                    found = True
                    break
        if not found:
            return None
        return adj_frame_idx


    def compute_object_direction(self, global_obj_idx: int):
        global_object = self.global_objects[global_obj_idx]
        start_frame_idx = global_object.start_frame_idx
        # end_frame_idx = global_object.end_frame_idx
        end_frame_idx = start_frame_idx + 4

        yaws = []

        # decide whether the object is moving or not
        is_moving = False
        obj_trf_to_first = global_object.trfs[end_frame_idx]
        ego_trf_to_first = self.video.poses[end_frame_idx]

        obj_motion = np.linalg.norm(obj_trf_to_first[[0, 2], -1])
        ego_motion = np.linalg.norm(ego_trf_to_first[[0, 2], -1])
        print(obj_motion, ego_motion, abs(ego_motion - obj_motion))

        if abs(ego_motion - obj_motion) > 2:
            is_moving = True
            for frame_idx in range(start_frame_idx, end_frame_idx+1):
                next_frame_idx = frame_idx + 1
                obj_trf = global_object.adj_trfs[next_frame_idx]
                if frame_idx == end_frame_idx:
                    next_frame_idx = frame_idx - 1
                    obj_trf = np.linalg.inv(global_object.adj_trfs[frame_idx])

                yaw = None
                self.global_objects[global_obj_idx].yaws[frame_idx] = yaw
            return

        if not is_moving:
            yaw, _ = compute_dir_with_PCA(global_object.pseudo_lidar)
        yaw = self.yaw_correction(yaw, 25)
        obj_rot = yaw_to_rotation_matrix(yaw)

        for frame_idx in range(start_frame_idx, global_object.end_frame_idx+1):
            # # slightly shift the yaw based on ego vehicle orientation
            ego_trf = self.video.poses[frame_idx]
            new_rot = ego_trf[:3, :3] @ obj_rot
            
            new_yaw = rotation_matrix_to_yaw(new_rot)

            self.global_objects[global_obj_idx].yaws[frame_idx] = new_yaw

    def compute_object_size(self, global_obj_idx: int):
        global_object = self.global_objects[global_obj_idx]
        # compute the global object size
        pseudo_lidar = self.transform_points(global_object.cleaned_lidar, np.linalg.inv(global_object.trfs[global_object.start_frame_idx]))
        axis_aligned_lidar = (rotate_y(-global_object.yaws[global_object.start_frame_idx]) @ pseudo_lidar.T).T
        # axis aligned center and size
        (L, H, W), _ = compute_size_from_axis_aligned_points(axis_aligned_lidar)
        self.global_objects[global_obj_idx].size = [L, H, W]

        for frame_idx in range(global_object.start_frame_idx, global_object.end_frame_idx+1):
            # send the pseudo-lidar to the target frame
            # print(list(global_object.trfs.keys()), global_object.start_frame_idx, global_object.end_frame_idx)
            # if frame_idx not in list(global_object.trfs.keys()):
            #     global_object.end_frame_idx = frame_idx - 1
            #     break
            trf = global_object.trfs[frame_idx]
            if trf is not None:
                pseudo_lidar = self.transform_points(global_object.cleaned_lidar, np.linalg.inv(trf))
                # make the points to axis-aligned
                axis_aligned_lidar = (rotate_y(-global_object.yaws[frame_idx]) @ pseudo_lidar.T).T
                # axis aligned center and size
                _, (x3d, y3d, z3d) = compute_size_from_axis_aligned_points(axis_aligned_lidar)
                axis_aligned_box = create_ego_box3d(x3d, y3d, z3d, L, H, W, yaw=0)
                # rotate the box
                ego_aligned_box = (rotate_y(global_object.yaws[frame_idx]) @ axis_aligned_box.T).T
                x3d, y3d, z3d = ego_aligned_box.mean(0)
            else:
                cx, cy = (global_object.boxes[frame_idx][0]+global_object.boxes[frame_idx][2])/2, (global_object.boxes[frame_idx][1]+global_object.boxes[frame_idx][3])/2
                x3d, y3d, z3d = self.video.frames[frame_idx].points3d[int(cy), int(cx), :]

            # front car
            if -1 < x3d < 1:
                self.global_objects[global_obj_idx].yaws[frame_idx] = -1.56
            self.global_objects[global_obj_idx].centers[frame_idx] = [x3d, y3d, z3d]

    def transform_points(self, points: np.ndarray, trf: np.ndarray):
        return ((trf[:3, :3] @ points.T) + trf[:3, -1:]).T

    def filter_gsam_detections(self, gsam_results: Dict[str, np.ndarray], depth_map: np.ndarray):
        boxes = []
        remaps = {}
        obj_counter = 0
        filtered_gsam_results = {"masks": [], "boxes": [], "logits": [], "labels": []}
        for obj_idx in range(len(gsam_results['boxes'])):
            x1, y1, x2, y2 = list(map(int, gsam_results['boxes'][obj_idx]))
            cx, cy = (x1+x2)//2, (y1+y2)//2
            center_depth = depth_map[cy, cx]
            area = (x2 - x1) * (y2 - y1)

            # filter small objects (based on masks size)
            if area < self.config.small_obj_threshold: continue

            # filter very near or far objects (based on depth)
            if center_depth < self.config.near_depth_threshold or center_depth > self.config.far_depth_threshold: continue

            # else aggregate the boxes
            boxes.append([x1, y1, x2, y2])
            remaps[obj_counter] = obj_idx
            obj_counter += 1

        # if no boxes return empty
        if len(boxes) > 0: 
            # hungarian match and find duplicate
            boxes = np.stack(boxes, axis=0)
            iou_mat = get_iou_matrix(boxes, boxes)
            diagonal = np.diag_indices(iou_mat.shape[0])
            iou_mat[diagonal] = 0
            indices1, indices2 = linear_sum_assignment(iou_mat, maximize=True)
            duplicated_indices = [id2 for id1, id2 in zip(indices1, indices2) if iou_mat[id1, id2] > self.config.duplicate_iou_threshold]

            masks, boxes, logits, labels = [], [], [], []
            for obj_idx in range(len(iou_mat)):
                orig_index = remaps[obj_idx]
                if obj_idx not in duplicated_indices:
                    masks.append(gsam_results['masks'][orig_index])
                    boxes.append(gsam_results['boxes'][orig_index])
                    logits.append(gsam_results['logits'][orig_index])
                    labels.append(gsam_results['labels'][orig_index])
            if len(boxes) < 1:
                return filtered_gsam_results
            filtered_gsam_results["masks"] = np.stack(masks)
            filtered_gsam_results["boxes"] = np.stack(boxes)
            filtered_gsam_results["logits"] = np.stack(logits)
            filtered_gsam_results["labels"] = np.stack(labels)      
        return filtered_gsam_results

    def save_kitti_labels(self, results, save_path: Path):
        labels, truncated = [], 0
        for label in results:
            # if object.ignore: continue
            # if object.occluded: continue
            size = label['size']
            box = label['box']
            yaw = label['yaw']
            center = label['center']
            # ignore floating objects
            if center[1] < 0 or center[1] > 2:
                continue

            # ignore far objects:
            if label['center'][2] > 65:
                continue

            y = center[1] + size[1]/2     # bottom center in KITTI labels
            occluded = 0
            alpha = self.calculate_alpha(yaw, box, self.video.intrinsic)
            labels.append(
                f"{label['label']} {truncated} {occluded} {alpha:.2f} " + 
                f"{box[0]:.2f} {box[1]:.2f} {box[2]:.2f} {box[3]:.2f} " +
                f"{size[1]:.2f} {size[2]:.2f} {size[0]:.2f} " +
                f"{center[0]:.2f} {y:.2f} {center[2]:.2f} {yaw:.2f}\n"
            )
        with open(save_path, "w") as f:
            f.writelines(labels)

    def visualize_pseudo_lidar_with_box(self, image: np.ndarray, labels: List[dict], save_path: str):
        boxes, movings, pseudo_lidars = [], [], []
        for label in labels:
            # if object.ignore: continue
            # if object.occluded: continue

            # ignore floating objects
            if label['center'][1] < 0 or label['center'][1] > 2:
                continue

            # ignore far objects:
            if label['center'][2] > 65:
                continue

            boxes.append([*label['center'], *label['size'], label['yaw']])
            # movings.append(object.is_moving)
            # pseudo_lidars.append(object.cleaned_lidar)
        visualize_bev_and_box3d(image, boxes, None, self.video.intrinsic, None, None, save_path)


    def clean_pseudo_lidar(self):
        for global_obj_idx in list(self.global_objects.keys()):
            pseudo_lidar = self.global_objects[global_obj_idx].pseudo_lidar
            cls_name = self.global_objects[global_obj_idx].label
            filtered_pseudo_lidar, _ = range_filter(pseudo_lidar, min_depth=1, max_depth=80)
            if len(filtered_pseudo_lidar) > 10:
                q_min, q_max = filter_thresholds[cls_name] if cls_name in filter_thresholds.keys() else filter_thresholds["Others"]
                cleaned_pseudo_lidar, _ = remove_outliers_by_percentile(filtered_pseudo_lidar, q_min, q_max)
                if len(cleaned_pseudo_lidar) > 10:
                    self.global_objects[global_obj_idx].cleaned_lidar = cleaned_pseudo_lidar
                else:
                    self.global_objects[global_obj_idx].cleaned_lidar = filtered_pseudo_lidar
            else:
                self.global_objects[global_obj_idx].cleaned_lidar = pseudo_lidar


    def visualize_point_tracks(self):
        H, W = self.info.image_size
        images = [self.video.frames[t].image.copy() for t in range(self.video.num_frames)]
        
        for local_obj_idx, local_obj in self.video.frames[0].objects.items():
            # subsample to make the vis more readable
            tracks2 = local_obj.tracks[:, ::4, :]
            tracks_vis2 = local_obj.tracks_vis[:, ::4]
            T, M, _ = tracks2.shape
            colors = get_2d_colors(tracks2[0], H, W)

            for t in range(T):
                draw_point_tracks(images[t], tracks2[t], tracks_vis2[t], colors)

        tracks_dir = self.pseudo_root / "alltracker_vis" / self.video.name
        tracks_dir.mkdir(parents=True, exist_ok=True)
        write_video(images, tracks_dir, self.video.name)

    def yaw_correction(self, yaw, offset=10):
        # facing south (90 deg)
        if np.deg2rad(90-offset) < yaw < np.deg2rad(90+offset):
            yaw = np.deg2rad(90)
        # facing north
        elif np.deg2rad(-90-offset) < yaw < np.deg2rad(-90+offset):
            yaw = np.deg2rad(-90)
        # facing east
        elif np.deg2rad(-offset) < yaw < np.deg2rad(offset):
            yaw = 0
        # facing west
        elif (np.deg2rad(-180+offset) > yaw) or (yaw > np.deg2rad(180-offset)):
            yaw = np.deg2rad(180) 
        return yaw

    def calculate_alpha(self, yaw: float, box: list, intrinsic: np.ndarray) -> float:
        return ry2alpha(yaw, (box[2]+box[0])/2, intrinsic[0, 0], intrinsic[0, 2])

    def create_visibility_mask(self, points2d, points2d_corr, frame_width, frame_height):
        vis_mask = (points2d[:, 0] < frame_width) & (points2d[:, 1] < frame_height) & \
                            (points2d_corr[:, 0] < frame_width) & (points2d_corr[:, 1] < frame_height)
        return vis_mask
    
    def filter_visibility_and_confidence(self, points1, points2, tracks_vis):
        H, W = self.video.image_size
        mask = np.zeros(len(points1)).astype(bool)
        vis_mask = self.create_visibility_mask(points1, points2, W, H)
        mask[vis_mask == 1] = 1
        mask[vis_mask == 1] = tracks_vis[vis_mask] > self.config.flow_vis_threshold 
        return mask

    def aggreage_flows(self, flows, flows_vis):
        # flows in shape [N, M, 2]
        vis_mask = flows_vis > self.config.flow_vis_threshold
        return np.mean(np.abs(flows[vis_mask]))
    
    def select_train_scenes(self):
        with open(self.dataset_root / "ImageSets" / "train.txt") as f:
            scene_names = f.read().splitlines()
        return scene_names
    
    def select_test_scenes(self):
        with open(self.dataset_root / "ImageSets" / "val.txt") as f:
            scene_names = f.read().splitlines()
        return scene_names
    
    def check_already_labeled(self, scene_name):
        if (self.pseudo_label_dir / f"{scene_name}.txt").exists():
            return True
        return False
    
    def decide_moving_cam(self, video: Video):
        # only aggregate the frame with large enough threshold during deciding the whole video is moving or not
        flows = [frame.background_flow for frame in video.frames.values() if frame.background_flow > self.config.scene_flow_threshold]
        if np.mean(flows) > self.config.scene_flow_threshold:
            return True
        return False
    
    def __len__(self):
        return len(self.scenes)
        
    def load_information(self, scene_name):
        pass

    def __getitem__(self, index) -> Video:
        return self.scenes[index]
    
