import argparse
import os
import json
import pickle

import cv2
import numpy as np
from numpy.typing import ArrayLike
import torch

import matplotlib
matplotlib.use("Agg")  # set backend before importing pyplot
import matplotlib as mpl
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec

from scipy.signal import savgol_filter

from trajs.traj_distribution import traj_fid_mtr
from trajs.traj_alignment import (
    ade as ade_fn, fde as fde_fn, success_rate, hausdorff, ndtw,
    dynamic_consistency, sdtw
)
from trajs.traj_quality import (
    feasibility_rate, jerk_rms, curvature_rms, yaw_rate_rms_deg, trajectory_consistency,
    comfort_score_3, comfort_raw, comfort_raw_gt, speed_score,
    get_others_quality
)
from objects.missing_and_occ import (
    disappeared_suddenly, get_missing_per_scene, get_missing_per_agent
)
from objects.stability import stability_metric


# ----------------------------- Sheet row printer -----------------------------

def print_sheet_row(metrics, include_header=False):
    """
    Print a single CSV-like row that can be pasted into Google Sheets and split
    with =SPLIT(..., ","). Column order:
    Distribution → Alignment → Ego Quality → Others.
    """
    m = metrics or {}
    dist   = (m.get('distribution') or {})
    align  = (m.get('alignment') or {})
    ego    = (m.get('ego_quality') or {})
    # Others: prefer agent_*, otherwise fall back to gt_quality
    others_dist = (m.get('agent_2_ego_distribution') or {})
    others_qual = (m.get('agent_quality') or m.get('gt_quality') or {})

    def triple(x):
        if isinstance(x, (list, tuple)) and len(x) >= 3:
            return x[0], x[1], x[2]
        return None, None, None

    # Ego comfort triple
    ego_j, ego_a, ego_y = triple(ego.get('comfort-raw'))
    # Others comfort triple
    oth_j, oth_a, oth_y = triple(others_qual.get('comfort_raw'))

    # Curvature key compatibility
    ego_cur = ego.get('geometry-curvature') or ego.get('geometry_curvature') or ego.get('curvature')
    oth_cur = others_qual.get('geometry-curvature') or others_qual.get('geometry_curvature') or others_qual.get('curvature')

    # traj_fid key compatibility
    dist_fid = dist.get('traj_fid') or dist.get('traj_fid_mtr') or dist.get('traj_fid_metric')
    oth_fid  = others_dist.get('traj_fid') or others_dist.get('traj_fid_mtr') or others_dist.get('traj_fid_metric')

    cols = [
        # Distribution
        ('traj_fid_mtr', dist_fid),

        # Alignment
        ('sDTW_score', align.get('sDTW_score') or align.get('sdtw_score')),
        ('motion_wasserstein_consistency', align.get('motion_wasserstein_consistency')),
        ('ade', align.get('ade')),
        ('fde', align.get('fde')),

        # Ego Quality
        ('jerk_per_m', ego_j),
        ('acc_per_m',  ego_a),
        ('yaw_per_m',  ego_y),
        ('speed_score', ego.get('speed_score')),
        ('curvature',  ego_cur),

        # Others Distribution to Ego
        ('others_traj_fid_mtr', oth_fid),

        # Others Quality
        ('others_jerk_per_m', oth_j),
        ('others_acc_per_m',  oth_a),
        ('others_yaw_per_m',  oth_y),
        ('others_speed_score', others_qual.get('speed_score')),
        ('others_curvature',   oth_cur),
    ]

    def fmt(v):
        if v is None:
            return ""
        if isinstance(v, (int, float)):
            return f"{v:.4f}"
        return str(v)

    headers = ",".join(k for k, _ in cols)
    values  = ",".join(fmt(v) for _, v in cols)

    if include_header:
        print(f'=SPLIT("{headers}", ",")')
    print(f'=SPLIT("{values}", ",")')


# ----------------------------- Geometry utils -------------------------------

