import os

os.environ["OMP_NUM_THREADS"] = "1"
os.environ["OPENBLAS_NUM_THREADS"] = "1"
os.environ["MKL_NUM_THREADS"] = "1"
os.environ["VECLIB_MAXIMUM_THREADS"] = "1"
os.environ["NUMEXPR_NUM_THREADS"] = "1"

import argparse

"""
Evaluation script mapping raw model predictions logically against strictly masked Ground Truth.
Implements least-squares scale alignment organically resolving monocular scaling ambiguities.
"""
import json
from pathlib import Path
import sys
import numpy as np
import cv2

cv2.setNumThreads(0)

SCRIPT_DIR = Path(__file__).resolve().parent
sys.path.append(str(SCRIPT_DIR))

from core.alignment import least_squares_align
from core.convert import inverse_to_depth
from core.evaluator import apply_max_depth_mask, build_distance_masks, compute_metrics_by_region
from core.io import read_pfm, load_depth_npy, resize_to_match
from core.masks import label_mask_from_hex, load_or_create_dark_mask
from core.metrics import compute_metrics
from core.multiprocessing_utils import default_worker_count, run_process_pool_as_completed

ROOT_DEFAULT = "path/to/lusnar-dataset"
MASK_ROOT_DEFAULT = "path/to/lusnar-dataset"
MAX_DEPTH = 50.0
DISTANCE_BINS = [(0.0, 5.0), (5.0, 15.0), (15.0, 50.0)]
BIN_NAMES = ["near", "medium", "far"]

REGOLITH_HEX = "#BB469C"
CRATER_HEX = "#7800C8"
ROCK_HEX = "#E8FA50"


def _mean_dict(dicts: list[dict]) -> dict:
    if not dicts:
        return {}
    keys = dicts[0].keys()
    out = {}
    for k in keys:
        vals = np.array([d[k] for d in dicts], dtype=np.float64)
        out[k] = float(np.nanmean(vals))
    return out


def _mean_region_dict(region_list: list[dict]) -> dict:
    if not region_list:
        return {}
    keys = region_list[0].keys()
    out = {}
    for k in keys:
        sub = [d[k] for d in region_list if k in d]
        out[k] = _mean_dict(sub)
    return out


def _write_experiment(output_dir: Path, experiment: str, payload: dict) -> None:
    exp_dir = output_dir / experiment
    exp_dir.mkdir(parents=True, exist_ok=True)
    json_path = exp_dir / "results.json"
    with json_path.open("w", encoding="utf-8") as f:
        json.dump(payload, f, indent=2)


def _valid_gt_mask(gt: np.ndarray) -> np.ndarray:
    return np.isfinite(gt) & (gt > 0) & (gt < 65500.0)


def _load_label(path: Path) -> np.ndarray:
    img = cv2.imread(str(path), cv2.IMREAD_COLOR)
    if img is None:
        raise FileNotFoundError(f"Could not load label image: {path}")
    return img


def _eval_semantic(pred: np.ndarray, gt: np.ndarray, base_mask: np.ndarray, label_bgr: np.ndarray) -> dict:
    masks = {
        "regolith": label_mask_from_hex(label_bgr, REGOLITH_HEX),
        "rock": label_mask_from_hex(label_bgr, ROCK_HEX),
        "crater": label_mask_from_hex(label_bgr, CRATER_HEX),
    }
    out = {}
    for name, mask in masks.items():
        out[name] = compute_metrics(pred, gt, base_mask & mask)
    return out


def _eval_distance(
    pred: np.ndarray,
    gt: np.ndarray,
    base_mask: np.ndarray,
    distance_bins: list[tuple[float, float | None]],
    bin_names: list[str],
) -> dict:
    region_masks = build_distance_masks(gt, base_mask, distance_bins, bin_names)
    return compute_metrics_by_region(pred, gt, base_mask, region_masks)


def _pack_distance(results: dict) -> dict:
    return {
        "overall": results.get("overall", {}),
        "bins": {name: results.get(name, {}) for name in BIN_NAMES},
    }


def _pred_candidates_for_sequence(pred_seq_dir: Path) -> list[Path]:
    root_files = sorted(pred_seq_dir.glob("*.npy"))
    if root_files:
        return root_files
    return sorted((pred_seq_dir / "depth_npy").glob("*.npy"))


