import argparse
import glob
import os
import re
import math
import json
import random

import cv2
import numpy as np
import torch
from PIL import Image
import matplotlib.pyplot as plt  # (kept for parity; not directly used)

# UniDepth (repo added via CLI arg)
# YOLOv10
from ultralytics import YOLOv10

# If your UniDepth repo has these modules, imports will work once sys.path is set from args
# from UniDepth.unidepth.models import UniDepthV1, UniDepthV2
from UniDepth.unidepth.utils import colorize, image_grid  # noqa: F401 (image_grid not used)
from UniDepth.scripts.dataset import *  # noqa: F403
from UniDepth.scripts.vo import *       # noqa: F403

depth_model = None
det_model = None


def init_depth_model(unidepth_repo: str,
                     version: str = "v2",
                     backbone: str = "vitl14",
                     device: str | None = None):
    """
    Initialize UniDepth model from a local repo via torch.hub.
    """
    global depth_model
    if device is None:
        device = "cuda" if torch.cuda.is_available() else "cpu"

    depth_model = torch.hub.load(
        unidepth_repo,
        "UniDepth",
        version=version,
        backbone=backbone,
        pretrained=True,
        trust_repo=True,
        force_reload=True,
        source="local",
    )
    # UniDepth v2 specific settings
    depth_model.resolution_level = 0
    depth_model.interpolation_mode = "bilinear"
    depth_model = depth_model.to(torch.device(device))


def init_det_model(weights_path: str, device: str | None = None):
    """
    Initialize YOLOv10 detector.
    """
    global det_model
    if device is None:
        device = "cuda" if torch.cuda.is_available() else "cpu"
    det_model = YOLOv10(weights_path)
    if device == "cuda":
        det_model = det_model.cuda()


def set_task_list(base_path, local_rank, gt_json=None, split="gt", action="free", all_id=8):
    """
    Build per-rank shard of scenes to process.
    """
    runs = []
    if split == "gt":
        with open(gt_json, "r") as f:
            gt_json_obj = json.load(f)
        dirs = gt_json_obj
    else:
        dirs = os.listdir(base_path)

    if split == "gt":
        scenes = dirs
    else:
        # e.g., <base>/<scene>/<split>/<action>
        scenes = [os.path.join(base_path, f, split, action) for f in dirs]

    global_rank = int(os.environ.get("RANK", 0))
    world_size = int(os.environ.get("WORLD_SIZE", 1))

    random.seed(2026)
    random.shuffle(scenes)

    data_chunks = np.array_split(scenes, world_size)
    local_data_chunk = data_chunks[global_rank]

    local_data_chunks = np.array_split(local_data_chunk, all_id)
    scenes = local_data_chunks[local_rank % all_id]

    for scene in scenes:
        if split == "gt":
            runs.append(os.path.join(scene, "CAM_F0"))
        else:
            runs.append(os.path.join(scene, "images"))

    print(f"Total {len(runs)} data to process, local index {local_rank}")
    return runs


def reconstruct_global_trajectory(pixel_centers, depth_values, Ks, camera_poses, delta_d=None):
    """
    Back-project pixel centers using per-frame intrinsics and depths, then transform to world.
    Returns list of [x, z].
    """
    global_trajectory = []
    for i, ((x_c, y_c), Z_c) in enumerate(zip(pixel_centers, depth_values)):
        K_inv = np.linalg.inv(Ks[i])
        pixel_coord = np.array([x_c, y_c, 1.0])
        if delta_d is not None:
            Z_c += float(delta_d[i])
        cam_coord = (K_inv @ pixel_coord) * Z_c

        R, T = camera_poses[i]   # R: (3,3), T: (3,)
        world_coord = R @ cam_coord + T
        global_trajectory.append(world_coord[[0, -1]])  # x, z
    return global_trajectory


def det_obj(frame):
    """
    Run detector on a single RGB frame. Returns:
    - binary mask suppressing movable classes
    - list of center pixels for cars
    - list of bounding boxes (xyxy) for cars
    """
    global det_model
    result = det_model(frame)[0]

    boxes = result.boxes
    h, w = frame.shape[:2]
    mask = (np.ones((h, w)) * 255).astype(np.uint8)

    movable = {
        0, 1, 2, 3, 4, 5, 6, 7, 8,   # people & vehicles & boat
        14, 15, 16, 17, 18, 19, 20, 21, 22, 23  # animals
    }
    car_centers = []
    car_bboxes = []

    if boxes is not None and len(boxes) > 0:
        for i in range(len(boxes)):
            cls_id = int(boxes.cls[i].item())
            if cls_id in movable:
                xy = boxes.data[i, :4].cpu().numpy()
                x1, y1, x2, y2 = map(int, xy[:4])
                mask[y1:y2, x1:x2] = 0

                if cls_id == 2:  # class 2: car
                    mid_x = int((x1 + x2) / 2)
                    mid_y = int((y1 + y2) / 2)
                    car_centers.append([mid_x, mid_y])
                    car_bboxes.append([x1, y1, x2, y2])

    return mask, car_centers, car_bboxes


def drive_roi_mask(h, w, keep=0.5, side=0.03):
    """
    Keep a central top region (rough driving ROI); zero elsewhere.
    """
    m = np.zeros((h, w), np.uint8)
    top, bot = 0, int(h * keep)
    l, r = int(w * side), int(w * (1 - side))
    m[top:bot, l:r] = 255
    return m


