import argparse

"""
Evaluation script mapping raw model predictions logically against strictly masked Ground Truth.
Implements least-squares scale alignment organically resolving monocular scaling ambiguities.
"""
import json
from pathlib import Path
import sys
import numpy as np

SCRIPT_DIR = Path(__file__).resolve().parent
sys.path.append(str(SCRIPT_DIR))

from core.alignment import least_squares_align
from core.convert import inverse_to_depth
from core.evaluator import apply_max_depth_mask, build_distance_masks, compute_metrics_by_region
from core.io import load_depth_npy, resize_to_match
from core.masks import load_or_create_dark_mask
from core.metrics import compute_metrics
from core.multiprocessing_utils import default_worker_count, run_process_pool_as_completed

GT_DIR_DEFAULT = "path/to/cheri-dataset/depth"
IMG_DIR_DEFAULT = "path/to/cheri-dataset/images"
MASK_DIR_DEFAULT = "path/to/cheri-dataset/masks"
MAX_DEPTH = 17.0
DISTANCE_BINS = [(0.0, 5.0), (5.0, 10.0), (10.0, 17.0)]
BIN_NAMES = ["near", "medium", "far"]


def _mean_dict(dicts: list[dict]) -> dict:
    if not dicts:
        return {}
    keys = dicts[0].keys()
    out = {}
    for k in keys:
        vals = np.array([d[k] for d in dicts], dtype=np.float64)
        out[k] = float(np.nanmean(vals))
    return out


def _mean_region_dict(region_list: list[dict]) -> dict:
    if not region_list:
        return {}
    keys = region_list[0].keys()
    out = {}
    for k in keys:
        sub = [d[k] for d in region_list if k in d]
        out[k] = _mean_dict(sub)
    return out


def _write_experiment(output_dir: Path, experiment: str, payload: dict) -> None:
    exp_dir = output_dir / experiment
    exp_dir.mkdir(parents=True, exist_ok=True)
    json_path = exp_dir / "results.json"
    with json_path.open("w", encoding="utf-8") as f:
        json.dump(payload, f, indent=2)


def _pack_distance(results: dict) -> dict:
    return {
        "overall": results.get("overall", {}),
        "bins": {name: results.get(name, {}) for name in BIN_NAMES},
    }


def _valid_gt_mask(gt: np.ndarray) -> np.ndarray:
    return np.isfinite(gt) & (gt > 0)


def _match_pairs(
    gt_dir: Path,
    pred_dir: Path,
    img_dir: Path,
    pred_pattern: str,
) -> tuple[list[tuple[Path, Path]], list[tuple[Path, Path, Path]]]:
    gt_files = sorted(gt_dir.glob("*.npy"))
    pred_files = sorted(pred_dir.glob(pred_pattern))
    pred_map = {p.stem: p for p in pred_files}
    img_map: dict[str, Path] = {}
    if img_dir.exists():
        for ext in (".png", ".jpg", ".jpeg"):
            for img in img_dir.glob(f"*{ext}"):
                img_map[img.stem] = img

    pairs = []
    shaded_pairs = []
    for gt_path in gt_files:
        pred_path = pred_map.get(gt_path.stem)
        if pred_path is None:
            continue
        pairs.append((gt_path, pred_path))
        img_path = img_map.get(gt_path.stem)
        if img_path is not None:
            shaded_pairs.append((gt_path, pred_path, img_path))
    return pairs, shaded_pairs


def _process_single_pair(task: tuple[str, str, str | None, str, bool, float, int, int]) -> dict:
    gt_path_str, pred_path_str, img_path_str, mask_dir_str, take_inverse, eps, dark_threshold, kernel_size = task

    gt_path = Path(gt_path_str)
    pred_path = Path(pred_path_str)
    img_path = Path(img_path_str) if img_path_str is not None else None
    mask_dir = Path(mask_dir_str)

    gt = load_depth_npy(str(gt_path))
    pred = load_depth_npy(str(pred_path))
    pred = resize_to_match(pred, gt)

    # Invert disparity-like spaces systematically back into physical metric matrices allowing scale calculations
    if take_inverse:
        pred_depth, pred_valid = inverse_to_depth(pred, eps=eps)
    else:
        pred_depth = pred.astype(np.float64)
        pred_valid = np.isfinite(pred_depth)

    # Mask logically bounds evaluation against pure metric validity (e.g. dropping sky sentinels natively)
    valid_mask = _valid_gt_mask(gt) & pred_valid
    # Eliminate evaluation noise naturally corrupting distant background regions by asserting strict bounds
    valid_mask = apply_max_depth_mask(gt, valid_mask, MAX_DEPTH)

    # Solve scalar shifts dynamically recovering uniform topological metric predictions irrespective of training domains
    aligned_pred, _, _ = least_squares_align(pred_depth, gt, valid_mask)
    result = {
        "classic": compute_metrics(aligned_pred, gt, valid_mask),
        "distance": compute_metrics_by_region(
            aligned_pred,
            gt,
            valid_mask,
            build_distance_masks(gt, valid_mask, DISTANCE_BINS, BIN_NAMES),
        ),
        "shaded": None,
    }

    if img_path is not None:
        mask_path = mask_dir / f"{img_path.stem}.npy"
        # Exclude unlit/shadowed geometric zones intrinsically preventing evaluation metrics penalizing sensor failures
        dark_mask = load_or_create_dark_mask(
            str(img_path),
            str(mask_path),
            threshold=dark_threshold,
            kernel_size=kernel_size,
        )
        shaded_base = valid_mask & dark_mask
        result["shaded"] = compute_metrics(aligned_pred, gt, shaded_base)

    return result