def gt_2_ego(gt_xy, heading=None, k_ahead=1, min_step=1):
    """
    Convert global (x, y) trajectory to a local ego frame with origin at t=0.
    Returns:
        gt_local_xy: (T, 2) ego frame (x, y)
        gt_local_yx: (T, 2) swapped version (y, x) for visualization
        theta:       heading angle (rad) used for rotation
    """
    gt = torch.from_numpy(gt_xy)  # (T, 2)

    # translate to origin
    origin = gt[0]
    rel_gt = gt - origin

    # robust heading estimate using multiple steps if needed
    k = min(k_ahead, len(gt)-1)
    v = gt[k] - gt[0]
    if torch.linalg.norm(v) < min_step:
        for j in range(1, len(gt)):
            v = gt[j] - gt[0]
            if torch.linalg.norm(v) >= min_step:
                break

    # rotate to align heading with +Z
    if heading is None:
        heading = v
        theta   = torch.atan2(heading[1], heading[0])  # angle w.r.t. +x
    else:
        theta = heading
    R = torch.tensor([
        [torch.cos(theta), -torch.sin(theta)],   # CCW rotation
        [torch.sin(theta),  torch.cos(theta)]
    ])

    gt_local = torch.matmul(rel_gt, R)
    gt_local_xy = gt_local.clone()
    gt_local[:, [0,1]] = gt_local[:, [1, 0]]
    gt_local[:, 0] = -gt_local[:, 0]
    gt_yx = gt_local.numpy()

    return gt_local_xy, gt_yx, theta


def ego_y_2_x(ego_xz):
    """Swap (x, z) → (z, x) with sign flip on the second component."""
    t = torch.from_numpy(np.asarray(ego_xz, float))
    t[:, [0,1]] = t[:, [1, 0]]
    t[:, 1] = -t[:, 1]
    return t.numpy()


# ----------------------------- Umeyama 2-D ----------------------------------

def umeyama_2d(X: np.ndarray, Y: np.ndarray, with_scale=True):
    """
    2D Umeyama alignment.
    X, Y: (N, 2), matched points (N ≥ 2). Returns s, R(2x2), t(2,)
    such that Y ≈ s·R·X + t.
    """
    muX, muY = X.mean(0), Y.mean(0)
    Xc, Yc   = X - muX, Y - muY

    H = Xc.T @ Yc / len(X)
    U, _, Vt = np.linalg.svd(H)
    R = Vt.T @ U.T
    if np.linalg.det(R) < 0:          # resolve reflection
        Vt[1] *= -1
        R = Vt.T @ U.T

    if with_scale:
        varX = (Xc**2).sum() / len(X)
        s = (Yc * (R @ Xc.T).T).sum() / (len(X) * varX)
    else:
        s = 1.0

    t = muY - s * R @ muX
    return s, R, t


def slam_align_to_gt_fix_origin(
    pred_xyz:   ArrayLike,      # (N,2) or (N,3)
    gt_local:   ArrayLike,      # (N,2) from gt_2_ego → first frame (0,0)
    with_scale: bool = True,
) -> np.ndarray:
    """
    Estimate scale s and rotation R (keep the first frame at the origin).
    Return (aligned (N,2), s, R).
    """
    P = np.asarray(pred_xyz, float)
    Q = np.asarray(gt_local, float)

    P_rel = P - P[0]
    Q_rel = Q

    H = P_rel.T @ Q_rel / len(P_rel)
    U, _, Vt = np.linalg.svd(H)
    R = Vt.T @ U.T
    if np.linalg.det(R) < 0:
        Vt[1] *= -1
        R = Vt.T @ U.T

    if with_scale:
        varP = (P_rel**2).sum() / len(P_rel)
        s = (Q_rel * (R @ P_rel.T).T).sum() / (len(P_rel) * varP)
    else:
        s = 1.0

    aligned = (s * (R @ P_rel.T)).T
    return aligned, s, R


def apply_sr(trajectory_xz, s, R):
    """Apply scale s and rotation R to a trajectory."""
    arr = np.asarray(trajectory_xz, float)
    aligned = (s * (R @ arr.T)).T
    return aligned


# ----------------------------- Smoothing ------------------------------------

def smooth_traj_sg(xy, dt=0.1, win_sec=0.4, poly=3):
    """
    Savitzky–Golay smoothing for a 2D trajectory.
    """
    xy = np.asarray(xy, float)
    T  = xy.shape[0]

    # window_length must be odd, <= T, and >= poly+2
    k = int(round(win_sec / dt))
    if k % 2 == 0:
        k += 1  # force odd

    if k > T:
        k = T if T % 2 == 1 else T - 1

    if T == 4:
        k = 3
        poly = 1
    elif poly >= k - 1:
        poly = max(1, k - 2)

    if k < 3:
        return xy

    xy_s = savgol_filter(xy, window_length=k, polyorder=poly, axis=0, mode="interp")
    return xy_s


