import argparse
import cv2
import glob
import matplotlib
matplotlib.use("Agg")           # must be set before importing pyplot
import matplotlib.pyplot as plt
import numpy as np
import os
from PIL import Image
import sys
import math
import random
import json
import torch
import pickle
from numpy.typing import ArrayLike

from objects.missing_and_occ import (
    disappeared_suddenly, get_missing_per_scene, get_missing_per_agent
)

# ----------------------------- Geometry utils -----------------------------

def gt_2_ego(gt_xy, heading=None, k_ahead=1, min_step=1):
    """
    Convert a global (x, y) trajectory to a local ego frame whose origin is the first point.
    Returns:
        gt_local_xy: (T, 2) in ego frame (x, y) for downstream encoding
        gt_local_yx: (T, 2) swapped version (y, x) for visualization if needed
        theta:       heading angle (rad) used for rotation
    """
    gt = torch.from_numpy(gt_xy)  # (T, 2)

    origin = gt[0]
    rel_gt = gt - origin

    k = min(k_ahead, len(gt) - 1)
    v = gt[k] - gt[0]
    if torch.linalg.norm(v) < min_step:
        for j in range(1, len(gt)):
            v = gt[j] - gt[0]
            if torch.linalg.norm(v) >= min_step:
                break

    if heading is None:
        heading = v
        theta = torch.atan2(heading[1], heading[0])  # angle w.r.t. +x
    else:
        theta = heading

    R = torch.tensor([
        [torch.cos(theta), -torch.sin(theta)],
        [torch.sin(theta),  torch.cos(theta)]
    ])  # CCW rotation

    gt_local = torch.matmul(rel_gt, R)
    gt_local_xy = gt_local.clone()
    gt_local[:, [0, 1]] = gt_local[:, [1, 0]]
    gt_local[:, 0] = -gt_local[:, 0]
    gt_yx = gt_local.numpy()

    return gt_local_xy, gt_yx, theta


def ego_y_2_x(ego_xz):
    arr = np.asarray(ego_xz, float)
    t = torch.from_numpy(arr)
    t[:, [0, 1]] = t[:, [1, 0]]
    t[:, 1] = -t[:, 1]
    return t.numpy()


def umeyama_2d(X: np.ndarray, Y: np.ndarray, with_scale=True):
    """
    2D Umeyama alignment.
    X, Y: (N, 2) matched points (N >= 2).
    Returns s, R(2x2), t(2,) such that Y ≈ s * R * X + t.
    """
    muX, muY = X.mean(0), Y.mean(0)
    Xc, Yc = X - muX, Y - muY

    H = Xc.T @ Yc / len(X)
    U, _, Vt = np.linalg.svd(H)
    R = Vt.T @ U.T
    if np.linalg.det(R) < 0:
        Vt[1] *= -1
        R = Vt.T @ U.T

    if with_scale:
        varX = (Xc ** 2).sum() / len(X)
        s = (Yc * (R @ Xc.T).T).sum() / (len(X) * varX)
    else:
        s = 1.0

    t = muY - s * R @ muX
    return s, R, t


def slam_align_to_gt_fix_origin(
    pred_xyz: ArrayLike,   # (N, 3) SLAM: (x_right, y, z_forward)
    gt_local: ArrayLike,   # (N, 2) from gt_2_ego; first frame is (0,0)
    with_scale: bool = True,
):
    """
    Estimate scale s and rotation R to align SLAM trajectory to the GT local frame,
    keeping the first frame at the origin. Returns aligned (N, 2), s, R.
    """
    P = np.asarray(pred_xyz, float)[:, [0, 2]]  # use (x, z)
    Q = np.asarray(gt_local, float)

    P_rel = P - P[0]
    Q_rel = Q

    H = P_rel.T @ Q_rel / len(P_rel)
    U, _, Vt = np.linalg.svd(H)
    R = Vt.T @ U.T
    if np.linalg.det(R) < 0:
        Vt[1] *= -1
        R = Vt.T @ U.T

    if with_scale:
        varP = (P_rel ** 2).sum() / len(P_rel)
        s = (Q_rel * (R @ P_rel.T).T).sum() / (len(P_rel) * varP)
    else:
        s = 1.0

    aligned = (s * (R @ P_rel.T)).T
    return aligned, s, R


