import argparse

"""
Evaluation script mapping raw model predictions logically against strictly masked Ground Truth.
Implements least-squares scale alignment organically resolving monocular scaling ambiguities.
"""
import json
from pathlib import Path
import sys
import numpy as np

SCRIPT_DIR = Path(__file__).resolve().parent
sys.path.append(str(SCRIPT_DIR))

from core.alignment import least_squares_align
from core.convert import inverse_to_depth
from core.evaluator import apply_max_depth_mask, build_distance_masks, compute_metrics_by_region
from core.io import load_depth_npy, resize_to_match
from core.masks import load_or_create_dark_mask
from core.metrics import compute_metrics
from core.multiprocessing_utils import default_worker_count, run_process_pool_as_completed

ROOT_DEFAULT = "path/to/s3li_dataset"
SEQUENCES = ["crater", "crater_inout", "landmarks", "loops", "mapping", "traverse_1", "traverse_2"]
MAX_DEPTH = 30.0
DISTANCE_BINS = [(0.0, 5.0), (5.0, 15.0), (15.0, 30.0)]
BIN_NAMES = ["near", "medium", "far"]
DARK_THRESHOLD = 50
KERNEL_SIZE = 9


def _mean_dict(dicts: list[dict]) -> dict:
    if not dicts:
        return {}
    keys = dicts[0].keys()
    out = {}
    for k in keys:
        vals = np.array([d[k] for d in dicts], dtype=np.float64)
        out[k] = float(np.nanmean(vals))
    return out


def _mean_region_dict(region_list: list[dict]) -> dict:
    if not region_list:
        return {}
    keys = region_list[0].keys()
    out = {}
    for k in keys:
        sub = [d[k] for d in region_list if k in d]
        out[k] = _mean_dict(sub)
    return out


def _write_experiment(output_dir: Path, experiment: str, payload: dict) -> None:
    exp_dir = output_dir / experiment
    exp_dir.mkdir(parents=True, exist_ok=True)
    json_path = exp_dir / "results.json"
    with json_path.open("w", encoding="utf-8") as f:
        json.dump(payload, f, indent=2)


def _pack_distance(results: dict) -> dict:
    return {
        "overall": results.get("overall", {}),
        "bins": {name: results.get(name, {}) for name in BIN_NAMES},
    }


def _valid_gt_mask(gt: np.ndarray) -> np.ndarray:
    return np.isfinite(gt) & (gt > 0)


def _pred_candidates_for_sequence(pred_seq_dir: Path, pred_pattern: str) -> list[Path]:
    root_files = sorted(pred_seq_dir.glob(pred_pattern))
    if root_files:
        return root_files
    depth_npy_files = sorted((pred_seq_dir / "depth_npy").glob(pred_pattern))
    return depth_npy_files


def _collect_pairs(
    gt_root: Path,
    img_root: Path,
    pred_root: Path,
    pred_pattern: str,
) -> tuple[list[tuple[Path, Path]], list[tuple[Path, Path, Path]]]:
    pairs: list[tuple[Path, Path]] = []
    shaded_pairs: list[tuple[Path, Path, Path]] = []

    for seq in SEQUENCES:
        gt_dir = gt_root / seq
        img_dir = img_root / seq
        pred_dir = pred_root / seq

        if not gt_dir.exists() or not img_dir.exists() or not pred_dir.exists():
            continue

        gt_files = sorted(gt_dir.glob("*.npy"))
        pred_files = _pred_candidates_for_sequence(pred_dir, pred_pattern)

        if len(pred_files) != len(gt_files):
            raise SystemExit(
                f"Prediction/GT count mismatch for sequence '{seq}': "
                f"pred={len(pred_files)} gt={len(gt_files)}. "
                "Ensure each GT file has one prediction file."
            )

        gt_stems = {p.stem for p in gt_files}
        pred_stems = {p.stem for p in pred_files}
        if gt_stems != pred_stems:
            missing_preds = sorted(gt_stems - pred_stems)
            extra_preds = sorted(pred_stems - gt_stems)
            raise SystemExit(
                f"Prediction/GT stem mismatch for sequence '{seq}'. "
                f"Missing preds: {missing_preds[:5]}{'...' if len(missing_preds) > 5 else ''}; "
                f"Extra preds: {extra_preds[:5]}{'...' if len(extra_preds) > 5 else ''}."
            )

        pred_map = {p.stem: p for p in pred_files}

        img_map: dict[str, Path] = {}
        for ext in (".png", ".jpg", ".jpeg"):
            for img in img_dir.glob(f"*{ext}"):
                img_map[img.stem] = img

        for gt_path in gt_files:
            pred_path = pred_map.get(gt_path.stem)
            if pred_path is None:
                continue
            pairs.append((gt_path, pred_path))

            img_path = img_map.get(gt_path.stem)
            if img_path is not None:
                shaded_pairs.append((gt_path, pred_path, img_path))

    return pairs, shaded_pairs


