import argparse
from pathlib import Path
import sys
import numpy as np
import cv2

SCRIPT_DIR = Path(__file__).resolve().parent
sys.path.append(str(SCRIPT_DIR))

from core.alignment import least_squares_align
from core.convert import inverse_to_depth
from core.evaluator import apply_max_depth_mask
from core.io import load_depth_npy, resize_to_match
from core.multiprocessing_utils import default_worker_count, run_process_pool_as_completed

ROOT_DEFAULT = "path/to/s3li_dataset/"
SEQUENCES = ["crater", "crater_inout", "landmarks", "loops", "mapping", "traverse_1", "traverse_2"]
MAX_DEPTH_DEFAULT = 30.0

def _valid_gt_mask(gt: np.ndarray) -> np.ndarray:
    """
    Excludes invalid physical artifacts implicitly. 
    Zeros usually denote sky voids or unrecorded LiDAR returns explicitly skipped during evaluation metrics.
    """
    return np.isfinite(gt) & (gt > 0)

def _normalize_depth(
    depth: np.ndarray,
    mask: np.ndarray,
    mode: str = "percentile",
    p_low: float = 2.0,
    p_high: float = 98.0,
    fixed_min: float = 0.0,
    fixed_max: float = 30.0,
) -> tuple[np.ndarray, np.ndarray]:
    """
    Transforms arbitrary scalar matrices definitively into fixed valid ranges [0.0 - 1.0].
    
    Mode 'percentile' trims outliers naturally, effectively expanding the color representation 
    for the primary topological variances inside the environment dynamically.
    
    Mode 'fixed' linearly maps spatial boundaries precisely, anchoring 0 directly to true 0.0,
    and 1 to the explicitly known metric maximum. This guarantees multiple outputs share exact identical visual scales.
    """
    depth = depth.astype(np.float64)
    valid = np.isfinite(depth) & mask

    if not np.any(valid):
        return np.zeros_like(depth, dtype=np.float64), valid

    # Calculate logical scaling constraints strictly respecting explicit mathematical parameters natively
    if mode == "percentile":
        vals = depth[valid]
        lo = float(np.percentile(vals, p_low))
        hi = float(np.percentile(vals, p_high))
    else:
        lo = fixed_min
        hi = fixed_max

    # Discard bounds where logical minimums crash scaling denominators 
    if not np.isfinite(lo) or not np.isfinite(hi) or hi <= lo:
        return np.zeros_like(depth, dtype=np.float64), valid

    # Enforce static truncation ensuring values fall seamlessly inside output logic bounds globally
    clipped = np.clip(depth, lo, hi)
    norm = np.zeros_like(depth, dtype=np.float64)
    norm[valid] = (clipped[valid] - lo) / (hi - lo)

    valid_after = valid & np.isfinite(norm)
    return np.clip(norm, 0.0, 1.0), valid_after


_OPENCV_CMAPS = {
    "viridis": cv2.COLORMAP_VIRIDIS,
    "magma":   cv2.COLORMAP_MAGMA,
    "inferno": cv2.COLORMAP_INFERNO,
    "plasma":  cv2.COLORMAP_PLASMA,
    "jet":     cv2.COLORMAP_JET,
    "turbo":   cv2.COLORMAP_TURBO,
}
ALL_CMAPS = ["gray16"] + sorted(_OPENCV_CMAPS.keys())


def _render_grayscale_layer(
    depth: np.ndarray,
    mask: np.ndarray,
    norm_mode: str,
    p_low: float, p_high: float,
    fixed_min: float, fixed_max: float,
    brightness: float = 0.7,
) -> np.ndarray:
    """
    Transcodes structural inference geometries into standardized monochrome backdrops natively.
    These grayscale foundations provide structural visual context dynamically supporting explicit sparse colored overlays without visual conflict.
    """
    norm, valid = _normalize_depth(depth, mask, mode=norm_mode,
                                   p_low=p_low, p_high=p_high,
                                   fixed_min=fixed_min, fixed_max=fixed_max)
    h, w = depth.shape
    
    # Render unmapped void arrays strictly utilizing neutral visual black limits
    canvas = np.zeros((h, w, 3), dtype=np.uint8)

    # Scale 8-bit conversions uniformly via brightness factors reducing focal strain organically
    gray_vals = (norm[valid] * 255 * brightness).astype(np.uint8)
    canvas[valid] = np.stack([gray_vals, gray_vals, gray_vals], axis=-1)
    
    return canvas


