import argparse
from pathlib import Path
import sys
import re
import numpy as np
import cv2

SCRIPT_DIR = Path(__file__).resolve().parent
sys.path.append(str(SCRIPT_DIR))

from core.alignment import least_squares_align
from core.convert import inverse_to_depth
from core.io import load_depth_npy, resize_to_match
from core.multiprocessing_utils import default_worker_count, run_process_pool_as_completed

GT_DIR_DEFAULT = "path/to/etna-dataset/subsampled_datasets/Dataset1/depth"

def _normalize_depth(
    depth: np.ndarray,
    mask: np.ndarray,
    p_low: float = 1.0,
    p_high: float = 99.0,
) -> tuple[np.ndarray, np.ndarray]:
    """
    Standardizes Etna depth evaluation rendering outputs via percentile + log constraints.
    Instead of allowing extreme sensor artifacts to flatten matrices visually, percentiles limit bounds naturally. 
    Applying logarithmic scales organically enhances near-distance variations over extended uniform backgrounds natively mimicking human visual perception scales.
    """
    depth = depth.astype(np.float64)
    valid = np.isfinite(depth) & mask

    if not np.any(valid):
        return np.zeros_like(depth, dtype=np.float64), valid

    # Calculate absolute limits explicitly dropping extreme sensor bounds dynamically
    vals = depth[valid]
    lo = float(np.percentile(vals, p_low))
    hi = float(np.percentile(vals, p_high))

    if not np.isfinite(lo) or not np.isfinite(hi) or hi <= lo:
        return np.zeros_like(depth, dtype=np.float64), valid

    # Protect boundaries truncating any spatial overflow statically 
    clipped = np.clip(depth, lo, hi)

    # Opt explicitly into exponential space mappings distributing variations logically across far domains
    use_log = lo > 0.0
    norm = np.zeros_like(depth, dtype=np.float64)
    
    if use_log:
        lo_t = np.log(lo)
        hi_t = np.log(hi)
        if hi_t > lo_t:
            norm[valid] = (np.log(clipped[valid]) - lo_t) / (hi_t - lo_t)
        else:
            use_log = False
            
    # Execute fallback handling limits natively spanning negative spaces 
    if not use_log:
        norm[valid] = (clipped[valid] - lo) / (hi - lo)

    valid_after = valid & np.isfinite(norm)
    return np.clip(norm, 0.0, 1.0), valid_after


def _save_outputs(
    depth: np.ndarray,
    mask: np.ndarray,
    out_png: Path,
    p_low: float,
    p_high: float,
    cmap: str = "gray16"
) -> None:
    """
    Export matrices to filesystem maintaining analytical bounds correctly.
    Visual regions masked internally evaluate explicitly black, preventing ambiguous interpolation artifacts naturally occurring around invalid metric zones.
    """
    out_png.parent.mkdir(parents=True, exist_ok=True)

    # Convert invalid logical structures into NaN elements natively preventing visual bleed logically
    depth_to_save = depth.copy()
    depth_to_save[~mask] = np.nan

    norm, valid_mask = _normalize_depth(depth_to_save, mask, p_low=p_low, p_high=p_high)

    if cmap == "gray16":
        # 16-bit mapping natively sustains maximum output details specifically preserving fine topographical shifts
        out = np.zeros_like(norm, dtype=np.uint16)
        out[valid_mask] = (norm[valid_mask] * 65535).astype(np.uint16)
        cv2.imwrite(str(out_png), out)
    else:
        out_8bit = np.zeros_like(norm, dtype=np.uint8)
        out_8bit[valid_mask] = (norm[valid_mask] * 255).astype(np.uint8)

        cmap_dict = {
            "viridis": cv2.COLORMAP_VIRIDIS,
            "magma": cv2.COLORMAP_MAGMA,
            "inferno": cv2.COLORMAP_INFERNO,
            "plasma": cv2.COLORMAP_PLASMA,
        }
        
        colored = cv2.applyColorMap(out_8bit, cmap_dict[cmap])
        
        # Override interpolated bounds strictly blocking out metrics outside dataset context
        colored[~valid_mask] = [0, 0, 0] 
        cv2.imwrite(str(out_png), colored)


def _numeric_stem_key(stem: str) -> str | None:
    """Parses sequential filename digits natively facilitating pairing alignments."""
    parts = re.findall(r"\d+", stem)
    return parts[-1] if parts else None

def _build_unique_map(files: list[Path], key_fn) -> tuple[dict[str, Path], list[str]]:
    """Deduplicates filename matching organically mitigating ambiguous duplicate output bugs."""
    out: dict[str, Path] = {}
    duplicates: list[str] = []
    for p in files:
        key = key_fn(p.stem)
        if key is None:
            continue
        if key in out:
            duplicates.append(key)
            continue
        out[key] = p
    return out, sorted(set(duplicates))