def _process_single_pair(task: tuple[str, str, str | None, str, bool, float]) -> dict:
    gt_path_str, pred_path_str, img_path_str, mask_root_str, take_inverse, eps = task

    gt_path = Path(gt_path_str)
    pred_path = Path(pred_path_str)
    img_path = Path(img_path_str) if img_path_str is not None else None
    mask_root = Path(mask_root_str)

    gt = load_depth_npy(str(gt_path))
    pred = load_depth_npy(str(pred_path))

    pred = resize_to_match(pred, gt)

    # Invert disparity-like spaces systematically back into physical metric matrices allowing scale calculations
    if take_inverse:
        pred_depth, pred_valid = inverse_to_depth(pred, eps=eps)
    else:
        pred_depth = pred.astype(np.float64)
        pred_valid = np.isfinite(pred_depth)

    # Mask logically bounds evaluation against pure metric validity (e.g. dropping sky sentinels natively)
    valid_mask = _valid_gt_mask(gt) & pred_valid
    # Eliminate evaluation noise naturally corrupting distant background regions by asserting strict bounds
    valid_mask = apply_max_depth_mask(gt, valid_mask, MAX_DEPTH)

    # Solve scalar shifts dynamically recovering uniform topological metric predictions irrespective of training domains
    aligned_pred, _, _ = least_squares_align(pred_depth, gt, valid_mask)
    result = {
        "classic": compute_metrics(aligned_pred, gt, valid_mask),
        "distance": compute_metrics_by_region(
            aligned_pred,
            gt,
            valid_mask,
            build_distance_masks(gt, valid_mask, DISTANCE_BINS, BIN_NAMES),
        ),
        "shaded": None,
    }

    if img_path is not None:
        mask_dir = mask_root / gt_path.parent.name
        mask_path = mask_dir / f"{img_path.stem}.npy"
        # Exclude unlit/shadowed geometric zones intrinsically preventing evaluation metrics penalizing sensor failures
        dark_mask = load_or_create_dark_mask(
            str(img_path),
            str(mask_path),
            threshold=DARK_THRESHOLD,
            kernel_size=KERNEL_SIZE,
        )
        shaded_base = valid_mask & dark_mask
        result["shaded"] = compute_metrics(aligned_pred, gt, shaded_base)

    return result


def main() -> None:
    ap = argparse.ArgumentParser(description="Evaluate S3LI predictions with pooled sequence metrics.")
    ap.add_argument("--pred-dir", required=True, help="Directory containing sequence subfolders with .npy predictions")
    ap.add_argument("--take-inverse", action="store_true", help="Interpret predictions as inverse depth and convert to depth")
    ap.add_argument("--output-dir", required=True, help="Directory to save evaluation results")
    ap.add_argument("--pred-pattern", default="*.npy", help="Glob pattern for prediction files")
    ap.add_argument("--root", default=ROOT_DEFAULT, help="S3LI root directory")
    ap.add_argument("--eps", type=float, default=1e-6, help="Small epsilon for inverse depth")
    ap.add_argument(
        "--workers",
        type=int,
        default=default_worker_count(),
        help="Number of worker processes (default: os.cpu_count() - 1, minimum 1)",
    )
    args = ap.parse_args()

    pred_root = Path(args.pred_dir)
    output_dir = Path(args.output_dir)
    root = Path(args.root)

    gt_root = root / "depth"
    img_root = root / "images"
    mask_root = root / "masks"

    pairs, shaded_pairs = _collect_pairs(gt_root, img_root, pred_root, args.pred_pattern)

    if not pairs:
        raise SystemExit("No matching GT/pred pairs found across S3LI sequences.")
    if not shaded_pairs:
        raise SystemExit("No matching GT/pred/RGB triples found across S3LI sequences for shaded evaluation.")

    aligned_classic = []
    aligned_distance = []
    aligned_shaded = []

    shaded_image_map = {str(gt_path): str(img_path) for gt_path, _, img_path in shaded_pairs}
    tasks = [
        (
            str(gt_path),
            str(pred_path),
            shaded_image_map.get(str(gt_path)),
            str(mask_root),
            args.take_inverse,
            args.eps,
        )
        for gt_path, pred_path in pairs
    ]

    workers = max(1, int(args.workers))

    # Spawning independent validation tasks bypassing internal interpreter constraints for I/O operations inherently
    def _on_progress(completed: int, total: int) -> None:
        if completed % 100 == 0 or completed == total:
            print(f"Processed {completed}/{total} pairs")

    for result in run_process_pool_as_completed(
        _process_single_pair,
        tasks,
        workers=workers,
        progress_callback=_on_progress,
    ):
        aligned_classic.append(result["classic"])
        aligned_distance.append(result["distance"])
        if result["shaded"] is not None:
            aligned_shaded.append(result["shaded"])

    common_meta = {
        "dataset": "s3li",
        "take_inverse": args.take_inverse,
        "num_pairs": len(pairs),
        "num_shaded_pairs": len(shaded_pairs),
        "pred_dir": str(pred_root),
        "gt_dir": str(gt_root),
        "img_dir": str(img_root),
        "sequences": SEQUENCES,
        "max_depth": MAX_DEPTH,
    }

    classic_payload = {
        **common_meta,
        "experiment": "classic",
        "protocol_a": {"overall": _mean_dict(aligned_classic)},
        "protocol_b": None,
    }
    _write_experiment(output_dir, "classic", classic_payload)

    distance_payload = {
        **common_meta,
        "experiment": "distance",
        "distance_bins": {name: list(bin_range) for name, bin_range in zip(BIN_NAMES, DISTANCE_BINS)},
        "protocol_a": _pack_distance(_mean_region_dict(aligned_distance)),
        "protocol_b": None,
    }
    _write_experiment(output_dir, "distance", distance_payload)

    shaded_payload = {
        **common_meta,
        "experiment": "shaded",
        "protocol_a": {"overall": _mean_dict(aligned_shaded)},
        "protocol_b": None,
    }
    _write_experiment(output_dir, "shaded", shaded_payload)

    print("Distance bins: near 0-5m, medium 5-15m, far 15-30m")
    print(f"Saved results to {output_dir}")


if __name__ == "__main__":
    main()
