import numpy as np
import time
import multiprocessing as mp
from tqdm import tqdm
from pathlib import Path
from plot.utils.io import load_frames, read_image
from plot.utils.misc import find_target_frame_idx, limit_frames
from plot.utils.geometry import meshgrid2d, depthmap_to_pts3d
from plot.utils.processing import *
from plot.datasets.utils import *
from plot.datasets.base_global import *


KITTI_CLASSES = ['Car', 'Pedestrian', 'Cyclist']


class KITTIPseudoLabeler(PseudoLabelerGlobal):
    def __init__(self, dataset_root: Path, pseudo_root: Path) -> None:
        kitti_config = BaseConfigVideo
        kitti_config.dataset_root = dataset_root
        kitti_config.pseudo_root = pseudo_root
        kitti_config.target_frame_idx = 20
        kitti_config.match_iou_threshold = 0.4
        kitti_config.scene_flow_threshold = 8
        kitti_config.distant_obj_threshold = 50
        kitti_config.sky_height = 150
        kitti_config.frame_gap = 3
        super().__init__(kitti_config)


    def load_information(self, scene_name):
        intrinsic = read_kitti_calib(self.calib_dir / f"{scene_name}.txt")
        _, frame_lists_original = load_frames(self.img_dir / scene_name)
        target_frame_idx_original = find_target_frame_idx(frame_lists_original, self.config.target_frame_idx)
        frame_lists = limit_frames(frame_lists_original, target_frame_idx_original, self.config.frame_limit)
        target_frame_idx = find_target_frame_idx(frame_lists, self.config.target_frame_idx)

        target_frame_obj_masks = np.load(self.gsam_dir / scene_name / f"{frame_lists[target_frame_idx].stem}.npy", allow_pickle=True).item()['masks']
        if len(target_frame_obj_masks) < 1:
            print("No Object Found")
            return None
        
        video = VideoClass()
        video.fps = 10
        video.name = scene_name
        video.intrinsic = intrinsic
        video.num_frames = len(frame_lists)
        video.target_frame_idx = target_frame_idx

        overall_flows = []
        for frame_idx, frame in enumerate(frame_lists):
            depth = np.load(self.depth_dir / scene_name / f"{frame.stem}.npz")
            gsam_results = np.load(self.gsam_dir / scene_name / f"{frame.stem}.npy", allow_pickle=True).item()
            gsam_results = self.remove_duplicate_boxes(gsam_results)
            no_detection = len(gsam_results['masks']) < 1

            frame_attrs = FrameClass()
            frame_attrs.idx = frame_idx
            frame_attrs.name = frame
            frame_attrs.image = read_image(frame)
            # frame_attrs.depth = depth['depth']
            frame_attrs.points3d = depthmap_to_pts3d(depth['depth'], intrinsic[0, 0], intrinsic[0, 2], intrinsic[1, 2])

            if no_detection:
                frame_attrs.objects = {}
                video.frames[frame_idx] = frame_attrs
                continue

            dense_results = np.load(self.track_dir / scene_name / f"{frame.stem}_dense.npz")
            dense_tracks = dense_results['tracks']
            dense_tracks_vis = dense_results['visibility']
            T, H, W, _ = dense_tracks.shape
            grid_xy = meshgrid2d(W, H).reshape(H, W, 2).numpy()
            dense_flows = dense_tracks - grid_xy[None, ...]

            frame_obj_mask_merged = merge_object_masks(gsam_results['masks'])
            frame_bg_mask = erode_mask(~frame_obj_mask_merged, 15, 15)
            frame_bg_mask[:self.config.sky_height, :] = 0
            scene_flow = self.aggreage_flows(dense_flows[:, frame_bg_mask], dense_tracks_vis[:, frame_bg_mask])
            overall_flows.append(scene_flow)
            # frame_attrs.background_mask = frame_bg_mask
            if frame_idx == target_frame_idx:
                frame_attrs.background_tracks = (dense_flows[:, frame_bg_mask] + grid_xy[None, frame_bg_mask]).astype(int)
                frame_attrs.background_tracks_vis = dense_tracks_vis[:, frame_bg_mask]

            objects = {}
            object_counter = 0
            for obj_idx in range(len(gsam_results['boxes'])):
                obj_mask = adaptive_erode_mask(gsam_results['masks'][obj_idx], 4, 2, 4, 2)
                x1, y1, x2, y2 = list(map(int, gsam_results['boxes'][obj_idx]))
                object_attrs = ObjectClass()
                object_attrs.idx = object_counter
                object_attrs.box = [x1, y1, x2, y2]
                object_attrs.logit = float(gsam_results['logits'][obj_idx])
                object_attrs.label = str(gsam_results['labels'][obj_idx]).strip('.').title()
                object_attrs.mask = obj_mask if np.sum(obj_mask) > 10 else gsam_results['masks'][obj_idx]

                # if object class is not present in kitti classes or if the mask is small or if the object is very close to the camera
                if (object_attrs.label not in KITTI_CLASSES) or (np.sum(object_attrs.mask) < 50): continue
                if depth['depth'][int((y1+y2)/2), int((x1+x2)/2)] < 3.5: continue

                object_attrs.flow = self.aggreage_flows(dense_flows[:, obj_mask], dense_tracks_vis[:, obj_mask])
                object_attrs.tracks = (dense_flows[:, obj_mask] + grid_xy[obj_mask][None, ...]).astype(int)
                object_attrs.tracks_vis = dense_tracks_vis[:, obj_mask]

                # if the flow rate of background and object is not the same, this object is moving
                if np.abs(object_attrs.flow - scene_flow) > self.config.object_flow_threshold:
                    object_attrs.is_moving = True
                                        
                objects[object_counter] = object_attrs
                object_counter += 1
            frame_attrs.objects = objects
            video.frames[frame_idx] = frame_attrs
        video.image_size = (H, W)
        if np.mean(overall_flows) > self.config.scene_flow_threshold:
            video.moving_cam = True
        return video
    

    def run_pipeline(self, scene_name):
        video = self.load_information(scene_name)
        if video is not None:
            self.match_forward(video)
            self.get_global_object_matches(video)
            self.add_missing_objects(video)
            self.adj_frame_object_registration(video)
            self.compute_attributes(video)



if __name__ == '__main__':

    dataset_root = "PATH_TO_DATA"
    pseudo_root = "PATH_TO_DATA"
    labeler = KITTIPseudoLabeler(dataset_root, pseudo_root)
    for scene_name in tqdm(labeler):
        if labeler.check_already_labeled(scene_name): continue
        labeler.run_pipeline(scene_name)