import torch
import numpy as np
import cv2
import shutil
import matplotlib.pyplot as plt
from pathlib import Path
from PIL import Image
from scipy.optimize import linear_sum_assignment
from plot.utils.io import load_frames, read_image
from plot.utils.misc import find_target_frame_idx, limit_frames
from plot.utils.geometry import meshgrid2d, depthmap_to_pts3d, rigid_points_registration_numpy
from plot.utils.viz import visualize_pcd, visualize_bev_and_box3d
from typing import List, Dict, Any, Tuple
from plot.utils.processing import create_ego_box3d, rotate_y, tracked_to_box, get_iou_matrix, ry2alpha
from plot.datasets.utils import calculate_yaw_from_position_changes, remove_outliers_by_percentile, ransac_registration, merge_and_filter_pseudo_lidars, ransac_registration_gpu
from plot.datasets.utils import compute_dir_with_PCA, compute_size_from_axis_aligned_points



KITTI_CLASSES = ['Car']

def center_change_detection(center1, center2, threshold=2):
    # if the car is moving in front of the ego car
    if abs(center2[1] - center1[1]) < threshold:
        return False
    return True


def compute_dir_from_object_motion(cur_pts: np.ndarray, adj_pts: np.ndarray, 
        trf_mat: np.ndarray, cur_index: int, adj_index: int):
    if trf_mat is not None:
        adj_pts = ((trf_mat[:3, :3] @ adj_pts.T) + trf_mat[:3, -1:]).T
    cur_center = np.min(cur_pts[:, [0, 2]], axis=0)
    adj_center = np.min(adj_pts[:, [0, 2]], axis=0)

    if adj_index > cur_index:
        yaw = -calculate_yaw_from_position_changes(cur_center, adj_center)
    else:
        yaw = -calculate_yaw_from_position_changes(adj_center, cur_center)
    return yaw, center_change_detection(cur_center, adj_center, threshold=1.5)


class BaseConfig:
    dataset_root            : Path 
    pseudo_root             : Path
    frame_limit             : int = 10      # in total = 21 frames, start=10, target=20, end=30
    target_frame_idx        : int = 20
    match_iou_threshold     : float = 0.3   # mask or box matching threshold for associating tracked mask and original mask
    distant_obj_threshold   : int = 50      # (in meters) drop object if the depth is larger than this threshold

    # RANSAC registration parameters
    ransac_iters            : int = 600
    ransac_inlier_threshold : float = 0.1   # (in meters) (0.01m > 1cm)
    ransac_stop_threshold   : float = 0.1   # (in meters)

    frame_gap               : int = 2       # adjacent frame idx to use for direction calculation
    sky_height              : int = 100     # ignore pixels beyond this threshold (remove sky regions)

    # flow thresholds
    object_vis_threshold    : float = 0.6   # visibility threshold of a tracked object (used in deciding start and end)
    flow_vis_threshold      : float = 0.5   # visibility threshold of each tracked point
    # bg_object_threshold     : float = 10    # (in pixels)
    scene_flow_threshold    : int = 10      # (in pixels) threshold to decide a stationary or moving cam
    object_flow_threshold   : int = 10      # (in pixels) threshold to decide an object as moving or not



class ObjectClass:
    idx             : int
    is_moving       : bool                                  # whether the object is moving or not
    pseudo_lidar    : np.ndarray = np.zeros((100, 3))       # completed pseudo lidar in shape (N, 3)
    cleaned_lidar   : np.ndarray = np.zeros((100, 3))
    label           : str = "Name"                          # object label name
    box             : List[int] = [0., 0., 0., 0.]          # (x1, y1, x2, y2)
    logit           : float
    center          : List[float]
    size            : List[float] = [0., 0., 0.]            # object dimension in [L, H, W]
    yaw             : float
    alpha           : float
    mask            : np.ndarray                            # object mask in shape [image_height, image_width]
    flow            : float
    flows           : np.ndarray                            # object flows for all frames (num_frames, num_pixels, 2)
    tracks          : np.ndarray                            # object tracks in other frames (N, M, 2)
    tracks_vis      : np.ndarray                            # tracks visibility in other frames (N, M)
    start_frame_idx : int 
    end_frame_idx   : int 
    visible_frames  : list
    ignore          : bool 
    occluded        : bool 
    trfs            : Dict[int, np.ndarray]                 # object transformations {adjacent_frame_idx: array[4,4]}
    occluded_frames : list
    
    def __init__(self) -> None:
        self.trfs = {}
        self.occluded_frames = []
        self.start_frame_idx = None
        self.end_frame_idx = None
        self.ignore = False
        self.occluded = False
        self.is_moving = False


