import argparse
from pathlib import Path
import sys
import re
import numpy as np
import cv2

SCRIPT_DIR = Path(__file__).resolve().parent
sys.path.append(str(SCRIPT_DIR))

from core.alignment import least_squares_align
from core.convert import inverse_to_depth
from core.io import load_depth_npy, resize_to_match
from core.multiprocessing_utils import default_worker_count, run_process_pool_as_completed

GT_DIR_DEFAULT = "path/to/cheri-dataset/depth"
MAX_DEPTH_DEFAULT = 17.0

def _valid_gt_mask(gt: np.ndarray) -> np.ndarray:
    """Mask strictly filtering metric GT space. Zeroes and infinities usually correlate to sky or missing sensor records."""
    return np.isfinite(gt) & (gt > 0)

def _normalize_depth_percentile(
    depth: np.ndarray,
    mask: np.ndarray,
    p_low: float = 2.0,
    p_high: float = 98.0,
) -> tuple[np.ndarray, np.ndarray]:
    """
    Transforms Cheri's raw continuous metric depth distributions into bounded visuals [0.0 - 1.0].
    Isolating data via percentiles (e.g. 2nd to 98th) ensures that erratic spatial reflections
    or extreme distant data points do not wash out the colormap variation of the core scene.
    """
    depth = depth.astype(np.float64)
    valid = np.isfinite(depth) & mask

    if not np.any(valid):
        return np.zeros_like(depth, dtype=np.float64), valid

    # Calculate spatial percentile distributions filtering sensor extremes
    vals = depth[valid]
    lo = float(np.percentile(vals, p_low))
    hi = float(np.percentile(vals, p_high))

    if not np.isfinite(lo) or not np.isfinite(hi) or hi <= lo:
        lo = float(np.min(vals))
        hi = float(np.max(vals))
        if hi <= lo:
            return np.zeros_like(depth, dtype=np.float64), valid

    # Clip mathematically preventing bounds overflows natively truncating extremes
    clipped = np.clip(depth, lo, hi)
    norm = np.zeros_like(depth, dtype=np.float64)
    norm[valid] = (clipped[valid] - lo) / (hi - lo)

    valid_after = valid & np.isfinite(norm)
    return np.clip(norm, 0.0, 1.0), valid_after

def _save_outputs(
    depth: np.ndarray,
    mask: np.ndarray,
    out_png: Path,
    p_low: float,
    p_high: float,
    cmap: str = "magma",
    invert_colormap: bool = False,
) -> None:
    """
    Renders bounds natively across standard colored visual spaces. 
    Cheri sequences occasionally scale invert due to shifting least-squares affine transformations.
    Invert flags flip map sequences organically back to logical depth progressions (near = bright).
    """
    out_png.parent.mkdir(parents=True, exist_ok=True)

    # Protect against NaN boundary contaminations internally masking logic natively
    depth_to_save = depth.copy()
    depth_to_save[~mask] = np.nan

    norm, valid_mask = _normalize_depth_percentile(depth_to_save, mask, p_low=p_low, p_high=p_high)
    if invert_colormap:
        norm = norm.copy()
        norm[valid_mask] = 1.0 - norm[valid_mask]

    # Explicit handling isolates bit depths for scientific evaluation bounds vs qualitative visual colormaps
    if cmap == "gray16":
        out = np.zeros_like(norm, dtype=np.uint16)
        out[valid_mask] = (norm[valid_mask] * 65535).astype(np.uint16)
        cv2.imwrite(str(out_png), out)
    else:
        out_8bit = np.zeros_like(norm, dtype=np.uint8)
        out_8bit[valid_mask] = (norm[valid_mask] * 255).astype(np.uint8)

        cmap_dict = {
            "viridis": cv2.COLORMAP_VIRIDIS,
            "magma": cv2.COLORMAP_MAGMA,
            "inferno": cv2.COLORMAP_INFERNO,
            "plasma": cv2.COLORMAP_PLASMA,
            "turbo": cv2.COLORMAP_TURBO,
        }

        colored = cv2.applyColorMap(out_8bit, cmap_dict[cmap])
        
        # Override spatial background artifacts enforcing pure black boundaries natively
        colored[~valid_mask] = [0, 0, 0]
        cv2.imwrite(str(out_png), colored)

def _numeric_stem_key(stem: str) -> str | None:
    """Parses sequential identifiers dynamically supporting automated frame aggregation logic."""
    parts = re.findall(r"\d+", stem)
    return parts[-1] if parts else None

def _build_unique_map(files: list[Path], key_fn) -> tuple[dict[str, Path], list[str]]:
    """Generates filtering dictionary bounds enforcing deduplicated frame sets."""
    out: dict[str, Path] = {}
    duplicates: list[str] = []
    for p in files:
        key = key_fn(p.stem)
        if key is None:
            continue
        if key in out:
            duplicates.append(key)
            continue
        out[key] = p
    return out, sorted(set(duplicates))

def _collect_pairs(
    pred_files: list[Path],
    gt_files: list[Path],
    pair_mode: str,
) -> tuple[list[tuple[Path, Path]], str, set[str], set[str], list[str], list[str]]:
    """
    Pairs predicted arrays against original metrics intelligently based on numerical prefixes
    bypassing variations introduced by specific pipeline scripts natively adding odd suffixes.
    """
    if pair_mode in {"auto", "exact"}:
        pred_exact, pred_dup = _build_unique_map(pred_files, lambda s: s)
        gt_exact, gt_dup = _build_unique_map(gt_files, lambda s: s)
        common_exact = sorted(pred_exact.keys() & gt_exact.keys())
        if pair_mode == "exact" or common_exact:
            pairs = [(pred_exact[k], gt_exact[k]) for k in common_exact]
            return pairs, "exact", set(pred_exact.keys()), set(gt_exact.keys()), pred_dup, gt_dup

    pred_num, pred_dup = _build_unique_map(pred_files, _numeric_stem_key)
    gt_num, gt_dup = _build_unique_map(gt_files, _numeric_stem_key)
    common_num = sorted(pred_num.keys() & gt_num.keys())
    pairs = [(pred_num[k], gt_num[k]) for k in common_num]
    return pairs, "digits", set(pred_num.keys()), set(gt_num.keys()), pred_dup, gt_dup