def _collect_sequence_pairs(root: Path, pred_root: Path) -> list[tuple[Path, Path, Path, Path]]:
    pairs = []
    
    for seq in range(1, 10):
        seq_name = f"Moon_{seq}"
        gt_dir = root / seq_name / "image0" / "depth"
        label_dir = root / seq_name / "image0" / "label"
        img_dir = root / seq_name / "image0" / "images"
        pred_dir = pred_root / seq_name

        if not gt_dir.exists() or not pred_dir.exists() or not img_dir.exists():
            continue

        gt_files = sorted(gt_dir.glob("*.pfm"))
        pred_files = _pred_candidates_for_sequence(pred_dir)

        if len(pred_files) != len(gt_files):
            raise SystemExit(
                f"Prediction/GT count mismatch for sequence '{seq_name}': "
                f"pred={len(pred_files)} gt={len(gt_files)}. "
                "Ensure each GT file has one prediction file."
            )

        gt_stems = {p.stem for p in gt_files}
        pred_stems = {p.stem for p in pred_files}
        if gt_stems != pred_stems:
            missing_preds = sorted(gt_stems - pred_stems)
            extra_preds = sorted(pred_stems - gt_stems)
            raise SystemExit(
                f"Prediction/GT stem mismatch for sequence '{seq_name}'. "
                f"Missing preds: {missing_preds[:5]}{'...' if len(missing_preds) > 5 else ''}; "
                f"Extra preds: {extra_preds[:5]}{'...' if len(extra_preds) > 5 else ''}."
            )

        pred_map = {p.stem: p for p in pred_files}
        label_map = {p.stem: p for p in label_dir.glob("*.png")}
        img_map: dict[str, Path] = {}
        for ext in (".png", ".jpg", ".jpeg"):
            for img in img_dir.glob(f"*{ext}"):
                img_map[img.stem] = img

        for gt_path in gt_files:
            stem = gt_path.stem
            pred_path = pred_map.get(stem)
            label_path = label_map.get(stem)
            img_path = img_map.get(stem)
            if pred_path is None or label_path is None or img_path is None:
                continue
            pairs.append((gt_path, pred_path, label_path, img_path))
    return pairs


def _process_single_pair(task: tuple[str, str, str, str, str, bool, float, int, int]) -> dict:
    gt_path_str, pred_path_str, label_path_str, img_path_str, mask_root_str, take_inverse, eps, dark_threshold, kernel_size = task

    gt_path = Path(gt_path_str)
    pred_path = Path(pred_path_str)
    label_path = Path(label_path_str)
    img_path = Path(img_path_str)
    mask_root = Path(mask_root_str)

    gt = read_pfm(str(gt_path))
    pred = load_depth_npy(str(pred_path))
    pred = resize_to_match(pred, gt)
    label_img = _load_label(label_path)

    # Invert disparity-like spaces systematically back into physical metric matrices allowing scale calculations
    if take_inverse:
        pred_depth, pred_valid = inverse_to_depth(pred, eps=eps)
    else:
        pred_depth = pred.astype(np.float64)
        pred_valid = np.isfinite(pred_depth)

    base_mask = _valid_gt_mask(gt) & pred_valid
    base_mask = apply_max_depth_mask(gt, base_mask, MAX_DEPTH)

    mask_dir = mask_root / gt_path.parents[2].name / "image0" / "masks"
    mask_path = mask_dir / f"{img_path.stem}.npy"
    # Exclude unlit/shadowed geometric zones intrinsically preventing evaluation metrics penalizing sensor failures
    dark_mask = load_or_create_dark_mask(
        str(img_path),
        str(mask_path),
        threshold=dark_threshold,
        kernel_size=kernel_size,
    )
    shaded_base = base_mask & dark_mask

    # Solve scalar shifts dynamically recovering uniform topological metric predictions irrespective of training domains
    aligned_pred, _, _ = least_squares_align(pred_depth, gt, base_mask)

    return {
        "classic": compute_metrics(aligned_pred, gt, base_mask),
        "distance": _eval_distance(aligned_pred, gt, base_mask, DISTANCE_BINS, BIN_NAMES),
        "semantic": _eval_semantic(aligned_pred, gt, base_mask, label_img),
        "shaded": compute_metrics(aligned_pred, gt, shaded_base),
    }


