import argparse

"""
Evaluation script mapping raw model predictions logically against strictly masked Ground Truth.
Implements least-squares scale alignment organically resolving monocular scaling ambiguities.
"""
import json
from pathlib import Path
import sys
import numpy as np

SCRIPT_DIR = Path(__file__).resolve().parent
sys.path.append(str(SCRIPT_DIR))

from core.alignment import least_squares_align
from core.convert import inverse_to_depth
from core.io import list_sorted_files, load_depth_npy, load_depth_png, resize_to_match
from core.masks import load_or_create_dark_mask
from core.metrics import compute_metrics
from core.multiprocessing_utils import default_worker_count, run_process_pool_as_completed

GT_DIR_DEFAULT = "path/to/lunarsim-dataset/depth"
IMG_DIR_DEFAULT = "path/to/lunarsim-dataset/images"
MASK_DIR_DEFAULT = "path/to/lunarsim-dataset/masks"


def _mean_dict(dicts: list[dict]) -> dict:
    if not dicts:
        return {}
    keys = dicts[0].keys()
    out = {}
    for k in keys:
        vals = np.array([d[k] for d in dicts], dtype=np.float64)
        out[k] = float(np.nanmean(vals))
    return out


def _write_experiment(output_dir: Path, experiment: str, payload: dict) -> None:
    exp_dir = output_dir / experiment
    exp_dir.mkdir(parents=True, exist_ok=True)
    json_path = exp_dir / "results.json"
    with json_path.open("w", encoding="utf-8") as f:
        json.dump(payload, f, indent=2)


def _process_single_pair(task: tuple[str, str, str | None, str, bool, float, int, int]) -> dict:
    pred_path_str, gt_path_str, img_path_str, mask_dir_str, take_inverse, eps, dark_threshold, kernel_size = task

    pred_path = Path(pred_path_str)
    gt_path = Path(gt_path_str)
    img_path = Path(img_path_str) if img_path_str is not None else None
    mask_dir = Path(mask_dir_str)

    gt = load_depth_png(str(gt_path))
    pred = load_depth_npy(str(pred_path))
    pred = resize_to_match(pred, gt)

    # Invert disparity-like spaces systematically back into physical metric matrices allowing scale calculations
    if take_inverse:
        pred_depth, pred_valid = inverse_to_depth(pred, eps=eps)
    else:
        pred_depth = pred.astype(np.float64)
        pred_valid = np.isfinite(pred_depth)

    valid_mask = (gt > 0) & pred_valid & np.isfinite(gt)
    # Solve scalar shifts dynamically recovering uniform topological metric predictions irrespective of training domains
    aligned_pred, _, _ = least_squares_align(pred_depth, gt, valid_mask)

    result = {
        "classic": compute_metrics(aligned_pred, gt, valid_mask),
        "shaded": None,
    }

    if img_path is not None:
        mask_path = mask_dir / f"{img_path.stem}.npy"
        # Exclude unlit/shadowed geometric zones intrinsically preventing evaluation metrics penalizing sensor failures
        dark_mask = load_or_create_dark_mask(
            str(img_path),
            str(mask_path),
            threshold=dark_threshold,
            kernel_size=kernel_size,
        )
        shaded_mask = valid_mask & dark_mask
        result["shaded"] = compute_metrics(aligned_pred, gt, shaded_mask)

    return result


