import argparse
from pathlib import Path
import sys
import re
import numpy as np
import cv2

SCRIPT_DIR = Path(__file__).resolve().parent
sys.path.append(str(SCRIPT_DIR))

from core.alignment import least_squares_align
from core.convert import inverse_to_depth
from core.io import load_depth_npy, resize_to_match
from core.multiprocessing_utils import default_worker_count, run_process_pool_as_completed

GT_DIR_DEFAULT = "path/to/change-dataset/depth"


def _normalize_depth(
    depth: np.ndarray,
    mask: np.ndarray,
    p_low: float = 1.0,
    p_high: float = 99.0,
) -> tuple[np.ndarray, np.ndarray]:
    """
    Applies robust percentile + log normalization.
    Using percentiles (p_low, p_high) isolates the informative part of the depth distribution
    and ignores extreme outliers. A log scale is then applied for distant regions.
    """
    depth = depth.astype(np.float64)
    valid = np.isfinite(depth) & mask

    if not np.any(valid):
        return np.zeros_like(depth, dtype=np.float64), valid

    vals = depth[valid]
    lo = float(np.percentile(vals, p_low))
    hi = float(np.percentile(vals, p_high))

    if not np.isfinite(lo) or not np.isfinite(hi) or hi <= lo:
        return np.zeros_like(depth, dtype=np.float64), valid

    clipped = np.clip(depth, lo, hi)


    use_log = lo > 0.0
    norm = np.zeros_like(depth, dtype=np.float64)
    
    if use_log:
        lo_t = np.log(lo)
        hi_t = np.log(hi)
        if hi_t > lo_t:
            norm[valid] = (np.log(clipped[valid]) - lo_t) / (hi_t - lo_t)
        else:
            use_log = False
            
    if not use_log:
        norm[valid] = (clipped[valid] - lo) / (hi - lo)

    valid_after = valid & np.isfinite(norm)
    return np.clip(norm, 0.0, 1.0), valid_after


def _save_outputs(
    depth: np.ndarray,
    mask: np.ndarray,
    out_png: Path,
    p_low: float,
    p_high: float,
    cmap: str = "gray16"
) -> None:
    out_png.parent.mkdir(parents=True, exist_ok=True)

    depth_to_save = depth.copy()
    depth_to_save[~mask] = np.nan

    norm, valid_mask = _normalize_depth(depth_to_save, mask, p_low=p_low, p_high=p_high)

    if cmap == "gray16":
        out = np.zeros_like(norm, dtype=np.uint16)
        out[valid_mask] = (norm[valid_mask] * 65535).astype(np.uint16)
        cv2.imwrite(str(out_png), out)
    else:
        out_8bit = np.zeros_like(norm, dtype=np.uint8)
        out_8bit[valid_mask] = (norm[valid_mask] * 255).astype(np.uint8)

        cmap_dict = {
            "viridis": cv2.COLORMAP_VIRIDIS,
            "magma": cv2.COLORMAP_MAGMA,
            "inferno": cv2.COLORMAP_INFERNO,
            "plasma": cv2.COLORMAP_PLASMA,
        }
        
        colored = cv2.applyColorMap(out_8bit, cmap_dict[cmap])
        colored[~valid_mask] = [0, 0, 0]  # Masked regions remain strict black
        cv2.imwrite(str(out_png), colored)


def _numeric_stem_key(stem: str) -> str | None:
    parts = re.findall(r"\d+", stem)
    return parts[-1] if parts else None


def _build_unique_map(
    files: list[Path],
    key_fn,
) -> tuple[dict[str, Path], list[str]]:
    out: dict[str, Path] = {}
    duplicates: list[str] = []
    for p in files:
        key = key_fn(p.stem)
        if key is None:
            continue
        if key in out:
            duplicates.append(key)
            continue
        out[key] = p
    return out, sorted(set(duplicates))


def _collect_pairs(
    pred_files: list[Path],
    gt_files: list[Path],
    pair_mode: str,
) -> tuple[list[tuple[Path, Path]], str, set[str], set[str], list[str], list[str]]:
    if pair_mode in {"auto", "exact"}:
        pred_exact, pred_dup = _build_unique_map(pred_files, lambda s: s)
        gt_exact, gt_dup = _build_unique_map(gt_files, lambda s: s)
        common_exact = sorted(pred_exact.keys() & gt_exact.keys())
        if pair_mode == "exact" or common_exact:
            pairs = [(pred_exact[k], gt_exact[k]) for k in common_exact]
            return pairs, "exact", set(pred_exact.keys()), set(gt_exact.keys()), pred_dup, gt_dup

    pred_num, pred_dup = _build_unique_map(pred_files, _numeric_stem_key)
    gt_num, gt_dup = _build_unique_map(gt_files, _numeric_stem_key)
    common_num = sorted(pred_num.keys() & gt_num.keys())
    pairs = [(pred_num[k], gt_num[k]) for k in common_num]
    return pairs, "digits", set(pred_num.keys()), set(gt_num.keys()), pred_dup, gt_dup