# ----------------------------- IoU helper -----------------------------------

def max_iou_box(query_box: np.ndarray, boxes: np.ndarray, return_index: bool = False):
    """
    Find the box with the highest IoU w.r.t. `query_box`.
    Returns (best_box, best_iou) or (best_box, best_iou, best_idx) if return_index=True.
    """
    query_box = np.asarray(query_box, dtype=float)
    boxes     = np.asarray(boxes,     dtype=float)
    if boxes.size == 0:
        raise ValueError("`boxes` is empty.")

    ix1 = np.maximum(query_box[0], boxes[:, 0])
    iy1 = np.maximum(query_box[1], boxes[:, 1])
    ix2 = np.minimum(query_box[2], boxes[:, 2])
    iy2 = np.minimum(query_box[3], boxes[:, 3])

    inter_w = np.clip(ix2 - ix1, 0, None)
    inter_h = np.clip(iy2 - iy1, 0, None)
    inter   = inter_w * inter_h

    query_area = (query_box[2] - query_box[0]) * (query_box[3] - query_box[1])
    boxes_area = (boxes[:, 2]  - boxes[:, 0]) * (boxes[:, 3]  - boxes[:, 1])

    iou = inter / (query_area + boxes_area - inter + 1e-8)

    best_idx  = int(iou.argmax())
    best_iou  = float(iou[best_idx])
    best_box  = boxes[best_idx]

    return (best_box, best_iou, best_idx) if return_index else (best_box, best_iou)


# ----------------------------- Visualization --------------------------------
# NOTE: draw_car(...) must be provided elsewhere in your project if draw_polygon=1.

