# run_agents_depth_track.py  — sanitized & parameterized
import argparse
import glob
import json
import math
import os
import random
import re

import cv2
import numpy as np
import torch
from PIL import Image
from ultralytics import YOLOv10

# UniDepth imports will work after we append --unidepth_repo to sys.path in main()
from UniDepth.unidepth.utils import colorize  # image_grid unused
from UniDepth.scripts.dataset import *  # noqa: F403
from UniDepth.scripts.vo import *       # noqa: F403

depth_model = None
det_model = None


def init_depth_model(unidepth_repo: str,
                     version: str = "v2",
                     backbone: str = "vitl14",
                     device: str | None = None):
    """Initialize UniDepth from a local repo via torch.hub."""
    import sys
    if unidepth_repo not in sys.path:
        sys.path.append(unidepth_repo)

    global depth_model
    if device is None:
        device = "cuda" if torch.cuda.is_available() else "cpu"

    depth_model = torch.hub.load(
        unidepth_repo,
        "UniDepth",
        version=version,
        backbone=backbone,
        pretrained=True,
        trust_repo=True,
        force_reload=True,
        source="local",
    )

    # UniDepth v2 specific settings
    depth_model.resolution_level = 0
    depth_model.interpolation_mode = "bilinear"
    depth_model = depth_model.to(torch.device(device))


def init_det_model(weights_path: str, device: str | None = None):
    """Initialize YOLOv10 detector."""
    global det_model
    if device is None:
        device = "cuda" if torch.cuda.is_available() else "cpu"
    det_model = YOLOv10(weights_path)
    if device == "cuda":
        det_model = det_model.cuda()


def set_task_list(base_path, local_rank, gt_json=None, split="gt", action="free", all_id=8):
    """Build per-rank shard of scenes to process."""
    runs = []

    if split == "gt":
        with open(gt_json, "r") as f:
            gt_json_obj = json.load(f)
        dirs = gt_json_obj
    else:
        dirs = os.listdir(base_path)

    if split == "gt":
        scenes = dirs
    else:
        scenes = [os.path.join(base_path, f, split, action) for f in dirs]

    global_rank = int(os.environ.get("RANK", 0))
    world_size = int(os.environ.get("WORLD_SIZE", 1))

    random.seed(2026)
    random.shuffle(scenes)

    data_chunks = np.array_split(scenes, world_size)
    local_data_chunk = data_chunks[global_rank]

    local_data_chunks = np.array_split(local_data_chunk, all_id)
    scenes = local_data_chunks[local_rank % all_id]

    for scene in scenes:
        if split == "gt":
            runs.append(os.path.join(scene, "CAM_F0"))
        else:
            runs.append(os.path.join(scene, "images"))

    print(f"Total {len(runs)} data to process, local index {local_rank}")
    return runs


def reconstruct_global_trajectory(pixel_centers, depth_values, Ks, camera_poses, delta_d=None):
    """Back-project pixel centers using per-frame intrinsics and depths, then transform to world. Returns list of [x, z]."""
    global_trajectory = []
    for i, ((x_c, y_c), Z_c) in enumerate(zip(pixel_centers, depth_values)):
        K_inv = np.linalg.inv(Ks[i])
        pixel_coord = np.array([x_c, y_c, 1.0])
        if delta_d is not None:
            Z_c += float(delta_d[i])
        cam_coord = (K_inv @ pixel_coord) * Z_c

        R, T = camera_poses[i]   # R: (3,3), T: (3,)
        world_coord = R @ cam_coord + T
        global_trajectory.append([world_coord[0], world_coord[-1]])  # x, z
    return global_trajectory


def det_obj(frame):
    """
    Run detector on a single RGB frame.
    Returns:
      mask (uint8 HxW), centers for selected classes, bboxes (xyxy), and a label dict keyed by "x1-y1-x2-y2".
    """
    global det_model
    result = det_model(frame)[0]

    boxes = result.boxes
    h, w = frame.shape[:2]
    mask = (np.ones((h, w)) * 255).astype(np.uint8)

    movable = {0, 1, 2, 3, 4, 5, 6, 7, 8, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23}
    centers = []
    bboxes = []
    label_map = {}

    lookup = {0: "person", 1: "bicycle", 2: "car", 3: "motorcycle", 5: "bus", 7: "truck"}

    if boxes is not None and len(boxes) > 0:
        for i in range(len(boxes)):
            cls_id = int(boxes.cls[i].item())
            if cls_id in movable:
                xy = boxes.data[i, :4].cpu().numpy()
                x1, y1, x2, y2 = map(int, xy[:4])
                mask[y1:y2, x1:x2] = 0

                if cls_id in lookup and boxes.conf[i].item() >= 0.3:
                    mid_x = int((x1 + x2) / 2)
                    mid_y = int((y1 + y2) / 2)
                    centers.append([mid_x, mid_y])
                    bboxes.append([x1, y1, x2, y2])
                    label_map[f"{x1}-{y1}-{x2}-{y2}"] = lookup[cls_id]

    return mask, centers, bboxes, label_map