def _process_single_pair(task: tuple) -> None:
    (pred_path_str, gt_path_str, take_inverse, eps,
     max_depth, png_dir_str, p_low, p_high, cmap) = task

    pred_path = Path(pred_path_str)
    gt_path = Path(gt_path_str)
    png_dir = Path(png_dir_str)

    # Load outputs from disk; we ensure prediction and ground-truth spatial sizes match explicitly
    pred = load_depth_npy(str(pred_path))
    gt = load_depth_npy(str(gt_path))
    pred = resize_to_match(pred, gt)

    if take_inverse:
        pred_depth, pred_valid = inverse_to_depth(pred, eps=eps)
    else:
        pred_depth = pred.astype(np.float64)
        pred_valid = np.isfinite(pred_depth)

    gt_f64 = gt.astype(np.float64)
    valid_mask = (gt > 0) & pred_valid & np.isfinite(gt_f64)

    if max_depth is not None:
        valid_mask &= (gt_f64 <= max_depth)

    # Align the prediction dynamically to ground-truth using least squares scaling and shifting
    # This mitigates arbitrary affine shifts caused by monocular depth estimation ambiguity
    aligned, scale, shift = least_squares_align(pred_depth, gt_f64, valid_mask)

    vis = aligned
    vis_mask = valid_mask

    base_name = pred_path.stem
    _save_outputs(vis, vis_mask, png_dir / f"{base_name}.png",
                  p_low=p_low, p_high=p_high, cmap=cmap)


def main() -> None:
    ap = argparse.ArgumentParser(
        description="Process .npy depth predictions to high-detail PNGs exactly as they enter compute_metrics.",
    )
    ap.add_argument("--input-path", "--input-dir", dest="input_path", required=True,
                    help="Directory containing prediction .npy files")
    ap.add_argument("--output-path", "--output-dir", dest="output_path", required=True,
                    help="Directory to save processed outputs")
    ap.add_argument("--take-inverse", action="store_true",
                    help="Interpret predictions as inverse depth and convert to metric depth before alignment")
    ap.add_argument("--pred-pattern", default="*.npy",
                    help="Glob pattern for prediction files")
    ap.add_argument("--pair-mode", choices=("auto", "exact", "digits"), default="auto",
                    help="How to pair prediction and GT files by name (default: auto)")
    ap.add_argument("--gt-dir", default=GT_DIR_DEFAULT,
                    help="Ground-truth NPY directory")
    ap.add_argument("--max-depth", type=float, default=25.0,
                    help="Max distance clipping in meters for evaluation mask (default: 25.0)")
    ap.add_argument("--p-low", type=float, default=1.0,
                    help="Lower percentile for robust normalization (default: 1.0)")
    ap.add_argument("--p-high", type=float, default=99.0,
                    help="Upper percentile for robust normalization (default: 99.0)")
    ap.add_argument("--eps", type=float, default=1e-6,
                    help="Small epsilon for inverse depth conversion")
    ap.add_argument("--cmap", choices=["gray16", "viridis", "magma", "inferno", "plasma"], default="gray16",
                    help="Colormap for the output image (default: gray16)")
    ap.add_argument("--workers", type=int, default=default_worker_count(),
                    help="Number of worker processes (default: os.cpu_count() - 1, minimum 1)")
    args = ap.parse_args()

    input_dir = Path(args.input_path)
    output_dir = Path(args.output_path)
    gt_dir = Path(args.gt_dir)

    if not input_dir.exists() or not input_dir.is_dir():
        raise SystemExit(f"Input directory does not exist: {input_dir}")
    if not gt_dir.exists() or not gt_dir.is_dir():
        raise SystemExit(f"GT directory does not exist: {gt_dir}")

    pred_files = sorted(input_dir.glob(args.pred_pattern))
    if not pred_files:
        raise SystemExit(f"No prediction files found in {input_dir} with pattern {args.pred_pattern}")

    gt_files = sorted(gt_dir.glob("*.npy"))
    if not gt_files:
        raise SystemExit(f"No GT NPY files found in {gt_dir}")

    pairs, matched_mode, pred_keys, gt_keys, pred_dups, gt_dups = _collect_pairs(
        pred_files,
        gt_files,
        pair_mode=args.pair_mode,
    )
    if not pairs:
        raise SystemExit("No matching prediction/GT pairs found.")

    png_dir = output_dir / f"png_{args.cmap}_eval_mask_detail"

    tasks = [
        (
            str(pred_path), str(gt_path),
            args.take_inverse, args.eps,
            args.max_depth, str(png_dir),
            args.p_low, args.p_high, args.cmap,
        )
        for pred_path, gt_path in pairs
    ]

    workers = max(1, int(args.workers))

    def _on_progress(completed: int, total: int) -> None:
        if completed % 50 == 0 or completed == total:
            print(f"Processed {completed}/{total}")

    run_process_pool_as_completed(
        _process_single_pair, tasks, workers=workers,
        progress_callback=_on_progress,
    )

    print(f"Saved {len(pairs)} outputs to {png_dir}")


if __name__ == "__main__":
    main()