import torch
import numpy as np
import cv2
import shutil
import matplotlib.pyplot as plt
from pathlib import Path
from PIL import Image
from scipy.optimize import linear_sum_assignment
from plot.datasets.utils import *
from plot.utils.io import load_frames, read_image
from plot.utils.misc import find_target_frame_idx, limit_frames
from plot.utils.geometry import meshgrid2d, depthmap_to_pts3d, rigid_points_registration_numpy
from plot.utils.processing import *
from plot.utils.viz import visualize_pcd, visualize_bev_and_box3d



PRIOR_CLASSES = ['Car', 'Pedestrian', 'Cyclist']

def center_change_detection(center1, center2, threshold=2):
    # if the car is moving in front of the ego car
    if abs(center2[1] - center1[1]) < threshold:
        return False
    return True


def compute_dir_from_object_motion(cur_pts: np.ndarray, adj_pts: np.ndarray, 
        trf_mat: np.ndarray, cur_index: int, adj_index: int):
    if trf_mat is not None:
        adj_pts = ((trf_mat[:3, :3] @ adj_pts.T) + trf_mat[:3, -1:]).T
    cur_center = np.min(cur_pts[:, [0, 2]], axis=0)
    adj_center = np.min(adj_pts[:, [0, 2]], axis=0)

    if adj_index > cur_index:
        yaw = -calculate_yaw_from_position_changes(cur_center, adj_center)
    else:
        yaw = -calculate_yaw_from_position_changes(adj_center, cur_center)
    return yaw, center_change_detection(cur_center, adj_center, threshold=1.5)


class BaseConfig:
    dataset_root            : Path 
    pseudo_root             : Path
    frame_limit             : int = 10      # in total = 21 frames, start=10, target=20, end=30
    target_frame_idx        : int = 20
    match_iou_threshold     : float = 0.3   # mask or box matching threshold for associating tracked mask and original mask
    distant_obj_threshold   : int = 50      # (in meters) drop object if the depth is larger than this threshold

    # RANSAC registration parameters
    ransac_iters            : int = 600
    ransac_inlier_threshold : float = 0.1   # (in meters) (0.01m > 1cm)
    ransac_stop_threshold   : float = 0.1   # (in meters)

    frame_gap               : int = 2       # adjacent frame idx to use for direction calculation
    sky_height              : int = 100     # ignore pixels beyond this threshold (remove sky regions)

    # flow thresholds
    object_vis_threshold    : float = 0.6   # visibility threshold of a tracked object (used in deciding start and end)
    flow_vis_threshold      : float = 0.5   # visibility threshold of each tracked point
    # bg_object_threshold     : float = 10    # (in pixels)
    scene_flow_threshold    : int = 10      # (in pixels) threshold to decide a stationary or moving cam
    object_flow_threshold   : int = 10      # (in pixels) threshold to decide an object as moving or not



class ObjectClass:
    idx             : int
    is_moving       : bool = False                          # whether the object is moving or not
    pseudo_lidar    : np.ndarray = np.zeros((100, 3))       # completed pseudo lidar in shape (N, 3)
    cleaned_lidar   : np.ndarray = np.zeros((100, 3))
    label           : str = "Name"                          # object label name
    box             : List[int] = [0., 0., 0., 0.]          # (x1, y1, x2, y2)
    logit           : float
    center          : List[float]
    size            : List[float] = [0., 0., 0.]            # object dimension in [L, H, W]
    yaw             : float
    alpha           : float
    mask            : np.ndarray                            # object mask in shape [image_height, image_width]
    flow            : float
    flows           : np.ndarray                            # object flows for all frames (num_frames, num_pixels, 2)
    tracks          : np.ndarray                            # object tracks in other frames (N, M, 2)
    tracks_vis      : np.ndarray                            # tracks visibility in other frames (N, M)
    start_frame_idx : int = None
    end_frame_idx   : int = None
    visible_frames  : list
    ignore          : bool = False
    occluded        : bool = False
    trfs            : Dict[int, np.ndarray] = {}            # object transformations {adjacent_frame_idx: array[4,4]}
    occluded_frames : list = []


