import torch
import cv2
import numpy as np
from tqdm import tqdm
from pathlib import Path
from typing import List
import matplotlib.pyplot as plt
from plot.utils.io import read_image, read_image_paths
from plot.utils.misc import find_target_frame_idx, limit_frames
from plot.utils.geometry import meshgrid2d, depthmap_to_pts3d
from plot.datasets.utils import merge_object_masks, read_kitti_calib, scale_intrinsics
from plot.utils.processing import erode_mask, adaptive_erode_mask
from plot.datasets.kitti.config import BaseConfig
from plot.datasets.kitti.adjacent_base import PseudoLabeler, LocalObject, Frame, Video


class KITTIPseudoLabeler(PseudoLabeler):
    def __init__(self, dataset_root: Path, pseudo_root: Path) -> None:
        kitti_config = BaseConfig()
        kitti_config.dataset_root = dataset_root
        kitti_config.pseudo_root = pseudo_root
        super().__init__(kitti_config)


    def load_video_information(self, scene_name: str):
        intrinsic = read_kitti_calib(self.calib_dir / f"{scene_name}.txt")
        frame_lists_orig = read_image_paths(self.img_dir / scene_name)
        tgt_frame_idx_orig = find_target_frame_idx(frame_lists_orig, self.config.target_frame_idx)
        frame_lists = limit_frames(frame_lists_orig, tgt_frame_idx_orig, self.config.frame_limit//2)
        tgt_frame_idx = find_target_frame_idx(frame_lists, self.config.target_frame_idx)

        H, W, _ = read_image(frame_lists[0]).shape
        grid_xy = meshgrid2d(W, H).reshape(H, W, 2).numpy()

        video = Video()
        video.fps = 10
        video.name = scene_name 
        video.intrinsic = intrinsic
        video.num_frames = len(frame_lists)
        video.frame_lists = frame_lists
        video.image_size = (H, W)
        video.meshgrid = grid_xy
        video.tgt_frame_idx = tgt_frame_idx

        tgt_poses = []
        for frame_idx, frame_path in enumerate(frame_lists):
            # load depth and convert to point cloud
            depth = np.load(self.depth_dir / video.name / f"{frame_path.stem}.npz")['depth']
            image = read_image(frame_path)

            # frame class
            frame_attrs = Frame()
            frame_attrs.idx = frame_idx
            frame_attrs.name = frame_path.stem
            frame_attrs.image = image
            frame_attrs.points3d = depthmap_to_pts3d(depth, intrinsic[0, 0], intrinsic[0, 2], intrinsic[1, 2])
    
            # load detections and masks
            gsam_results = np.load(self.gsam_dir / video.name / f"{frame_path.stem}.npy", allow_pickle=True).item()
            gsam_results = self.filter_gsam_detections(gsam_results, depth)

            # load target frame poses
            tgt_poses.append(np.load(self.pose_dir / video.name / f"{frame_path.stem}.npy"))

            # load dense tracking results and convert to flows
            dense_results = np.load(self.track_dir / video.name / f"{frame_path.stem}_dense.npz")
            dense_tracks = dense_results['tracks']
            dense_tracks_vis = dense_results['visibility']
            dense_flows = dense_tracks - video.meshgrid[None, ...]

            # if no objects in this frame, skip further steps
            if len(gsam_results['masks']) >= 1:
                # calculate background flow
                frame_obj_mask = merge_object_masks(gsam_results['masks'])
                frame_bg_mask = erode_mask(~frame_obj_mask, 15, 15)
                # remove the sky region
                frame_bg_mask[:self.config.sky_height, :] = 0

                frame_attrs.background_flow = self.aggreage_flows(dense_flows[:, frame_bg_mask], dense_tracks_vis[:, frame_bg_mask])

                obj_counter = 0
                # decode 2d object attributes
                for obj_idx in range(len(gsam_results['boxes'])):
                    orig_obj_mask = gsam_results['masks'][obj_idx]
                    obj_mask = adaptive_erode_mask(orig_obj_mask, 3, 1, 3, 1)
                    if dense_tracks_vis[:, obj_mask].shape[0] < 10:
                        continue
                    if dense_tracks_vis[:, obj_mask].shape[1] < 1:
                        continue
                    x1, y1, x2, y2 = list(map(int, gsam_results['boxes'][obj_idx]))
                    local_object = LocalObject()
                    local_object.idx = obj_counter
                    local_object.box = [x1, y1, x2, y2]
                    local_object.area = (x2 - x1) * (y2 - y1)
                    local_object.logit = float(gsam_results['logits'][obj_idx])
                    local_object.label = str(gsam_results['labels'][obj_idx]).strip('.').title()
                    local_object.mask = obj_mask if np.sum(obj_mask) > 10 else orig_obj_mask
                    local_object.flow = self.aggreage_flows(dense_flows[:, obj_mask], dense_tracks_vis[:, obj_mask])
                    local_object.tracks = (dense_flows[:, obj_mask] + video.meshgrid[None, obj_mask]).astype(int)
                    local_object.tracks_vis = dense_tracks_vis[:, obj_mask]
                    frame_attrs.objects[obj_counter] = local_object 
                    obj_counter += 1
            else:
                frame_bg_mask = np.ones(video.image_size, dtype=bool)
                frame_bg_mask[:self.config.sky_height, :] = 0
                # TODO: improve the logic based on looping though each frame and only calculate on near depth points
                frame_attrs.background_flow = self.aggreage_flows(dense_flows[:, frame_bg_mask], dense_tracks_vis[:, frame_bg_mask])
            video.frames[frame_idx] = frame_attrs

        video.moving_cam = self.decide_moving_cam(video)

        # first frame poses
        for frame_idx in range(len(tgt_poses)):
            if frame_idx == 0:
                trf = np.eye(4)
            elif frame_idx == tgt_frame_idx:
                trf = np.linalg.inv(tgt_poses[0])
            else:
                trf = np.linalg.inv(tgt_poses[0]) @ tgt_poses[frame_idx]
            video.poses[frame_idx] = trf

        # adjacent frame poses
        for frame_idx in range(len(tgt_poses)):
            if frame_idx == 0:
                trf = np.eye(4)
            elif frame_idx == tgt_frame_idx:
                trf = np.linalg.inv(tgt_poses[frame_idx-1])
            elif frame_idx - tgt_frame_idx == 1:
                trf = tgt_poses[frame_idx]
            else:
                trf = np.linalg.inv(tgt_poses[frame_idx-1]) @ tgt_poses[frame_idx]
            video.rel_poses[frame_idx] = trf
        self.video = video
    

    def run_pipeline(self, scene_name):
        self.load_video_information(scene_name)
        self.adjacent_frame_registration()
        self.compute_attributes()



if __name__ == '__main__':

    dataset_root = "PATH_TO_DATA"
    pseudo_root = "PATH_TO_DATA"
    labeler = KITTIPseudoLabeler(dataset_root, pseudo_root)
    for scene_name in tqdm(labeler):
        if int(scene_name) != 15: continue
        labeler.run_pipeline(scene_name)
        break