def apply_sr(trajectory_xz, s, R):
    """
    Apply scale s and rotation R to a trajectory given in (x_right, z_forward).
    Returns aligned coordinates in the GT local frame.
    """
    arr = np.asarray(trajectory_xz, float)
    aligned = (s * (R @ arr.T)).T
    return aligned


# ----------------------------- Smoothing -----------------------------------
from scipy.signal import savgol_filter

def smooth_traj_sg(xy, dt=0.1, win_sec=0.4, poly=3):
    """
    Savitzky–Golay smoothing for a 2D trajectory.
    """
    xy = np.asarray(xy, float)
    T = xy.shape[0]

    k = int(round(win_sec / dt))
    if k % 2 == 0:
        k += 1
    if k > T:
        k = T if T % 2 == 1 else T - 1
    if T == 4:
        k = 3
        poly = 1
    elif poly >= k - 1:
        poly = max(1, k - 2)

    if k < 3:
        return xy

    xy_s = savgol_filter(xy, window_length=k, polyorder=poly, axis=0, mode="interp")
    return xy_s


# ----------------------------- IoU helper ----------------------------------

def max_iou_box(query_box: np.ndarray, boxes: np.ndarray, return_index: bool = False):
    """
    Return the box with maximum IoU to `query_box`.
    """
    query_box = np.asarray(query_box, dtype=float)
    boxes = np.asarray(boxes, dtype=float)
    if boxes.size == 0:
        raise ValueError("`boxes` is empty.")

    ix1 = np.maximum(query_box[0], boxes[:, 0])
    iy1 = np.maximum(query_box[1], boxes[:, 1])
    ix2 = np.minimum(query_box[2], boxes[:, 2])
    iy2 = np.minimum(query_box[3], boxes[:, 3])

    inter_w = np.clip(ix2 - ix1, 0, None)
    inter_h = np.clip(iy2 - iy1, 0, None)
    inter = inter_w * inter_h

    query_area = (query_box[2] - query_box[0]) * (query_box[3] - query_box[1])
    boxes_area = (boxes[:, 2] - boxes[:, 0]) * (boxes[:, 3] - boxes[:, 1])

    iou = inter / (query_area + boxes_area - inter + 1e-8)

    best_idx = int(iou.argmax())
    best_iou = float(iou[best_idx])
    best_box = boxes[best_idx]

    return (best_box, best_iou, best_idx) if return_index else (best_box, best_iou)


# ----------------------------- Visualization -------------------------------

import matplotlib as mpl
import matplotlib.gridspec as gridspec