class FrameClass:
    idx             : int
    name            : str = "Name"                          # frame name
    image           : np.ndarray                            # rgb image
    depth           : np.ndarray                            # depth image
    background_mask : np.ndarray
    background_tracks: np.ndarray                           # shape (num_points, 3)
    background_tracks_vis: np.ndarray                       # shape (num_points)
    points3d        : np.ndarray                            # shape (H, W, 3)
    objects         : Dict[int, ObjectClass] = {}           # objects
    matched_indices : dict = {}                             # mapping between target frame object indices and current frame object indices 
    # trfs            : Dict[int, np.ndarray]

class VideoClass:
    name            : str = "Name"                          # scene name
    intrinsic       : np.ndarray = np.eye(3)
    target_frame_idx: int = 20
    num_frames      : int = 41                              # num of frames exist in the scene
    fps             : int = 10                              # fps used to generate the video
    frames          : Dict[int, FrameClass] = {}                 # list of frame names
    moving_cam      : bool = False
    image_size      : list                                  # (H, W)
    bg_flows        : np.ndarray
    scene_flow      : float                                 # average flow of the video


class PseudoLabeler:
    def __init__(self, config: BaseConfig) -> None:
        self.config = config
        self.dataset_root = Path(config.dataset_root)
        self.pseudo_root = Path(config.pseudo_root)

        # original video infos
        self.img_dir = self.dataset_root / "frames"
        self.calib_dir = self.dataset_root / "calib"

        # foundation model outputs
        self.depth_dir = self.pseudo_root / "unidepthv1"
        self.gsam_dir = self.pseudo_root / "gsam_frames"
        self.track_dir = self.pseudo_root / "alltracker"

        # save directories
        self.pseudo_label_dir = self.pseudo_root / "pseudo_labels_val"
        self.pseudo_results_dir = self.pseudo_root / "image_results_val"
        self.pseudo_label_dir.mkdir(parents=True, exist_ok=True)
        self.pseudo_results_dir.mkdir(parents=True, exist_ok=True)

        # self.scenes = self.select_train_scenes()
        self.scenes = self.select_test_scenes()

    
    def pseudo_lidar_completion(self, frame_lists: List[int], video: VideoClass, object: ObjectClass,
                                tgt_frame_idx: int = None, tgt_obj_idx: int = None):
        pseudo_lidars, trfs = [], {frame_idx: None for frame_idx in frame_lists}
        for src_frame_idx in frame_lists:
            src_pts2d = object.tracks[src_frame_idx]
            if tgt_frame_idx is not None:
                tgt_pts2d = object.tracks[tgt_frame_idx]

            if src_frame_idx == tgt_frame_idx:
                trfs[src_frame_idx] = np.eye(4)
                pseudo_lidars.append(video.frames[tgt_frame_idx].points3d[src_pts2d[:, 1], src_pts2d[:, 0], :])
            else:
                # masking based on tracks visibility
                vis_mask = self.filter_visibility_and_confidence(tgt_pts2d, src_pts2d, object.tracks_vis[src_frame_idx], *video.image_size)
                tgt_pts3d = video.frames[tgt_frame_idx].points3d[tgt_pts2d[vis_mask, 1], tgt_pts2d[vis_mask, 0], :]
                src_pts3d = video.frames[src_frame_idx].points3d[src_pts2d[vis_mask, 1], src_pts2d[vis_mask, 0], :]
                # object is occluded in this frame
                if (len(tgt_pts3d) < 10) or (len(src_pts3d) < 10):
                    object.occluded_frames.append(src_frame_idx)
                    continue

                # outliers removal (only used for faster registration)
                _, tgt_noise_mask = remove_outliers_by_percentile(tgt_pts3d, 10, 85)
                _, src_noise_mask = remove_outliers_by_percentile(src_pts3d, 10, 85)
                noise_mask = tgt_noise_mask & src_noise_mask
                if len(noise_mask) > 20:
                    tgt_pts3d, src_pts3d = tgt_pts3d[noise_mask], src_pts3d[noise_mask]
                
                trf, _, cost = ransac_registration(src_pts3d, tgt_pts3d, self.config.ransac_iters, self.config.ransac_inlier_threshold, self.config.ransac_stop_threshold)

                # only aggregate if there is an original mask
                src_obj_idx = video.frames[src_frame_idx].matched_indices.get(tgt_obj_idx, None)
                if (src_obj_idx is not None) and (cost < 0.5):
                    src_obj_mask = video.frames[src_frame_idx].objects[src_obj_idx].mask
                    src_pts3d = video.frames[src_frame_idx].points3d[src_obj_mask, :]
                    src_pts3d_in_target = ((trf[:3, :3] @ src_pts3d.T) + trf[:3, -1:]).T
                    pseudo_lidars.append(src_pts3d_in_target)
                trfs[src_frame_idx] = trf

        pseudo_lidar, cleaned_lidar = merge_and_filter_pseudo_lidars(pseudo_lidars, object.label)
        object.trfs = trfs
        object.pseudo_lidar = pseudo_lidar
        object.cleaned_lidar = cleaned_lidar
    

    def target_frame_object_registration(self, video: VideoClass):
        for tgt_obj_idx, tgt_object in video.frames[video.target_frame_idx].objects.items():
            # forward and backward registration from the visible frames to the target frame
            self.pseudo_lidar_completion(tgt_object.visible_frames, video, tgt_object, video.target_frame_idx, tgt_obj_idx)


    def frame_to_frame_registration(self, video: VideoClass, object: ObjectClass, target_frame_idx: int):
        # # search for adjacent frame idx
        adjacent_frame_idx = target_frame_idx + self.config.frame_gap
        # backward search
        while (adjacent_frame_idx > object.visible_frames[-1]) or (adjacent_frame_idx == target_frame_idx) or \
            (adjacent_frame_idx in object.occluded_frames):
            adjacent_frame_idx -= 2

        # forward search
        while (adjacent_frame_idx < 0) or (adjacent_frame_idx == target_frame_idx) or \
            (adjacent_frame_idx in object.occluded_frames):
            adjacent_frame_idx += 1
        
        # print(object.visible_frames, object.occluded_frames, target_frame_idx, adjacent_frame_idx)
        if adjacent_frame_idx not in list(object.trfs.keys()):
            return adjacent_frame_idx, None

        tgt_points2d = video.frames[target_frame_idx].background_tracks
        src_points2d = video.frames[adjacent_frame_idx].background_tracks

        vis_mask = self.filter_visibility_and_confidence(tgt_points2d, src_points2d, video.frames[adjacent_frame_idx].background_tracks_vis, *video.image_size)
        tgt_points3d = video.frames[target_frame_idx].points3d[tgt_points2d[vis_mask, 1], tgt_points2d[vis_mask, 0], :]
        src_points3d = video.frames[adjacent_frame_idx].points3d[src_points2d[vis_mask, 1], src_points2d[vis_mask, 0], :]

        depth_mask = (tgt_points3d[:, 2] < self.config.distant_obj_threshold) & (src_points3d[:, 2] < self.config.distant_obj_threshold)
        f2f_mat, inliers, cost = ransac_registration(src_points3d[depth_mask], tgt_points3d[depth_mask], 300, 0.2, 0.2, n_samples=10)
        return adjacent_frame_idx, f2f_mat

    def compute_object_direction(self, video: VideoClass):
        tgt_frame_idx = video.target_frame_idx
        for tgt_obj_idx, object in video.frames[video.target_frame_idx].objects.items():
            # if the object is identified as moving, use direction from object motion
            yaw = None
            if object.ignore: continue
            if object.is_moving:
                adj_frame_idx, f2f_mat = self.frame_to_frame_registration(video, object, tgt_frame_idx)
                if f2f_mat is not None:
                    tgt_obj_pts = remove_outliers_by_percentile(object.cleaned_lidar, 10, 80)[0]
                    adj_obj_trf = np.linalg.inv(object.trfs[adj_frame_idx])
                    adj_obj_pts = ((adj_obj_trf[:3, :3] @ tgt_obj_pts.T) + adj_obj_trf[:3, -1:]).T
                    yaw, center_change = compute_dir_from_object_motion(tgt_obj_pts, adj_obj_pts, f2f_mat, tgt_frame_idx, adj_frame_idx)
                    yaw = self.yaw_correction(yaw, 15)
                    # if the object depth from the ego car does not change, assume stationary
                    if video.moving_cam and (not center_change):
                        object.is_moving = False
                # if there is only one mask for that object in the entire video
                else:
                    object.is_moving = False
            # if not moving, calculate direction from PCA
            if (not object.is_moving) or (yaw is None): 
                yaw = -compute_dir_with_PCA(object.pseudo_lidar)
                yaw = self.yaw_correction(yaw, 25)
            object.yaw = yaw
            object.alpha = self.calculate_alpha(yaw, object.box, video.intrinsic)

    def compute_object_size(self, video: VideoClass):
        for object in video.frames[video.target_frame_idx].objects.values():
            if np.sum(object.mask) > 650:
                # make the points to axis-aligned
                axis_aligned_lidar = (rotate_y(-object.yaw) @ object.cleaned_lidar.T).T
                # axis aligned center and size
                (L, H, W), (x3d, y3d, z3d) = compute_size_from_axis_aligned_points(axis_aligned_lidar)
                reasonable_size = 0.5 if object.label == 'Pedestrian' else 1.0
                if L < reasonable_size or H < reasonable_size or W < reasonable_size:
                    L, H, W = L_prior, H_prior, W_prior
                else:
                    if abs(H - H_prior) < 0.2:
                        scale = H / H_prior
                        L = L_prior * scale
                        W = W_prior * scale
                    else:
                        L, H, W = L_prior, H_prior, W_prior
                # recompute center from estimated size
                axis_aligned_box = create_ego_box3d(x3d, y3d, z3d, L, H, W, yaw=0)
                # rotate the box
                ego_aligned_box = (rotate_y(object.yaw) @ axis_aligned_box.T).T
                x3d, y3d, z3d = ego_aligned_box.mean(0)
            else:
                object_ignore = True

    def hungarian_match_masks(self, video: VideoClass):
        H, W, _ = video.frames[video.target_frame_idx].image.shape
        for frame_idx in range(video.num_frames):
            tracked_masks = [tracked_to_box(object.tracks[frame_idx], 
                                            object.tracks_vis[frame_idx] > self.config.flow_vis_threshold, W, H) 
                            for object in video.frames[video.target_frame_idx].objects.values()]
            original_masks = [obj.box for obj in video.frames[frame_idx].objects.values()]
            if (len(original_masks) < 1) or (len(tracked_masks) < 1): continue
            iou_mat = get_iou_matrix(np.stack(tracked_masks, axis=0), np.stack(original_masks, axis=0)) # (tracked, orig)
            tracked_indices, original_indices = linear_sum_assignment(iou_mat, maximize=True)
            matched_indices, unmatched_originals, unmatched_tracks = {}, [], []
            for t_idx, o_idx in zip(tracked_indices, original_indices):
                if (iou_mat[t_idx, o_idx] > self.config.match_iou_threshold):
                    matched_indices[t_idx] = o_idx
                elif frame_idx == video.target_frame_idx:
                    matched_indices[t_idx] = o_idx
                else:
                    unmatched_tracks.append(t_idx)
            video.frames[frame_idx].matched_indices = matched_indices

    def remove_duplicate_boxes(self, gsam_results: Dict[str, np.ndarray]):
        boxes = []
        for obj_idx in range(len(gsam_results['boxes'])):
            x1, y1, x2, y2 = list(map(int, gsam_results['boxes'][obj_idx]))
            boxes.append([x1, y1, x2, y2])
        if len(boxes) < 1: return {"masks": [], "boxes": [], "logits": [], "labels": []}
        boxes = np.stack(boxes, axis=0)
        iou_mat = get_iou_matrix(boxes, boxes)
        diagonal = np.diag_indices(iou_mat.shape[0])
        iou_mat[diagonal] = 0
        indices1, indices2 = linear_sum_assignment(iou_mat, maximize=True)
        duplicated_indices = []
        for id1, id2 in zip(indices1, indices2):
            if iou_mat[id1, id2] > 0.9:
                duplicated_indices.append(id2)

        masks, boxes, logits, labels = [], [], [], []
        for obj_idx in range(len(gsam_results['boxes'])):
            if obj_idx not in duplicated_indices:
                masks.append(gsam_results['masks'][obj_idx])
                boxes.append(gsam_results['boxes'][obj_idx])
                logits.append(gsam_results['logits'][obj_idx])
                labels.append(gsam_results['labels'][obj_idx])
        new_gsam_results = {
            "masks": np.stack(masks),
            "boxes": np.stack(boxes),
            "logits": np.stack(logits),
            "labels": np.stack(labels)
        }        
        return new_gsam_results

    def compute_attributes(self, video: VideoClass, visualize=True):
        self.compute_object_direction(video)
        self.compute_object_size(video)
        self.save_kitti_labels(video.frames[video.target_frame_idx].objects, self.pseudo_label_dir / f"{video.name}.txt")
        if visualize:
            save_path = self.pseudo_results_dir / f"{video.name}.png"
            self.visualize_pseudo_lidar_with_box(video, video.frames[video.target_frame_idx].objects, video.target_frame_idx, save_path)

    def save_kitti_labels(self, objects: Dict[int, ObjectClass], save_path: Path):
        labels, truncated = [], 0
        for object_idx, object in objects.items():
            if object.ignore: continue
            # if object.occluded: continue
            y = object.center[1] + object.size[1]/2     # bottom center in KITTI labels
            occluded = 1 if object.occluded else 0
            labels.append(
                f"{object.label} {truncated} {occluded} {object.alpha:.2f} " + 
                f"{object.box[0]:.2f} {object.box[1]:.2f} {object.box[2]:.2f} {object.box[3]:.2f} " +
                f"{object.size[1]:.2f} {object.size[2]:.2f} {object.size[0]:.2f} " +
                f"{object.center[0]:.2f} {y:.2f} {object.center[2]:.2f} {object.yaw:.2f}\n"
            )
        with open(save_path, "w") as f:
            f.writelines(labels)
    def visualize_pseudo_lidar_with_box(self, video: VideoClass, objects: Dict[int, ObjectClass], frame_idx: int, save_path: str):
        image = video.frames[frame_idx].image
        boxes, movings, pseudo_lidars = [], [], []
        for object_idx, object in objects.items():
            if object.ignore: continue
            if object.occluded: continue
            boxes.append([*object.center, *object.size, object.yaw])
            movings.append(object.is_moving)
            # pseudo_lidars.append(object.cleaned_lidar)
        visualize_bev_and_box3d(image, boxes, movings, video.intrinsic, None, save_path)

    def calculate_alpha(self, yaw: float, box: list, intrinsic: np.ndarray) -> float:
        return ry2alpha(yaw, (box[2]+box[0])/2, intrinsic[0, 0], intrinsic[0, 2])

    def create_visibility_mask(self, points2d, points2d_corr, frame_width, frame_height):
        vis_mask = (points2d[:, 0] < frame_width) & (points2d[:, 1] < frame_height) & \
                            (points2d_corr[:, 0] < frame_width) & (points2d_corr[:, 1] < frame_height)
        return vis_mask
    
    def filter_visibility_and_confidence(self, points1, points2, tracks_vis, H, W):
        mask = np.zeros(len(points1)).astype(bool)
        vis_mask = self.create_visibility_mask(points1, points2, W, H)
        mask[vis_mask == 1] = 1
        mask[vis_mask == 1] = tracks_vis[vis_mask] > self.config.flow_vis_threshold 
        return mask

    def aggreage_flows(self, flows, flows_vis):
        # flows in shape [N, M, 2]
        vis_mask = flows_vis > self.config.flow_vis_threshold
        return np.mean(np.abs(flows[vis_mask]))
    
    def select_train_scenes(self):
        with open(self.dataset_root / "ImageSets" / "train.txt") as f:
            scene_names = f.read().splitlines()
        return scene_names
    
    def select_test_scenes(self):
        with open(self.dataset_root / "ImageSets" / "val.txt") as f:
            scene_names = f.read().splitlines()
        return scene_names
    
    def check_already_labeled(self, scene_name):
        if (self.pseudo_label_dir / f"{scene_name}.txt").exists():
            return True
        return False
    
    def __len__(self):
        return len(self.scenes)
        
    def load_information(self, scene_name):
        pass

    def __getitem__(self, index) -> VideoClass:
        return self.scenes[index]
    
