"""
FoundationStereo High-Fidelity Depth Extractor

This script handles the rigorous translation of pristine, zero-distortion stereoscopic images (generated by `rectify_stereo.py`) 
into massive, continuous high-quality ground-truth disparity matrices using the SysCV FoundationStereo model.

Unlike monocular methods which guess structural scale, FoundationStereo implicitly calculates rigid metric disparity dynamically checking spatial feature correlations across both parallel optical planes natively. 

It handles:
1. Orchestrating Multi-GPU pools natively bypassing standard GIL limitations.
2. Eliminating infinite scale boundaries explicitly, isolating absolute sky/void regions cleanly limiting artifact propagation.
3. Clipping boundary noise statistically organically handling isolated matching failures dynamically.

Input layout (under --input_path):
    pair1/left/<left_rectified>.png
    pair1/right/<right_rectified>.png
    ...

Usage:
    python create_fs_outputs.py \
        --input_path /path/to/rectified_images \
        --output_path /path/to/foundation_out \
        --foundation_repo /path/to/FoundationStereo \
        --ckpt /path/to/checkpoint/model_best_bp2.pth \
        --scale 0.5 --valid_iters 12 --get_pc 0
"""

from __future__ import annotations

import argparse
import os
import re
import sys
from typing import Iterable, List, Optional, Tuple, Dict, Set
import numpy as np
import cv2

try:
    from imageio import v3 as iio  # modern imageio
except Exception:  # pragma: no cover
    import imageio as iio  # type: ignore
from concurrent.futures import ThreadPoolExecutor, as_completed

# Ensure local imports work even if launched from another directory
sys.path.insert(0, os.path.dirname(__file__))

from run_foundation_stereo import run_foundation_stereo  # type: ignore
from masking import (
    sky_mask_watershed,
    sky_mask_from_disparity_biggest_component,
    top_row_sky_mask_biggest_component,
    apply_mask_and_save,
)
try:
    from masking import sky_mask_sam2  # type: ignore
    _HAS_SAM2 = True
except Exception:
    _HAS_SAM2 = False


ALLOWED_EXTS = (".png", ".jpg", ".jpeg", ".bmp", ".tif", ".tiff")


try:
    from tqdm.auto import tqdm  # type: ignore
    _HAS_TQDM = True
except Exception:
    _HAS_TQDM = False


def _natural_pair_index(name: str) -> Optional[int]:
    m = re.fullmatch(r"pair(\d+)", name, flags=re.IGNORECASE)
    return int(m.group(1)) if m else None


def _iter_pair_dirs(
    root: str,
    *,
    start: Optional[int],
    end: Optional[int],
    include: Optional[Set[int]] = None,
) -> List[Tuple[int, str]]:
    pairs: List[Tuple[int, str]] = []
    if not os.path.isdir(root):
        raise NotADirectoryError(f"--input_path does not exist or is not a directory: {root}")
    for name in os.listdir(root):
        idx = _natural_pair_index(name)
        if idx is None:
            continue
        if include is not None:
            if idx not in include:
                continue
        else:
            if start is not None and idx < start:
                continue
            if end is not None and idx > end:
                continue
        pairs.append((idx, os.path.join(root, name)))
    pairs.sort(key=lambda t: t[0])
    return pairs


def _find_single_image(folder: str) -> str:
    if not os.path.isdir(folder):
        raise FileNotFoundError(f"Missing folder: {folder}")


    all_imgs: List[str] = []
    for fn in os.listdir(folder):
        fpath = os.path.join(folder, fn)
        if os.path.isfile(fpath) and os.path.splitext(fn)[1].lower() in ALLOWED_EXTS:
            all_imgs.append(fpath)

    if not all_imgs:
        raise FileNotFoundError(f"No images found in {folder}")

    def is_rectified(path: str) -> bool:
        base = os.path.splitext(os.path.basename(path))[0]
        return base.lower().endswith("_rect")

    rectified = [p for p in all_imgs if is_rectified(p)]
    if len(rectified) == 1:
        return rectified[0]
    if len(rectified) > 1:
        raise RuntimeError(
            f"Ambiguous: multiple '*_rect' images in {folder}:\n  " + "\n  ".join(os.path.basename(p) for p in rectified)
        )

    # Fallback: if no *_rect found, require exactly one image
    if len(all_imgs) == 1:
        return all_imgs[0]
    raise RuntimeError(
        f"Ambiguous: expected exactly one rectified image in {folder}, found {len(all_imgs)} images"
    )


def _ensure_dir(path: str) -> None:
    os.makedirs(path, exist_ok=True)