def estimate_depth_from_mask(depth, mask, method="median_top", use_percentile=False):
    """
    Robust depth estimate (meters) given a binary mask on a depth map.
    method: 'median_top', 'mean_top', 'full_median', 'mean', etc.
    """
    if not np.any(mask):
        return None

    ys, xs = np.where(mask)
    y_min, y_max = ys.min(), ys.max()
    h = y_max - y_min + 1

    if "top" in method:
        threshold = y_min + h * 0.5
        top_mask = mask.copy()
        for y, x in zip(ys, xs):
            if y >= threshold:
                top_mask[y, x] = False
        target_mask = top_mask
    else:
        target_mask = mask

    valid = depth[target_mask > 0]
    if len(valid) == 0:
        return None

    if use_percentile:
        lo, hi = np.percentile(valid, [2, 98])
        valid = valid[(valid >= lo) & (valid <= hi)]

    if "mean" in method:
        return float(np.mean(valid))
    else:
        return float(np.median(valid))


def images_to_video(image_folder, output_video, fps=30):
    from moviepy.editor import ImageSequenceClip
    images = [
        os.path.join(image_folder, x)
        for x in os.listdir(image_folder)
        if x.lower().endswith((".png", ".jpg", ".jpeg"))
    ]
    images.sort()
    clip = ImageSequenceClip(images, fps=fps)
    clip.write_videofile(output_video, codec="libx264", verbose=False, logger=None)


def main():
    parser = argparse.ArgumentParser(description="Agents depth & tracking (sanitized)")
    parser.add_argument("--root_path", type=str, required=True)
    parser.add_argument("--outdir", type=str, default="./vis_depth")
    parser.add_argument("--gt_meta_path", type=str, default="")
    parser.add_argument("--split", type=str, default="gt", choices=["gt", "gen"])
    parser.add_argument("--action", type=str, default="free")
    parser.add_argument("--local_id", type=int, default=0)
    parser.add_argument("--all_id", type=int, default=8)

    # New: configurable repos/weights and optional device & intrinsics meta root
    parser.add_argument("--unidepth_repo", type=str, required=True,
                        help="Local UniDepth repo path for torch.hub and python imports.")
    parser.add_argument("--yolo_weights", type=str, required=True,
                        help="YOLOv10 weights path (e.g., yolov10x.pt).")
    parser.add_argument("--samurai_repo", type=str, required=True,
                        help="Local Samurai repo path providing samurai.scripts.demo.samurai_main.")
    parser.add_argument("--intrinsics_meta_root", type=str, default="",
                        help="Optional: if set, load intrinsics from this meta root instead of predicted.")
    parser.add_argument("--device", type=str, default=None, choices=[None, "cpu", "cuda"],
                        help="Override device; default auto-detect.")
    args = parser.parse_args()

    # Prepare import paths
    import sys
    if args.unidepth_repo not in sys.path:
        sys.path.append(args.unidepth_repo)
    if args.samurai_repo not in sys.path:
        sys.path.append(args.samurai_repo)

    # Deferred import after sys.path updated
    from samurai.scripts.demo import samurai_main  # noqa: E402

    print("Torch version:", torch.__version__)
    init_depth_model(args.unidepth_repo, version="v2", backbone="vitl14", device=args.device)
    init_det_model(args.yolo_weights, device=args.device)

    runs = set_task_list(
        args.root_path, int(args.local_id), args.gt_meta_path, args.split, args.action, args.all_id
    )

    split = args.split
    act_dir = args.action

    for run in runs:
        print(f"{args.local_id}: {run}")

        filenames = glob.glob(os.path.join(run, "**", "*"), recursive=True)
        filenames.sort()

        rgbs = []
        depths_raw = []
        points_3d = []
        rgb_intrinsics = []
        masks = []

        # initialize in outer scope for later usage
        person_car_bbox = []
        person_car_label = {}

        if split == "gt":
            s_name = run.split("/")[-3] + "+" + run.split("/")[-2]
            log_base = os.path.join(args.outdir, s_name, split, "unidepth")
        else:
            s_name = run.split("/")[-4]
            log_base = os.path.join(args.outdir, s_name, split, act_dir, "unidepth")

        depth_out_dir = os.path.join(log_base, "depth_frame")
        os.makedirs(depth_out_dir, exist_ok=True)

        seg_out_dir = os.path.join(log_base, "track_frame")
        os.makedirs(seg_out_dir, exist_ok=True)

        # Depth & intrinsics inference
        for k, filename in enumerate(filenames):
            if "conds_20" in args.action and k < 19:  # 20th as the first (81 total)
                continue

            rgb = np.array(Image.open(filename))

            if k == 0:
                mask, person_car, person_car_bbox, person_car_label = det_obj(rgb)
            else:
                mask, _, _, _ = det_obj(rgb)
            masks.append(mask)