class FrameClass:
    idx             : int
    name            : str = "Name"                          # frame name
    image           : np.ndarray                            # rgb image
    points3d        : np.ndarray                            # shape (H, W, 3)
    objects         : Dict[int, ObjectClass]                # objects
    matched_indices : dict                                  # mapping between target frame object indices and current frame object indices 
    pose            : np.ndarray
    # trfs            : Dict[int, np.ndarray]

    def __init__(self) -> None:
        self.objects = {}
        self.matched_indices = {}

class VideoClass:
    name            : str = "Name"                          # scene name
    intrinsic       : np.ndarray
    target_frame_idx: int = 20
    num_frames      : int = 41                              # num of frames exist in the scene
    fps             : int = 10                              # fps used to generate the video
    frames          : Dict[int, FrameClass]                 # list of frame names
    moving_cam      : bool 
    image_size      : list                                  # (H, W)
    scene_flow      : float                                 # average flow of the video
    def __init__(self) -> None:
        self.frames = {}
        self.moving_cam = False


class PseudoLabeler:
    def __init__(self, config: BaseConfig) -> None:
        self.config = config
        self.dataset_root = Path(config.dataset_root)
        self.pseudo_root = Path(config.pseudo_root)

        # original video infos
        self.img_dir = self.dataset_root / "frames"
        self.calib_dir = self.dataset_root / "calib"

        # foundation model outputs
        self.depth_dir = self.pseudo_root / "unidepthv1"
        self.gsam_dir = self.pseudo_root / "gsam_frames"
        self.track_dir = self.pseudo_root / "alltracker"
        self.pose_dir = self.pseudo_root / "poses"

        # save directories
        self.pseudo_label_dir = self.pseudo_root / "pseudo_labels_val"
        self.pseudo_results_dir = self.pseudo_root / "image_results_val"
        self.pseudo_label_dir.mkdir(parents=True, exist_ok=True)
        self.pseudo_results_dir.mkdir(parents=True, exist_ok=True)

        # self.scenes = self.select_scenes(mode='train')
        self.scenes = self.select_scenes(mode='val')

    
    def pseudo_lidar_completion(self, frame_lists: List[int], video: VideoClass, object: ObjectClass,
                                tgt_frame_idx: int = None, tgt_obj_idx: int = None):
        pseudo_lidars, trfs = [], {frame_idx: None for frame_idx in frame_lists}
        for src_frame_idx in frame_lists:
            src_pts2d = object.tracks[src_frame_idx]
            tgt_pts2d = object.tracks[tgt_frame_idx]

            if src_frame_idx == tgt_frame_idx:
                trfs[src_frame_idx] = np.eye(4)
                pseudo_lidars.append(video.frames[tgt_frame_idx].points3d[src_pts2d[:, 1], src_pts2d[:, 0], :])
            else:
                # masking based on tracks visibility
                vis_mask = self.filter_visibility_and_confidence(tgt_pts2d, src_pts2d, object.tracks_vis[src_frame_idx], *video.image_size)
                tgt_pts3d = video.frames[tgt_frame_idx].points3d[tgt_pts2d[vis_mask, 1], tgt_pts2d[vis_mask, 0], :]
                src_pts3d = video.frames[src_frame_idx].points3d[src_pts2d[vis_mask, 1], src_pts2d[vis_mask, 0], :]
                # object is occluded in this frame
                if (len(tgt_pts3d) < 10) or (len(src_pts3d) < 10):
                    object.occluded_frames.append(src_frame_idx)
                    continue

                # outliers removal (only used for faster registration)
                _, tgt_noise_mask = remove_outliers_by_percentile(tgt_pts3d, 10, 85)
                _, src_noise_mask = remove_outliers_by_percentile(src_pts3d, 10, 85)
                noise_mask = tgt_noise_mask & src_noise_mask
                if np.sum(noise_mask) > 20:
                    tgt_pts3d, src_pts3d = tgt_pts3d[noise_mask], src_pts3d[noise_mask]
                
                trf, _, cost = ransac_registration(src_pts3d, tgt_pts3d, self.config.ransac_iters, self.config.ransac_inlier_threshold, self.config.ransac_stop_threshold)

                # only aggregate if there is an original mask
                src_obj_idx = video.frames[src_frame_idx].matched_indices.get(tgt_obj_idx, None)
                if (src_obj_idx is not None) and (cost < 1):
                    src_obj_mask = video.frames[src_frame_idx].objects[src_obj_idx].mask
                    src_pts3d = video.frames[src_frame_idx].points3d[src_obj_mask, :]
                    src_pts3d_in_target = ((trf[:3, :3] @ src_pts3d.T) + trf[:3, -1:]).T
                    pseudo_lidars.append(src_pts3d_in_target)
                trfs[src_frame_idx] = trf

        pseudo_lidar, cleaned_lidar = merge_and_filter_pseudo_lidars(pseudo_lidars, object.label)
        object.trfs = trfs
        object.pseudo_lidar = pseudo_lidar
        object.cleaned_lidar = cleaned_lidar
    

    def target_frame_object_registration(self, video: VideoClass):
        for tgt_obj_idx, tgt_object in video.frames[video.target_frame_idx].objects.items():
            # forward and backward registration from the visible frames to the target frame
            self.pseudo_lidar_completion(tgt_object.visible_frames, video, tgt_object, video.target_frame_idx, tgt_obj_idx)
    
    def search_valid_adjacent_frame_idx(self, tgt_frame_idx: int, object: ObjectClass):
        adj_frame_idx = tgt_frame_idx + self.config.frame_gap
        # search backward
        while (adj_frame_idx > object.visible_frames[-1]) or (adj_frame_idx == tgt_frame_idx) or \
            (adj_frame_idx in object.occluded_frames):
            adj_frame_idx -= 2

        # forward search
        while (adj_frame_idx < 0) or (adj_frame_idx == tgt_frame_idx) or \
            (adj_frame_idx in object.occluded_frames):
            adj_frame_idx += 1
        
        if adj_frame_idx not in list(object.trfs.keys()):
            return None
        return adj_frame_idx
        

    def frame_to_frame_registration(self, video: VideoClass, object: ObjectClass, target_frame_idx: int):
        # search for adjacent frame idx
        adjacent_frame_idx = self.search_valid_adjacent_frame_idx(target_frame_idx, object)
        if adjacent_frame_idx is None:
            return None, None
        pose = video.frames[adjacent_frame_idx].pose
        return adjacent_frame_idx, pose

    def compute_object_direction(self, video: VideoClass):
        tgt_frame_idx = video.target_frame_idx
        for tgt_obj_idx, object in video.frames[video.target_frame_idx].objects.items():
            # if the object is identified as moving, use direction from object motion
            yaw = None
            if object.ignore: continue
            if object.is_moving:
                adj_frame_idx, f2f_mat = self.frame_to_frame_registration(video, object, tgt_frame_idx)
                if f2f_mat is not None:
                    tgt_obj_pts = remove_outliers_by_percentile(object.cleaned_lidar, 10, 80)[0]
                    adj_obj_trf = np.linalg.inv(object.trfs[adj_frame_idx])
                    adj_obj_pts = ((adj_obj_trf[:3, :3] @ tgt_obj_pts.T) + adj_obj_trf[:3, -1:]).T
                    yaw, center_change = compute_dir_from_object_motion(tgt_obj_pts, adj_obj_pts, f2f_mat, tgt_frame_idx, adj_frame_idx)
                    yaw = self.yaw_correction(yaw, 15)
                    # if the object depth from the ego car does not change, assume stationary
                    if video.moving_cam and (not center_change):
                        object.is_moving = False
                # if there is only one mask for that object in the entire video
                else:
                    object.is_moving = False
            # if not moving, calculate direction from PCA
            if (not object.is_moving) or (yaw is None): 
                yaw = -compute_dir_with_PCA(object.pseudo_lidar)
                yaw = self.yaw_correction(yaw, 25)
            object.yaw = yaw
            object.alpha = self.calculate_alpha(yaw, object.box, video.intrinsic)

    def compute_object_size(self, video: VideoClass):
        for object in video.frames[video.target_frame_idx].objects.values():
            if np.sum(object.mask) > 650:
                # make the points to axis-aligned
                axis_aligned_lidar = (rotate_y(-object.yaw) @ object.cleaned_lidar.T).T
                # axis aligned center and size
                (L, H, W), (x3d, y3d, z3d) = compute_size_from_axis_aligned_points(axis_aligned_lidar)
                reasonable_size = 0.5 if object.label == 'Pedestrian' else 1.0
                if L < reasonable_size or H < reasonable_size or W < reasonable_size:
                    L, H, W = L_prior, H_prior, W_prior
                else:
                    if abs(H - H_prior) < 0.2:
                        scale = H / H_prior
                        L = L_prior * scale
                        W = W_prior * scale
                    else:
                        L, H, W = L_prior, H_prior, W_prior
                # recompute center from estimated size
                axis_aligned_box = create_ego_box3d(x3d, y3d, z3d, L, H, W, yaw=0)
                # rotate the box
                ego_aligned_box = (rotate_y(object.yaw) @ axis_aligned_box.T).T
                x3d, y3d, z3d = ego_aligned_box.mean(0)
            # if not enough points, center from pseudo lidar is not reliable, just get the 2d center point
            else:   # for more improvement, you can also change it to mask center instead of box center
                object.ignore = True    # if you want to ignore that object

    def hungarian_match_masks(self, video: VideoClass):
        H, W, _ = video.frames[video.target_frame_idx].image.shape
        for frame_idx in range(video.num_frames):
            tracked_boxes = [tracked_to_box(object.tracks[frame_idx], 
                                            object.tracks_vis[frame_idx] > self.config.flow_vis_threshold, W, H) 
                            for object in video.frames[video.target_frame_idx].objects.values()]
            original_boxes = []
            for object in video.frames[frame_idx].objects.values():
                x1, y1, x2, y2 = object.box
                offset_x = 5 if x2 - x1 >= 10 else 3
                offset_y = 5 if y2 - y1 >= 10 else 3
                original_boxes.append([int(x1+offset_x), int(y1+offset_y), int(x2-offset_x), int(y2-offset_y)])
            if (len(original_boxes) < 1) or (len(tracked_boxes) < 1): continue
            iou_mat = get_iou_matrix(np.stack(tracked_boxes, axis=0), np.stack(original_boxes, axis=0), add1=True) # (tracked, orig)
            tracked_indices, original_indices = linear_sum_assignment(iou_mat, maximize=True)
            matched_indices, unmatched_originals, unmatched_tracks = {}, [], []
            for t_idx, o_idx in zip(tracked_indices, original_indices):
                if (iou_mat[t_idx, o_idx] > self.config.match_iou_threshold):
                    matched_indices[t_idx] = o_idx
                elif frame_idx == video.target_frame_idx:
                    matched_indices[t_idx] = o_idx
                else:
                    unmatched_tracks.append(t_idx)
            video.frames[frame_idx].matched_indices = matched_indices

    ###################################################
    # NOTHING TO CHANGE THIS CLASS
    ###################################################
    def filter_gsam_detections(self, gsam_results: Dict[str, np.ndarray], depth_map: np.ndarray):
        boxes = []
        remaps = {}
        obj_counter = 0
        filtered_gsam_results = {"masks": [], "boxes": [], "logits": [], "labels": []}
        for obj_idx in range(len(gsam_results['boxes'])):
            x1, y1, x2, y2 = list(map(int, gsam_results['boxes'][obj_idx]))
            cx, cy = (x1+x2)//2, (y1+y2)//2
            center_depth = depth_map[cy, cx]
            area = (x2 - x1) * (y2 - y1)

            # filter only kitt360 classes
            if str(gsam_results['labels'][obj_idx]).strip('.').title() not in KITTI_CLASSES: continue

            # filter small objects (based on masks size)
            if area < 150: continue

            # filter very near or far objects (based on depth)
            if center_depth < 3.5 or center_depth > 80: continue

            # else aggregate the boxes
            boxes.append([x1, y1, x2, y2])
            remaps[obj_counter] = obj_idx
            obj_counter += 1

        if len(boxes) > 0:
            # hungarian match and find duplicate
            boxes = np.stack(boxes, axis=0)
            iou_mat = get_iou_matrix(boxes, boxes)
            diagonal = np.diag_indices(iou_mat.shape[0])
            iou_mat[diagonal] = 0
            indices1, indices2 = linear_sum_assignment(iou_mat, maximize=True)
            duplicated_indices = [id2 for id1, id2 in zip(indices1, indices2) if iou_mat[id1, id2] > 0.6]
            masks, boxes, logits, labels = [], [], [], []
            for obj_idx in range(len(iou_mat)):
                orig_index = remaps[obj_idx]
                if obj_idx not in duplicated_indices:
                    masks.append(gsam_results['masks'][orig_index])
                    boxes.append(gsam_results['boxes'][orig_index])
                    logits.append(gsam_results['logits'][orig_index])
                    labels.append(gsam_results['labels'][orig_index])
            if len(masks) < 1:
                return filtered_gsam_results 
            filtered_gsam_results["masks"] = np.stack(masks)
            filtered_gsam_results["boxes"] = np.stack(boxes)
            filtered_gsam_results["logits"] = np.stack(logits)
            filtered_gsam_results["labels"] = np.stack(labels)   
        return filtered_gsam_results

    ###################################################
    # NOTHING TO CHANGE THIS CLASS
    ###################################################
    def compute_attributes(self, video: VideoClass, visualize=True):
        self.compute_object_direction(video)
        self.compute_object_size(video)
        self.save_kitti_labels(video.frames[video.target_frame_idx].objects, self.pseudo_label_dir / f"{video.name}.txt")
        if visualize:
            save_path = self.pseudo_results_dir / f"{video.name}.png"
            self.visualize_pseudo_lidar_with_box(video, video.frames[video.target_frame_idx].objects, video.target_frame_idx, save_path)

    ###################################################
    # NOTHING TO CHANGE THIS CLASS
    ###################################################
    def save_kitti_labels(self, objects: Dict[int, ObjectClass], save_path: Path):
        labels, truncated = [], 0
        for object_idx, object in objects.items():
            if object.ignore: continue
            # if object.occluded: continue
            y = object.center[1] + object.size[1]/2     # bottom center in KITTI labels
            occluded = 1 if object.occluded else 0
            labels.append(
                f"{object.label} {truncated} {occluded} {object.alpha:.2f} " + 
                f"{object.box[0]:.2f} {object.box[1]:.2f} {object.box[2]:.2f} {object.box[3]:.2f} " +
                f"{object.size[1]:.2f} {object.size[2]:.2f} {object.size[0]:.2f} " +
                f"{object.center[0]:.2f} {y:.2f} {object.center[2]:.2f} {object.yaw:.2f}\n"
            )
        with open(save_path, "w") as f:
            f.writelines(labels)
    def visualize_pseudo_lidar_with_box(self, video: VideoClass, objects: Dict[int, ObjectClass], frame_idx: int, save_path: str):
        image = video.frames[frame_idx].image
        boxes, movings, pseudo_lidars = [], [], []
        for object_idx, object in objects.items():
            if object.ignore: continue
            if object.occluded: continue
            boxes.append([*object.center, *object.size, object.yaw])
            movings.append(object.is_moving)
            # pseudo_lidars.append(object.cleaned_lidar)
        visualize_bev_and_box3d(image, boxes, movings, video.intrinsic, None, save_path)

    def yaw_correction(self, yaw, offset=10):
        # facing south (90 deg)
        if np.deg2rad(90-offset) < yaw < np.deg2rad(90+offset):
            yaw = np.deg2rad(90)
        # facing north
        elif np.deg2rad(-90-offset) < yaw < np.deg2rad(-90+offset):
            yaw = np.deg2rad(-90)
        # facing east
        elif np.deg2rad(-offset) < yaw < np.deg2rad(offset):
            yaw = 0
        # facing west
        elif (np.deg2rad(-180+offset) > yaw) or (yaw > np.deg2rad(180-offset)):
            yaw = np.deg2rad(180) 
        return yaw

    def calculate_alpha(self, yaw: float, box: list, intrinsic: np.ndarray) -> float:
        return ry2alpha(yaw, (box[2]+box[0])/2, intrinsic[0, 0], intrinsic[0, 2])

    def create_visibility_mask(self, points2d, points2d_corr, frame_width, frame_height):
        vis_mask = (points2d[:, 0] < frame_width) & (points2d[:, 1] < frame_height) & \
                            (points2d_corr[:, 0] < frame_width) & (points2d_corr[:, 1] < frame_height)
        return vis_mask
    
    def filter_visibility_and_confidence(self, points1, points2, tracks_vis, H, W):
        mask = np.zeros(len(points1)).astype(bool)
        vis_mask = self.create_visibility_mask(points1, points2, W, H)
        mask[vis_mask == 1] = 1
        mask[vis_mask == 1] = tracks_vis[vis_mask] > self.config.flow_vis_threshold 
        return mask

    def aggreage_flows(self, flows, flows_vis):
        # flows in shape [N, M, 2]
        vis_mask = flows_vis > self.config.flow_vis_threshold
        return np.mean(np.abs(flows[vis_mask]))
    
    def select_scenes(self, mode="train"):
        with open(self.dataset_root / "ImageSets" / f"{mode}.txt") as f:
            scene_names = f.read().splitlines()
        return scene_names
    
    def check_already_labeled(self, scene_name):
        if (self.pseudo_label_dir / f"{scene_name}.txt").exists():
            return True
        return False
    
    def __len__(self):
        return len(self.scenes)
        
    def load_information(self, scene_name):
        pass

    def __getitem__(self, index) -> VideoClass:
        return self.scenes[index]
    
