import torch
import numpy as np
import cv2
import shutil
import matplotlib.pyplot as plt
from pathlib import Path
from PIL import Image
from scipy.optimize import linear_sum_assignment
from plot.datasets.utils import *
from plot.utils.io import load_frames, read_image
from plot.utils.misc import find_target_frame_idx, limit_frames
from plot.utils.geometry import meshgrid2d, depthmap_to_pts3d, rigid_points_registration_numpy
from plot.utils.processing import *
from plot.utils.viz import visualize_pcd, visualize_bev_and_box3d
from plot.datasets.base import PseudoLabeler



class BaseConfigVideo:
    dataset_root            : Path 
    pseudo_root             : Path
    frame_limit             : int = 10      # in total = 21 frames, start=10, target=20, end=30
    target_frame_idx        : int = 20

    min_match_count         : int = 15
    distant_obj_threshold   : int = 50      # (in meters) drop object if the depth is larger than this threshold
    logit_threshold         : float = 0.35  # logit threshold to filter out low confidence objects
    match_iou_threshold     : float = 0.3   # mask or box matching threshold for associating tracked mask and original mask
    merge_iou_threshold     : float = 0.1   # iou threshold to merge two objects in the global object dict
    valid_obj_threshold     : int = 5       # minimum number of tracks in global object dict to consider an object valid and add to target frame

    # RANSAC registration parameters
    ransac_iters            : int = 600
    ransac_inlier_threshold : float = 0.1   # (in meters) (0.01m > 1cm)
    ransac_stop_threshold   : float = 0.1   # (in meters)

    frame_gap               : int = 1       # adjacent frame idx to use for direction calculation
    sky_height              : int = 100     # ignore pixels beyond this threshold (remove sky regions)

    # flow thresholds
    object_vis_threshold    : float = 0.6   # visibility threshold of a tracked object (used in deciding start and end)
    flow_vis_threshold      : float = 0.5   # visibility threshold of each tracked point
    scene_flow_threshold    : int = 10      # (in pixels) threshold to decide a stationary or moving cam
    object_flow_threshold   : int = 10      # (in pixels) threshold to decide an object as moving or not



class ObjectClass:
    idx             : int
    is_moving       : bool = False                          # whether the object is moving or not
    pseudo_lidar    : np.ndarray = np.zeros((100, 3))       # completed pseudo lidar in shape (N, 3)
    cleaned_lidar   : np.ndarray = np.zeros((100, 3))       # processed pseudo lidar in shape (N, 3)
    label           : str = "Name"                          # object label name
    box             : List[int] = [0., 0., 0., 0.]          # (x1, y1, x2, y2)
    logit           : float
    center          : List[float]
    size            : List[float] = [0., 0., 0.]            # object dimension in [L, H, W]
    yaw             : float
    alpha           : float
    mask            : np.ndarray                            # object mask in shape [image_height, image_width]
    flow            : float                                 # mean flow for the object
    flows           : np.ndarray                            # object flows for all frames (num_frames, num_pixels, 2)
    tracks          : np.ndarray                            # object tracks in other frames (N, M, 2)
    tracks_vis      : np.ndarray                            # tracks visibility in other frames (N, M)
    start_frame_idx : int = None
    end_frame_idx   : int = None
    visible_frames  : list = []
    ignore          : bool = False
    occluded        : bool = False
    trfs            : Dict[int, np.ndarray] = {}            # object transformations {adjacent_frame_idx: array[4,4]}
    occluded_frames : list = []


class FrameClass:
    idx             : int
    name            : str = "Name"                          # frame name
    image           : np.ndarray                            # rgb image
    depth           : np.ndarray                            # depth image
    background_mask : np.ndarray
    background_tracks: np.ndarray                           # shape (num_frames, num_points, 3)
    background_tracks_vis: np.ndarray                       # shape (num_frames, num_points)
    points3d        : np.ndarray                            # shape (H, W, 3)
    
    objects         : Dict[int, ObjectClass] = {}           # objects
    frame_trfs      : Dict[int, np.ndarray] = {}
    prev_matched_indices: dict = {}                         # previous matched indices for the object in the adjacent frame
    matched_indices : dict = {}                             # mapping between target frame object indices and current frame object indices 