def main(argv: Optional[List[str]] = None) -> int:
    p = argparse.ArgumentParser(description="Batch-create FoundationStereo outputs for pairN/left/right inputs")

    p.add_argument("--input_path", required=True, help="Root folder containing pairN/left and pairN/right subdirs")
    p.add_argument("--output_path", required=True, help="Root output folder; per-pair outputs go to output/pairN/")

    # FoundationStereo config
    p.add_argument("--foundation_repo", required=True, help="Path to local FoundationStereo repo")
    p.add_argument("--ckpt", required=True, help="Checkpoint path or directory for FoundationStereo")
    p.add_argument("--scale", type=float, default=0.5, help="run_demo.py --scale")
    p.add_argument("--valid_iters", type=int, default=12, help="run_demo.py --valid_iters")
    p.add_argument("--get_pc", type=int, default=0, help="run_demo.py --get_pc")
    p.add_argument("--hiera", type=int, default=0, help="run_demo.py --hiera (hierarchical inference)")

    # Depth map options (optional)
    p.add_argument("--focal_px", type=float, default=None, help="Camera focal length in pixels (for depth computation)")
    p.add_argument("--baseline_m", type=float, default=None, help="Stereo baseline in meters (for depth computation)")
    p.add_argument(
        "--depth_scale",
        type=float,
        default=1000.0,
        help="Scale for saving PNG depth (e.g., 1000 -> millimeters). Only used if focal and baseline are given.",
    )
    p.add_argument(
        "--depth_color_pmin",
        type=float,
        default=0.0,
        help="Lower percentile for depth color PNG scaling (ignored values are NaN/<=0)",
    )
    p.add_argument(
        "--depth_color_pmax",
        type=float,
        default=100.0,
        help="Upper percentile for depth color PNG scaling (ignored values are NaN/<=0)",
    )
    p.add_argument(
        "--disp_percentile",
        type=float,
        default=None,
        help="If set (e.g., 0.1), eliminate lower/upper percentiles from disparity when computing depth (set outside to NaN).",
    )
    # Outlier masking (preferred over percentile clamping): detect very small disparities (far depths) as outliers per-image
    p.add_argument(
        "--disp_outlier_method",
        type=str,
        default="none",
        choices=["none", "logdepth_iqr", "lof", "logdepth_otsu", "depth_hist_valley", "depth_25m_global"],
        help=(
            "Outlier masking method for far depths (very small disparities). "
            "'logdepth_iqr' uses Tukey high fence on -log(disp). "
            "'lof' uses Local Outlier Factor on a downscaled grid over features [x,y,-log(disp)]. "
            "'logdepth_otsu' finds a valley split in the -log(disp) histogram via Otsu thresholding. "
            "'depth_hist_valley' finds the primary and a secondary peak in the depth histogram and thresholds at the valley between them. "
            "'depth_25m_global' simply masks all depths > 25m (recommended simple approach)."
        ),
    )
    p.add_argument(
        "--disp_outlier_k",
        type=float,
        default=1.5,
        help="Fence multiplier k for outlier method (e.g., 1.5 for standard Tukey fence, 3.0 for conservative).",
    )
    p.add_argument(
        "--disp_outlier_lof_neighbors",
        type=int,
        default=35,
        help="n_neighbors for LOF (effective on downscaled grid of valid pixels).",
    )
    p.add_argument(
        "--disp_outlier_lof_contamination",
        type=float,
        default=0.02,
        help="Expected fraction of outliers for LOF (0<contamination<0.5).",
    )
    p.add_argument(
        "--disp_outlier_lof_use_xy",
        type=int,
        default=1,
        help="Whether to include normalized (x,y) spatial coordinates in LOF features (1=yes, 0=no).",
    )
    p.add_argument(
        "--disp_outlier_lof_target_points",
        type=int,
        default=80000,
        help="Target number of pixels for LOF via downscaling (performance control).",
    )
    p.add_argument(
        "--disp_outlier_restrict_low",
        type=int,
        default=1,
        help="When 1, only keep outliers whose disparity is below the global median (focus on far-depth spikes).",
    )
    p.add_argument(
        "--depth_hist_plateau_margin_pct",
        type=float,
        default=3.0,
        help=(
            "For depth_hist_valley: margin percentage to shift the threshold to the right of the plateau start. "
            "Recommended 1-5 (%)."
        ),
    )
    # No depth clipping; PNGs are saved from raw depths scaled by --depth_scale.

    # (Removed) No intrinsic file will be written; depth uses provided focal_px/baseline_m uniformly.

    # Disparity PNG options
    p.add_argument(
        "--disp_scale",
        type=float,
        default=256.0,
        help="Scale to convert disparity to uint16 PNG (u16 = disp * disp_scale).",
    )

    # Range/behavior
    p.add_argument("--start_pair", type=int, default=None, help="Start index (inclusive), e.g., 1")
    p.add_argument("--end_pair", type=int, default=None, help="End index (inclusive), e.g., 168")
    p.add_argument(
        "--pairs",
        type=str,
        default=None,
        help=(
            "Comma-separated list of pair indices or ranges to process (overrides start/end). "
            "Examples: --pairs 1,5,10-12"
        ),
    )
    p.add_argument("--overwrite", action="store_true", help="Re-run even if output pair folder already exists")
    p.add_argument(
        "--gpus",
        type=str,
        default="",
        help=(
            "Comma-separated GPU ordinals to use for parallel runs (relative to current CUDA_VISIBLE_DEVICES). "
            "Example: --gpus 0,1,2. Leave empty to run serially."
        ),
    )

    # Masking options
    p.add_argument(
        "--mask_methods",
        type=str,
        default="watershed,disp_component",
        help=(
            "Comma-separated masking methods to run: "
            "watershed, disp_component, top_row_disp_component, top_row or sam2"
        ),
    )
    p.add_argument(
        "--max_workers",
        type=int,
        default=None,
        help=(
            "Maximum concurrent workers. Defaults to number of GPUs when --gpus is set, otherwise 1."
        ),
    )

    p.add_argument(
        "--sky_threshold",
        type=int,
        default=25,
        help="Intensity threshold for top-row sky masking (0-255). Default 25.",
    )

    args = p.parse_args(argv)

    input_root = os.path.abspath(args.input_path)
    output_root = os.path.abspath(args.output_path)
    _ensure_dir(output_root)

    def _parse_pairs(spec: Optional[str]) -> Optional[Set[int]]:
        if not spec:
            return None
        out: Set[int] = set()
        tokens = [t.strip() for t in spec.replace(" ", "").split(",") if t.strip()]
        for tok in tokens:
            if "-" in tok:
                try:
                    a, b = tok.split("-", 1)
                    ia, ib = int(a), int(b)
                    if ia > ib:
                        ia, ib = ib, ia
                    out.update(range(ia, ib + 1))
                except Exception:
                    raise SystemExit(f"Invalid range in --pairs: '{tok}'")
            else:
                try:
                    out.add(int(tok))
                except Exception:
                    raise SystemExit(f"Invalid index in --pairs: '{tok}'")
        return out

    include_pairs = _parse_pairs(args.pairs)

    pairs = _iter_pair_dirs(input_root, start=args.start_pair if include_pairs is None else None, end=args.end_pair if include_pairs is None else None, include=include_pairs)
    if not pairs:
        raise SystemExit("No pairN directories found under input_path.")

    if include_pairs is not None:
        print(f"Found {len(pairs)} selected pairs under {input_root} (from --pairs)")
    else:
        print(f"Found {len(pairs)} pairs under {input_root}")

    jobs: List[Tuple[int, str, str, str, str]] = []  # (idx, pair_dir, left_img, right_img, out_dir)
    post_only_jobs: List[Tuple[int, str, str, str, bool, bool]] = []  # (idx, out_dir, disp_path, left_img, need_disp, need_depth)
    successes = 0
    failures = 0

    for idx, pair_dir in pairs:
        pair_name = f"pair{idx}"
        left_dir = os.path.join(pair_dir, "left")
        right_dir = os.path.join(pair_dir, "right")
        try:
            left_img = _find_single_image(left_dir)
            right_img = _find_single_image(right_dir)
        except Exception as e:
            if _HAS_TQDM:
                tqdm.write(f"[SKIP] {pair_name}: {e}")
            else:
                print(f"[SKIP] {pair_name}: {e}")
            failures += 1
            continue

        out_dir = os.path.join(output_root, pair_name)
        if os.path.isdir(out_dir) and not args.overwrite:
            disp_candidate = os.path.join(out_dir, "disp.npy")
            if os.path.isfile(disp_candidate):
                # check missing disp artifacts
                need_disp = not (os.path.isfile(os.path.join(out_dir, "disp_u16.png")) and os.path.isfile(os.path.join(out_dir, "invalid_mask.png")))
                # check missing depth artifacts
                need_depth = (args.focal_px is not None and args.baseline_m is not None)
                if need_depth:
                    depth_npy = os.path.join(out_dir, "depth_m.npy")
                    depth_png_main = os.path.join(out_dir, "depth_mm.png" if abs((args.depth_scale or 1000.0) - 1000.0) < 1e-3 else "depth_scaled.png")
                    if os.path.isfile(depth_npy) and os.path.isfile(depth_png_main):
                        need_depth = False

                post_only_jobs.append((idx, out_dir, disp_candidate, left_img, need_disp, need_depth))
                if _HAS_TQDM:
                    tqdm.write(f"[TODO] {pair_name}: postprocess existing disp -> disp:{need_disp} depth:{need_depth} + mask")
                else:
                    print(f"[TODO] {pair_name}: postprocess existing disp -> disp:{need_disp} depth:{need_depth} + mask")
                continue

                if _HAS_TQDM:
                    tqdm.write(f"[SKIP] {pair_name}: output exists at {out_dir}")
                else:
                    print(f"[SKIP] {pair_name}: output exists at {out_dir}")
                successes += 1
                continue

        _ensure_dir(out_dir)
        jobs.append((idx, pair_dir, left_img, right_img, out_dir))

    if not jobs:
        print("No jobs to run after skipping; exiting.")
        return 0 if failures == 0 else 1

    # Parallelization config
    gpu_list: List[str] = []
    if args.gpus.strip():
        gpu_list = [g.strip() for g in args.gpus.split(",") if g.strip() != ""]
    max_workers = args.max_workers if args.max_workers is not None else (len(gpu_list) if gpu_list else 1)
    max_workers = max(1, min(max_workers, len(jobs)))

    total_tasks = len(jobs) + len(post_only_jobs)
    if _HAS_TQDM:
        pbar = tqdm(total=total_tasks, desc="Pairs", unit="task")
    else:
        pbar = None

    def _save_depth_histogram_valley_debug(
        centers: np.ndarray,
        counts: np.ndarray,
        *,
        p1_idx: int,
        p2_idx: Optional[int],
        valley_idx: int,
        out_dir: str,
        png_name: str = "depth_hist_valley_debug.png",
        csv_name: str = "depth_hist_valley_debug.csv",
        plateau_start_idx: Optional[int] = None,
        plateau_end_idx: Optional[int] = None,
    ) -> None:
        try:
            # Save CSV of (center,count)
            with open(os.path.join(out_dir, csv_name), "w") as f:
                f.write("depth_center_m,count\n")
                for c, n in zip(centers, counts):
                    f.write(f"{float(c):.6f},{int(n)}\n")


            Hplot, Wplot = 420, 1000
            img = np.full((Hplot, Wplot, 3), 255, dtype=np.uint8)
            left, right, top, bottom = 60, Wplot - 20, 30, Hplot - 50
            plot_w = right - left
            plot_h = bottom - top
            max_count = int(np.max(counts)) if counts.size > 0 else 1
            max_count = max(max_count, 1)

            cv2.line(img, (left, bottom), (right, bottom), (0, 0, 0), 1)
            cv2.line(img, (left, bottom), (left, top), (0, 0, 0), 1)

            num_bins = len(centers)
            step = max(1, num_bins // 10)
            for i in range(0, num_bins, step):
                x = int(left + (i / max(1, num_bins - 1)) * plot_w)
                cv2.line(img, (x, bottom), (x, bottom + 5), (0, 0, 0), 1)
                cv2.putText(img, f"{centers[i]:.1f}", (x - 10, bottom + 20), cv2.FONT_HERSHEY_SIMPLEX, 0.4, (0, 0, 0), 1, cv2.LINE_AA)

            for j in range(5):
                yv = int(j * max_count / 4)
                y = int(bottom - (yv / max_count) * plot_h)
                cv2.line(img, (left - 5, y), (left, y), (0, 0, 0), 1)
                cv2.putText(img, str(yv), (5, y + 4), cv2.FONT_HERSHEY_SIMPLEX, 0.4, (0, 0, 0), 1, cv2.LINE_AA)

            pts = []
            for i, n in enumerate(counts):
                x = int(left + (i / max(1, num_bins - 1)) * plot_w)
                y = int(bottom - (int(n) / max_count) * plot_h)
                pts.append([x, y])
            if len(pts) >= 2:
                cv2.polylines(img, [np.array(pts, dtype=np.int32)], False, (50, 100, 220), 2)


            def _x_for_idx(i: int) -> int:
                return int(left + (i / max(1, num_bins - 1)) * plot_w)
            def _y_for_count(n: float) -> int:
                return int(bottom - (float(n) / max_count) * plot_h)
            for idx, color, label in [
                (p1_idx, (0, 180, 0), "P1"),
                (p2_idx if p2_idx is not None else -1, (0, 0, 180), "P2"),
                (valley_idx, (0, 0, 0), "Threshold"),
            ]:
                if 0 <= idx < num_bins:
                    x = _x_for_idx(idx)
                    y = _y_for_count(counts[idx])
                    cv2.line(img, (x, bottom), (x, top), color, 1)
                    cv2.circle(img, (x, y), 4, color, -1)
                    cv2.putText(img, label, (x + 5, top + 15), cv2.FONT_HERSHEY_SIMPLEX, 0.5, color, 1, cv2.LINE_AA)


            if plateau_start_idx is not None and 0 <= plateau_start_idx < num_bins:
                x = _x_for_idx(plateau_start_idx)
                cv2.line(img, (x, bottom), (x, top), (0, 160, 0), 1)
                cv2.putText(img, "plateau_start", (x + 5, top + 35), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 160, 0), 1, cv2.LINE_AA)
            if plateau_end_idx is not None and 0 <= plateau_end_idx < num_bins:
                x = _x_for_idx(plateau_end_idx)
                cv2.line(img, (x, bottom), (x, top), (0, 0, 200), 1)
                cv2.putText(img, "plateau_end", (x + 5, top + 55), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 0, 200), 1, cv2.LINE_AA)

            cv2.putText(img, "Depth histogram (valley threshold)", (left, 22), cv2.FONT_HERSHEY_SIMPLEX, 0.6, (0, 0, 0), 2, cv2.LINE_AA)
            cv2.imwrite(os.path.join(out_dir, png_name), img)
        except Exception as e:
            print(f"[WARN] Failed to save depth histogram debug: {e}")

    def _save_disp_artifacts(disp_path: str, out_dir: str) -> Tuple[str, str]:
        """Save disparity 16-bit PNG and invalid mask from disp.npy.

        Returns (disp_png_path, invalid_mask_path).
        """
        try:
            disp = np.load(disp_path)
        except Exception as e:
            raise RuntimeError(f"Failed to load disparity at {disp_path}: {e}")

        disp = disp.astype(np.float32, copy=False)
        if disp.ndim == 3 and disp.shape[0] == 1:
            disp = disp[0]
        if disp.ndim == 3 and disp.shape[-1] == 1:
            disp = disp[..., 0]

        # invalid: non-finite or <= 0 or out-of-bounds reprojection (x - disp outside [0,W-1])
        H, W = disp.shape[:2]
        yy, xx = np.meshgrid(np.arange(H), np.arange(W), indexing="ij")
        us_right = xx - disp
        invalid = (~np.isfinite(disp)) | (disp <= 0) | (us_right < 0) | (us_right >= W)
        disp_to_save = disp.copy()
        disp_to_save[invalid] = 0.0
        scale = float(args.disp_scale) if args.disp_scale and args.disp_scale > 0 else 256.0
        disp_u16 = np.clip(disp_to_save * scale, 0, 65535).astype(np.uint16)

        disp_png = os.path.join(out_dir, "disp_u16.png")
        mask_png = os.path.join(out_dir, "invalid_mask.png")
        iio.imwrite(disp_png, disp_u16)
        iio.imwrite(mask_png, (invalid.astype(np.uint8) * 255))
        # Also save invalid mask as .npy for downstream use
        np.save(os.path.join(out_dir, "invalid_mask.npy"), invalid.astype(bool))
        return disp_png, mask_png

    def _save_depth_color_png(depth_m: np.ndarray, out_dir: str, filename: str) -> Optional[str]:
        try:
            dm = depth_m
            valid = np.isfinite(dm) & (dm > 0)
            if not np.any(valid):
                color = np.zeros((*dm.shape, 3), dtype=np.uint8)
            else:
                # Use full range without percentile elimination
                vmin = float(np.min(dm[valid]))
                vmax = float(np.max(dm[valid]))
                if not np.isfinite(vmax) or vmax <= vmin:
                    vmax = vmin + 1e-6
                norm = np.clip((np.nan_to_num(dm, nan=0.0) - vmin) / (vmax - vmin), 0.0, 1.0)
                u8 = (norm * 255).astype(np.uint8)
                color = cv2.applyColorMap(u8, cv2.COLORMAP_TURBO)
                color[~valid] = (0, 0, 0)
            out_path = os.path.join(out_dir, filename)
            cv2.imwrite(out_path, color)
            return out_path
        except Exception as e:
            print(f"[WARN] failed to save depth color PNG '{filename}': {e}")
            return None

    def _save_depth_distribution_plot(depth_m: np.ndarray, out_dir: str, filename_png: str, filename_csv: str, *, bin_width: float = 0.5, pmax: float = 100.0) -> None:
        try:
            dm = depth_m
            valid = np.isfinite(dm) & (dm > 0)
            Hplot, Wplot = 420, 1000
            img = np.full((Hplot, Wplot, 3), 255, dtype=np.uint8)
            cv2.rectangle(img, (0, 0), (Wplot-1, Hplot-1), (230, 230, 230), 1)
            if not np.any(valid):
                cv2.putText(img, "No valid depth values", (30, 220), cv2.FONT_HERSHEY_SIMPLEX, 0.8, (0, 0, 0), 2, cv2.LINE_AA)
                cv2.imwrite(os.path.join(out_dir, filename_png), img)
                with open(os.path.join(out_dir, filename_csv), "w") as f:
                    f.write("depth_center_m,count\n")
                return
            # Determine upper bound using percentile and create bins
            # Use full range without percentile elimination
            hi = float(np.max(dm[valid]))
            bw = float(bin_width) if bin_width > 0 else 0.5
            hi = np.ceil(hi / bw) * bw
            bins = np.arange(0.0, hi + bw*0.5, bw)
            counts, edges = np.histogram(dm[valid], bins=bins)
            centers = (edges[:-1] + edges[1:]) * 0.5
            # Save CSV
            with open(os.path.join(out_dir, filename_csv), "w") as f:
                f.write("depth_center_m,count\n")
                for c, n in zip(centers, counts):
                    f.write(f"{c:.6f},{int(n)}\n")
            # Plot as a line graph
            left, right, top, bottom = 60, Wplot-20, 30, Hplot-50
            plot_w = right - left
            plot_h = bottom - top
            max_count = int(np.max(counts)) if counts.size > 0 else 1
            max_count = max(max_count, 1)
            # Axes
            cv2.line(img, (left, bottom), (right, bottom), (0, 0, 0), 1)
            cv2.line(img, (left, bottom), (left, top), (0, 0, 0), 1)
            # X ticks (every ~10 bins)
            num_bins = len(centers)
            step = max(1, num_bins // 10)
            for i in range(0, num_bins, step):
                x = int(left + (i / max(1, num_bins - 1)) * plot_w)
                cv2.line(img, (x, bottom), (x, bottom + 5), (0, 0, 0), 1)
                cv2.putText(img, f"{centers[i]:.1f}", (x-10, bottom + 20), cv2.FONT_HERSHEY_SIMPLEX, 0.4, (0, 0, 0), 1, cv2.LINE_AA)
            # Y ticks (4 ticks)
            for j in range(5):
                yv = int(j * max_count / 4)
                y = int(bottom - (yv / max_count) * plot_h)
                cv2.line(img, (left-5, y), (left, y), (0, 0, 0), 1)
                cv2.putText(img, str(yv), (5, y+4), cv2.FONT_HERSHEY_SIMPLEX, 0.4, (0, 0, 0), 1, cv2.LINE_AA)
            # Polyline
            pts = []
            for i, n in enumerate(counts):
                x = int(left + (i / max(1, num_bins - 1)) * plot_w)
                y = int(bottom - (int(n) / max_count) * plot_h)
                pts.append([x, y])
            if len(pts) >= 2:
                cv2.polylines(img, [np.array(pts, dtype=np.int32)], False, (50, 100, 220), 2)

            cv2.putText(img, "Depth distribution (counts)", (left, 20), cv2.FONT_HERSHEY_SIMPLEX, 0.7, (0, 0, 0), 2, cv2.LINE_AA)
            cv2.imwrite(os.path.join(out_dir, filename_png), img)
        except Exception as e:
            print(f"[WARN] Failed to save depth distribution plot: {e}")

    def _save_depth_from_disp(disp_path: str, out_dir: str) -> Optional[Tuple[str, Optional[str]]]:
        """If focal_px and baseline_m provided, compute and save depth.

        Returns tuple of (depth_npy_path, depth_png_path or None) if saved, else None.
        """
        if args.focal_px is None or args.baseline_m is None:
            return None

        try:
            disp = np.load(disp_path)
        except Exception as e:
            raise RuntimeError(f"Failed to load disparity at {disp_path}: {e}")


        disp = disp.astype(np.float32, copy=False)

        if disp.ndim == 3 and disp.shape[0] == 1:
            disp = disp[0]
        if disp.ndim == 3 and disp.shape[-1] == 1:
            disp = disp[..., 0]

        # Build disparity used for depth with optional percentile cropping
        disp_for_depth = disp.copy()
        valid = np.isfinite(disp_for_depth) & (disp_for_depth > 0)
        if args.disp_percentile is not None and (args.disp_outlier_method == "none"):
            try:
                p = float(args.disp_percentile)
                if 0 < p < 50 and np.any(valid):
                    lo = np.percentile(disp_for_depth[valid], p)
                    hi = np.percentile(disp_for_depth[valid], 100 - p)

                    disp_for_depth = np.where(np.isfinite(disp_for_depth), np.clip(disp_for_depth, lo, hi), disp_for_depth)
                    valid = np.isfinite(disp_for_depth) & (disp_for_depth > 0)
            except Exception as e:
                print(f"[WARN] disp_percentile cropping failed: {e}")

        # If percentile clipping was applied, save the clipped disparity for analysis
        if args.disp_percentile is not None and (args.disp_outlier_method == "none"):
            try:
                np.save(os.path.join(out_dir, "disp_clipped.npy"), disp_for_depth)
                disp_vis = disp_for_depth.copy()
                inv = ~(np.isfinite(disp_vis) & (disp_vis > 0))
                disp_vis[inv] = 0.0
                dscale = float(args.disp_scale) if args.disp_scale and args.disp_scale > 0 else 256.0
                disp_u16 = np.clip(disp_vis * dscale, 0, 65535).astype(np.uint16)
                iio.imwrite(os.path.join(out_dir, "disp_u16_clipped.png"), disp_u16)
            except Exception as e:
                print(f"[WARN] failed to save clipped disparity artifacts: {e}")


        depth_m = np.full_like(disp_for_depth, np.nan, dtype=np.float32)
        depth_m[valid] = (args.focal_px * args.baseline_m) / disp_for_depth[valid]

        # Optional: detect low-disparity outliers (far depths); save mask for later combination
        if args.disp_outlier_method != "none":
            try:
                # Valid disparity pixels
                v = np.isfinite(disp) & (disp > 0)
                # Exclude invalid_mask if available
                try:
                    inv_path = os.path.join(out_dir, "invalid_mask.npy")
                    if os.path.isfile(inv_path):
                        invm = np.load(inv_path).astype(bool)
                        if invm.shape != disp.shape:
                            invm = cv2.resize(
                                invm.astype(np.uint8), (disp.shape[1], disp.shape[0]), interpolation=cv2.INTER_NEAREST
                            ).astype(bool)
                        v = v & (~invm)
                except Exception:
                    pass

                if np.any(v):
                    om = np.zeros_like(disp, dtype=bool)
                    method = args.disp_outlier_method
                    if method == "logdepth_iqr":

                        ld = -np.log(np.clip(disp[v], 1e-12, None))
                        q1 = float(np.percentile(ld, 25))
                        q3 = float(np.percentile(ld, 75))
                        iqr = max(q3 - q1, 1e-6)
                        fence = q3 + float(args.disp_outlier_k) * iqr
                        thr_disp = float(np.exp(-fence))  # ld > fence  <=> disp < exp(-fence)
                        om[v] = disp[v] < thr_disp
                        stat_lines = [
                            "method=logdepth_iqr\n",
                            f"k={args.disp_outlier_k}\n",
                            f"q1={q1}\nq3={q3}\n",
                            f"iqr={iqr}\n",
                            f"fence={fence}\n",
                            f"thr_disp={thr_disp}\n",
                        ]
                    elif method == "logdepth_otsu":

                        ld = -np.log(np.clip(disp[v], 1e-12, None)).astype(np.float32)

                        ld_lo = float(np.percentile(ld, 1))
                        ld_hi = float(np.percentile(ld, 99))
                        if not np.isfinite(ld_hi) or ld_hi <= ld_lo:
                            ld_hi = ld_lo + 1e-6
                        ld_u8 = np.clip((ld - ld_lo) / (ld_hi - ld_lo), 0.0, 1.0)
                        ld_u8 = (ld_u8 * 255.0 + 0.5).astype(np.uint8)
                        # cv2.threshold expects an image; Otsu returns threshold in [0,255]
                        _ret, th = cv2.threshold(ld_u8, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU)

                        ld_thr = ld_lo + (float(th) / 255.0) * (ld_hi - ld_lo)
                        thr_disp = float(np.exp(-ld_thr))
                        om[v] = disp[v] < thr_disp
                        stat_lines = [
                            "method=logdepth_otsu\n",
                            f"ld_lo={ld_lo}\nld_hi={ld_hi}\n",
                            f"otsu_th_u8={th}\n",
                            f"ld_thr={ld_thr}\n",
                            f"thr_disp={thr_disp}\n",
                        ]
                    elif method == "lof":
                        try:
                            from sklearn.neighbors import LocalOutlierFactor
                        except Exception as e:
                            print(f"[WARN] sklearn not available for LOF, falling back to logdepth_iqr: {e}")

                            ld = -np.log(np.clip(disp[v], 1e-12, None))
                            q1 = float(np.percentile(ld, 25))
                            q3 = float(np.percentile(ld, 75))
                            iqr = max(q3 - q1, 1e-6)
                            fence = q3 + float(args.disp_outlier_k) * iqr
                            thr_disp = float(np.exp(-fence))
                            om[v] = disp[v] < thr_disp
                            stat_lines = [
                                "method=logdepth_iqr(fallback)\n",
                                f"k={args.disp_outlier_k}\n",
                                f"q1={q1}\nq3={q3}\n",
                                f"iqr={iqr}\n",
                                f"fence={fence}\n",
                                f"thr_disp={thr_disp}\n",
                            ]
                        else:
                            H, W = disp.shape
                            target = max(10000, int(args.disp_outlier_lof_target_points))
                            total = H * W
                            scale = (total / float(target)) ** 0.5 if total > target else 1.0
                            dsH = max(16, int(round(H / scale)))
                            dsW = max(16, int(round(W / scale)))
                            disp_ds = cv2.resize(disp, (dsW, dsH), interpolation=cv2.INTER_AREA)
                            # Build invalid_ds similarly to exclude
                            invm_local = np.zeros_like(disp, dtype=bool) if 'invm' not in locals() else invm
                            inv_ds = cv2.resize(invm_local.astype(np.uint8), (dsW, dsH), interpolation=cv2.INTER_NEAREST).astype(bool)
                            v_ds = np.isfinite(disp_ds) & (disp_ds > 0) & (~inv_ds)
                            # Features: [x_norm, y_norm, ld_norm]
                            ys, xs = np.where(v_ds)
                            if ys.size > 1:
                                x_norm = xs.astype(np.float32) / float(dsW)
                                y_norm = ys.astype(np.float32) / float(dsH)
                                ld = -np.log(np.clip(disp_ds[v_ds], 1e-12, None)).astype(np.float32)
                                # robust scale ld to [0,1]
                                ld_q1 = float(np.percentile(ld, 1))
                                ld_q99 = float(np.percentile(ld, 99))
                                denom = max(ld_q99 - ld_q1, 1e-6)
                                ld_n = np.clip((ld - ld_q1) / denom, 0.0, 1.0)
                                use_xy = int(args.disp_outlier_lof_use_xy) != 0
                                if use_xy:
                                    feats = np.stack([x_norm, y_norm, ld_n], axis=1)
                                else:
                                    feats = ld_n.reshape(-1, 1)
                                nn = int(args.disp_outlier_lof_neighbors)
                                nn = max(5, min(nn, feats.shape[0] - 1))
                                contamination = float(args.disp_outlier_lof_contamination)
                                contamination = min(max(contamination, 1e-4), 0.49)
                                lof = LocalOutlierFactor(n_neighbors=nn, contamination=contamination)
                                y_pred = lof.fit_predict(feats)  # -1 outlier, 1 inlier
                                om_ds = np.zeros((dsH, dsW), dtype=bool)
                                om_ds[ys, xs] = (y_pred == -1)
                                # Upsample to original size
                                om_up = cv2.resize(om_ds.astype(np.uint8), (W, H), interpolation=cv2.INTER_NEAREST).astype(bool)
                                om = om_up
                                stat_lines = [
                                    "method=lof\n",
                                    f"neighbors={nn}\n",
                                    f"contamination={contamination}\n",
                                    f"downscaled=({dsH},{dsW}) from ({H},{W}) target={target}\n",
                                    f"use_xy={use_xy}\n",
                                ]
                            else:
                                stat_lines = ["method=lof\n", "status=insufficient_valid_points_after_downscale\n"]
                    elif method == "depth_hist_valley":
                        # Build depth map (meters) for valid pixels only
                        depth_tmp = np.full_like(disp, np.nan, dtype=np.float32)
                        depth_tmp[v] = (args.focal_px * args.baseline_m) / disp[v]
                        dv = depth_tmp[v]
                        dv = dv[np.isfinite(dv) & (dv > 0)]
                        if dv.size < 10:
                            stat_lines = [
                                "method=depth_hist_valley\n",
                                "status=insufficient_valid_depth_values\n",
                            ]
                            om = np.zeros_like(disp, dtype=bool)
                        else:
                            # Robust histogram range to avoid extreme tails
                            hi = float(np.percentile(dv, 99.5)) if dv.size > 200 else float(np.max(dv))
                            lo = float(np.percentile(dv, 0.5)) if dv.size > 200 else float(np.min(dv))
                            if not np.isfinite(hi) or hi <= 0:
                                hi = float(np.max(dv))
                            if not np.isfinite(lo) or lo < 0:
                                lo = 0.0
                            if hi <= lo:
                                hi = lo + 1e-6
                            # Choose number of bins to provide ~0.5m resolution while keeping at least 32 bins
                            bw = 0.5
                            nbins = max(32, int(np.ceil((hi - lo) / bw)))
                            counts, edges = np.histogram(dv, bins=nbins, range=(lo, hi))
                            centers = (edges[:-1] + edges[1:]) * 0.5
                            # Smooth counts with a small kernel to stabilize extrema
                            counts_f = counts.astype(np.float32)
                            if counts_f.size >= 5:
                                ker = np.array([1, 2, 3, 2, 1], dtype=np.float32)
                                ker /= ker.sum()
                                counts_s = np.convolve(counts_f, ker, mode="same")
                            else:
                                counts_s = counts_f
                            # Primary peak: global maximum
                            p1_idx = int(np.argmax(counts_s)) if counts_s.size > 0 else 0
                            p1_depth = float(centers[p1_idx]) if centers.size > 0 else float(np.median(dv))

                            # First try: find a long near-zero plateau anywhere to the right of the primary peak
                            max_c = float(np.max(counts_s)) if counts_s.size > 0 else 1.0
                            zero_rel = 1e-3  # 0.1% of the primary height considered near-zero
                            zero_thr = max(1.0, max_c * zero_rel)
                            min_len_bins = 5  # require at least ~2.5m if bw=0.5
                            best_len = 0
                            best_start = None
                            cur_len = 0
                            cur_start = None
                            # Scan from just after the primary peak
                            for i in range(p1_idx + 1, counts_s.size):
                                if counts_s[i] <= zero_thr:
                                    if cur_start is None:
                                        cur_start = i
                                        cur_len = 1
                                    else:
                                        cur_len += 1
                                else:
                                    if cur_start is not None and cur_len > best_len:
                                        best_len = cur_len
                                        best_start = cur_start
                                    cur_start = None
                                    cur_len = 0
                            # tail check
                            if cur_start is not None and cur_len > best_len:
                                best_len = cur_len
                                best_start = cur_start

                            used_plateau = False
                            if best_start is not None and best_len >= min_len_bins:
                                # Threshold at the left boundary of the plateau
                                thr_idx = int(best_start)
                                # Apply margin: shift threshold slightly to the right by a percentage of depth
                                margin_pct = float(getattr(args, "depth_hist_plateau_margin_pct", 3.0))
                                margin_pct = max(0.0, min(20.0, margin_pct))  # clamp to sensible range
                                base_depth = float(centers[thr_idx])
                                thr_depth = base_depth * (1.0 + margin_pct / 100.0)
                                om = np.zeros_like(disp, dtype=bool)
                                valid_depth = np.isfinite(depth_tmp) & (depth_tmp > 0)
                                om[valid_depth] = depth_tmp[valid_depth] >= max(thr_depth, 0.0)
                                # For debug/annotation, define end index
                                plateau_end = int(best_start + best_len - 1)
                                valley_idx = thr_idx
                                p2_idx = None
                                try:
                                    _save_depth_histogram_valley_debug(
                                        centers=centers,
                                        counts=counts,
                                        p1_idx=p1_idx,
                                        p2_idx=p2_idx,
                                        valley_idx=valley_idx,
                                        out_dir=out_dir,
                                        png_name="depth_hist_valley_debug.png",
                                        csv_name="depth_hist_valley_debug.csv",
                                        plateau_start_idx=thr_idx,
                                        plateau_end_idx=plateau_end,
                                    )
                                except Exception as _e_dbg:
                                    print(f"[WARN] Failed to save depth_hist_valley debug plot: {_e_dbg}")
                                stat_lines = [
                                    "method=depth_hist_valley\n",
                                    "mode=plateau\n",
                                    f"primary_peak_depth={p1_depth}\n",
                                    f"plateau_start_depth={base_depth}\n",
                                    f"plateau_end_depth={float(centers[min(plateau_end, centers.size-1)])}\n",
                                    f"thr_depth={thr_depth}\n",
                                    f"margin_pct={margin_pct}\n",
                                    f"hist_bins={nbins}\n",
                                    f"range_lo={lo}\n",
                                    f"range_hi={hi}\n",
                                ]
                                used_plateau = True

                            if not used_plateau:
                                # As requested: if no plateau is found, do NOT apply threshold masking to avoid information loss
                                stat_lines = [
                                    "method=depth_hist_valley\n",
                                    "mode=no-plateau\n",
                                    "status=skip_thresholding\n",
                                    f"primary_peak_depth={p1_depth}\n",
                                ]
                                om = np.zeros_like(disp, dtype=bool)
                    elif method == "depth_25m_global":
                        # Simple global 25m threshold: mask all depths > 25m
                        depth_tmp = np.full_like(disp, np.nan, dtype=np.float32)
                        depth_tmp[v] = (args.focal_px * args.baseline_m) / disp[v]
                        om = np.zeros_like(disp, dtype=bool)
                        valid_depth = np.isfinite(depth_tmp) & (depth_tmp > 0)
                        om[valid_depth] = depth_tmp[valid_depth] > 25.0
                        stat_lines = [
                            "method=depth_25m_global\n",
                            "threshold_m=25.0\n",
                            "description=mask_all_depths_greater_than_25m\n",
                        ]
                    else:
                        stat_lines = ["method=unknown\n"]

                    # Optional: restrict to low disparities (below global median) to focus on far-depth spikes
                    # Skip this additional restriction for depth-based methods (depth_hist_valley, depth_25m_global)
                    if method not in ("depth_hist_valley", "depth_25m_global") and int(args.disp_outlier_restrict_low) != 0 and np.any(v):
                        try:
                            dmed = float(np.median(disp[v]))
                            om = om & (disp <= dmed)
                            # record in stats
                            stat_lines.append(f"restrict_low=true dmed={dmed}\n")
                        except Exception:
                            stat_lines.append("restrict_low=failed\n")
                    else:
                        stat_lines.append("restrict_low=false\n")

                    outlier_mask = om
                    # Save mask artifacts (generic)
                    np.save(os.path.join(out_dir, "outlier_mask_lowdisp.npy"), outlier_mask)
                    iio.imwrite(os.path.join(out_dir, "outlier_mask_lowdisp.png"), (outlier_mask.astype(np.uint8) * 255))
                    # Save stats (generic)
                    with open(os.path.join(out_dir, "outlier_mask_stats.txt"), "w") as f:
                        total = int(v.sum())
                        masked = int(outlier_mask.sum())
                        for line in stat_lines:
                            f.write(line)
                        f.write(f"valid={total}\nmasked={masked}\nmasked_pct={100.0*masked/max(1,total):.3f}%\n")
                    try:
                        print(f"[INFO] Saved outlier stats: valid={total}, masked={masked}, masked_pct={100.0*masked/max(1,total):.3f}% -> {os.path.join(out_dir, 'outlier_mask_stats.txt')}")
                    except Exception:
                        pass

                    # Also save Otsu-specific artifacts if applicable
                    if method == "logdepth_otsu":
                        np.save(os.path.join(out_dir, "outlier_mask_lowdisp_otsu.npy"), outlier_mask)
                        iio.imwrite(
                            os.path.join(out_dir, "outlier_mask_lowdisp_otsu.png"), (outlier_mask.astype(np.uint8) * 255)
                        )
                        with open(os.path.join(out_dir, "outlier_mask_stats_otsu.txt"), "w") as f2:
                            total = int(v.sum())
                            masked = int(outlier_mask.sum())
                            for line in stat_lines:
                                f2.write(line)
                            f2.write(
                                f"valid={total}\nmasked={masked}\nmasked_pct={100.0*masked/max(1,total):.3f}%\n"
                            )
                        try:
                            print(f"[INFO] Saved Otsu outlier stats to {os.path.join(out_dir, 'outlier_mask_stats_otsu.txt')}")
                        except Exception:
                            pass
                    elif method == "depth_hist_valley":
                        # Save method-specific mask and stats for inspection
                        np.save(os.path.join(out_dir, "outlier_mask_depth_hist_valley.npy"), outlier_mask)
                        iio.imwrite(
                            os.path.join(out_dir, "outlier_mask_depth_hist_valley.png"), (outlier_mask.astype(np.uint8) * 255)
                        )
                        with open(os.path.join(out_dir, "outlier_mask_stats_depth_hist_valley.txt"), "w") as f3:
                            total = int(v.sum())
                            masked = int(outlier_mask.sum())
                            for line in stat_lines:
                                f3.write(line)
                            f3.write(
                                f"valid={total}\nmasked={masked}\nmasked_pct={100.0*masked/max(1,total):.3f}%\n"
                            )
                        try:
                            print(
                                f"[INFO] Saved depth_hist_valley outlier stats to {os.path.join(out_dir, 'outlier_mask_stats_depth_hist_valley.txt')}"
                            )
                        except Exception:
                            pass
                    elif method == "depth_25m_global":
                        # Save method-specific mask and stats for inspection
                        np.save(os.path.join(out_dir, "outlier_mask_depth_25m_global.npy"), outlier_mask)
                        iio.imwrite(
                            os.path.join(out_dir, "outlier_mask_depth_25m_global.png"), (outlier_mask.astype(np.uint8) * 255)
                        )
                        with open(os.path.join(out_dir, "outlier_mask_stats_depth_25m_global.txt"), "w") as f4:
                            total = int(v.sum())
                            masked = int(outlier_mask.sum())
                            for line in stat_lines:
                                f4.write(line)
                            f4.write(
                                f"valid={total}\nmasked={masked}\nmasked_pct={100.0*masked/max(1,total):.3f}%\n"
                            )
                        try:
                            print(
                                f"[INFO] Saved depth_25m_global stats to {os.path.join(out_dir, 'outlier_mask_stats_depth_25m_global.txt')}"
                            )
                        except Exception:
                            pass
                else:
                    print("[INFO] Outlier masking skipped: no valid disparities")
            except Exception as e:
                print(f"[WARN] Outlier masking failed: {e}")

        # Save depth in meters as npy (optionally with outliers masked)
        depth_npy_path = os.path.join(out_dir, "depth_m.npy")
        np.save(depth_npy_path, depth_m)

        # Create and save 25m depth threshold mask (True where depth > 25m)
        depth_25m_mask = np.zeros_like(depth_m, dtype=bool)
        valid_depth_for_mask = np.isfinite(depth_m) & (depth_m > 0)
        depth_25m_mask[valid_depth_for_mask] = depth_m[valid_depth_for_mask] > 25.0
        np.save(os.path.join(out_dir, "depth_25m_mask.npy"), depth_25m_mask)
        iio.imwrite(os.path.join(out_dir, "depth_25m_mask.png"), (depth_25m_mask.astype(np.uint8) * 255))
        print(f"[INFO] Saved 25m depth threshold mask: {int(depth_25m_mask.sum())} pixels masked (> 25m)")

        # Save a 16-bit PNG (scaled)
        png_path: Optional[str] = None
        if args.depth_scale is not None and args.depth_scale > 0:
            dm = depth_m.copy()
            # No depth clipping
            dm_scaled = dm * float(args.depth_scale)
            dm_scaled = np.where(np.isfinite(dm_scaled) & (dm_scaled > 0), dm_scaled, 0.0)
            dm_scaled = np.clip(dm_scaled, 0, 65535).astype(np.uint16)
            png_path = os.path.join(out_dir, "depth_mm.png" if abs(args.depth_scale - 1000.0) < 1e-3 else "depth_scaled.png")
            iio.imwrite(png_path, dm_scaled)

        # Save colorized depth PNG (unmasked)
        _ = _save_depth_color_png(depth_m, out_dir, "depth_color.png")

        return depth_npy_path, png_path

    def _postprocess_disp_and_depth(disp_path: str, out_dir: str, need_disp: bool, need_depth: bool) -> str:
        msgs = []
        if need_disp:
            d_png, m_png = _save_disp_artifacts(disp_path, out_dir)
            msgs.append(f"disp_png={os.path.basename(d_png)}, mask={os.path.basename(m_png)}")
        if need_depth:
            info = _save_depth_from_disp(disp_path, out_dir)
            if info is not None:
                npy, png = info
                msg = f"depth_npy={os.path.basename(npy)}"
                if png:
                    msg += f", depth_png={os.path.basename(png)}"
                msgs.append(msg)
        return "; ".join(msgs) if msgs else "no-op"

    def _apply_masks(left_img_path: str, disp_path: str, out_dir: str, methods_csv: str) -> Dict[str, Dict[str, str]]:
        # Load disparity to get target size
        disp = np.load(disp_path)
        disp = disp.astype(np.float32, copy=False)
        if disp.ndim == 3 and disp.shape[0] == 1:
            disp = disp[0]
        if disp.ndim == 3 and disp.shape[-1] == 1:
            disp = disp[..., 0]
        H, W = disp.shape[:2]

        # Load left image once and resize
        img = cv2.imread(left_img_path, cv2.IMREAD_COLOR)
        if img is None:
            raise FileNotFoundError(f"Could not read left image: {left_img_path}")
        img_rs = cv2.resize(img, (W, H), interpolation=cv2.INTER_AREA)

        methods = [m.strip().lower() for m in (methods_csv or "").split(",") if m.strip()]
        results: Dict[str, Dict[str, str]] = {}

        # Load invalid mask (if exists) to always include in final masking
        invalid_mask_path_npy = os.path.join(out_dir, "invalid_mask.npy")
        invalid_mask: Optional[np.ndarray]
        if os.path.isfile(invalid_mask_path_npy):
            invalid_mask = np.load(invalid_mask_path_npy).astype(bool)
            if invalid_mask.shape != (H, W):
                invalid_mask = cv2.resize(invalid_mask.astype(np.uint8), (W, H), interpolation=cv2.INTER_NEAREST).astype(bool)
        else:
            # If missing, recompute a minimal invalid mask from disp (non-finite or <=0)
            invalid_mask = ~np.isfinite(disp) | (disp <= 0)

        # Load outlier mask (if created during depth computation by methods like logdepth_otsu, depth_hist_valley, etc.)
        outlier_mask_path = os.path.join(out_dir, "outlier_mask_lowdisp.npy")
        outlier_mask: Optional[np.ndarray] = None
        if os.path.isfile(outlier_mask_path):
            try:
                outlier_mask = np.load(outlier_mask_path).astype(bool)
                if outlier_mask.shape != (H, W):
                    outlier_mask = cv2.resize(outlier_mask.astype(np.uint8), (W, H), interpolation=cv2.INTER_NEAREST).astype(bool)
                print(f"[INFO] Loaded outlier mask: {int(outlier_mask.sum())} pixels masked")
            except Exception as e:
                print(f"[WARN] Failed to load outlier mask: {e}")

        # Load 25m depth threshold mask (created during depth computation if using depth_25m_global or always-on)
        depth_25m_mask_path = os.path.join(out_dir, "depth_25m_mask.npy")
        depth_25m_mask: Optional[np.ndarray] = None
        if os.path.isfile(depth_25m_mask_path):
            try:
                depth_25m_mask = np.load(depth_25m_mask_path).astype(bool)
                if depth_25m_mask.shape != (H, W):
                    depth_25m_mask = cv2.resize(depth_25m_mask.astype(np.uint8), (W, H), interpolation=cv2.INTER_NEAREST).astype(bool)
                print(f"[INFO] Loaded 25m depth mask: {int(depth_25m_mask.sum())} pixels (> 25m)")
            except Exception as e:
                print(f"[WARN] Failed to load 25m depth mask: {e}")

        def _ensure_bool_mask(m: np.ndarray) -> np.ndarray:
            m = m.astype(bool) if m.dtype != bool else m
            if m.shape != (H, W):
                m = cv2.resize(m.astype(np.uint8), (W, H), interpolation=cv2.INTER_NEAREST).astype(bool)
            return m

        def _apply_total_for_method(method_key: str, sky_mask_bool: np.ndarray):
            # Save sky-only
            np.save(os.path.join(out_dir, f"{method_key}_sky_mask.npy"), sky_mask_bool.astype(bool))
            iio.imwrite(os.path.join(out_dir, f"{method_key}_sky_mask.png"), (sky_mask_bool.astype(np.uint8) * 255))
            # Build TOTAL mask: invalid OR sky OR outliers (from methods) OR 25m depth threshold (always-on or from depth_25m_global)
            m_total = sky_mask_bool | invalid_mask
            if outlier_mask is not None:
                m_total = m_total | outlier_mask
            if depth_25m_mask is not None:
                m_total = m_total | depth_25m_mask
            np.save(os.path.join(out_dir, f"{method_key}_total_mask.npy"), m_total.astype(bool))
            iio.imwrite(os.path.join(out_dir, f"{method_key}_total_mask.png"), (m_total.astype(np.uint8) * 255))
            # Apply to disparity -> save masked .npy and colored png
            disp_masked = disp.copy()
            disp_masked[m_total] = np.nan
            np.save(os.path.join(out_dir, f"disp_masked_{method_key}_total.npy"), disp_masked)
            # Colored disparity visualization with 1-99 percentiles
            try:
                d_ok = np.isfinite(disp_masked) & (disp_masked > 0)
                if np.any(d_ok):
                    vmin = float(np.percentile(disp_masked[d_ok], 1))
                    vmax = float(np.percentile(disp_masked[d_ok], 99))
                    vmax = max(vmax, vmin + 1e-6)
                    norm = np.clip((np.nan_to_num(disp_masked, nan=0.0) - vmin) / (vmax - vmin), 0.0, 1.0)
                    u8 = (norm * 255).astype(np.uint8)
                    color = cv2.applyColorMap(u8, cv2.COLORMAP_TURBO)
                    color[~d_ok] = (0, 0, 0)
                else:
                    color = np.zeros((*disp_masked.shape, 3), dtype=np.uint8)
                cv2.imwrite(os.path.join(out_dir, f"disp_masked_{method_key}_total_color.png"), color)
            except Exception as e:
                print(f"[WARN] Failed to save masked disparity color visualization for {method_key}: {e}")
            # Compute depth from masked disparity if parameters provided
            if args.focal_px is not None and args.baseline_m is not None:
                try:
                    depth_from_masked = np.full_like(disp_masked, np.nan, dtype=np.float32)
                    valid_md = np.isfinite(disp_masked) & (disp_masked > 0)
                    depth_from_masked[valid_md] = (args.focal_px * args.baseline_m) / disp_masked[valid_md]
                    np.save(os.path.join(out_dir, f"depth_m_masked_{method_key}_total.npy"), depth_from_masked)
                    # Grayscale u16 (mm) visualization
                    dvis_scaled = depth_from_masked * float(args.depth_scale if args.depth_scale else 1000.0)
                    dvis_scaled = np.where(np.isfinite(dvis_scaled) & (dvis_scaled > 0), dvis_scaled, 0.0)
                    dvis_u16 = np.clip(dvis_scaled, 0, 65535).astype(np.uint16)
                    cv2.imwrite(os.path.join(out_dir, f"depth_mm_masked_grayscale_{method_key}_total.png"), dvis_u16)
                    # Colored
                    _ = _save_depth_color_png(depth_from_masked, out_dir, f"depth_mm_masked_colored_{method_key}_total.png")
                except Exception as e:
                    print(f"[WARN] Failed to compute/save depth from masked disparity for {method_key}: {e}")
            else:
                print(f"[INFO] Skipped depth-from-masked for {method_key}: focal_px/baseline_m not provided")

        # Method: watershed
        if "watershed" in methods:
            try:
                sky = sky_mask_watershed(img_rs)
                sky = _ensure_bool_mask(sky)
                # Save distribution WITHOUT outlier masking: use invalid OR sky only
                m_wo_outlier = (sky | invalid_mask)
                try:
                    depth_for_plot = np.full((H, W), np.nan, dtype=np.float32)
                    if args.focal_px is not None and args.baseline_m is not None:
                        valid_md = np.isfinite(disp) & (disp > 0) & (~m_wo_outlier)
                        depth_for_plot[valid_md] = (args.focal_px * args.baseline_m) / disp[valid_md]
                    _save_depth_distribution_plot(depth_for_plot, out_dir, "depth_distribution_watershed.png", "depth_distribution_watershed.csv")
                    # Save distribution AFTER outlier elimination, if used
                    if outlier_mask is not None and args.disp_outlier_method in ("logdepth_otsu", "depth_hist_valley", "depth_25m_global"):
                        m_with_outlier = m_wo_outlier | outlier_mask
                        depth_for_plot2 = np.full((H, W), np.nan, dtype=np.float32)
                        if args.focal_px is not None and args.baseline_m is not None:
                            valid_md2 = np.isfinite(disp) & (disp > 0) & (~m_with_outlier)
                            depth_for_plot2[valid_md2] = (args.focal_px * args.baseline_m) / disp[valid_md2]
                        suffix = {
                            "logdepth_otsu": "after_otsu",
                            "depth_hist_valley": "after_depth_hist_valley",
                            "depth_25m_global": "after_25m_global"
                        }.get(args.disp_outlier_method, "after_outlier")
                        _save_depth_distribution_plot(depth_for_plot2, out_dir, f"depth_distribution_watershed_{suffix}.png", f"depth_distribution_watershed_{suffix}.csv", pmax=100.0)
                except Exception as e:
                    print(f"[WARN] Failed to save watershed depth distribution: {e}")
                _apply_total_for_method("watershed", sky)
                results["watershed_total"] = {
                    "mask_total": os.path.join(out_dir, "watershed_total_mask.png"),
                }
            except Exception as e:
                print(f"[WARN] Watershed masking failed: {e}")

        # Method: disparity biggest component (low disparity/invalid with top connection)
        if "disp_component" in methods:
            try:
                sky = sky_mask_from_disparity_biggest_component(disp, percentile=10.0, restrict_to_top=True)
                sky = _ensure_bool_mask(sky)
                # Distribution without outlier
                m_wo_outlier = (sky | invalid_mask)
                try:
                    depth_for_plot = np.full((H, W), np.nan, dtype=np.float32)
                    if args.focal_px is not None and args.baseline_m is not None:
                        valid_md = np.isfinite(disp) & (disp > 0) & (~m_wo_outlier)
                        depth_for_plot[valid_md] = (args.focal_px * args.baseline_m) / disp[valid_md]
                    _save_depth_distribution_plot(depth_for_plot, out_dir, "depth_distribution_dispcomp.png", "depth_distribution_dispcomp.csv")
                    # Save distribution AFTER outlier elimination, if used
                    if outlier_mask is not None and args.disp_outlier_method in ("logdepth_otsu", "depth_hist_valley", "depth_25m_global"):
                        m_with_outlier = m_wo_outlier | outlier_mask
                        depth_for_plot2 = np.full((H, W), np.nan, dtype=np.float32)
                        if args.focal_px is not None and args.baseline_m is not None:
                            valid_md2 = np.isfinite(disp) & (disp > 0) & (~m_with_outlier)
                            depth_for_plot2[valid_md2] = (args.focal_px * args.baseline_m) / disp[valid_md2]
                        suffix = {
                            "logdepth_otsu": "after_otsu",
                            "depth_hist_valley": "after_depth_hist_valley",
                            "depth_25m_global": "after_25m_global"
                        }.get(args.disp_outlier_method, "after_outlier")
                        _save_depth_distribution_plot(depth_for_plot2, out_dir, f"depth_distribution_dispcomp_{suffix}.png", f"depth_distribution_dispcomp_{suffix}.csv", pmax=100.0)
                except Exception as e:
                    print(f"[WARN] Failed to save dispcomp depth distribution: {e}")
                _apply_total_for_method("dispcomp", sky)
                results["disp_component_total"] = {
                    "mask_total": os.path.join(out_dir, "dispcomp_total_mask.png"),
                }
            except Exception as e:
                print(f"[WARN] Disp-component masking failed: {e}")

    # Method: SAM2 (optional)
        if "sam2" in methods:
            if not _HAS_SAM2:
                print("[INFO] SAM2 not available; skipping sam2 masking.")
            else:
                try:
                    sky = sky_mask_sam2(img_rs, disp)
                    sky = _ensure_bool_mask(sky)
                    # Distribution without outlier
                    m_wo_outlier = (sky | invalid_mask)
                    try:
                        depth_for_plot = np.full((H, W), np.nan, dtype=np.float32)
                        if args.focal_px is not None and args.baseline_m is not None:
                            valid_md = np.isfinite(disp) & (disp > 0) & (~m_wo_outlier)
                            depth_for_plot[valid_md] = (args.focal_px * args.baseline_m) / disp[valid_md]
                        _save_depth_distribution_plot(depth_for_plot, out_dir, "depth_distribution_sam2.png", "depth_distribution_sam2.csv")
                        # Save distribution AFTER outlier elimination, if used
                        if outlier_mask is not None and args.disp_outlier_method in ("logdepth_otsu", "depth_hist_valley", "depth_25m_global"):
                            m_with_outlier = m_wo_outlier | outlier_mask
                            depth_for_plot2 = np.full((H, W), np.nan, dtype=np.float32)
                            if args.focal_px is not None and args.baseline_m is not None:
                                valid_md2 = np.isfinite(disp) & (disp > 0) & (~m_with_outlier)
                                depth_for_plot2[valid_md2] = (args.focal_px * args.baseline_m) / disp[valid_md2]
                            suffix = {
                                "logdepth_otsu": "after_otsu",
                                "depth_hist_valley": "after_depth_hist_valley",
                                "depth_25m_global": "after_25m_global"
                            }.get(args.disp_outlier_method, "after_outlier")
                            _save_depth_distribution_plot(depth_for_plot2, out_dir, f"depth_distribution_sam2_{suffix}.png", f"depth_distribution_sam2_{suffix}.csv", pmax=100.0)
                    except Exception as e:
                        print(f"[WARN] Failed to save sam2 depth distribution: {e}")
                    _apply_total_for_method("sam2", sky)
                    results["sam2_total"] = {
                        "mask_total": os.path.join(out_dir, "sam2_total_mask.png"),
                    }
                except Exception as e:
                    print(f"[WARN] SAM2 masking failed: {e}")

        # Method: top-row black-region components from left image (rectified)
        if "top_row_disp_component" in methods or "top_row" in methods or "top_row_sky" in methods or "top_row_sky_mask_biggest_component" in methods:
            try:
                sky = top_row_sky_mask_biggest_component(img_rs, threshold=args.sky_threshold, top_rows=2, morph_open=1)
                sky = _ensure_bool_mask(sky)
                # Distribution without outlier
                m_wo_outlier = (sky | invalid_mask)
                try:
                    depth_for_plot = np.full((H, W), np.nan, dtype=np.float32)
                    if args.focal_px is not None and args.baseline_m is not None:
                        valid_md = np.isfinite(disp) & (disp > 0) & (~m_wo_outlier)
                        depth_for_plot[valid_md] = (args.focal_px * args.baseline_m) / disp[valid_md]
                    _save_depth_distribution_plot(depth_for_plot, out_dir, "depth_distribution_toprow.png", "depth_distribution_toprow.csv")
                    # Save distribution AFTER outlier elimination, if used
                    if outlier_mask is not None and args.disp_outlier_method in ("logdepth_otsu", "depth_hist_valley", "depth_25m_global"):
                        m_with_outlier = m_wo_outlier | outlier_mask
                        depth_for_plot2 = np.full((H, W), np.nan, dtype=np.float32)
                        if args.focal_px is not None and args.baseline_m is not None:
                            valid_md2 = np.isfinite(disp) & (disp > 0) & (~m_with_outlier)
                            depth_for_plot2[valid_md2] = (args.focal_px * args.baseline_m) / disp[valid_md2]
                        suffix = {
                            "logdepth_otsu": "after_otsu",
                            "depth_hist_valley": "after_depth_hist_valley",
                            "depth_25m_global": "after_25m_global"
                        }.get(args.disp_outlier_method, "after_outlier")
                        _save_depth_distribution_plot(depth_for_plot2, out_dir, f"depth_distribution_toprow_{suffix}.png", f"depth_distribution_toprow_{suffix}.csv", pmax=100.0)
                except Exception as e:
                    print(f"[WARN] Failed to save toprow depth distribution: {e}")
                _apply_total_for_method("toprow", sky)
                results["toprow_total"] = {
                    "mask_total": os.path.join(out_dir, "toprow_total_mask.png"),
                }
            except Exception as e:
                print(f"[WARN] Top-row component masking failed: {e}")

        return results

    def _mask_extra_artifacts(mask: np.ndarray, out_dir: str, H: int, W: int, *, suffix: str) -> None:
        """Mask additional artifacts for comparison: disp_u16 and depth outputs, per method suffix."""
        # Disparity u16
        disp_u16_path = os.path.join(out_dir, "disp_u16.png")
        if os.path.isfile(disp_u16_path):
            dpng = cv2.imread(disp_u16_path, cv2.IMREAD_UNCHANGED)
            if dpng is not None and dpng.shape[:2] == (H, W):
                dpng_masked = dpng.copy()
                dpng_masked[mask] = 0
                cv2.imwrite(os.path.join(out_dir, f"disp_u16_masked_{suffix}.png"), dpng_masked)

        # Our depth (meters + PNGs)
        our_depth_npy = os.path.join(out_dir, "depth_m.npy")
        if os.path.isfile(our_depth_npy):
            # Avoid duplicating the final masked depth for the toprow_total path; it is computed explicitly from masked disparity
            if suffix != "toprow_total":
                dm = np.load(our_depth_npy).astype(np.float32, copy=False)
                if dm.shape != (H, W):
                    dm = cv2.resize(dm, (W, H), interpolation=cv2.INTER_NEAREST)
                dm_masked = dm.copy()
                dm_masked[mask] = np.nan
                np.save(os.path.join(out_dir, f"depth_m_masked_{suffix}.npy"), dm_masked)

        # Demo depth (meters)
        demo_depth_npy = os.path.join(out_dir, "depth_meter.npy")
        if os.path.isfile(demo_depth_npy):
            dd = np.load(demo_depth_npy).astype(np.float32, copy=False)
            if dd.shape != (H, W):
                dd = cv2.resize(dd, (W, H), interpolation=cv2.INTER_NEAREST)
                dd_masked = dd.copy()
                dd_masked[mask] = np.nan
                if suffix != "toprow_total":
                    np.save(os.path.join(out_dir, f"depth_meter_masked_{suffix}.npy"), dd_masked)


    def _run_job(job_idx: int, left_img: str, right_img: str, out_dir: str, gpu_id: Optional[str]):
        pair_name = f"pair{job_idx}"
        if _HAS_TQDM:
            tqdm.write(f"[RUN ] {pair_name}:\n  left : {left_img}\n  right: {right_img}\n  out  : {out_dir}\n  gpu  : {gpu_id if gpu_id is not None else 'default'}")
        else:
            print(f"[RUN ] {pair_name}:\n  left : {left_img}\n  right: {right_img}\n  out  : {out_dir}\n  gpu  : {gpu_id if gpu_id is not None else 'default'}")

        disp_path = run_foundation_stereo(
            foundation_repo=args.foundation_repo,
            left_path=left_img,
            right_path=right_img,
            ckpt=args.ckpt,
            out_dir=out_dir,
            scale=args.scale,
            valid_iters=args.valid_iters,
            get_pc=args.get_pc,
            hiera=args.hiera,
            intrinsic_file=None,
            focal_px=(args.focal_px if args.focal_px is not None else None),
            baseline_m=(args.baseline_m if args.baseline_m is not None else None),
            disp_percentile=(args.disp_percentile if args.disp_percentile is not None else None),
            python_exe=sys.executable,
            gpu_id=gpu_id,
        )

        if not disp_path or not os.path.isfile(disp_path):
            disp_candidate = os.path.join(out_dir, "disp.npy")
            if os.path.isfile(disp_candidate):
                disp_path = disp_candidate
            else:
                raise FileNotFoundError(
                    f"disp.npy was not found after run in {out_dir}. Ensure run_demo.py saves disp.npy."
                )
        # Post-process: disparity artifacts (png/mask) and optional depth
        try:
            post_msg = _postprocess_disp_and_depth(
                disp_path,
                out_dir,
                need_disp=True,  # When running fresh, always produce disp artifacts
                need_depth=(args.focal_px is not None and args.baseline_m is not None),
            )
            if post_msg and post_msg != "no-op":
                if _HAS_TQDM:
                    tqdm.write(f"[INFO] {pair_name}: {post_msg}")
                else:
                    print(f"[INFO] {pair_name}: {post_msg}")
        except Exception as de:
            if _HAS_TQDM:
                tqdm.write(f"[WARN] {pair_name}: postprocess failed: {de}")
            else:
                print(f"[WARN] {pair_name}: postprocess failed: {de}")

        return disp_path

    # Execute jobs
    if max_workers == 1:
        # Serial execution
        # First handle post-only jobs
        for (idx, out_dir, disp_path, left_img_post, need_disp, need_depth) in post_only_jobs:
            try:
                post_msg = _postprocess_disp_and_depth(disp_path, out_dir, need_disp, need_depth)
                # Apply watershed masking regardless of need_* (uses existing artifacts)
                _apply_masks(left_img_post, disp_path, out_dir, args.mask_methods)
                if _HAS_TQDM:
                    tqdm.write(f"[DONE] pair{idx}: post -> {post_msg}")
                else:
                    print(f"[DONE] pair{idx}: post -> {post_msg}")
                successes += 1
            except Exception as e:
                if _HAS_TQDM:
                    tqdm.write(f"[FAIL] pair{idx}: post-only: {e}")
                else:
                    print(f"[FAIL] pair{idx}: post-only: {e}")
                failures += 1
            finally:
                if pbar is not None:
                    pbar.update(1)

        for (idx, _pair_dir, left_img, right_img, out_dir) in jobs:
            try:
                disp_path = _run_job(idx, left_img, right_img, out_dir, gpu_id=(gpu_list[0] if gpu_list else None))
                # Apply watershed masking on new outputs
                _apply_masks(left_img, disp_path, out_dir, args.mask_methods)
                if _HAS_TQDM:
                    tqdm.write(f"[DONE] pair{idx}: disp -> {disp_path}")
                else:
                    print(f"[DONE] pair{idx}: disp -> {disp_path}")
                successes += 1
            except Exception as e:
                if _HAS_TQDM:
                    tqdm.write(f"[FAIL] pair{idx}: {e}")
                else:
                    print(f"[FAIL] pair{idx}: {e}")
                failures += 1
            finally:
                if pbar is not None:
                    pbar.update(1)
        if pbar is not None:
            pbar.close()
    else:
        # Parallel execution across GPUs (round-robin assignment)
        if not gpu_list:
            # No GPUs specified but max_workers > 1: allow CPU-parallel (not recommended for GPU inference)
            gpu_list = [None] * max_workers  # type: ignore

        futures = []
        future_to_idx: dict = {}
        with ThreadPoolExecutor(max_workers=max_workers) as ex:
            # Post-only jobs run on CPU; assign None GPU
            for (idx, out_dir, disp_path, left_img_post, need_disp, need_depth) in post_only_jobs:
                def _post_and_mask(dpath=disp_path, odir=out_dir, limg=left_img_post, nd=need_disp, ndep=need_depth):
                    msg = _postprocess_disp_and_depth(dpath, odir, nd, ndep)
                    _apply_masks(limg, dpath, odir, args.mask_methods)
                    return msg
                fut = ex.submit(_post_and_mask)
                futures.append(fut)
                future_to_idx[fut] = (idx, 'post')

            for i, (idx, _pair_dir, left_img, right_img, out_dir) in enumerate(jobs):
                gpu_id = gpu_list[i % len(gpu_list)]
                fut = ex.submit(_run_job, idx, left_img, right_img, out_dir, gpu_id)
                futures.append(fut)
                future_to_idx[fut] = (idx, 'run')
            for fut in as_completed(futures):
                try:
                    result = fut.result()
                    idx, kind = future_to_idx[fut]
                    if kind == 'run':
                        disp_path = result
                        if _HAS_TQDM:
                            tqdm.write(f"[DONE] pair{idx}: disp -> {disp_path}")
                        else:
                            print(f"[DONE] pair{idx}: disp -> {disp_path}")
                    else:
                        # post-only returns tuple of (post_msg, saved_paths)
                        if _HAS_TQDM:
                            tqdm.write(f"[DONE] pair{idx}: post + mask")
                        else:
                            print(f"[DONE] pair{idx}: post + mask")
                    successes += 1
                except Exception as e:
                    idx_kind = future_to_idx.get(fut, ('?', ''))
                    idx = idx_kind[0]
                    if _HAS_TQDM:
                        tqdm.write(f"[FAIL] pair{idx}: {e}")
                    else:
                        print(f"[FAIL] pair{idx}: {e}")
                    failures += 1
                finally:
                    if pbar is not None:
                        pbar.update(1)
        if pbar is not None:
            pbar.close()

    print(f"\nCompleted. Success: {successes}, Fail: {failures}, Total: {len(pairs)}")
    return 0 if failures == 0 else 1


if __name__ == "__main__":
    raise SystemExit(main())