def _render_color_layer(
    depth: np.ndarray,
    mask: np.ndarray,
    norm_mode: str,
    p_low: float, p_high: float,
    fixed_min: float, fixed_max: float,
    cmap: str = "turbo",
    point_size: int = 3,
) -> tuple[np.ndarray, np.ndarray]:
    """
    Renders the quantitative LiDAR points or precise metric bounds intelligently mapped inside OpenCV colormaps.
    Expansion kernels inherently morph subpixel point limits so isolated arrays show physically against dense visual foundations naturally.
    """
    norm, valid = _normalize_depth(depth, mask, mode=norm_mode,
                                   p_low=p_low, p_high=p_high,
                                   fixed_min=fixed_min, fixed_max=fixed_max)
    h, w = depth.shape
    out_8bit = np.zeros((h, w), dtype=np.uint8)
    out_8bit[valid] = (norm[valid] * 255).astype(np.uint8)

    colored = cv2.applyColorMap(out_8bit, _OPENCV_CMAPS[cmap])
    
    # Hardcode excluded boundaries inherently blocking ambiguous OpenCV artifact generations logically  
    colored[~valid] = [0, 0, 0]

    overlay_mask = valid.copy()

    if point_size > 1:
        # Utilize kernel dilations specifically enhancing micro-scattered structural LiDAR elements seamlessly
        kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (point_size, point_size))
        colored = cv2.dilate(colored, kernel)
        overlay_mask = cv2.dilate(overlay_mask.astype(np.uint8), kernel).astype(bool)

    return colored, overlay_mask


def _save_composite(
    dense_depth: np.ndarray,
    dense_mask: np.ndarray,
    sparse_depth: np.ndarray,
    sparse_mask: np.ndarray,
    out_png: Path,
    norm_mode: str,
    p_low: float, p_high: float,
    fixed_min: float, fixed_max: float,
    cmap: str = "turbo",
    point_size: int = 3,
    bg_brightness: float = 0.7,
) -> None:
    """
    Orchestrates complex multi-level composites seamlessly inserting sparse quantitative visual evaluations organically atop foundational estimations.
    """
    out_png.parent.mkdir(parents=True, exist_ok=True)

    # Construct logical foundations accurately mapping dense estimations natively
    bg = _render_grayscale_layer(
        dense_depth, dense_mask,
        norm_mode=norm_mode,
        p_low=p_low, p_high=p_high,
        fixed_min=fixed_min, fixed_max=fixed_max,
        brightness=bg_brightness,
    )

    # Compile secondary metrics organically retaining clear qualitative visibility differences objectively  
    fg, fg_mask = _render_color_layer(
        sparse_depth, sparse_mask,
        norm_mode=norm_mode,
        p_low=p_low, p_high=p_high,
        fixed_min=fixed_min, fixed_max=fixed_max,
        cmap=cmap,
        point_size=point_size,
    )

    # Merge explicitly isolating precise overlay paths mapping foregrounds systematically 
    canvas = bg.copy()
    canvas[fg_mask] = fg[fg_mask]

    cv2.imwrite(str(out_png), canvas)


def _save_single(
    depth: np.ndarray,
    mask: np.ndarray,
    out_png: Path,
    norm_mode: str,
    p_low: float, p_high: float,
    fixed_min: float, fixed_max: float,
    cmap: str = "turbo",
    point_size: int = 1,
) -> None:
    """Fallback standard exports organically generating pure monolithic visual frames independently without contextual composite overlays natively."""
    out_png.parent.mkdir(parents=True, exist_ok=True)

    norm, valid_mask = _normalize_depth(
        depth.copy(), mask,
        mode=norm_mode,
        p_low=p_low, p_high=p_high,
        fixed_min=fixed_min, fixed_max=fixed_max,
    )

    if cmap == "gray16":
        out = np.zeros_like(norm, dtype=np.uint16)
        out[valid_mask] = (norm[valid_mask] * 65535).astype(np.uint16)
        if point_size > 1:
            kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (point_size, point_size))
            out = cv2.dilate(out, kernel)
        cv2.imwrite(str(out_png), out)
    else:
        out_8bit = np.zeros_like(norm, dtype=np.uint8)
        out_8bit[valid_mask] = (norm[valid_mask] * 255).astype(np.uint8)
        colored = cv2.applyColorMap(out_8bit, _OPENCV_CMAPS[cmap])
        colored[~valid_mask] = [0, 0, 0]
        if point_size > 1:
            kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (point_size, point_size))
            colored = cv2.dilate(colored, kernel)
        cv2.imwrite(str(out_png), colored)


