#!/usr/bin/env python3
"""
gt_visualize.py — Visualize Ground-Truth map sequences.

Faithfully recreates metric visualizations matching evaluation script masking limits,
so that predictions alongside Ground Truth exhibit exactly parallel distributions.
Outputs to: Outputs/GroundTruths/{DatasetName}/
"""

import os

# Confine computational resources strictly to a single processing core
# This sidesteps systemic bottlenecks and thrashing introduced by thread racing.
os.environ["OMP_NUM_THREADS"] = "1"
os.environ["OPENBLAS_NUM_THREADS"] = "1"
os.environ["MKL_NUM_THREADS"] = "1"
os.environ["VECLIB_MAXIMUM_THREADS"] = "1"
os.environ["NUMEXPR_NUM_THREADS"] = "1"

import argparse
from pathlib import Path
import sys
import numpy as np
import cv2

cv2.setNumThreads(0)

SCRIPT_DIR = Path(__file__).resolve().parent
sys.path.append(str(SCRIPT_DIR))

from core.io import load_depth_npy, load_depth_png, read_pfm
from core.multiprocessing_utils import default_worker_count, run_process_pool_as_completed

# Standard output layout
OUTPUT_ROOT = SCRIPT_DIR / "Outputs" / "GroundTruths"

# Target dataset roots (Absolute paths resolving physical raw locations)
LUNARSIM_GT_DIR = "path/to/lunarsim-dataset/LunarSim-Final/depth"
LUSNAR_ROOT     = "path/to/lusnar-dataset"
CHANGE_GT_DIR   = "path/to/change-dataset/depth"
CHERI_GT_DIR    = "path/to/cheri-dataset/depth"
ETNA_GT_DIR     = "path/to/etna-dataset/subsampled_datasets/Dataset1/depth"
S3LI_ROOT       = "path/to/s3li_dataset/clean_benchmark_dataset_latest"

S3LI_SEQUENCES = ["crater", "crater_inout", "landmarks", "loops",
                   "mapping", "traverse_1", "traverse_2"]

# Exact clipping bounds matching underlying compute evaluations
LUSNAR_SKY_SENTINEL = 65500.0
LUSNAR_MAX_DEPTH    = 50.0
CHANGE_MAX_DEPTH    = 25.0
CHERI_MAX_DEPTH     = 17.0
ETNA_MAX_DEPTH      = 15.0
S3LI_MAX_DEPTH      = 30.0

_OPENCV_CMAPS = {
    "viridis": cv2.COLORMAP_VIRIDIS,
    "magma":   cv2.COLORMAP_MAGMA,
    "inferno": cv2.COLORMAP_INFERNO,
    "plasma":  cv2.COLORMAP_PLASMA,
    "jet":     cv2.COLORMAP_JET,
    "turbo":   cv2.COLORMAP_TURBO,
}

def _normalize_depth(depth: np.ndarray, mask: np.ndarray, p_low: float = 2.0, p_high: float = 98.0) -> tuple[np.ndarray, np.ndarray]:
    """
    Normalizes depth into a [0.0, 1.0] domain explicitly excluding erratic sparse artifacts.
    This limits interpolation distortion mapping from raw float metric estimates to
    a bounded visual span without extreme sensor noise flattening the meaningful details.
    """
    depth = depth.astype(np.float64)
    valid = np.isfinite(depth) & mask

    if not np.any(valid):
        # Empty inputs yield zero matrices
        return np.zeros_like(depth, dtype=np.float64), valid

    # Calculate target percentile extremes avoiding extreme sensor boundaries
    vals = depth[valid]
    lo = float(np.percentile(vals, p_low))
    hi = float(np.percentile(vals, p_high))

    if not np.isfinite(lo) or not np.isfinite(hi) or hi <= lo:
        return np.zeros_like(depth, dtype=np.float64), valid

    # Clip mathematically, transforming metrics into proportional linear range
    clipped = np.clip(depth, lo, hi)
    norm = np.zeros_like(depth, dtype=np.float64)
    norm[valid] = (clipped[valid] - lo) / (hi - lo)

    # Secondary boolean cascade confirms normalizations bypassed NaN outputs
    valid_after = valid & np.isfinite(norm)
    return np.clip(norm, 0.0, 1.0), valid_after

