#!/usr/bin/env python3
"""Command-line runner for TSP baseline methods."""

from __future__ import annotations

import argparse
import json
import os
from typing import Dict, List

import numpy as np

from . import config_baselines, registry
from ..utils import DatasetPack, load_tsplib_datasets


def parse_args() -> argparse.Namespace:
    base_dir = os.path.dirname(os.path.abspath(__file__))
    tsp_root = os.path.dirname(os.path.dirname(base_dir))
    default_pkl_dir = os.path.join(tsp_root, "testing", "TestingData")

    parser = argparse.ArgumentParser(description="Evaluate TSP baseline methods.")
    parser.add_argument(
        "--pkl_dir",
        type=str,
        default=default_pkl_dir,
        help="Directory containing TSPLIB-style PKL files.",
    )
    parser.add_argument(
        "--categories",
        type=str,
        default="heuristic",
        help="Comma-separated categories to include (heuristic, learning).",
    )
    parser.add_argument(
        "--methods",
        type=str,
        default=None,
        help="Explicit comma-separated method names (overrides categories).",
    )
    parser.add_argument(
        "--max_instances",
        type=int,
        default=None,
        help="Optional limit on instances per dataset.",
    )
    parser.add_argument(
        "--output",
        type=str,
        default=None,
        help="Optional path to store JSON results.",
    )
    return parser.parse_args()


def resolve_methods(args: argparse.Namespace) -> List[registry.MethodSpec]:
    if args.methods:
        names = [name.strip() for name in args.methods.split(",") if name.strip()]
        return [registry.get(name) for name in names]

    cfg = config_baselines.default_config()
    selected: List[registry.MethodSpec] = []
    family_map = {
        "heuristic": cfg.heuristics,
        "learning": cfg.learning,
    }
    categories = {cat.strip() for cat in args.categories.split(",") if cat.strip()}
    for cat in categories:
        for name in family_map.get(cat, []):
            selected.append(registry.get(name))
    return selected


def evaluate_method_on_datasets(
    spec: registry.MethodSpec,
    datasets: Dict[str, DatasetPack],
) -> Dict[str, Dict[str, float]]:
    if not spec.available:
        print(f"  Skipping {spec.fullname} ({spec.name}): marked unavailable.")
        return {}
    if spec.handler is None:
        raise RuntimeError(f"Method {spec.name} has no handler.")

    summary: Dict[str, Dict[str, float]] = {}
    for label, pack in datasets.items():
        matrices = pack.distance_matrices
        coords = pack.coordinates
        optimal_costs = pack.optimal_costs
        costs, gaps = [], []
        for idx, dist in enumerate(matrices):
            coord = coords[idx] if idx < len(coords) else None
            result = _run_handler(spec, dist, coord)
            costs.append(float(result.cost))
            opt = optimal_costs[idx] if idx < len(optimal_costs) else None
            if opt is not None and np.isfinite(opt) and opt > 0:
                gaps.append(100.0 * (result.cost - opt) / opt)
        summary[label] = {
            "mean_cost": float(np.mean(costs)) if costs else float("nan"),
            "mean_gap": float(np.mean(gaps)) if gaps else float("nan"),
        }
    return summary


def _run_handler(spec: registry.MethodSpec, dist: np.ndarray, coords: np.ndarray | None):
    """Safely invoke handler with/without coordinates."""
    if spec.requires_coords:
        if coords is None:
            raise ValueError(f"{spec.name} requires coordinates but got None.")
        return spec.handler(distance_matrix=dist, coordinates=coords)  # type: ignore[arg-type]
    try:
        return spec.handler(dist)  # type: ignore[arg-type]
    except TypeError:
        # Allow handlers that expect keyword args
        return spec.handler(distance_matrix=dist, coordinates=coords)  # type: ignore[arg-type]


def main() -> None:
    args = parse_args()
    datasets = load_tsplib_datasets(args.pkl_dir, args.max_instances)
    if not datasets:
        raise RuntimeError(f"No datasets loaded from {args.pkl_dir}")

    specs = resolve_methods(args)
    if not specs:
        raise RuntimeError("No baseline methods selected")

    results: Dict[str, Dict[str, Dict[str, float]]] = {}
    for spec in specs:
        if spec.handler is None:
            print(f"  Skipping {spec.fullname} ({spec.name}): handler not implemented.")
            continue
        stats = evaluate_method_on_datasets(spec, datasets)
        results[spec.name] = stats

        print(f"\n {spec.fullname} ({spec.name})")
        for label, metrics in stats.items():
            cost = metrics["mean_cost"]
            gap = metrics["mean_gap"]
            cost_str = "nan" if np.isnan(cost) else f"{cost:.2f}"
            gap_str = "nan" if np.isnan(gap) else f"{gap:.2f}%"
            print(f"  {label:<12} cost={cost_str:>8}  gap={gap_str:>8}")

    if args.output:
        with open(args.output, "w", encoding="utf-8") as f:
            json.dump(results, f, indent=2)
        print(f"\nResults saved to {args.output}")


if __name__ == "__main__":
    main()