def main() -> None:
    ap = argparse.ArgumentParser(description="Evaluate LunarSim predictions against PNG GT with least-squares alignment.")
    ap.add_argument("--pred-dir", required=True, help="Directory containing .npy predictions")
    ap.add_argument("--take-inverse", action="store_true", help="Interpret predictions as inverse depth and convert to depth")
    ap.add_argument("--output-dir", required=True, help="Directory to save evaluation results")
    ap.add_argument("--pred-pattern", default="*.npy", help="Glob pattern for prediction files")
    ap.add_argument("--gt-dir", default=GT_DIR_DEFAULT, help="Ground-truth PNG directory")
    ap.add_argument("--img-dir", default=IMG_DIR_DEFAULT, help="RGB image directory")
    ap.add_argument("--mask-dir", default=MASK_DIR_DEFAULT, help="Directory for cached dark masks")
    ap.add_argument("--dark-threshold", type=int, default=50, help="Dark pixel threshold")
    ap.add_argument("--kernel-size", type=int, default=9, help="Morphology kernel size for dark mask")
    ap.add_argument("--eps", type=float, default=1e-6, help="Small epsilon for inverse depth")
    ap.add_argument(
        "--workers",
        type=int,
        default=default_worker_count(),
        help="Number of worker processes (default: os.cpu_count() - 1, minimum 1)",
    )
    args = ap.parse_args()

    pred_dir = Path(args.pred_dir)
    output_dir = Path(args.output_dir)
    gt_dir = Path(args.gt_dir)
    img_dir = Path(args.img_dir)
    mask_dir = Path(args.mask_dir)

    pred_files = sorted(pred_dir.glob(args.pred_pattern))
    gt_files = list_sorted_files(str(gt_dir), (".png",))
    img_files = list_sorted_files(str(img_dir), (".png", ".jpg", ".jpeg"))

    if not pred_files:
        raise SystemExit(f"No prediction files found in {pred_dir} with pattern {args.pred_pattern}")
    if not gt_files:
        raise SystemExit(f"No GT PNG files found in {gt_dir}")
    if not img_files:
        print(f"[WARN] No RGB images found in {img_dir}; shaded metrics will be skipped.")

    n = min(len(pred_files), len(gt_files), len(img_files) if img_files else len(gt_files))
    if len(pred_files) != len(gt_files):
        print(f"[WARN] Count mismatch: preds={len(pred_files)} gt={len(gt_files)}; using first {n}.")
    if img_files and len(img_files) != len(gt_files):
        print(f"[WARN] Count mismatch: images={len(img_files)} gt={len(gt_files)}; using first {n}.")

    overall_metrics = []
    shaded_metrics = []

    tasks = [
        (
            str(pred_files[i]),
            str(gt_files[i]),
            str(img_files[i]) if img_files else None,
            str(mask_dir),
            args.take_inverse,
            args.eps,
            args.dark_threshold,
            args.kernel_size,
        )
        for i in range(n)
    ]

    workers = max(1, int(args.workers))

    # Spawning independent validation tasks bypassing internal interpreter constraints for I/O operations inherently
    def _on_progress(completed: int, total: int) -> None:
        if completed % 50 == 0 or completed == total:
            print(f"Processed {completed}/{total}")

    for result in run_process_pool_as_completed(
        _process_single_pair,
        tasks,
        workers=workers,
        progress_callback=_on_progress,
    ):
        overall_metrics.append(result["classic"])
        if result["shaded"] is not None:
            shaded_metrics.append(result["shaded"])

    overall_mean = _mean_dict(overall_metrics)
    shaded_mean = _mean_dict(shaded_metrics) if shaded_metrics else {}

    classic_payload = {
        "dataset": "lunarsim",
        "experiment": "classic",
        "take_inverse": args.take_inverse,
        "num_pairs": n,
        "pred_dir": str(pred_dir),
        "gt_dir": str(gt_dir),
        "img_dir": str(img_dir),
        "max_depth": None,
        "protocol_a": {"overall": overall_mean},
        "protocol_b": None,
    }
    _write_experiment(output_dir, "classic", classic_payload)

    shaded_payload = {
        "dataset": "lunarsim",
        "experiment": "shaded",
        "take_inverse": args.take_inverse,
        "num_pairs": len(shaded_metrics),
        "pred_dir": str(pred_dir),
        "gt_dir": str(gt_dir),
        "img_dir": str(img_dir),
        "max_depth": None,
        "protocol_a": {"overall": shaded_mean},
        "protocol_b": None,
    }
    _write_experiment(output_dir, "shaded", shaded_payload)
    print(f"Saved results to {output_dir}")


if __name__ == "__main__":
    main()