def _save_visualization(
    depth: np.ndarray,
    mask: np.ndarray,
    out_path: Path,
    cmap_name: str = "inferno",
    p_low: float = 2.0,
    p_high: float = 98.0,
    point_size: int = 1,
    invert_colormap: bool = False,
) -> None:
    """
    Pipelines the normalized map dynamically to a colored spectrum via OpenCV mapping.
    Missing/Invalid components marked by `mask` are enforced strictly black.
    We also enlarge sparse artifacts dynamically so thin sparse LiDAR beams appear readable.
    """
    out_path.parent.mkdir(parents=True, exist_ok=True)

    norm_depth, valid_mask = _normalize_depth(depth, mask, p_low, p_high)
    
    # Inversions handle depth paradigms where near=white is preferred
    if invert_colormap:
        norm_depth = norm_depth.copy()
        norm_depth[valid_mask] = 1.0 - norm_depth[valid_mask]

    # Convert linear scalar arrays into [0, 255] byte spans for standard color coding 
    out_8bit = np.zeros_like(norm_depth, dtype=np.uint8)
    out_8bit[valid_mask] = (norm_depth[valid_mask] * 255).astype(np.uint8)

    cmap_code = _OPENCV_CMAPS.get(cmap_name, cv2.COLORMAP_INFERNO)
    colored = cv2.applyColorMap(out_8bit, cmap_code)

    # Blank out mathematically obsolete/masked space natively initialized by applying colormaps
    colored[~valid_mask] = [0, 0, 0]
    
    # Kernel dilation thickens sparse individual points to visible radii
    if point_size > 1:
        kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (point_size, point_size))
        colored = cv2.dilate(colored, kernel)

    cv2.imwrite(str(out_path), colored)

def _process_lunarsim_single(task: tuple) -> None:
    """Executes single frame conversion specifically handling LunarSim PNG rules."""
    gt_path_str, out_path_str, cmap, p_low, p_high = task
    gt = load_depth_png(gt_path_str)
    gt_f64 = gt.astype(np.float64)
    valid = (gt > 0) & np.isfinite(gt_f64)
    _save_visualization(gt_f64, valid, Path(out_path_str), cmap_name=cmap, p_low=p_low, p_high=p_high)

def visualize_lunarsim_gt(workers: int, cmap: str, p_low: float, p_high: float) -> None:
    """
    Routes LunarSim PNG evaluations. Uses pure distance thresholding without depth ceilings.
    """
    gt_dir = Path(LUNARSIM_GT_DIR)
    out_dir = OUTPUT_ROOT / "LunarSim"
    if not gt_dir.exists():
        print(f"  [SKIP] LunarSim GT dir not found: {gt_dir}")
        return

    gt_files = sorted(gt_dir.glob("*.png"))
    if not gt_files:
        return

    tasks = [(str(p), str(out_dir / f"{p.stem}.png"), cmap, p_low, p_high) for p in gt_files]
    print(f"  LunarSim: {len(tasks)} GT frames → {out_dir}")
    _run_pool(tasks, _process_lunarsim_single, workers)

def _process_lusnar_single(task: tuple) -> None:
    """Core translation converting LuSNAR PFM to evaluation limits under max-depth bounds."""
    gt_path_str, out_path_str, cmap, p_low, p_high = task
    gt = read_pfm(gt_path_str)
    gt_f64 = gt.astype(np.float64)
    valid = np.isfinite(gt_f64) & (gt_f64 > 0) & (gt_f64 < LUSNAR_SKY_SENTINEL)
    valid = valid & (gt_f64 <= LUSNAR_MAX_DEPTH)
    _save_visualization(gt_f64, valid, Path(out_path_str), cmap_name=cmap, p_low=p_low, p_high=p_high)


def visualize_lusnar_gt(workers: int, cmap: str, p_low: float, p_high: float) -> None:
    """
    Orchestrates logic compiling the 9 multi-sequence splits inside LuSNAR limits.
    Inherent sky-sentinel values implicitly void bounds above 50.0 units.
    """
    root = Path(LUSNAR_ROOT)
    out_dir = OUTPUT_ROOT / "LuSNAR"
    tasks = []
    # Moon_1 => Moon_9 subdirectories
    for seq in range(1, 10):
        seq_name = f"Moon_{seq}"
        gt_dir = root / seq_name / "image0" / "depth"
        if not gt_dir.exists(): continue
        for p in sorted(gt_dir.glob("*.pfm")):
            out_path = out_dir / seq_name / f"{p.stem}.png"
            tasks.append((str(p), str(out_path), cmap, p_low, p_high))

    if not tasks: return
    print(f"  LuSNAR: {len(tasks)} GT frames (9 sequences) → {out_dir}")
    _run_pool(tasks, _process_lusnar_single, workers)