def visualize_trajectory(
    trajectory,
    outdir,
    gt=None,
    gt_pred=None,
    others=None,
    others_gt=None,
    map_scale=1,
    car_length=4.5,
    car_width=2.0,
    draw_polygon=0,
    yaw0=0,
):
    locX = trajectory[:, 0]
    locZ = trajectory[:, 1]

    mpl.rc("figure", facecolor="white")
    plt.style.use("seaborn-v0_8-whitegrid")

    fig, axes = plt.subplots(nrows=2, ncols=2, figsize=(10, 8), dpi=100)
    traj_main_plt = axes[0, 0]
    gt_plt = axes[0, 1]
    ego_plt = axes[1, 0]
    others_plt = axes[1, 1]
    _ = gridspec.GridSpec(1, 1)
    colors = plt.rcParams["axes.prop_cycle"].by_key()["color"]

    traj_main_plt.set_title("Trajectory (Z, X)", y=1.0)

    if draw_polygon:
        draw_car(traj_main_plt, locX, locZ, car_length, car_width, color=colors[0], zorder=1)
        draw_car(ego_plt, locX, locZ, car_length, car_width, color=colors[0], zorder=1)

    traj_main_plt.plot(locX, locZ, ".-", label="Ego_of_gen_video", zorder=6, linewidth=1, markersize=4, color=colors[0])
    ego_plt.plot(locX, locZ, ".-", label="Ego_of_gen_video", zorder=6, linewidth=1, markersize=4, color=colors[0])

    max_x, min_x = np.max(locX), np.min(locX)
    max_y, min_y = np.max(locZ), np.min(locZ)

    if gt is not None:
        if draw_polygon:
            draw_car(traj_main_plt, gt[:, 0], gt[:, 1], car_length, car_width, color=colors[1], zorder=1)
            draw_car(gt_plt, gt[:, 0], gt[:, 1], car_length, car_width, color=colors[1], zorder=1)

        traj_main_plt.plot(gt[:, 0], gt[:, 1], ".-", label="Ego_of_gt_traj", zorder=1, linewidth=1, markersize=4, color=colors[1])
        gt_plt.plot(gt[:, 0], gt[:, 1], ".-", label="Ego_of_gt_traj", zorder=1, linewidth=1, markersize=4, color=colors[1])

        max_x = max(max_x, np.max(gt[:, 0]))
        min_x = min(min_x, np.min(gt[:, 0]))
        max_y = max(max_y, np.max(gt[:, 1]))
        min_y = min(min_y, np.min(gt[:, 1]))

    if gt_pred is not None:
        if draw_polygon:
            draw_car(traj_main_plt, gt_pred[:, 0], gt_pred[:, 1], car_length, car_width, color=colors[1], zorder=1)
            draw_car(gt_plt, gt_pred[:, 0], gt_pred[:, 1], car_length, car_width, color=colors[1], zorder=1)

        gt_plt.plot(gt_pred[:, 0], gt_pred[:, 1], ".-", label="Ego_of_gt_video", zorder=1, linewidth=1, markersize=4, color=colors[5])

        max_x = max(max_x, np.max(gt_pred[:, 0]))
        min_x = min(min_x, np.min(gt_pred[:, 0]))
        max_y = max(max_y, np.max(gt_pred[:, 1]))
        min_y = min(min_y, np.min(gt_pred[:, 1]))

    if others is not None:
        for idx, other in enumerate(others):
            if idx == 0:
                traj_main_plt.plot(other[:, 0], other[:, 1], ".-", label="Others", zorder=3, linewidth=1, markersize=4, color=colors[2])
                others_plt.plot(other[:, 0], other[:, 1], ".-", label="Others", zorder=3, linewidth=1, markersize=4, color=colors[2])
            else:
                traj_main_plt.plot(other[:, 0], other[:, 1], ".-", zorder=3, linewidth=1, markersize=4, color=colors[2])
                others_plt.plot(other[:, 0], other[:, 1], ".-", zorder=3, linewidth=1, markersize=4, color=colors[2])

            max_x = max(max_x, np.max(other[:, 0]))
            min_x = min(min_x, np.min(other[:, 0]))
            max_y = max(max_y, np.max(other[:, 1]))
            min_y = min(min_y, np.min(other[:, 1]))

    if others_gt is not None:
        count = 0
        for o_gt in others_gt:
            if abs(o_gt[-1, 0] - o_gt[0, 0]) > 10:
                count += 1
                if count == 2:
                    if draw_polygon:
                        draw_car(traj_main_plt, o_gt[:, 1], o_gt[:, 0], car_length, car_width, color=colors[4], zorder=1)
                        draw_car(gt_plt, o_gt[:, 1], o_gt[:, 0], car_length, car_width, color=colors[4], zorder=1)

                    traj_main_plt.plot(o_gt[:, 1], o_gt[:, 0], ".-", label="FocalGT", zorder=2, linewidth=1, markersize=4, color=colors[4])
                    gt_plt.plot(o_gt[:, 1], o_gt[:, 0], ".-", label="FocalGT", zorder=2, linewidth=1, markersize=4, color=colors[4])
                    max_x = max(max_x, np.max(o_gt[:, 1]))
                    min_x = min(min_x, np.min(o_gt[:, 1]))
                    max_y = max(max_y, np.max(o_gt[:, 0]))
                    min_y = min(min_y, np.min(o_gt[:, 0]))

    min_x = min(-30, min_x)
    max_x = max(30, max_x)
    x_ = max(abs(min_x), abs(max_x))
    min_y = min(-10, min_y)
    max_y = max(50, max_y)

    def set_plot(ax):
        ax.set_xlabel("X")
        ax.set_ylabel("Z")
        ax.set_xlim([-x_ - 3, x_ + 3])
        ax.set_ylim([min_y - 3, max_y + 3])
        handles, labels = ax.get_legend_handles_labels()
        if handles:
            ax.legend(loc=1, title="Legend", borderaxespad=0.0, fontsize="medium", frameon=True)

    set_plot(traj_main_plt)
    set_plot(gt_plt)
    set_plot(ego_plt)
    set_plot(others_plt)

    plt.savefig(outdir)
    plt.close(fig)


