import argparse

"""
Evaluation script mapping raw model predictions logically against strictly masked Ground Truth.
Implements least-squares scale alignment organically resolving monocular scaling ambiguities.
"""
import json
from pathlib import Path
import sys
import numpy as np

SCRIPT_DIR = Path(__file__).resolve().parent
sys.path.append(str(SCRIPT_DIR))

from core.alignment import least_squares_align
from core.convert import inverse_to_depth
from core.evaluator import apply_max_depth_mask, build_distance_masks, compute_metrics_by_region
from core.io import list_sorted_files, load_depth_npy, resize_to_match
from core.masks import load_or_create_dark_mask
from core.metrics import compute_metrics
from core.multiprocessing_utils import default_worker_count, run_process_pool_as_completed

GT_DIR_DEFAULT = "path/to/etna-dataset/depth"
IMG_DIR_DEFAULT = "path/to/etna-dataset/images"
MASK_DIR_DEFAULT = "path/to/etna-dataset/masks"
MAX_DEPTH = 15.0
DISTANCE_BINS = [(0.0, 5.0), (5.0, 10.0), (10.0, 15.0)]
BIN_NAMES = ["near", "medium", "far"]


def _mean_dict(dicts: list[dict]) -> dict:
    if not dicts:
        return {}
    keys = dicts[0].keys()
    out = {}
    for k in keys:
        vals = np.array([d[k] for d in dicts], dtype=np.float64)
        out[k] = float(np.nanmean(vals))
    return out


def _mean_region_dict(region_list: list[dict]) -> dict:
    if not region_list:
        return {}
    keys = region_list[0].keys()
    out = {}
    for k in keys:
        sub = [d[k] for d in region_list if k in d]
        out[k] = _mean_dict(sub)
    return out


def _write_experiment(output_dir: Path, experiment: str, payload: dict) -> None:
    exp_dir = output_dir / experiment
    exp_dir.mkdir(parents=True, exist_ok=True)
    json_path = exp_dir / "results.json"
    with json_path.open("w", encoding="utf-8") as f:
        json.dump(payload, f, indent=2)


def _valid_gt_mask(gt: np.ndarray) -> np.ndarray:
    return np.isfinite(gt) & (gt > 0)


def _pack_distance(results: dict) -> dict:
    return {
        "overall": results.get("overall", {}),
        "bins": {name: results.get(name, {}) for name in BIN_NAMES},
    }


def _process_single_pair(task: tuple[str, str, str | None, str, bool, float, int, int]) -> dict:
    gt_path_str, pred_path_str, img_path_str, mask_dir_str, take_inverse, eps, dark_threshold, kernel_size = task

    gt_path = Path(gt_path_str)
    pred_path = Path(pred_path_str)
    img_path = Path(img_path_str) if img_path_str is not None else None
    mask_dir = Path(mask_dir_str)

    gt = load_depth_npy(str(gt_path))
    pred = load_depth_npy(str(pred_path))
    pred = resize_to_match(pred, gt)

    # Invert disparity-like spaces systematically back into physical metric matrices allowing scale calculations
    if take_inverse:
        pred_depth, pred_valid = inverse_to_depth(pred, eps=eps)
    else:
        pred_depth = pred.astype(np.float64)
        pred_valid = np.isfinite(pred_depth)

    # Mask logically bounds evaluation against pure metric validity (e.g. dropping sky sentinels natively)
    valid_mask = _valid_gt_mask(gt) & pred_valid
    # Eliminate evaluation noise naturally corrupting distant background regions by asserting strict bounds
    valid_mask = apply_max_depth_mask(gt, valid_mask, MAX_DEPTH)

    # Solve scalar shifts dynamically recovering uniform topological metric predictions irrespective of training domains
    aligned_pred, _, _ = least_squares_align(pred_depth, gt, valid_mask)

    result = {
        "classic": compute_metrics(aligned_pred, gt, valid_mask),
        "distance": compute_metrics_by_region(
            aligned_pred,
            gt,
            valid_mask,
            build_distance_masks(gt, valid_mask, DISTANCE_BINS, BIN_NAMES),
        ),
        "shaded": None,
    }

    if img_path is not None:
        mask_path = mask_dir / f"{img_path.stem}.npy"
        shaded_mask = load_or_create_dark_mask(
            str(img_path),
            str(mask_path),
            threshold=dark_threshold,
            kernel_size=kernel_size,
        )
        shaded_base = valid_mask & shaded_mask
        result["shaded"] = compute_metrics(aligned_pred, gt, shaded_base)

    return result