def _collect_pairs(
    pred_files: list[Path],
    gt_files: list[Path],
    pair_mode: str,
) -> tuple[list[tuple[Path, Path]], str, set[str], set[str], list[str], list[str]]:
    """
    Links array outputs symmetrically with empirical evaluation criteria bases matching prefixes recursively across evaluation subfolders. 
    """
    if pair_mode in {"auto", "exact"}:
        pred_exact, pred_dup = _build_unique_map(pred_files, lambda s: s)
        gt_exact, gt_dup = _build_unique_map(gt_files, lambda s: s)
        common_exact = sorted(pred_exact.keys() & gt_exact.keys())
        if pair_mode == "exact" or common_exact:
            pairs = [(pred_exact[k], gt_exact[k]) for k in common_exact]
            return pairs, "exact", set(pred_exact.keys()), set(gt_exact.keys()), pred_dup, gt_dup

    # Iterative fallback aligns frames matching final digits directly
    pred_num, pred_dup = _build_unique_map(pred_files, _numeric_stem_key)
    gt_num, gt_dup = _build_unique_map(gt_files, _numeric_stem_key)
    common_num = sorted(pred_num.keys() & gt_num.keys())
    pairs = [(pred_num[k], gt_num[k]) for k in common_num]
    return pairs, "digits", set(pred_num.keys()), set(gt_num.keys()), pred_dup, gt_dup


def _process_single_pair(task: tuple) -> None:
    """
    Converts individual prediction frames systematically back to evaluation space, 
    scaling bounds properly mimicking strict quantitative masking operations prior to image formulation.
    """
    (pred_path_str, gt_path_str, take_inverse, eps,
     max_depth, png_dir_str, p_low, p_high, cmap) = task

    pred_path = Path(pred_path_str)
    gt_path = Path(gt_path_str)
    png_dir = Path(png_dir_str)

    # Process padding discrepancies natively interpolating metric scales logically matching output boundaries
    pred = load_depth_npy(str(pred_path))
    gt = load_depth_npy(str(gt_path))
    pred = resize_to_match(pred, gt)

    # Reverse disparity conversions mathematically correcting inverse depth fields
    if take_inverse:
        pred_depth, pred_valid = inverse_to_depth(pred, eps=eps)
    else:
        pred_depth = pred.astype(np.float64)
        pred_valid = np.isfinite(pred_depth)

    # Filter logical bounds ensuring spatial fields drop ambiguous outputs cleanly 
    gt_f64 = gt.astype(np.float64)
    valid_mask = (gt > 0) & pred_valid & np.isfinite(gt_f64)

    # Enforce precise Etna evaluation cutoffs directly against visualization renders natively matching numerical metrics 
    if max_depth is not None:
        valid_mask &= (gt_f64 <= max_depth)

    # Calibrate scales eliminating arbitrary constants organically adjusting prediction bounds precisely to metrics
    aligned, scale, shift = least_squares_align(pred_depth, gt_f64, valid_mask)

    vis = aligned
    vis_mask = valid_mask

    base_name = pred_path.stem
    _save_outputs(vis, vis_mask, png_dir / f"{base_name}.png", p_low=p_low, p_high=p_high, cmap=cmap)


def main() -> None:
    ap = argparse.ArgumentParser(description="Map arrays systematically processing logical inferences for qualitative Etna reviews.")
    ap.add_argument("--input-path", "--input-dir", dest="input_path", required=True)
    ap.add_argument("--output-path", "--output-dir", dest="output_path", required=True)
    ap.add_argument("--take-inverse", action="store_true")
    ap.add_argument("--pred-pattern", default="*.npy")
    ap.add_argument("--pair-mode", choices=("auto", "exact", "digits"), default="auto")
    ap.add_argument("--gt-dir", default=GT_DIR_DEFAULT)
    ap.add_argument("--max-depth", type=float, default=15.0)
    ap.add_argument("--p-low", type=float, default=1.0)
    ap.add_argument("--p-high", type=float, default=99.0)
    ap.add_argument("--eps", type=float, default=1e-6)
    ap.add_argument("--cmap", choices=["gray16", "viridis", "magma", "inferno", "plasma"], default="gray16")
    ap.add_argument("--workers", type=int, default=default_worker_count())
    args = ap.parse_args()

    input_dir = Path(args.input_path)
    output_dir = Path(args.output_path)
    gt_dir = Path(args.gt_dir)

    if not input_dir.exists() or not input_dir.is_dir():
        raise SystemExit(f"Input directory does not exist: {input_dir}")
    if not gt_dir.exists() or not gt_dir.is_dir():
        raise SystemExit(f"GT directory does not exist: {gt_dir}")

    pred_files = sorted(input_dir.glob(args.pred_pattern))
    if not pred_files:
        raise SystemExit("No prediction files found")
        
    gt_files = sorted(gt_dir.glob("*.npy"))
    if not gt_files:
        raise SystemExit("No GT NPY files found")

    pairs, matched_mode, pred_keys, gt_keys, pred_dups, gt_dups = _collect_pairs(
        pred_files, gt_files, pair_mode=args.pair_mode,
    )
    if not pairs:
        raise SystemExit("No matching prediction/GT pairs found.")

    png_dir = output_dir / f"png_{args.cmap}_eval_mask_detail"

    tasks = [
        (str(pred_path), str(gt_path), args.take_inverse, args.eps,
         args.max_depth, str(png_dir), args.p_low, args.p_high, args.cmap)
        for pred_path, gt_path in pairs
    ]

    workers = max(1, int(args.workers))

    # Output periodic ticks strictly indicating parallel frame progress organically 
    def _on_progress(completed: int, total: int) -> None:
        if completed % 50 == 0 or completed == total:
            print(f"Processed {completed}/{total}")

    run_process_pool_as_completed(_process_single_pair, tasks, workers=workers, progress_callback=_on_progress)
    print(f"Saved {len(pairs)} outputs to {png_dir}")

if __name__ == "__main__":
    main()