def _process_change_single(task: tuple) -> None:
    """Filters chang'e 25m capped limits directly on structural inputs."""
    gt_path_str, out_path_str, cmap, p_low, p_high = task
    gt = load_depth_npy(gt_path_str)
    gt_f64 = gt.astype(np.float64)
    valid = np.isfinite(gt_f64) & (gt_f64 > 0) & (gt_f64 <= CHANGE_MAX_DEPTH)
    _save_visualization(gt_f64, valid, Path(out_path_str), cmap_name=cmap, p_low=p_low, p_high=p_high)

def visualize_change_gt(workers: int, cmap: str, p_low: float, p_high: float) -> None:
    """Deploys batch rendering against Chang'e environments strictly capping extreme distances to 25m."""
    gt_dir = Path(CHANGE_GT_DIR)
    out_dir = OUTPUT_ROOT / "Change"
    if not gt_dir.exists(): return
    gt_files = sorted(gt_dir.glob("*.npy"))
    if not gt_files: return

    tasks = [(str(p), str(out_dir / f"{p.stem}.png"), cmap, p_low, p_high) for p in gt_files]
    print(f"  Chang'e: {len(tasks)} GT frames → {out_dir}")
    _run_pool(tasks, _process_change_single, workers)


def _process_cheri_single(task: tuple) -> None:
    """Processes Cheri bounds accommodating potential colormap inversions dynamically."""
    gt_path_str, out_path_str, cmap, p_low, p_high, invert_cheri, cheri_max_depth = task
    gt = load_depth_npy(gt_path_str)
    gt_f64 = gt.astype(np.float64)
    valid = np.isfinite(gt_f64) & (gt_f64 > 0)
    if cheri_max_depth is not None:
        valid = valid & (gt_f64 <= cheri_max_depth)
    _save_visualization(gt_f64, valid, Path(out_path_str), cmap_name=cmap, p_low=p_low, p_high=p_high, invert_colormap=invert_cheri)

def visualize_cheri_gt(workers: int, cmap: str, p_low: float, p_high: float, invert_cheri: bool = False, cheri_max_depth: float | None = None) -> None:
    """
    Extracts explicit mapping targeting Cheri metrics—leveraging inverted colormaps 
    when analytical comparisons need yellow mapped centrally.
    """
    gt_dir = Path(CHERI_GT_DIR)
    out_dir = OUTPUT_ROOT / "Cheri"
    if not gt_dir.exists(): return
    gt_files = sorted(gt_dir.glob("*.npy"))
    if not gt_files: return

    tasks = [(str(p), str(out_dir / f"{p.stem}.png"), cmap, p_low, p_high, invert_cheri, cheri_max_depth) for p in gt_files]
    print(f"  Cheri: {len(tasks)} GT frames → {out_dir}")
    _run_pool(tasks, _process_cheri_single, workers)

def _process_etna_single(task: tuple) -> None:
    """Etna evaluation strict bound applications—enforcing 15 meter valid depth limits."""
    gt_path_str, out_path_str, cmap, p_low, p_high = task
    gt = load_depth_npy(gt_path_str)
    gt_f64 = gt.astype(np.float64)
    valid = np.isfinite(gt_f64) & (gt_f64 > 0) & (gt_f64 <= ETNA_MAX_DEPTH)
    _save_visualization(gt_f64, valid, Path(out_path_str), cmap_name=cmap, p_low=p_low, p_high=p_high)

def visualize_etna_gt(workers: int, cmap: str, p_low: float, p_high: float) -> None:
    """Initiates translation logic constraining Etna arrays under fixed dataset evaluation logic."""
    gt_dir = Path(ETNA_GT_DIR)
    out_dir = OUTPUT_ROOT / "Etna"
    if not gt_dir.exists(): return
    gt_files = sorted(gt_dir.glob("*.npy"))
    if not gt_files: return

    tasks = [(str(p), str(out_dir / f"{p.stem}.png"), cmap, p_low, p_high) for p in gt_files]
    print(f"  Etna: {len(tasks)} GT frames → {out_dir}")
    _run_pool(tasks, _process_etna_single, workers)