def images_to_video(image_folder, output_video, fps=30):
    from moviepy.editor import ImageSequenceClip
    images = [
        os.path.join(image_folder, x)
        for x in os.listdir(image_folder)
        if x.lower().endswith((".png", ".jpg", ".jpeg"))
    ]
    images.sort()
    clip = ImageSequenceClip(images, fps=fps)
    clip.write_videofile(
        output_video,
        codec="libx264",
        verbose=False,
        logger=None
    )


def main():
    parser = argparse.ArgumentParser(description="Depth & VO pipeline (sanitized)")
    parser.add_argument("--root_path", type=str, required=True)
    parser.add_argument("--outdir", type=str, default="./vis_depth")
    parser.add_argument("--gt_meta_path", type=str, default="")
    parser.add_argument("--split", type=str, default="gt", choices=["gt", "gen"])
    parser.add_argument("--action", type=str, default="free")
    parser.add_argument("--local_id", type=int, default=0)
    parser.add_argument("--all_id", type=int, default=8)

    # Newly added: configurable paths to avoid hard-coded personal paths
    parser.add_argument("--unidepth_repo", type=str, required=True,
                        help="Local UniDepth repo path for torch.hub (e.g., /path/to/UniDepth)")
    parser.add_argument("--yolo_weights", type=str, required=True,
                        help="Path to YOLOv10 weights (e.g., /path/to/yolov10x.pt)")
    parser.add_argument("--device", type=str, default=None, choices=[None, "cpu", "cuda"], help="Override device")

    args = parser.parse_args()

    # add UniDepth repo to sys.path so that utils/scripts imports work
    import sys
    if args.unidepth_repo not in sys.path:
        sys.path.append(args.unidepth_repo)

    print("Torch version:", torch.__version__)
    init_depth_model(args.unidepth_repo, version="v2", backbone="vitl14", device=args.device)
    init_det_model(args.yolo_weights, device=args.device)

    runs = set_task_list(
        args.root_path, int(args.local_id), args.gt_meta_path, args.split, args.action, args.all_id
    )

    split = args.split
    act_dir = args.action

    gt_paths = {}
    if args.gt_meta_path:
        with open(args.gt_meta_path, "r") as f:
            gt_json = json.load(f)
        for gt_base in gt_json:
            key = gt_base.split("/")[-2] + "+" + gt_base.split("/")[-1]
            gt_paths[key] = gt_base

    for run in runs:
        print(f"{args.local_id}: {run}")

        filenames = glob.glob(os.path.join(run, "**", "*"), recursive=True)
        filenames.sort()

        rgbs = []
        depths_raw = []
        points_3d = []
        rgb_intrinsics = []
        masks = []

        if split == "gt":
            s_name = run.split("/")[-3] + "+" + run.split("/")[-2]
            log_base = os.path.join(args.outdir, s_name, split, "unidepth")
        else:
            s_name = run.split("/")[-4]
            log_base = os.path.join(args.outdir, s_name, split, act_dir, "unidepth")

        depth_out_dir = os.path.join(log_base, "depth_frame")
        os.makedirs(depth_out_dir, exist_ok=True)

        for k, filename in enumerate(filenames):
            if "conds_20" in args.action and k < 19:
                continue

            rgb = np.array(Image.open(filename))

            mask, car_centers, car_bboxes = det_obj(rgb)
            # ROI mask example; adjust H/W if your frames differ
            h, w = rgb.shape[:2]
            mask_wh = drive_roi_mask(h, w)
            mask[mask_wh == 0] = 0
            masks.append(mask)

            rgbs.append(rgb.copy())
            rgb_torch = torch.from_numpy(rgb).permute(2, 0, 1)

            # UniDepth inference
            predictions = depth_model.infer(rgb_torch, None)
            depth = predictions["depth"].squeeze().cpu().numpy()
            depths_raw.append(depth.copy())

            intrinsics = predictions["intrinsics"].squeeze(0).cpu().numpy()
            rgb_intrinsics.append(intrinsics)

            points_3d.append(
                predictions["points"].squeeze().permute(1, 2, 0).cpu().numpy().copy()
            )

            depth_pred_col = colorize(depth, cmap="jet")
            cv2.imwrite(os.path.join(depth_out_dir, f"{k:05}.png"), depth_pred_col)

        # Write depth visualization video
        images_to_video(depth_out_dir, os.path.join(log_base, "depths.mp4"), fps=10)

        # VO pipeline
        dataset_handler = DatasetHandler(rgbs, depths_raw, rgb_intrinsics)
        check_data(dataset_handler, args.outdir)

        images = dataset_handler.images
        kp_list, des_list = extract_features_dataset(images, masks)
        matches, matches2 = match_features_dataset(des_list)

        # Optionally filter matches
        dist_threshold = 0.7
        filtered_matches = filter_matches_dataset(matches, dist_threshold, matches2)
        matches = filtered_matches

        depth_maps = dataset_handler.depth_maps
        try:
            trajectory, poses_3x3 = estimate_trajectory(
                matches, kp_list, dataset_handler.k, depth_maps=depth_maps, dataset_handler=dataset_handler
            )
        except Exception as e:
            print(f"VO failed for {run}: {e}")
            continue

        locs = []
        for i in range(trajectory.shape[1]):
            current_pos = trajectory[:, i]
            locs.append([current_pos.item(0), current_pos.item(2)])
        locs = np.array(locs, dtype=np.float32)

        ego_traj = {"locs": locs, "poses_3x3": poses_3x3}
        import pickle
        os.makedirs(os.path.dirname(log_base), exist_ok=True)
        with open(log_base + "-estimate_ego_traj_0619.pkl", "wb") as f:
            pickle.dump(ego_traj, f)


if __name__ == "__main__":
    main()