def main() -> None:
    ap = argparse.ArgumentParser(description="Evaluate Cheri predictions against NPY GT with least-squares alignment.")
    ap.add_argument("--pred-dir", required=True, help="Directory containing .npy predictions")
    ap.add_argument("--take-inverse", action="store_true", help="Interpret predictions as inverse depth and convert to depth")
    ap.add_argument("--output-dir", required=True, help="Directory to save evaluation results")
    ap.add_argument("--pred-pattern", default="*.npy", help="Glob pattern for prediction files")
    ap.add_argument("--gt-dir", default=GT_DIR_DEFAULT, help="Ground-truth NPY directory")
    ap.add_argument("--img-dir", default=IMG_DIR_DEFAULT, help="RGB image directory (used for matching)")
    ap.add_argument("--mask-dir", default=MASK_DIR_DEFAULT, help="Directory for cached dark masks")
    ap.add_argument("--dark-threshold", type=int, default=81, help="Dark pixel threshold")
    ap.add_argument("--kernel-size", type=int, default=9, help="Morphology kernel size for dark mask")
    ap.add_argument("--eps", type=float, default=1e-6, help="Small epsilon for inverse depth")
    ap.add_argument(
        "--workers",
        type=int,
        default=default_worker_count(),
        help="Number of worker processes (default: os.cpu_count() - 1, minimum 1)",
    )
    args = ap.parse_args()

    pred_dir = Path(args.pred_dir)
    output_dir = Path(args.output_dir)
    gt_dir = Path(args.gt_dir)
    img_dir = Path(args.img_dir)
    mask_dir = Path(args.mask_dir)

    pairs, shaded_pairs = _match_pairs(gt_dir, pred_dir, img_dir, args.pred_pattern)
    if not pairs:
        raise SystemExit("No matching GT/pred pairs found.")

    if not img_dir.exists():
        raise SystemExit(f"RGB directory not found: {img_dir}")
    if not shaded_pairs:
        raise SystemExit("No matching GT/pred/RGB triples found for shaded evaluation.")

    aligned_classic = []
    aligned_distance = []
    aligned_shaded = []

    image_map = {str(gt_path): str(img_path) for gt_path, _, img_path in shaded_pairs}
    tasks = [
        (
            str(gt_path),
            str(pred_path),
            image_map.get(str(gt_path)),
            str(mask_dir),
            args.take_inverse,
            args.eps,
            args.dark_threshold,
            args.kernel_size,
        )
        for gt_path, pred_path in pairs
    ]

    workers = max(1, int(args.workers))

    # Spawning independent validation tasks bypassing internal interpreter constraints for I/O operations inherently
    def _on_progress(completed: int, total: int) -> None:
        if completed % 100 == 0 or completed == total:
            print(f"Processed {completed}/{total}")

    for result in run_process_pool_as_completed(
        _process_single_pair,
        tasks,
        workers=workers,
        progress_callback=_on_progress,
    ):
        aligned_classic.append(result["classic"])
        aligned_distance.append(result["distance"])
        if result["shaded"] is not None:
            aligned_shaded.append(result["shaded"])

    common_meta = {
        "dataset": "cheri",
        "take_inverse": args.take_inverse,
        "num_pairs": len(pairs),
        "pred_dir": str(pred_dir),
        "gt_dir": str(gt_dir),
        "img_dir": str(img_dir),
        "max_depth": MAX_DEPTH,
    }

    classic_payload = {
        **common_meta,
        "experiment": "classic",
        "protocol_a": {"overall": _mean_dict(aligned_classic)},
        "protocol_b": None,
    }
    _write_experiment(output_dir, "classic", classic_payload)

    distance_payload = {
        **common_meta,
        "experiment": "distance",
        "distance_bins": {name: list(bin_range) for name, bin_range in zip(BIN_NAMES, DISTANCE_BINS)},
        "protocol_a": _pack_distance(_mean_region_dict(aligned_distance)),
        "protocol_b": None,
    }
    _write_experiment(output_dir, "distance", distance_payload)

    shaded_payload = {
        **common_meta,
        "experiment": "shaded",
        "num_pairs": len(aligned_shaded),
        "protocol_a": {"overall": _mean_dict(aligned_shaded)},
        "protocol_b": None,
    }
    _write_experiment(output_dir, "shaded", shaded_payload)

    print("Distance bins: near 0-5m, medium 5-10m, far 10-17m")
    print(f"Saved results to {output_dir}")


if __name__ == "__main__":
    main()