def _process_single_pair(task: tuple) -> None:
    """
    Reconstructs metric sequences directly scaling inverse predictions via global least-squares.
    Applies logic respecting 'dense_vis', exporting visualizations reflecting the full frame context versus sparse arrays.
    """
    (pred_path_str, gt_path_str, take_inverse, eps,
     max_depth, png_dir_str, p_low, p_high, cmap, dense_vis) = task

    pred_path = Path(pred_path_str)
    gt_path = Path(gt_path_str)
    png_dir = Path(png_dir_str)

    # Resolve architectural output padding differences dynamically scaling the prediction matrices
    pred = load_depth_npy(str(pred_path))
    gt = load_depth_npy(str(gt_path))
    pred = resize_to_match(pred, gt)

    # Invert ambiguous disparity sequences systematically recovering natural distance spaces 
    if take_inverse:
        pred_depth, pred_valid = inverse_to_depth(pred, eps=eps)
    else:
        pred_depth = pred.astype(np.float64)
        pred_valid = np.isfinite(pred_depth)

    # Implement boundary metric culling removing sky noise and extreme limit overflows natively 
    gt_f64 = gt.astype(np.float64)
    valid_mask = _valid_gt_mask(gt_f64) & pred_valid
    valid_mask &= (gt_f64 <= max_depth)

    # Shift scales mathematically anchoring uncalibrated ML output sequences against true metric bounds
    aligned, scale, shift = least_squares_align(pred_depth, gt_f64, valid_mask)

    if dense_vis:
        # Preserve holistic context visualizing non-sparse backgrounds alongside aligned predictions 
        ego_mask_valid = np.isfinite(gt_f64)
        vis_mask = ego_mask_valid & pred_valid & np.isfinite(aligned)
        vis_mask &= (aligned >= 0.0) & (aligned <= max_depth)
        vis = aligned
    else:
        vis = aligned
        vis_mask = valid_mask

    base_name = pred_path.stem
    
    # Scale bounds returning negative values indicate systematic inversions natively requiring dynamic colormap flipping
    _save_outputs(vis, vis_mask, png_dir / f"{base_name}.png", p_low=p_low, p_high=p_high, cmap=cmap, invert_colormap=(scale < 0.0))

def main() -> None:
    ap = argparse.ArgumentParser(description="Map arrays systematically processing logical inferences for qualitative review.")
    ap.add_argument("--input-path", "--input-dir", dest="input_path", required=True)
    ap.add_argument("--output-path", "--output-dir", dest="output_path", required=True)
    ap.add_argument("--take-inverse", action="store_true")
    ap.add_argument("--pred-pattern", default="*.npy")
    ap.add_argument("--pair-mode", choices=("auto", "exact", "digits"), default="auto")
    ap.add_argument("--gt-dir", default=GT_DIR_DEFAULT)
    ap.add_argument("--max-depth", type=float, default=MAX_DEPTH_DEFAULT)
    ap.add_argument("--p-low", type=float, default=2.0)
    ap.add_argument("--p-high", type=float, default=98.0)
    ap.add_argument("--eps", type=float, default=1e-6)
    ap.add_argument("--cmap", choices=["gray16", "viridis", "magma", "inferno", "plasma", "turbo"], default="magma")
    ap.add_argument("--dense-vis", action="store_true")
    ap.add_argument("--workers", type=int, default=default_worker_count())
    args = ap.parse_args()

    input_dir = Path(args.input_path)
    output_dir = Path(args.output_path)
    gt_dir = Path(args.gt_dir)

    if not input_dir.exists() or not input_dir.is_dir():
        raise SystemExit(f"Input directory does not exist: {input_dir}")
    if not gt_dir.exists() or not gt_dir.is_dir():
        raise SystemExit(f"GT directory does not exist: {gt_dir}")

    pred_files = sorted(input_dir.glob(args.pred_pattern))
    if not pred_files:
        raise SystemExit(f"No prediction files found")

    gt_files = sorted(gt_dir.glob("*.npy"))
    if not gt_files:
        raise SystemExit(f"No GT NPY files found")

    pairs, matched_mode, pred_keys, gt_keys, pred_dups, gt_dups = _collect_pairs(
        pred_files, gt_files, pair_mode=args.pair_mode,
    )
    if not pairs:
        raise SystemExit("No matching prediction/GT pairs found.")

    suffix = "dense" if args.dense_vis else "eval_mask"
    png_dir = output_dir / f"png_{args.cmap}_{suffix}"

    tasks = [
        (str(pred_path), str(gt_path), args.take_inverse, args.eps,
         args.max_depth, str(png_dir), args.p_low, args.p_high, args.cmap, args.dense_vis)
        for pred_path, gt_path in pairs
    ]

    workers = max(1, int(args.workers))

    # Control verbosity logically avoiding buffer floods 
    def _on_progress(completed: int, total: int) -> None:
        if completed % 50 == 0 or completed == total:
            print(f"Processed {completed}/{total}")

    run_process_pool_as_completed(_process_single_pair, tasks, workers=workers, progress_callback=_on_progress)
    print(f"\nSaved {len(pairs)} outputs to {png_dir}")

if __name__ == "__main__":
    main()