def visualize_trajectory(trajectory, out_path, gt=None, gt_pred=None,
                         others=None, others_gt=None, map_scale=1,
                         car_length=4.5, car_width=2.0, draw_polygon=0, yaw0=0):
    locX = trajectory[:, 0]
    locZ = trajectory[:, 1]

    mpl.rc("figure", facecolor="white")
    plt.style.use("seaborn-v0_8-whitegrid")

    fig, axes = plt.subplots(nrows=2, ncols=2, figsize=(10, 8), dpi=100)

    traj_main_plt = axes[0, 0]
    gt_plt        = axes[0, 1]
    ego_plt       = axes[1, 0]
    others_plt    = axes[1, 1]

    _ = gridspec.GridSpec(1, 1)
    colors = plt.rcParams['axes.prop_cycle'].by_key()['color']

    traj_main_plt.set_title("Trajectory (Z, X)", y=1.0)
    if draw_polygon:
        draw_car(traj_main_plt, locX, locZ, car_length, car_width, color=colors[0], zorder=1)
        draw_car(ego_plt,       locX, locZ, car_length, car_width, color=colors[0], zorder=1)

    traj_main_plt.plot(locX, locZ, ".-", label="Ego_of_gen_video", zorder=6, linewidth=1, markersize=4, color=colors[0])
    ego_plt.plot(locX,    locZ, ".-", label="Ego_of_gen_video", zorder=6, linewidth=1, markersize=4, color=colors[0])

    max_x, min_x = np.max(locX), np.min(locX)
    max_y, min_y = np.max(locZ), np.min(locZ)

    if gt is not None:
        if draw_polygon:
            draw_car(traj_main_plt, gt[:, 0], gt[:, 1], car_length, car_width, color=colors[1], zorder=1)
            draw_car(gt_plt,       gt[:, 0], gt[:, 1], car_length, car_width, color=colors[1], zorder=1)
        traj_main_plt.plot(gt[:, 0], gt[:, 1], ".-", label="Ego_of_gt_traj", zorder=1, linewidth=1, markersize=4, color=colors[1])
        gt_plt.plot(       gt[:, 0], gt[:, 1], ".-", label="Ego_of_gt_traj", zorder=1, linewidth=1, markersize=4, color=colors[1])

        max_x = max(max_x, np.max(gt[:, 0]))
        min_x = min(min_x, np.min(gt[:, 0]))
        max_y = max(max_y, np.max(gt[:, 1]))
        min_y = min(min_y, np.min(gt[:, 1]))

    if gt_pred is not None:
        if draw_polygon:
            draw_car(traj_main_plt, gt_pred[:, 0], gt_pred[:, 1], car_length, car_width, color=colors[1], zorder=1)
            draw_car(gt_plt,       gt_pred[:, 0], gt_pred[:, 1], car_length, car_width, color=colors[1], zorder=1)
        gt_plt.plot(gt_pred[:, 0], gt_pred[:, 1], ".-", label="Ego_of_gt_video", zorder=1, linewidth=1, markersize=4, color=colors[5])

        max_x = max(max_x, np.max(gt_pred[:, 0]))
        min_x = min(min_x, np.min(gt_pred[:, 0]))
        max_y = max(max_y, np.max(gt_pred[:, 1]))
        min_y = min(min_y, np.min(gt_pred[:, 1]))

    if others is not None:
        for idx, other in enumerate(others):
            if idx == 0:
                traj_main_plt.plot(other[:, 0], other[:, 1], ".-", label="Others", zorder=3, linewidth=1, markersize=4, color=colors[2])
                others_plt.plot(   other[:, 0], other[:, 1], ".-", label="Others", zorder=3, linewidth=1, markersize=4, color=colors[2])
            else:
                traj_main_plt.plot(other[:, 0], other[:, 1], ".-", zorder=3, linewidth=1, markersize=4, color=colors[2])
                others_plt.plot(   other[:, 0], other[:, 1], ".-", zorder=3, linewidth=1, markersize=4, color=colors[2])

            max_x = max(max_x, np.max(other[:, 0]))
            min_x = min(min_x, np.min(other[:, 0]))
            max_y = max(max_y, np.max(other[:, 1]))
            min_y = min(min_y, np.min(other[:, 1]))

    if others_gt is not None:
        count = 0
        for o_gt in others_gt:
            if abs(o_gt[-1, 0] - o_gt[0, 0]) > 10:
                count += 1
                if count == 2:
                    if draw_polygon:
                        draw_car(traj_main_plt, o_gt[:, 1], o_gt[:, 0], car_length, car_width, color=colors[4], zorder=1)
                        draw_car(gt_plt,       o_gt[:, 1], o_gt[:, 0], car_length, car_width, color=colors[4], zorder=1)
                    traj_main_plt.plot(o_gt[:, 1], o_gt[:, 0], ".-", label="FocalGT", zorder=2, linewidth=1, markersize=4, color=colors[4])
                    gt_plt.plot(       o_gt[:, 1], o_gt[:, 0], ".-", label="FocalGT", zorder=2, linewidth=1, markersize=4, color=colors[4])
                    max_x = max(max_x, np.max(o_gt[:, 1]))
                    min_x = min(min_x, np.min(o_gt[:, 1]))
                    max_y = max(max_y, np.max(o_gt[:, 0]))
                    min_y = min(min_y, np.min(o_gt[:, 0]))

    min_x = min(-30, min_x)
    max_x = max(30, max_x)
    x_ = max(abs(min_x), abs(max_x))
    min_y = min(-10, min_y)
    max_y = max(50, max_y)

    def set_plot(ax):
        ax.set_xlabel("X")
        ax.set_ylabel("Z")
        ax.set_xlim([-x_-3, x_+3])
        ax.set_ylim([min_y-3, max_y+3])
        handles, labels = ax.get_legend_handles_labels()
        if handles:
            ax.legend(loc=1, title="Legend", borderaxespad=0., fontsize="medium", frameon=True)

    set_plot(traj_main_plt)
    set_plot(gt_plt)
    set_plot(ego_plt)
    set_plot(others_plt)

    plt.savefig(out_path)
    plt.close(fig)


# --------------------------------- Main -------------------------------------