def main() -> None:
    ap = argparse.ArgumentParser(description="Evaluate Etna predictions against NPY GT with least-squares alignment.")
    ap.add_argument("--pred-dir", required=True, help="Directory containing .npy predictions")
    ap.add_argument("--take-inverse", action="store_true", help="Interpret predictions as inverse depth and convert to depth")
    ap.add_argument("--output-dir", required=True, help="Directory to save evaluation results")
    ap.add_argument("--pred-pattern", default="*.npy", help="Glob pattern for prediction files")
    ap.add_argument("--gt-dir", default=GT_DIR_DEFAULT, help="Ground-truth NPY directory")
    ap.add_argument("--img-dir", default=IMG_DIR_DEFAULT, help="RGB image directory")
    ap.add_argument("--mask-dir", default=MASK_DIR_DEFAULT, help="Directory for cached dark masks")
    ap.add_argument("--dark-threshold", type=int, default=50, help="Dark pixel threshold")
    ap.add_argument("--kernel-size", type=int, default=9, help="Morphology kernel size for dark mask")
    ap.add_argument("--eps", type=float, default=1e-6, help="Small epsilon for inverse depth")
    ap.add_argument(
        "--workers",
        type=int,
        default=default_worker_count(),
        help="Number of worker processes (default: os.cpu_count() - 1, minimum 1)",
    )
    args = ap.parse_args()

    pred_dir = Path(args.pred_dir)
    output_dir = Path(args.output_dir)
    gt_dir = Path(args.gt_dir)
    img_dir = Path(args.img_dir)
    mask_dir = Path(args.mask_dir)

    pred_files = sorted(pred_dir.glob(args.pred_pattern))
    gt_files = list_sorted_files(str(gt_dir), (".npy",))
    img_files = list_sorted_files(str(img_dir), (".png",))

    if not pred_files:
        raise SystemExit(f"No prediction files found in {pred_dir} with pattern {args.pred_pattern}")
    if not gt_files:
        raise SystemExit(f"No GT NPY files found in {gt_dir}")
    if not img_files:
        print(f"[WARN] No RGB images found in {img_dir}; shaded metrics will be skipped.")

    n = min(len(pred_files), len(gt_files), len(img_files) if img_files else len(gt_files))
    if len(pred_files) != len(gt_files):
        print(f"[WARN] Count mismatch: preds={len(pred_files)} gt={len(gt_files)}; using first {n}.")
    if img_files and len(img_files) != len(gt_files):
        print(f"[WARN] Count mismatch: images={len(img_files)} gt={len(gt_files)}; using first {n}.")

    aligned_classic = []
    aligned_distance = []
    aligned_shaded = []

    tasks = [
        (
            str(gt_files[i]),
            str(pred_files[i]),
            str(img_files[i]) if img_files else None,
            str(mask_dir),
            args.take_inverse,
            args.eps,
            args.dark_threshold,
            args.kernel_size,
        )
        for i in range(n)
    ]

    workers = max(1, int(args.workers))

    # Spawning independent validation tasks bypassing internal interpreter constraints for I/O operations inherently
    def _on_progress(completed: int, total: int) -> None:
        if completed % 50 == 0 or completed == total:
            print(f"Processed {completed}/{total}")

    for result in run_process_pool_as_completed(
        _process_single_pair,
        tasks,
        workers=workers,
        progress_callback=_on_progress,
    ):
        aligned_classic.append(result["classic"])
        aligned_distance.append(result["distance"])
        if result["shaded"] is not None:
            aligned_shaded.append(result["shaded"])

    common_meta = {
        "dataset": "etna",
        "take_inverse": args.take_inverse,
        "num_pairs": n,
        "pred_dir": str(pred_dir),
        "gt_dir": str(gt_dir),
        "img_dir": str(img_dir),
        "max_depth": MAX_DEPTH,
    }

    classic_payload = {
        **common_meta,
        "experiment": "classic",
        "protocol_a": {"overall": _mean_dict(aligned_classic)},
        "protocol_b": None,
    }
    _write_experiment(output_dir, "classic", classic_payload)

    distance_payload = {
        **common_meta,
        "experiment": "distance",
        "distance_bins": {name: list(bin_range) for name, bin_range in zip(BIN_NAMES, DISTANCE_BINS)},
        "protocol_a": _pack_distance(_mean_region_dict(aligned_distance)),
        "protocol_b": None,
    }
    _write_experiment(output_dir, "distance", distance_payload)

    shaded_payload = {
        **common_meta,
        "experiment": "shaded",
        "num_pairs": len(aligned_shaded),
        "protocol_a": {"overall": _mean_dict(aligned_shaded)},
        "protocol_b": None,
    }
    _write_experiment(output_dir, "shaded", shaded_payload)

    print("Distance bins: near 0-5m, medium 5-10m, far 10-15m")
    print(f"Saved results to {output_dir}")


if __name__ == "__main__":
    main()