class VideoClass:
    name            : str = "Name"                          # scene name
    intrinsic       : np.ndarray = np.eye(3)
    target_frame_idx: int = 20
    num_frames      : int = 41                              # num of frames exist in the scene
    fps             : int = 10                              # fps used to generate the video
    frames          : Dict[int, FrameClass] = {}                 # list of frame names
    moving_cam      : bool = False
    image_size      : list                                  # frame size in (H, W)
    bg_flows        : np.ndarray
    scene_flow      : float                                 # average flow of the video
    global_object_dict: dict = {}


class PseudoLabelerGlobal(PseudoLabeler):
    def __init__(self, config: BaseConfigVideo) -> None:
        super().__init__(config)
        self.track_dir = self.pseudo_root / "alltracker_nxn"
        self.pseudo_label_dir = self.pseudo_root / "pseudo_labels_nxn"
        self.pseudo_results_dir = self.pseudo_root / "image_results_nxn"
        self.pseudo_label_dir.mkdir(parents=True, exist_ok=True)
        self.pseudo_results_dir.mkdir(parents=True, exist_ok=True)

    def match_forward(self, video: VideoClass):
        H, W = video.image_size
        for src_frame_idx in range(video.num_frames - 1):
            tgt_frame_idx = src_frame_idx + 1
            # forward registration (registration from the first frame to the second frame)
            tracked_boxes = [tracked_to_box(object.tracks[tgt_frame_idx], object.tracks_vis[tgt_frame_idx] > self.config.flow_vis_threshold, W, H) 
                             for object in video.frames[src_frame_idx].objects.values()]  
            original_boxes = [object.box for object in video.frames[tgt_frame_idx].objects.values()]
            if len(original_boxes) < 1 or len(tracked_boxes) < 1: continue
            iou_mat = get_iou_matrix(np.stack(tracked_boxes, axis=0), np.stack(original_boxes, axis=0))
            tracked_indices, original_indices = linear_sum_assignment(iou_mat, maximize=True)
            matched_indices = {t_idx: o_idx for t_idx, o_idx in zip(tracked_indices, original_indices) 
                               if iou_mat[t_idx, o_idx] >= self.config.match_iou_threshold}            
            video.frames[tgt_frame_idx].prev_matched_indices = matched_indices


    def get_global_object_matches(self, video: VideoClass):
        global_object_dict = {} # {global_id: {frame_idx: local_object_idx}}
        global_id_counter = 0
        H, W = video.image_size

        # Step 1: Initialize with first frame objects
        for local_obj_idx in video.frames[0].objects.keys():
            global_object_dict[global_id_counter] = {0: local_obj_idx}
            global_id_counter += 1

        # Step 2-3: Match forward
        for frame_idx in range(1, video.num_frames):
            matched_prev = video.frames[frame_idx].prev_matched_indices # {t_idx: o_idx, ...}
            used_cur_indices = set()

            # Step 2: Extend existing tracks
            for global_id, track in global_object_dict.items():
                if (frame_idx - 1) in track:
                    prev_local_idx = track[frame_idx - 1]
                    if prev_local_idx in matched_prev:
                        cur_local_idx = matched_prev[prev_local_idx]
                        track[frame_idx] = cur_local_idx
                        used_cur_indices.add(cur_local_idx)

            # Step 3: Add unmatched new objects
            for cur_local_idx in video.frames[frame_idx].objects.keys():
                if cur_local_idx not in used_cur_indices:
                    if (frame_idx == video.target_frame_idx) or \
                        (video.frames[frame_idx].objects[cur_local_idx].logit > self.config.logit_threshold):
                        global_object_dict[global_id_counter] = {frame_idx: cur_local_idx}
                        global_id_counter += 1

        # Step 4: reconcile fragmented tracks
        merge_iou_th = self.config.merge_iou_threshold

        def track_mask_between(start_f, start_idx, end_f):
            """Track object mask from start_f to end_f using stored tracks."""
            obj = video.frames[start_f].objects[start_idx]
            coords = obj.tracks[end_f]
            vis = obj.tracks_vis[end_f] > self.config.flow_vis_threshold
            return tracked_to_mask(coords, vis, W, H)
        
        merged = True
        while merged:
            merged = False
            gids = sorted(global_object_dict.keys())

            for i in range(len(gids)):
                if gids[i] not in global_object_dict:
                    continue
                track_a = global_object_dict[gids[i]]
                last_f = max(track_a.keys())
                last_idx = track_a[last_f]
                label_a = video.frames[last_f].objects[last_idx].label

                for j in range(i+1, len(gids)):
                    if gids[j] not in global_object_dict:
                        continue
                    track_b = global_object_dict[gids[j]]
                    first_f = min(track_b.keys())
                    first_idx = track_b[first_f]
                    label_b = video.frames[first_f].objects[first_idx].label

                    # only consider non-overlapping tracks with temporal order
                    if last_f > first_f:
                        continue

                    # label consistency check
                    if label_a != label_b:
                        continue

                    # Track A's last mask forward to B's first frame
                    tracked_mask = track_mask_between(last_f, last_idx, first_f)
                    if tracked_mask is None or tracked_mask.sum() == 0:
                        continue

                    mask_b = video.frames[first_f].objects[first_idx].mask
                    iou = calculate_iou(tracked_mask, mask_b)

                    if iou >= merge_iou_th:
                        # merge track_b into track_a
                        for f, idx in track_b.items():
                            track_a[f] = idx
                        del global_object_dict[gids[j]]
                        merged = True
                        break
                if merged:
                    break

        # Step 5: reindex global IDs to be contiguous (0, 1, 2, ...)
        reindexed_dict = {}
        new_id = 0
        for old_id in sorted(global_object_dict.keys()):
            reindexed_dict[new_id] = global_object_dict[old_id]
            new_id += 1
        video.global_object_dict = reindexed_dict

        
    def add_missing_objects(self, video: VideoClass):
        """Add missing objects to the target frame by propagating from adjacent frames
        """
        tgt_frame_idx = video.target_frame_idx
        H, W = video.image_size

        next_object_id = max(video.frames[tgt_frame_idx].objects.keys(), default=-1) + 1
        for global_id, track in video.global_object_dict.items():
            if tgt_frame_idx in track or len(track) < self.config.valid_obj_threshold:
                continue

            candidate_frames = sorted(track.keys())
            prev_frame = max([f for f in candidate_frames if f < tgt_frame_idx], default=None)
            next_frame = min([f for f in candidate_frames if f > tgt_frame_idx], default=None)

            ref_frame = None
            ref_idx = None
            if prev_frame is not None and \
                (next_frame is None or tgt_frame_idx - prev_frame <= next_frame - tgt_frame_idx):
                ref_frame, ref_idx = prev_frame, track[prev_frame]
            elif next_frame is not None:
                ref_frame, ref_idx = next_frame, track[next_frame]
            
            if ref_frame is None:
                continue

            # tracking from reference frame to target frame
            obj = video.frames[ref_frame].objects[ref_idx]
            tracked_mask = tracked_to_mask(obj.tracks[tgt_frame_idx], obj.tracks_vis[tgt_frame_idx] > self.config.flow_vis_threshold, W, H)
            tracked_box = tracked_to_box(obj.tracks[tgt_frame_idx], obj.tracks_vis[tgt_frame_idx] > self.config.flow_vis_threshold, W, H)
            if tracked_mask is None: continue

            # Attribute construction
            new_obj = ObjectClass()
            new_obj.idx = next_object_id
            new_obj.is_moving = obj.is_moving
            new_obj.mask = tracked_mask.astype(bool)
            new_obj.label = obj.label
            new_obj.box = tracked_box
            new_obj.logit = getattr(obj, "logit", 1.0)
            new_obj.start_frame_idx = min(candidate_frames)
            new_obj.end_frame_idx = max(candidate_frames)
            new_obj.tracks = obj.tracks.copy()
            new_obj.tracks_vis = obj.tracks_vis.copy()

            # add to target frame and update global dict
            video.frames[tgt_frame_idx].objects[next_object_id] = new_obj
            video.global_object_dict[global_id][tgt_frame_idx] = next_object_id
            next_object_id += 1
    

    def adj_frame_object_registration(self, video: VideoClass):
        tgt_frame_idx = video.target_frame_idx
        for gid, obj_dict in video.global_object_dict.items():
            # if that global object does not appear in the target frame, skip it
            if tgt_frame_idx not in obj_dict: continue
            
            tgt_obj_idx = obj_dict[tgt_frame_idx]
            video.frames[tgt_frame_idx].objects[tgt_obj_idx].visible_frames = sorted(list(obj_dict.keys()))
            past_frames = sorted(f for f in obj_dict.keys() if f < tgt_frame_idx)
            next_frames = sorted((f for f in obj_dict if f > tgt_frame_idx), reverse=True)

            pseudo_lidars, trfs = [], {}
            # past to target
            if len(past_frames) > 0:
                past_pts3d, past_trfs = self.pseudo_lidar_completion(past_frames + [tgt_frame_idx], video, obj_dict)
                if past_pts3d is not None and len(past_pts3d) > 0:
                    pseudo_lidars.append(past_pts3d)
                    target_trfs = {}
                    frame_indices = sorted(list(past_trfs.keys()), reverse=True)    # [10, 9, ..., 0]
                    prev_trf = np.eye(4)
                    target_trfs = {}
                    for frame_idx in frame_indices:
                        if past_trfs[frame_idx] is None: continue
                        prev_trf = prev_trf @ past_trfs[frame_idx]
                        target_trfs[frame_idx] = prev_trf
                    trfs.update(target_trfs)
            
            # add target points
            tgt_mask = video.frames[tgt_frame_idx].objects[tgt_obj_idx].mask
            if tgt_mask.sum() > 0: # this shouldn't be checked since tgt_frame_idx must be in obj_dict
                pseudo_lidars.append(video.frames[tgt_frame_idx].points3d[tgt_mask, :])
                trfs.update({tgt_frame_idx: np.eye(4)})

            if len(next_frames) > 0:
                future_pts3d, future_trfs = self.pseudo_lidar_completion(next_frames + [tgt_frame_idx], video, obj_dict)
                if future_pts3d is not None and len(future_pts3d) > 0:
                    pseudo_lidars.append(future_pts3d)
                    target_trfs = {}
                    frame_indices = sorted(list(future_trfs.keys()))    # [10, ..., 20]
                    prev_trf = np.eye(4)
                    target_trfs = {}
                    for frame_idx in frame_indices:
                        if future_trfs[frame_idx] is None: continue
                        prev_trf = prev_trf @ future_trfs[frame_idx]
                        target_trfs[frame_idx] = prev_trf
                    trfs.update(target_trfs)
            
            if len(pseudo_lidars) < 1: 
                video.frames[tgt_frame_idx].objects[tgt_obj_idx].ignore = True
                continue
            pseudo_lidar, cleaned_lidar = merge_and_filter_pseudo_lidars(pseudo_lidars, video.frames[tgt_frame_idx].objects[tgt_obj_idx].label)
            video.frames[tgt_frame_idx].objects[tgt_obj_idx].trfs = trfs
            video.frames[tgt_frame_idx].objects[tgt_obj_idx].pseudo_lidar = pseudo_lidar
            video.frames[tgt_frame_idx].objects[tgt_obj_idx].cleaned_lidar = cleaned_lidar
            # visualize_pcd(cleaned_lidar)

    
    def pseudo_lidar_completion(self, frame_lists: List[int], video: VideoClass, obj_dict: dict = None):
        pseudo_lidar, trfs = None, {frame_idx: None for frame_idx in frame_lists[:-1]}
        for i in range(len(frame_lists)-1):
            src_frame_idx = frame_lists[i]
            tgt_frame_idx = frame_lists[i+1]

            src_obj_idx = obj_dict[src_frame_idx]

            src_pts2d = video.frames[src_frame_idx].objects[src_obj_idx].tracks[src_frame_idx]
            tgt_pts2d = video.frames[src_frame_idx].objects[src_obj_idx].tracks[tgt_frame_idx]

            # masking based on tracks visibility
            vis_mask = self.filter_visibility_and_confidence(tgt_pts2d, src_pts2d, 
                                                             video.frames[src_frame_idx].objects[src_obj_idx].tracks_vis[tgt_frame_idx], *video.image_size)
            src_pts3d = video.frames[src_frame_idx].points3d[src_pts2d[vis_mask, 1], src_pts2d[vis_mask, 0], :]
            tgt_pts3d = video.frames[tgt_frame_idx].points3d[tgt_pts2d[vis_mask, 1], tgt_pts2d[vis_mask, 0], :]
            
            # object is occluded in this frame
            if (len(tgt_pts3d) < 10) or (len(src_pts3d) < 10):
                video.frames[src_frame_idx].objects[src_obj_idx].occluded_frames.append(src_frame_idx)
                continue

            # outliers removal (only used for faster registration)
            _, src_noise_mask = remove_outliers_by_percentile(src_pts3d, 10, 85)
            _, tgt_noise_mask = remove_outliers_by_percentile(tgt_pts3d, 10, 85)
            noise_mask = tgt_noise_mask & src_noise_mask
            if len(noise_mask) > 20: tgt_pts3d, src_pts3d = tgt_pts3d[noise_mask], src_pts3d[noise_mask]
            
            trf, _, cost = ransac_registration(src_pts3d, tgt_pts3d, self.config.ransac_iters, self.config.ransac_inlier_threshold, self.config.ransac_stop_threshold)
            trfs[src_frame_idx] = trf
            # only aggregate if there is an original mask
            if cost < 0.5:
                src_obj_mask = video.frames[src_frame_idx].objects[src_obj_idx].mask
                src_pts3d = video.frames[src_frame_idx].points3d[src_obj_mask, :]
                src_pts3d_in_tgt = ((trf[:3, :3] @ src_pts3d.T) + trf[:3, -1:]).T

                if pseudo_lidar is None:
                    pseudo_lidar = src_pts3d_in_tgt
                else:
                    # send previous pseudo lidar to the current frame
                    pseudo_lidar = ((trf[:3, :3] @ pseudo_lidar.T) + trf[:3, -1:]).T
                    pseudo_lidar = np.concatenate([pseudo_lidar, src_pts3d_in_tgt], axis=0)
        return pseudo_lidar, trfs


    def frame_to_frame_registration(self, video: VideoClass, object: ObjectClass, target_frame_idx: int):
        # # search for adjacent frame idx
        adjacent_frame_idx = target_frame_idx + self.config.frame_gap
        # backward search
        if len(object.visible_frames) < 1:
            return adjacent_frame_idx, None
        while (adjacent_frame_idx > object.visible_frames[-1]) or (adjacent_frame_idx == target_frame_idx) or \
            (adjacent_frame_idx in object.occluded_frames):
            adjacent_frame_idx -= 2

        # forward search
        while (adjacent_frame_idx < 0) or (adjacent_frame_idx == target_frame_idx) or \
            (adjacent_frame_idx in object.occluded_frames):
            adjacent_frame_idx += 1
        
        # print(object.visible_frames, object.occluded_frames, target_frame_idx, adjacent_frame_idx)
        if adjacent_frame_idx not in list(object.trfs.keys()):
            return adjacent_frame_idx, None
        
        if not video.moving_cam:
            return adjacent_frame_idx, np.eye(4)

        src_points2d = video.frames[target_frame_idx].background_tracks[adjacent_frame_idx]
        tgt_points2d = video.frames[target_frame_idx].background_tracks[target_frame_idx]

        vis_mask = self.filter_visibility_and_confidence(tgt_points2d, src_points2d, 
                                                         video.frames[target_frame_idx].background_tracks_vis[adjacent_frame_idx], *video.image_size)
        tgt_points3d = video.frames[target_frame_idx].points3d[tgt_points2d[vis_mask, 1], tgt_points2d[vis_mask, 0], :]
        src_points3d = video.frames[adjacent_frame_idx].points3d[src_points2d[vis_mask, 1], src_points2d[vis_mask, 0], :]

        depth_mask = (tgt_points3d[:, 2] < self.config.distant_obj_threshold) & (src_points3d[:, 2] < self.config.distant_obj_threshold)
        R, t, s = rigid_points_registration_numpy(src_points3d[depth_mask], tgt_points3d[depth_mask])
        f2f_mat = np.eye(4)
        f2f_mat[:3, :3] = s * R
        f2f_mat[:3, -1] = t
        return adjacent_frame_idx, f2f_mat