def _process_s3li_single(task: tuple) -> None:
    """Injects specialized morphological expansion point_sizes expanding microscopic points into visibility."""
    gt_path_str, out_path_str, cmap, p_low, p_high, point_size = task
    gt = load_depth_npy(gt_path_str)
    gt_f64 = gt.astype(np.float64)
    valid = np.isfinite(gt_f64) & (gt_f64 > 0) & (gt_f64 <= S3LI_MAX_DEPTH)
    _save_visualization(gt_f64, valid, Path(out_path_str), cmap_name=cmap, p_low=p_low, p_high=p_high, point_size=point_size)

def visualize_s3li_gt(workers: int, cmap: str, p_low: float, p_high: float, s3li_point_size: int = 3) -> None:
    """
    Renders S3LI LiDAR mapping. Dense point dilation allows extremely 
    sparse physical laser projections to map gracefully dynamically to flat PNG matrices.
    """
    s3li_point_size = max(1, int(s3li_point_size))
    root = Path(S3LI_ROOT)
    out_dir = OUTPUT_ROOT / "S3LI"
    tasks = []
    
    # Recursive folder fallbacks to reliably resolve complex S3LI dataset layout shifts
    for seq in S3LI_SEQUENCES:
        gt_dir = root / "depth" / seq
        if not gt_dir.exists(): gt_dir = root / seq / "depth"
        if not gt_dir.exists(): gt_dir = root / seq
        if not gt_dir.exists(): continue
        for p in sorted(gt_dir.glob("*.npy")):
            out_path = out_dir / seq / f"{p.stem}.png"
            tasks.append((str(p), str(out_path), cmap, p_low, p_high, s3li_point_size))

    if not tasks: return
    print(f"  S3LI: {len(tasks)} GT frames → {out_dir}")
    _run_pool(tasks, _process_s3li_single, workers)


def _run_pool(tasks: list, worker_fn, workers: int) -> None:
    """
    Delegates generation overhead comprehensively out to distinct multiprocess segments.
    Visualizing and scaling thousands of frames natively saturates I/O—this prevents locking.
    """
    def _on_progress(completed: int, total: int) -> None:
        if completed % 50 == 0 or completed == total:
            print(f"    [{completed}/{total}]")

    try:
        run_process_pool_as_completed(worker_fn, tasks, workers=workers, progress_callback=_on_progress)
    except TypeError:
        run_process_pool_as_completed(worker_fn, tasks, workers=workers)
    print(f"    Done ({len(tasks)} frames).")


DATASET_FUNCS = {
    "lunarsim": visualize_lunarsim_gt,
    "lusnar":   visualize_lusnar_gt,
    "change":   visualize_change_gt,
    "cheri":    visualize_cheri_gt,
    "etna":     visualize_etna_gt,
    "s3li":     visualize_s3li_gt,
}
ALL_DATASETS = list(DATASET_FUNCS.keys())

def main() -> None:
    """
    Master handler bridging CLI requests targeting independent dataset translation bounds.
    Delegates tasks directly correlating exactly to strict metric evaluation criteria. 
    """
    ap = argparse.ArgumentParser(description="Visualize GT depth maps natively incorporating eval masking distributions.")
    ap.add_argument("datasets", nargs="*", default=ALL_DATASETS, help=f"Datasets to visualise (default: all). Choices: {ALL_DATASETS}")
    ap.add_argument("--cmap", default="inferno", choices=sorted(_OPENCV_CMAPS.keys()), help="OpenCV colormap name")
    ap.add_argument("--p-low", type=float, default=2.0)
    ap.add_argument("--p-high", type=float, default=98.0)
    ap.add_argument("--s3li-point-size", type=int, default=3)
    ap.add_argument("--workers", type=int, default=default_worker_count())
    ap.add_argument("--invert-cheri", action="store_true", help="Near elements evaluate brighter yellow.")
    ap.add_argument("--cheri-max-depth", type=float, default=None)
    args = ap.parse_args()

    for name in args.datasets:
        name_lower = name.lower()
        if name_lower not in DATASET_FUNCS:
            continue
        if name_lower == "s3li":
            DATASET_FUNCS[name_lower](args.workers, args.cmap, args.p_low, args.p_high, args.s3li_point_size)
        elif name_lower == "cheri":
            DATASET_FUNCS[name_lower](args.workers, args.cmap, args.p_low, args.p_high, args.invert_cheri, args.cheri_max_depth)
        else:
            DATASET_FUNCS[name_lower](args.workers, args.cmap, args.p_low, args.p_high)

if __name__ == "__main__":
    main()