if __name__ == '__main__':
    parser = argparse.ArgumentParser(description='Unidepth')
    parser.add_argument('--root_path', type=str)
    parser.add_argument('--outdir', type=str, default='./vis_depth')
    parser.add_argument('--video_path', type=str)
    parser.add_argument('--split', type=str, default='gt')
    parser.add_argument('--action', type=str, default='free')
    parser.add_argument('--smooth_slam', type=int, default=0)
    parser.add_argument('--eval_agents', type=int, default=0)
    parser.add_argument('--use_droid', type=int, default=0)
    args = parser.parse_args()

    split = args.split
    act_dir = args.action if split != 'gt' else ''
    runs = [r for r in os.listdir(args.outdir) if os.path.isdir(os.path.join(args.outdir, r))]

    method = 'unidepth' if args.use_droid == 0 else 'droid_unidepth'
    print(f'Using method: {method}')
    print(f'Total runs: {len(runs)}')

    save_base = os.path.join('outputs', args.outdir.split('/')[-1])
    os.makedirs(save_base, exist_ok=True)

    vis_base = os.path.join(save_base, f'{method}-vis-{split}_{act_dir}')
    os.makedirs(vis_base, exist_ok=True)

    preds = []
    gts = []  # local

    agents_traj = []
    agents_bbox = []
    agents_label = []

    valid_ego_runs = []
    valid_agents_runs = []

    with open(args.root_path, 'r') as f:
        gt_json = json.load(f)

    if 'general' in args.outdir:
        for gt_base in gt_json:
            gt_path = os.path.join(gt_base, 'ego_motion_0620.npy')
            gt = np.load(gt_path, allow_pickle=True)  # global
            gt_local_xy, gt_local_yx, theta = gt_2_ego(gt[:101, :2])
            if args.smooth_slam:
                gt_local_xy = smooth_traj_sg(gt_local_xy, dt=0.1, win_sec=0.4, poly=3)
            gts.append(gt_local_xy)
    else:
        gt_paths = {}
        for gt_base in gt_json:
            key = gt_base.split('/')[-2] + '+' + gt_base.split('/')[-1]
            gt_paths[key] = gt_base

    for run in runs:
        s_name = run

        if split == 'gt':
            log_base = os.path.join(args.outdir, s_name, split, method)
        else:
            log_base = os.path.join(args.outdir, s_name, split, act_dir, method)

        # GT for alignment track
        if 'align' in args.outdir:
            gt_base = gt_paths[run]
            gt_path = os.path.join(gt_base, 'ego_motion_0620.npy')
            gt = np.load(gt_path, allow_pickle=True)
            gt_local_xy, gt_local_yx, theta = gt_2_ego(gt[:101, :2])

        # Prediction
        try:
            if args.use_droid == 0:
                with open(log_base+'-estimate_ego_traj_0619.pkl', 'rb') as f:
                    data = pickle.load(f)
                    pred = data['locs'].astype(np.float32)
            else:
                pred = np.load(log_base+'-estimate_ego_traj.npy', allow_pickle=True).astype(np.float32)
            pred_xy = ego_y_2_x(pred)   # to (x, y)
        except Exception:
            print("Skipping one run due to missing files.")
            continue

        # Align heading
        if 'align' in args.outdir:
            try:
                pred_xy, s, r = slam_align_to_gt_fix_origin(pred_xy, gt_local_xy, with_scale=False)
            except Exception:
                import pdb; pdb.set_trace()

            gt_log_base = os.path.join(args.outdir, s_name, 'gt', method)
            if args.use_droid == 0:
                with open(gt_log_base+'-estimate_ego_traj_0619.pkl', 'rb') as f:
                    data = pickle.load(f)
                    gt_pred = data['locs'].astype(np.float32)
                gt_pred_xy = ego_y_2_x(gt_pred)
            else:
                try:
                    gt_pred = np.load(gt_log_base+'-estimate_ego_traj.npy', allow_pickle=True).astype(np.float32)
                    gt_pred_xy = ego_y_2_x(gt_pred)
                except Exception:
                    print("Skipping one run due to missing GT files.")
                    continue

            gt_pred_xy, _, _ = slam_align_to_gt_fix_origin(gt_pred_xy, gt_local_xy, with_scale=False)
            if args.smooth_slam:
                gt_pred_xy  = smooth_traj_sg(gt_pred_xy,  dt=0.1, win_sec=0.4, poly=3)
                pred_xy     = smooth_traj_sg(pred_xy,     dt=0.1, win_sec=0.4, poly=3)
                gt_local_xy = smooth_traj_sg(gt_local_xy, dt=0.1, win_sec=0.4, poly=3)

            gts.append(gt_local_xy)
        else:
            if args.smooth_slam:
                pred_xy = smooth_traj_sg(pred_xy, dt=0.1, win_sec=0.4, poly=3)
            gt_local_xy = None
            gt_pred_xy = None

        preds.append(pred_xy)

        # Agents
        if args.eval_agents:
            try:
                with open(log_base+'-estimate_agents_traj.pkl', 'rb') as f:
                    agent_traj = pickle.load(f)
                with open(log_base+'-estimate_agents_bbox.pkl', 'rb') as f:
                    agent_bbox = pickle.load(f)
                with open(log_base+'-estimate_agents_bbox_label.pkl', 'rb') as f:
                    agent_label = pickle.load(f)
            except Exception:
                agent_traj = None

            if agent_traj is not None:
                trans_agent_traj = []
                for traj_id in range(len(agent_traj)):
                    try:
                        ids, traj, traj_label = agent_traj[traj_id]
                    except Exception:
                        import pdb; pdb.set_trace()
                    if traj_label != 'car':
                        continue
                    if len(traj) != ids[-1] + 1:
                        # keep the first consecutive segment
                        ids_valid, traj_valid = [], []
                        for i, _id in enumerate(ids):
                            if _id == i:
                                ids_valid.append(i)
                                traj_valid.append(traj[i])
                            else:
                                break
                    else:
                        ids_valid = ids
                        traj_valid = traj
                    if len(traj_valid) <= 10:
                        continue
                    traj_np = ego_y_2_x(traj_valid)
                    if 'align' in args.outdir:
                        traj_np = apply_sr(traj_np, s, r)
                    if args.smooth_slam:
                        traj_np = smooth_traj_sg(traj_np, dt=0.1, win_sec=1.0, poly=2)
                    trans_agent_traj.append((ids_valid, traj_np))

                # Trim fragmented box tracks (keep first segment)
                trans_agent_bbox = []
                for boxes in agent_bbox:
                    ids = [box[0] for box in boxes]
                    ids_valid, boxes_valid = [], []
                    for i, _id in enumerate(ids):
                        ids_valid.append(_id)
                        boxes_valid.append(boxes[i])
                        if i < len(ids) - 1 and _id + 10 < ids[i+1]:
                            print("Trim a fragmented track.")
                            break
                    trans_agent_bbox.append(boxes_valid)

                agents_traj.append(trans_agent_traj)
                # keep original bbox list (or switch to trans_agent_bbox if preferred)
                agents_bbox.append(agent_bbox)
                agents_label.append(agent_label)
                vis_traj = [t[1] for t in trans_agent_traj]
                valid_agents_runs.append(run)
            else:
                vis_traj = None
        else:
            vis_traj = None

        vis_path = os.path.join(vis_base, f'{s_name}.jpg')
        visualize_trajectory(pred_xy, vis_path, gt=gt_local_xy, gt_pred=gt_pred_xy, others=vis_traj)

        valid_ego_runs.append(run)

    assert len(valid_ego_runs) == len(preds), "Mismatch in number of valid runs and predictions."

    preds = np.array(preds)
    gts = np.array(gts)

    # ----------------------------- Metrics -----------------------------------

    traj_fid = traj_fid_mtr(preds, gts, stride=1)

    ade_vals = ade_fn(preds, gts, reduce='none')
    fde_vals = fde_fn(preds, gts, reduce='none')
    dc = dynamic_consistency(preds, gts, reduce='none')
    sdtw_score = sdtw(preds, gts, threshold=8, reduce='none')

    # per-run log (kept separate; not dumped)
    metric_of_runs = {}
    for i, run in enumerate(valid_ego_runs):
        metric_of_runs.setdefault(run, {})
        metric_of_runs[run]['ego_alignment'] = {
            'ade': ade_vals[i],
            'fde': fde_vals[i],
            'motion_wasserstein_consistency': dc[i],
            'sdtw': sdtw_score[i],
        }

    comfort = comfort_score_3(preds, reduce='none')
    crms = curvature_rms(preds, reduce='none')
    tc = trajectory_consistency(preds, reduce='none')
    ss = speed_score(preds, reduce='none')
    comfort_r = comfort_raw(preds, reduce='none')

    for i, run in enumerate(valid_ego_runs):
        metric_of_runs.setdefault(run, {})
        metric_of_runs[run]['ego_quality'] = {
            'comfort_score': comfort[i],
            'curvature': crms[i],
            'consistency': tc[i],
            'speed_score': ss[i],
            'comfort_raw': comfort_r[i]
        }

    # GT stats
    gt_comfort = comfort_score_3(gts, reduce='none')
    gt_crms = curvature_rms(gts, reduce='none')
    gt_consistency = trajectory_consistency(gts, reduce='none')
    gt_ss = speed_score(gts, reduce='none')
    gt_comfort_r = comfort_raw(gts, reduce='none')

    metrics = {
        'distribution': {
            'traj_fid': traj_fid
        },
        'alignment': {
            'sDTW_score': np.nanmean(sdtw_score),
            'motion_wasserstein_consistency': np.nanmean(dc),
            'ade': np.nanmean(ade_vals),
            'fde': np.nanmean(fde_vals),
        },
        'ego_quality': {
            'comfort-raw': np.nanmean(comfort_r, axis=tuple(range(comfort_r.ndim - 1))),
            'consistency-velo_acc': np.nanmean(tc),
            'speed_score': np.nanmean(ss),
            'geometry-curvature': np.nanmean(crms),
        },
        'gt_quality': {
            'comfort-raw': np.nanmean(gt_comfort_r, axis=tuple(range(gt_comfort_r.ndim - 1))),
            'consistency-velo_acc': np.nanmean(gt_consistency),
            'speed_score': np.nanmean(gt_ss),
            'geometry-curvature': np.nanmean(gt_crms),
        }
    }

    # -------------------- Others / objects metrics ---------------------------

    if args.eval_agents:
        agent_quality, scene_quality, agent_consistency = get_others_quality(agents_traj, gts)
    else:
        agent_quality = None
        scene_quality = None
        agent_consistency = None

    # Object missing rate
    if args.eval_agents:
        missing_rate = []
        num_agent = 0
        num_missing_agent = 0
        glm_dir = os.path.join(save_base, 'glm_input', split, act_dir)
        print(f'Valid runs with agents: {len(valid_agents_runs)}; agent_bbox sets: {len(agents_bbox)}')
        for scene_id, scene_agents in enumerate(agents_bbox):
            is_missing = []
            num_agent += len(scene_agents)
            if split == 'gt':
                gen_video_path_ = os.path.join(gt_paths[valid_agents_runs[scene_id]], 'CAM_F0')
            else:
                gen_video_path_ = os.path.join(args.video_path, valid_agents_runs[scene_id], split, act_dir, 'images')
            gen_video_path = sorted([os.path.join(gen_video_path_, f) for f in os.listdir(gen_video_path_)])
            video_img_dict = {i: img for i, img in enumerate(gen_video_path)}

            for agent_id, scene_agent in enumerate(scene_agents):
                other_agents = scene_agents.copy()
                other_agents.remove(scene_agent)
                other_boxes_by_frame = {}
                for other in other_agents:
                    for o in other:
                        f_id, bbox = o
                        other_boxes_by_frame.setdefault(f_id, []).append(bbox)

                glm_this = os.path.join(glm_dir, 'scene', f'{agent_id}')
                missing = disappeared_suddenly(
                    scene_agent,
                    other_boxes_by_frame,
                    video_img_dict,
                    glm_this,
                    img_size=(1024, 576),
                    edge_margin=0.1,
                    iou_threshold=0.5,
                    min_track_len=1
                )
                is_missing.append((scene_agent, missing))
            missing_rate.append((get_missing_per_scene(is_missing), is_missing))
            num_missing_agent += get_missing_per_agent(is_missing)

        assert len(missing_rate) == len(valid_agents_runs)

        # Optional: build debug videos
        from moviepy.editor import ImageSequenceClip

        def images_to_video(image_folder, output_video, fps=30):
            images = [os.path.join(image_folder, img) for img in os.listdir(image_folder)
                      if (img.endswith(".png") or img.endswith(".jpg"))]
            images.sort()
            clip = ImageSequenceClip(images, fps=fps)
            clip.write_videofile(output_video, codec="libx264", verbose=False, logger=None)

        missing_dir_video = os.path.join(save_base, 'missing_debug_video', split, act_dir)
        os.system('rm -rf ' + missing_dir_video)
        os.makedirs(missing_dir_video, exist_ok=True)
        missing_dir = os.path.join(save_base, 'missing_debug', split, act_dir)
        os.system('rm -rf ' + missing_dir)
        os.makedirs(missing_dir, exist_ok=True)

        for scene_idx, ms in enumerate(missing_rate):
            _, is_missing = ms
            m_bbox = []
            for box, missing_flag in is_missing:
                if missing_flag:
                    m_bbox += box

            missing_dir_this = os.path.join(missing_dir, f'scene_{scene_idx}')
            os.system('rm -rf ' + missing_dir_this)
            os.makedirs(missing_dir_this, exist_ok=True)

            if split == 'gt':
                img_dir = os.path.join(gt_paths[valid_agents_runs[scene_idx]], 'CAM_F0')
            else:
                img_dir = os.path.join(args.video_path, valid_agents_runs[scene_idx], split, act_dir, 'images')

            boxes_dict = {}
            for f_id, bbox in m_bbox:
                boxes_dict.setdefault(int(f_id), []).append(bbox)

            for fn in os.listdir(img_dir):
                img = cv2.imread(os.path.join(img_dir, fn))
                for (x1, y1, x2, y2) in boxes_dict.get(int(os.path.splitext(fn)[0]), []):
                    cv2.rectangle(img, (x1, y1), (x2, y2), (0, 255, 0), 2)
                cv2.imwrite(os.path.join(missing_dir_this, fn), img)

            images_to_video(missing_dir_this, os.path.join(missing_dir_video, f"scene_{scene_idx}.mp4"), fps=10)

        missing_rate_vals = [m[0] for m in missing_rate]
        missing_rate_vals = np.nanmean(missing_rate_vals, axis=0)

        metrics['objects-feasibility'] = {}
        metrics['objects-feasibility']['object_missing_error_scene'] = missing_rate_vals
        metrics['objects-feasibility']['object_missing_error_agent'] = num_missing_agent / max(1, num_agent)

    # Agent blur/stability (optional)
    if args.eval_agents:
        print(f'Runs for stability check: {len(valid_agents_runs)}; agent_bbox sets: {len(agents_bbox)}')
        scenes_stability = []
        for scene_id, scene_agents in enumerate(agents_bbox):
            labels_this_scene = agents_label[scene_id]
            candidate_box = []
            for bbox in list(labels_this_scene.keys()):
                c_x1, c_y1, c_x2, c_y2 = (int(x) for x in bbox.split('-'))
                candidate_box.append([c_x1, c_y1, c_x2, c_y2])
            candidate_box = np.array(candidate_box).astype(np.int32)

            scene_stability = []
            for scene_agent in scene_agents:
                if len(scene_agent) < 2:
                    print('Skip a very short track.')
                    continue
                frame_0_bbox = scene_agent[0][1]
                agent_bbox_np = np.array(frame_0_bbox).astype(np.int32)
                match_box, _ = max_iou_box(agent_bbox_np, candidate_box)
                match_box = match_box.astype(np.int32)
                key = f'{match_box[0]}-{match_box[1]}-{match_box[2]}-{match_box[3]}'
                label = labels_this_scene[key]

                if split == 'gt':
                    img_dir = os.path.join(gt_paths[valid_agents_runs[scene_id]], 'CAM_F0')
                else:
                    img_dir = os.path.join(args.video_path, valid_agents_runs[scene_id], split, act_dir, 'images')

                obj_stablity = stability_metric(img_dir, scene_agent, label)
                first_sim = obj_stablity['R']
                adj_sim = obj_stablity['A']
                text_rel_sim = obj_stablity['S']
                scene_stability.append([first_sim, adj_sim, text_rel_sim])

            scene_stability = np.array(scene_stability)
            scene_stability = np.nanmean(scene_stability, axis=0)
            scenes_stability.append(scene_stability)

        scenes_stability = np.array(scenes_stability)
        scenes_stability = np.nanmean(scenes_stability, axis=0)
        metrics.setdefault('objects-feasibility', {})
        metrics['objects-feasibility']['object_stability'] = scenes_stability

    if agent_quality is not None:
        metrics.update(agent_quality)
    if scene_quality is not None:
        metrics.update(scene_quality)
    if agent_consistency is not None:
        metrics.update(agent_consistency)  # fixed typo

    # ----------------------------- to-native & dump --------------------------

    def _native(x):
        """Convert NumPy scalars/arrays to Python native types for JSON dumping."""
        if isinstance(x, np.generic):
            return x.item()
        if isinstance(x, np.ndarray):
            return x.item() if x.size == 1 else x.tolist()
        return x

    for key, subdict in metrics.items():
        if isinstance(subdict, dict):
            for sub_key, sub_val in subdict.items():
                metrics[key][sub_key] = _native(sub_val)

    print(f"Metrics for {split} - {act_dir}:")
    for key, subdict in metrics.items():
        print(f"{key}:")
        for sub_key, sub_val in subdict.items():
            if isinstance(sub_val, (float, int)):
                print(f"  {sub_key}: {sub_val:.4f}")
            else:
                print(f"  {sub_key}: {sub_val}")

    metrics_path = os.path.join(save_base, f'{method}-{split}-{act_dir}-smooth{args.smooth_slam}.json')
    with open(metrics_path, 'w') as f:
        json.dump(metrics, f, indent=4)

    print_sheet_row(metrics, include_header=False)