def main() -> None:
    ap = argparse.ArgumentParser(description="Evaluate Lusnar predictions against PFM GT with semantic masks.")
    ap.add_argument("--pred-dir", required=True, help="Directory containing Moon_1..Moon_9 subfolders with .npy predictions")
    ap.add_argument("--take-inverse", action="store_true", help="Interpret predictions as inverse depth and convert to depth")
    ap.add_argument("--output-dir", required=True, help="Directory to save evaluation results")
    ap.add_argument("--root", default=ROOT_DEFAULT, help="Lusnar dataset root")
    ap.add_argument("--mask-root", default=MASK_ROOT_DEFAULT, help="Root directory for cached dark masks")
    ap.add_argument("--dark-threshold", type=int, default=50, help="Dark pixel threshold")
    ap.add_argument("--kernel-size", type=int, default=9, help="Morphology kernel size for dark mask")
    ap.add_argument("--eps", type=float, default=1e-6, help="Small epsilon for inverse depth")
    ap.add_argument(
        "--workers",
        type=int,
        default=default_worker_count(),
        help="Number of worker processes (default: os.cpu_count() - 1, minimum 1)",
    )
    args = ap.parse_args()

    root = Path(args.root)
    pred_root = Path(args.pred_dir)
    output_dir = Path(args.output_dir)
    mask_root = Path(args.mask_root)

    pairs = _collect_sequence_pairs(root, pred_root)
    if not pairs:
        raise SystemExit("No matching GT/Prediction/Label/Image quadruples found.")

    aligned_classic = []
    aligned_distance = []
    aligned_semantic = []
    aligned_shaded = []

    workers = max(1, int(args.workers))
    tasks = [
        (
            str(gt_path),
            str(pred_path),
            str(label_path),
            str(img_path),
            str(mask_root),
            args.take_inverse,
            args.eps,
            args.dark_threshold,
            args.kernel_size,
        )
        for gt_path, pred_path, label_path, img_path in pairs
    ]

    # Spawning independent validation tasks bypassing internal interpreter constraints for I/O operations inherently
    def _on_progress(completed: int, total: int) -> None:
        if completed % 100 == 0 or completed == total:
            print(f"Processed {completed}/{total}")

    for result in run_process_pool_as_completed(
        _process_single_pair,
        tasks,
        workers=workers,
        progress_callback=_on_progress,
    ):
        aligned_classic.append(result["classic"])
        aligned_distance.append(result["distance"])
        aligned_semantic.append(result["semantic"])
        aligned_shaded.append(result["shaded"])

    common_meta = {
        "dataset": "lusnar",
        "take_inverse": args.take_inverse,
        "num_pairs": len(pairs),
        "pred_dir": str(pred_root),
        "root": str(root),
        "max_depth": MAX_DEPTH,
    }

    classic_payload = {
        **common_meta,
        "experiment": "classic",
        "protocol_a": {"overall": _mean_dict(aligned_classic)},
        "protocol_b": None,
    }
    _write_experiment(output_dir, "classic", classic_payload)

    distance_payload = {
        **common_meta,
        "experiment": "distance",
        "distance_bins": {name: [lo, hi] for name, (lo, hi) in zip(BIN_NAMES, DISTANCE_BINS)},
        "protocol_a": _pack_distance(_mean_region_dict(aligned_distance)),
        "protocol_b": None,
    }
    _write_experiment(output_dir, "distance", distance_payload)

    shaded_payload = {
        **common_meta,
        "experiment": "shaded",
        "protocol_a": {"overall": _mean_dict(aligned_shaded)},
        "protocol_b": None,
    }
    _write_experiment(output_dir, "shaded", shaded_payload)

    semantic_payload = {
        **common_meta,
        "experiment": "semantic",
        "protocol_a": {"classes": _mean_region_dict(aligned_semantic)},
        "protocol_b": None,
    }
    _write_experiment(output_dir, "semantic", semantic_payload)

    print("Distance bins: near 0-5m, medium 5-15m, far 15-50m")
    print(f"Saved results to {output_dir}")


if __name__ == "__main__":
    main()