def _pred_candidates_for_sequence(pred_seq_dir: Path, pred_pattern: str) -> list[Path]:
    """Iterates standard structural patterns explicitly handling fragmented output subfolder architectures sequentially."""
    root_files = sorted(pred_seq_dir.glob(pred_pattern))
    if root_files:
        return root_files
    depth_npy_files = sorted((pred_seq_dir / "depth_npy").glob(pred_pattern))
    if depth_npy_files:
        return depth_npy_files
    return sorted((pred_seq_dir / "depth").glob(pred_pattern))


def _collect_pairs(
    gt_root: Path,
    img_root: Path,
    pred_root: Path,
    pred_pattern: str,
) -> tuple[list[tuple[Path, Path]], list[tuple[Path, Path, Path]]]:
    """
    Pairs sequences explicitly mapping visual frameworks organically bridging multiple 
    dataset components concurrently without breaking alignment constraints automatically.
    """
    pairs: list[tuple[Path, Path]] = []
    shaded_pairs: list[tuple[Path, Path, Path]] = []

    for seq in SEQUENCES:
        gt_dir = gt_root / seq
        img_dir = img_root / seq
        pred_dir = pred_root / seq

        if not gt_dir.exists() or not img_dir.exists() or not pred_dir.exists():
            continue

        gt_files = sorted(gt_dir.glob("*.npy"))
        pred_files = _pred_candidates_for_sequence(pred_dir, pred_pattern)
        pred_map = {p.stem: p for p in pred_files}

        img_map: dict[str, Path] = {}
        for ext in (".png", ".jpg", ".jpeg"):
            for img in img_dir.glob(f"*{ext}"):
                img_map[img.stem] = img

        for gt_path in gt_files:
            pred_path = pred_map.get(gt_path.stem)
            if pred_path is None:
                continue
            
            pairs.append((gt_path, pred_path))
            
            img_path = img_map.get(gt_path.stem)
            if img_path is not None:
                shaded_pairs.append((gt_path, pred_path, img_path))

    return pairs, shaded_pairs


def _process_single_pair(task: tuple) -> None:
    """
    Compiles full visualization chains intelligently matching internal eval distributions natively.
    Enforces identical masking implementations objectively reflecting true evaluative limits perfectly prior to any color translations dynamically.
    """
    (gt_path_str, pred_path_str, take_inverse, eps,
     max_depth, png_base_dir_str, mode,
     norm_mode, p_low, p_high, cmap,
     point_size, bg_brightness) = task

    gt_path = Path(gt_path_str)
    pred_path = Path(pred_path_str)
    png_base_dir = Path(png_base_dir_str)

    seq_name = gt_path.parent.name
    seq_png_dir = png_base_dir / seq_name

    pred = load_depth_npy(str(pred_path))
    gt = load_depth_npy(str(gt_path))
    
    # Equalize array boundaries dynamically scaling inference matrices systematically 
    pred = resize_to_match(pred, gt)

    # Shift scales mathematically recovering absolute metric sequences natively matching LiDAR limits logically
    if take_inverse:
        pred_depth, pred_valid = inverse_to_depth(pred, eps=eps)
    else:
        pred_depth = pred.astype(np.float64)
        pred_valid = np.isfinite(pred_depth)

    gt_f64 = gt.astype(np.float64)
    # Mask explicitly limiting logic boundaries strictly against evaluation limits independently
    eval_mask = _valid_gt_mask(gt_f64) & pred_valid
    eval_mask = apply_max_depth_mask(gt_f64, eval_mask, max_depth)

    # Scale matrices correctly solving internal ambiguity structures mathematically via optimal alignment logic globally
    aligned, scale, shift = least_squares_align(pred_depth, gt_f64, eval_mask)

    # Calculate logical visual bounds separately supporting complex multi-layer integrations independently
    dense_depth = np.clip(aligned, 0.0, max_depth)
    dense_mask = pred_valid & np.isfinite(aligned) & (aligned >= 0.0) & (aligned <= max_depth)

    sparse_depth = aligned
    sparse_mask = eval_mask

    base_name = pred_path.stem
    out_path = seq_png_dir / f"{base_name}.png"

    # Branch specifically orchestrating final image construction modes seamlessly matching user configuration dynamically
    if mode == "composite":
        _save_composite(
            dense_depth, dense_mask,
            sparse_depth, sparse_mask,
            out_path,
            norm_mode=norm_mode,
            p_low=p_low, p_high=p_high,
            fixed_min=0.0, fixed_max=max_depth,
            cmap=cmap,
            point_size=point_size,
            bg_brightness=bg_brightness,
        )
    elif mode == "dense_only":
        _save_single(
            dense_depth, dense_mask, out_path,
            norm_mode=norm_mode,
            p_low=p_low, p_high=p_high,
            fixed_min=0.0, fixed_max=max_depth,
            cmap=cmap, point_size=1,
        )
    else:
        _save_single(
            sparse_depth, sparse_mask, out_path,
            norm_mode=norm_mode,
            p_low=p_low, p_high=p_high,
            fixed_min=0.0, fixed_max=max_depth,
            cmap=cmap, point_size=point_size,
        )