# NOTE: draw_car is referenced; keep your implementation/import in your project if needed.


# ----------------------------- Main ----------------------------------------

if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Unidepth")
    parser.add_argument("--root_path", type=str)
    parser.add_argument("--outdir", type=str, default="./vis_depth")
    parser.add_argument("--video_path", type=str)
    parser.add_argument("--split", type=str, default="gt")
    parser.add_argument("--action", type=str, default="free")
    parser.add_argument("--smooth_slam", type=int, default=0)
    parser.add_argument("--eval_agents", type=int, default=0)
    parser.add_argument("--use_droid", type=int, default=0)
    args = parser.parse_args()

    split = args.split
    act_dir = args.action if split != "gt" else ""
    runs = [r for r in os.listdir(args.outdir) if os.path.isdir(os.path.join(args.outdir, r))]

    method = "unidepth" if args.use_droid == 0 else "droid_unidepth"
    print(f"Using method: {method}")
    print(f"Total runs: {len(runs)}")

    save_base = os.path.join("outputs", args.outdir.split("/")[-1])
    os.makedirs(save_base, exist_ok=True)

    vis_base = os.path.join(save_base, f"{method}-vis-{split}_{act_dir}")
    os.makedirs(vis_base, exist_ok=True)

    gts = []            # local GT trajectories
    gt_paths = {}
    agents_traj = []
    gt_agents_traj = []
    agents_bbox = []
    gt_agents_bbox = []
    agents_label = []
    valid_ego_runs = []
    valid_agents_runs = []

    gt_json = args.root_path
    with open(gt_json, "r") as f:
        gt_json = json.load(f)

    if "general" in args.outdir:
        for gt_base in gt_json:
            gt_path = os.path.join(gt_base, "ego_motion_0620.npy")
            gt = np.load(gt_path, allow_pickle=True)
            gt_local_xy, gt_local_yx, theta = gt_2_ego(gt[:101, :2])

            if args.smooth_slam:
                gt_local_xy = smooth_traj_sg(gt_local_xy, dt=0.1, win_sec=0.4, poly=3)
            gts.append(gt_local_xy)
    else:
        for gt_base in gt_json:
            key = gt_base.split("/")[-2] + "+" + gt_base.split("/")[-1]
            gt_paths[key] = gt_base

    for run in runs:
        s_name = run
        if split == "gt":
            log_base = os.path.join(args.outdir, s_name, split, method)
        else:
            log_base = os.path.join(args.outdir, s_name, split, act_dir, method)

        if args.eval_agents:
            try:
                with open(log_base + "-estimate_agents_traj.pkl", "rb") as f:
                    agent_traj = pickle.load(f)
                with open(log_base + "-estimate_agents_bbox.pkl", "rb") as f:
                    agent_bbox = pickle.load(f)
                with open(log_base + "-estimate_agents_bbox_label.pkl", "rb") as f:
                    agent_label = pickle.load(f)
            except Exception:
                print("Skip: no agents for this run.")
                continue

            trans_agent_bbox = []
            for box_id, boxes in enumerate(agent_bbox):
                ids = [box[0] for box in boxes]
                ids_valid, boxes_valid = [], []
                for i, id_ in enumerate(ids):
                    ids_valid.append(id_)
                    boxes_valid.append(boxes[i])
                    if i < len(ids) - 1 and id_ + 10 < ids[i + 1]:
                        print(f"Trim track at frame {id_}")
                        break
                trans_agent_bbox.append(boxes_valid)

            agents_bbox.append(agent_bbox)  # keep original behavior
            agents_label.append(agent_label)
            valid_agents_runs.append(run)

        valid_ego_runs.append(run)

    metrics = {}

    # ------------------------- Abnormal disappearance ------------------------
    if args.eval_agents:
        missing_rate = []
        num_agent = 0
        num_missing_agent = 0
        glm_dir = os.path.join(save_base, "glm_input", split, act_dir)
        os.system(f"rm -rf {glm_dir}")
        os.makedirs(glm_dir, exist_ok=True)
        print(f"Valid runs (with agents): {len(valid_agents_runs)}; bbox lists: {len(agents_bbox)}")

        for scene_id, scene_agents in enumerate(agents_bbox):
            is_missing = []
            num_agent += len(scene_agents)

            if split == "gt":
                gen_video_path_ = os.path.join(gt_paths[valid_agents_runs[scene_id]], "CAM_F0")
            else:
                gen_video_path_ = os.path.join(args.video_path, valid_agents_runs[scene_id], split, act_dir, "images")

            gen_video_path = sorted([os.path.join(gen_video_path_, f) for f in os.listdir(gen_video_path_)])
            video_img_dict = {img_id: img for img_id, img in enumerate(gen_video_path)}

            for agent_id, scene_agent in enumerate(scene_agents):
                other_boxes_by_frame = {}
                for other in [o for o in scene_agents if o is not scene_agent]:
                    for o in other:
                        f_id, bbox = o
                        other_boxes_by_frame.setdefault(f_id, []).append(bbox)

                glm_this = os.path.join(glm_dir, f"run_{scene_id}", f"{agent_id}")
                missing = disappeared_suddenly(
                    scene_agent,
                    other_boxes_by_frame,
                    video_img_dict,
                    glm_this,
                    img_size=(1024, 576),
                    edge_margin=0.1,
                    iou_threshold=0.5,
                    min_track_len=2,
                )
                print(f"Run {scene_id} - agent {agent_id} missing: {missing}")
                is_missing.append((scene_agent, missing))

            missing_rate.append((get_missing_per_scene(is_missing), is_missing))
            num_missing_agent += get_missing_per_agent(is_missing)

        assert len(missing_rate) == len(valid_agents_runs)

        missing_rate_vals = [m[0] for m in missing_rate]
        missing_rate_mean = np.nanmean(missing_rate_vals, axis=0)
        metrics["objects-feasibility"] = {}
        metrics["objects-feasibility"]["object_missing_error_scene"] = 1 - missing_rate_mean
        metrics["objects-feasibility"]["object_missing_error_agent"] = num_missing_agent / num_agent

    print(metrics)
    track = "align" if "align" in args.root_path else "general"
    os.makedirs("./abnormal", exist_ok=True)
    with open(f"./abnormal/{split}_{track}.json", "w") as f:
        json.dump(metrics, f, indent=4)