def main() -> None:
    ap = argparse.ArgumentParser(description="Visualize S3LI natively injecting evaluation masks precisely replicating internal clipping behaviors.")
    ap.add_argument("--pred-dir", required=True)
    ap.add_argument("--output-dir", required=True)
    ap.add_argument("--take-inverse", action="store_true")
    ap.add_argument("--pred-pattern", default="*.npy")
    ap.add_argument("--root", default=ROOT_DEFAULT)
    ap.add_argument("--max-depth", type=float, default=MAX_DEPTH_DEFAULT)
    ap.add_argument("--cmap", choices=ALL_CMAPS, default="turbo")
    ap.add_argument("--norm-mode", choices=["percentile", "fixed"], default="percentile")
    ap.add_argument("--p-low", type=float, default=2.0)
    ap.add_argument("--p-high", type=float, default=98.0)
    ap.add_argument("--eps", type=float, default=1e-6)

    vis_group = ap.add_mutually_exclusive_group()
    vis_group.add_argument("--composite", action="store_true", default=True)
    vis_group.add_argument("--dense-only", action="store_true")
    vis_group.add_argument("--sparse-only", action="store_true")

    ap.add_argument("--point-size", type=int, default=3)
    ap.add_argument("--bg-brightness", type=float, default=0.7)
    ap.add_argument("--workers", type=int, default=default_worker_count())

    args = ap.parse_args()

    # Distill configuration boundaries mapping directly matching independent visualization formats correctly
    if args.dense_only: mode = "dense_only"
    elif args.sparse_only: mode = "sparse_only"
    else: mode = "composite"

    pred_root = Path(args.pred_dir)
    output_dir = Path(args.output_dir)
    root = Path(args.root)

    gt_root = root / "depth"
    img_root = root / "images"

    if not pred_root.exists() or not pred_root.is_dir():
        raise SystemExit(f"Prediction root directory does not exist: {pred_root}")
    if not gt_root.exists() or not gt_root.is_dir():
        raise SystemExit(f"GT root directory does not exist: {gt_root}")

    pairs, _ = _collect_pairs(gt_root, img_root, pred_root, args.pred_pattern)
    if not pairs:
        raise SystemExit("No matching GT/pred pairs found across S3LI sequences.")

    png_base_dir = output_dir / f"png_{args.cmap}_{args.norm_mode}_{mode}"

    tasks = [
        (str(gt_path), str(pred_path), args.take_inverse, args.eps,
         args.max_depth, str(png_base_dir), mode, args.norm_mode, 
         args.p_low, args.p_high, args.cmap, args.point_size, args.bg_brightness)
        for gt_path, pred_path in pairs
    ]

    workers = max(1, int(args.workers))

    # Control diagnostic streams dynamically keeping execution logs completely minimal identically across scripts
    def _on_progress(completed: int, total: int) -> None:
        if completed % 100 == 0 or completed == total:
            print(f"Processed {completed}/{total}")

    run_process_pool_as_completed(_process_single_pair, tasks, workers=workers, progress_callback=_on_progress)
    print(f"\nSaved {len(pairs)} outputs across sequences in {png_base_dir}")

if __name__ == "__main__":
    main()