#!/usr/bin/env python3
"""Evaluate TSP GLS solvers (pool, population, or heuristic baselines)."""

from __future__ import annotations

import argparse
import json
import os
import re
import sys
from typing import Any, Dict, Iterable, List, Optional, Tuple

import numpy as np
import pandas as pd

_SCRIPT_DIR = os.path.dirname(os.path.abspath(__file__))
_PROBLEM_DIR = os.path.dirname(_SCRIPT_DIR)
_HEUPSRO_DIR = os.path.dirname(os.path.dirname(_PROBLEM_DIR))
_PROJECT_ROOT = os.path.dirname(_HEUPSRO_DIR)

for path in (_PROJECT_ROOT, _HEUPSRO_DIR):
    if path not in sys.path:
        sys.path.insert(0, path)

from heupsro.core.config import HeuPSROConfig
from heupsro.problems.tsp_gls.config import TSPGLSConfig
from heupsro.problems.tsp_gls.evaluation import evaluate_solvers_on_instances_with_optcosts
from heupsro.problems.tsp_gls.testing.baselines import config_baselines, registry
from heupsro.problems.tsp_gls.testing.utils import DatasetPack, load_tsplib_datasets, save_table
from heupsro.problems.tsp_gls.initialization import get_eoh_optimal_code


def parse_args() -> argparse.Namespace:
    default_pkl_dir = os.path.join(_SCRIPT_DIR, "TestingData")
    parser = argparse.ArgumentParser(description="Evaluate TSP GLS solvers.")
    parser.add_argument(
        "--mode",
        type=str,
        required=True,
        choices=["population", "pool", "heuristic", "optimal"],
        help="Evaluation mode.",
    )
    parser.add_argument("--exp_with_pool", type=str, default=None, help="Experiment dir containing psro_results/pools.json (pool mode).")
    parser.add_argument("--exp_without_pool", type=str, default=None, help="Experiment dir containing solver_eoh/results/pops (population mode).")
    parser.add_argument("--pkl_dir", type=str, default=default_pkl_dir, help="Directory of TSPLIB PKL datasets.")
    parser.add_argument("--max_instances", type=int, default=None, help="Cap number of instances per dataset.")
    parser.add_argument("--max_size", type=int, default=None, help="Skip datasets with problem size (cities) larger than this.")
    parser.add_argument("--out_dir", type=str, default=None, help="Custom output directory.")
    parser.add_argument("--methods", type=str, default=None, help="Comma-separated heuristic method names (heuristic mode).")
    parser.add_argument("--categories", type=str, default="heuristic", help="Comma-separated categories (heuristic mode).")
    parser.add_argument("--limit_pool_solvers", type=int, default=None, help="Evaluate first N pool solvers.")
    parser.add_argument("--limit_populations", type=int, default=None, help="Evaluate first N populations.")
    parser.add_argument("--select", type=str, default="min", choices=["min", "max"], help="Population selection objective.")
    parser.add_argument("--n_jobs", type=int, default=-1)
    parser.add_argument("--backend", type=str, default="loky", choices=["loky", "threading"])
    parser.add_argument("--prefer", type=str, default="processes", choices=["processes", "threads"])
    parser.add_argument("--time_limit", type=int, default=60, help="Instance solver time limit.")
    parser.add_argument("--ite_max", type=int, default=1000, help="Max solver iterations.")
    parser.add_argument("--perturb_moves", type=int, default=1, help="Perturbation moves for GLS.")
    parser.add_argument("--oracle_timeout", type=int, default=30, help="Oracle timeout seconds.")
    parser.add_argument("--save_cost_table", action="store_true", help="Also save mean cost table.")
    parser.add_argument("--methods_basename", type=str, default=None, help="Optional override for output basename.")
    return parser.parse_args()


def natural_key(s: str) -> List[Any]:
    return [int(t) if t.isdigit() else t.lower() for t in re.split(r"(\d+)", s)]


def list_population_files(pops_dir: str) -> List[str]:
    files = [
        os.path.join(pops_dir, f)
        for f in os.listdir(pops_dir)
        if f.startswith("population_generation_") and f.endswith(".json")
    ]
    files.sort(key=natural_key)
    return files


def load_population_best_individual(pop_file: str, select: str = "min") -> Optional[Dict[str, Any]]:
    try:
        with open(pop_file, "r", encoding="utf-8") as f:
            population = json.load(f)
        if not isinstance(population, list) or not population:
            return None
        if select == "min":
            return min(population, key=lambda ind: ind.get("objective", float("inf")))
        return max(population, key=lambda ind: ind.get("objective", float("-inf")))
    except Exception:
        return None


def load_pools(exp_dir: str) -> Dict[str, Any]:
    pools_path = os.path.join(exp_dir, "psro_results", "pools.json")
    with open(pools_path, "r", encoding="utf-8") as f:
        return json.load(f)


def resolve_heuristic_specs(args: argparse.Namespace) -> List[registry.MethodSpec]:
    if args.methods:
        names = [name.strip() for name in args.methods.split(",") if name.strip()]
        return [registry.get(name) for name in names]
    cfg = config_baselines.default_config()
    chosen = []
    categories = {cat.strip() for cat in args.categories.split(",") if cat.strip()}
    family = {"heuristic": cfg.heuristics, "learning": cfg.learning}
    for cat in categories:
        for name in family.get(cat, []):
            chosen.append(registry.get(name))
    return chosen


def ensure_row(table: Dict[str, Dict[str, Optional[float]]], row_key: str, labels: Iterable[str]) -> None:
    if row_key not in table:
        table[row_key] = {}
    for label in labels:
        table[row_key].setdefault(label, None)


def load_existing_table(out_dir: str, basename: str) -> Optional[Dict[str, Dict[str, float]]]:
    cp = os.path.join(out_dir, f"{basename}.csv")
    if not os.path.exists(cp):
        return None
    try:
        df = pd.read_csv(cp)
        table: Dict[str, Dict[str, float]] = {}
        for _, row in df.iterrows():
            solver = row["solver"]
            table[solver] = {
                col: float(row[col])
                for col in df.columns
                if col != "solver" and pd.notna(row[col])
            }
        return table
    except Exception:
        return None


def evaluate_python_heuristic(
    spec: registry.MethodSpec,
    datasets: Dict[str, DatasetPack],
    labels: List[str],
) -> Tuple[Dict[str, float], Dict[str, float]]:
    gap_row: Dict[str, float] = {}
    cost_row: Dict[str, float] = {}
    handler = spec.handler
    if handler is None or not spec.available:
        raise ValueError(f"Handler missing for {spec.name}")
    if labels and datasets:
        test_label = labels[0]
        test_pack = datasets.get(test_label)
        if test_pack and test_pack.distance_matrices:
            test_result = _invoke_handler(spec, handler, test_pack.distance_matrices[0], test_pack.coordinates[0] if test_pack.coordinates else None)
            if test_result.extras and test_result.extras.get("note") == "fallback_nearest_insertion":
                print(f"  {spec.fullname} ({spec.name}) fallback to nearest insertion, skip evaluation")
                return {}, {}
    for label in labels:
        pack = datasets[label]
        costs, gaps = [], []
        for idx, dist in enumerate(pack.distance_matrices):
            coord = pack.coordinates[idx] if idx < len(pack.coordinates) else None
            result = _invoke_handler(spec, handler, dist, coord)
            costs.append(float(result.cost))
            opt = pack.optimal_costs[idx] if idx < len(pack.optimal_costs) else None
            if opt is not None and np.isfinite(opt) and opt > 0:
                gaps.append(100.0 * (result.cost - opt) / opt)
        cost_row[label] = float(np.mean(costs)) if costs else None
        gap_row[label] = float(np.mean(gaps)) if gaps else None
    return gap_row, cost_row


def evaluate_solver_codes(
    row_keys: List[str],
    solver_codes: List[str],
    datasets: Dict[str, DatasetPack],
    labels: List[str],
    config: HeuPSROConfig,
) -> Tuple[Dict[Tuple[int, int], Tuple[Optional[float], Optional[float]]], List[str]]:
    if not solver_codes:
        return {}, []
    all_instances = []
    all_optimal_costs = []
    instance_to_label = []
    for label in labels:
        pack = datasets[label]
        for coords, opt in zip(pack.coordinates, pack.optimal_costs):
            all_instances.append(coords)
            all_optimal_costs.append(opt)
            instance_to_label.append(label)
    if not all_instances:
        return {}, []
    results = evaluate_solvers_on_instances_with_optcosts(
        solver_codes=solver_codes,
        instances=all_instances,
        optimal_costs=all_optimal_costs,
        config=config,
        return_format="raw",
    )
    return results, instance_to_label


def aggregate_solver_results(
    row_keys: List[str],
    labels: List[str],
    raw_results: Dict[Tuple[int, int], Tuple[Optional[float], Optional[float]]],
    instance_to_label: List[str],
) -> Tuple[Dict[str, Dict[str, Optional[float]]], Dict[str, Dict[str, Optional[float]]]]:
    label_indices: Dict[str, List[int]] = {}
    for idx, label in enumerate(instance_to_label):
        label_indices.setdefault(label, []).append(idx)

    gap_updates: Dict[str, Dict[str, Optional[float]]] = {}
    cost_updates: Dict[str, Dict[str, Optional[float]]] = {}
    for solver_idx, row_key in enumerate(row_keys):
        gap_updates[row_key] = {}
        cost_updates[row_key] = {}
        for label in labels:
            inst_indices = label_indices.get(label, [])
            costs, gaps = [], []
            for inst_idx in inst_indices:
                entry = raw_results.get((solver_idx, inst_idx))
                if not entry:
                    continue
                cost, gap = entry
                if cost is not None and np.isfinite(cost):
                    costs.append(float(cost))
                if gap is not None and np.isfinite(gap):
                    gaps.append(float(gap))
            cost_updates[row_key][label] = float(np.mean(costs)) if costs else None
            gap_updates[row_key][label] = float(np.mean(gaps)) if gaps else None
    return gap_updates, cost_updates


def _invoke_handler(
    spec: registry.MethodSpec,
    handler,
    dist: np.ndarray,
    coords: np.ndarray | None,
):
    """Call handler with optional coordinates support."""
    if spec.requires_coords:
        if coords is None:
            raise ValueError(f"{spec.name} requires coordinates but got None")
        return handler(distance_matrix=dist, coordinates=coords)
    try:
        return handler(dist)
    except TypeError:
        return handler(distance_matrix=dist, coordinates=coords)


def build_config(args: argparse.Namespace) -> TSPGLSConfig:
    cfg = TSPGLSConfig()
    cfg.instance_solver_time_limit = args.time_limit
    cfg.tsp_solver_max_iterations = args.ite_max
    cfg.tsp_solver_perturbation_moves = args.perturb_moves
    cfg.oracle_timeout = args.oracle_timeout
    cfg.parallel_backend = args.backend
    cfg.parallel_prefer = args.prefer
    cfg.eval_n_jobs = args.n_jobs
    return cfg


def main() -> None:
    args = parse_args()

    # normalize experiment path, determine output directory in advance, for resume support
    exp_with_pool = os.path.abspath(args.exp_with_pool) if args.exp_with_pool else None
    exp_without_pool = os.path.abspath(args.exp_without_pool) if args.exp_without_pool else None

    if args.out_dir:
        out_dir = os.path.abspath(args.out_dir)
    else:
        if args.mode == "pool" and exp_with_pool:
            out_dir = os.path.join(exp_with_pool, "test")
        elif args.mode == "population" and exp_without_pool:
            out_dir = os.path.join(exp_without_pool, "test")
        elif args.mode == "optimal":
            out_dir = os.path.join(_SCRIPT_DIR, "baseline_results")
        else:
            out_dir = _SCRIPT_DIR

    datasets = load_tsplib_datasets(args.pkl_dir, args.max_instances, args.max_size)
    if not datasets:
        raise RuntimeError(f"No datasets loaded from {args.pkl_dir}")

    # automatically try to load existing results from CSV, for "incremental evaluation":
    # - for (solver, dataset) already in CSV, no need to re-evaluate
    # - for new datasets in TestingData/tspuniform (e.g. TSP100), will automatically add new columns
    if args.mode in ("pool", "population"):
        existing = load_existing_table(out_dir, "trend_gap_table")
        gap_table: Dict[str, Dict[str, Optional[float]]] = existing or {}
        if existing is not None:
            print(f" Loaded existing results from {out_dir}/trend_gap_table.csv (auto-resume)")
    elif args.mode == "optimal":
        existing = load_existing_table(out_dir, "eoh_optimal_gap_table")
        gap_table = existing or {}
        if existing is not None:
            print(f" Loaded existing results from {out_dir}/eoh_optimal_gap_table.csv (auto-resume)")
    elif args.mode == "heuristic":
        existing = load_existing_table(out_dir, "heuristic_gap_table")
        gap_table = existing or {}
        if existing is not None:
            print(f" Loaded existing results from {out_dir}/heuristic_gap_table.csv (auto-resume)")
    else:
        gap_table = {}

    all_labels = list(datasets.keys())

    existing_labels: List[str] = []
    if gap_table:
        existing_labels = sorted({c for row in gap_table.values() for c in row.keys()})

    uniform_labels: List[str] = []
    tspuniform_dir = os.path.join(args.pkl_dir, "tspuniform")
    for lbl in all_labels:
        p_candidate = os.path.join(tspuniform_dir, f"{lbl}.pkl")
        if os.path.exists(p_candidate):
            uniform_labels.append(lbl)

    if existing_labels:
        labels_to_eval = []
        
        labels_to_eval.extend([lbl for lbl in all_labels if lbl in set(uniform_labels)])
        
        for lbl in all_labels:
            if lbl not in existing_labels and lbl not in uniform_labels:
                labels_to_eval.append(lbl)
        
        labels = labels_to_eval if labels_to_eval else all_labels  
    else:
        labels = all_labels

    all_datasets = datasets.copy()
    datasets = {k: v for k, v in datasets.items() if k in labels}

    config = build_config(args)
    cost_table: Dict[str, Dict[str, Optional[float]]] = {}

    if args.mode == "pool":
        if not exp_with_pool:
            raise ValueError("--exp_with_pool is required for pool mode")
        pools = load_pools(exp_with_pool)
        solvers = pools.get("solvers", [])
        if args.limit_pool_solvers is not None:
            solvers = solvers[: args.limit_pool_solvers]
        row_keys, solver_codes = [], []
        for idx, solver in enumerate(solvers):
            code = solver.get("code")
            if not code:
                continue
            row_key = f"h{idx}"
            row_keys.append(row_key)
            solver_codes.append(code)
        eval_row_keys: List[str] = []
        eval_solver_codes: List[str] = []
        has_new_solver = False
        for rk, code in zip(row_keys, solver_codes):
            if rk not in gap_table:
                eval_row_keys.append(rk)
                eval_solver_codes.append(code)
                has_new_solver = True
            else:
                row_vals = gap_table.get(rk, {})
                need_eval = any(
                    (lbl not in row_vals) or (row_vals.get(lbl) is None)
                    for lbl in labels
                )
                if need_eval:
                    eval_row_keys.append(rk)
                    eval_solver_codes.append(code)

        if eval_row_keys:
            if has_new_solver:
                eval_labels = all_labels
                eval_datasets = all_datasets
                for rk in eval_row_keys:
                    if rk not in gap_table:
                        ensure_row(gap_table, rk, eval_labels)
                        ensure_row(cost_table, rk, eval_labels)
            else:
                eval_labels = labels
                eval_datasets = datasets
                for rk in eval_row_keys:
                    ensure_row(gap_table, rk, eval_labels)
                    ensure_row(cost_table, rk, eval_labels)
            
            total_instances = sum(len(pack.coordinates) for pack in eval_datasets.values())
            print(
                f" Evaluating {len(eval_row_keys)} solvers "
                f"on {len(eval_labels)} datasets ({total_instances} total instances)..."
            )
            raw_results, instance_to_label = evaluate_solver_codes(
                eval_row_keys, eval_solver_codes, eval_datasets, eval_labels, config
            )
            print(" Evaluation complete, processing results...")
            gap_updates, cost_updates = aggregate_solver_results(
                eval_row_keys, eval_labels, raw_results, instance_to_label
            )
            for rk in eval_row_keys:
                gap_table.setdefault(rk, {}).update(gap_updates.get(rk, {}))
                cost_table.setdefault(rk, {}).update(cost_updates.get(rk, {}))
        else:
            print(" All pool solvers already evaluated for all datasets, skipping evaluation.")
        basename = args.methods_basename or "trend_gap_table"

    elif args.mode == "population":
        if not exp_without_pool:
            raise ValueError("--exp_without_pool is required for population mode")
        pops_dir = os.path.join(exp_without_pool, "solver_eoh", "results", "pops")
        if not os.path.isdir(pops_dir):
            raise FileNotFoundError(f"{pops_dir} not found")
        pop_files = list_population_files(pops_dir)
        if args.limit_populations is not None:
            pop_files = pop_files[: args.limit_populations]
        row_keys, solver_codes = [], []
        for pf in pop_files:
            ind = load_population_best_individual(pf, select=args.select)
            if ind is None:
                continue
            code = ind.get("code")
            if not code:
                continue
            filename = os.path.splitext(os.path.basename(pf))[0]
            match = re.search(r'population_generation_(\d+)', filename)
            if match:
                gen_num = int(match.group(1))
                row_key = f"h{gen_num}"
            else:
                row_key = f"pop_{filename}"
            row_keys.append(row_key)
            solver_codes.append(code)
        eval_row_keys = []
        eval_solver_codes = []
        has_new_solver = False
        for rk, code in zip(row_keys, solver_codes):
            if rk not in gap_table:
                eval_row_keys.append(rk)
                eval_solver_codes.append(code)
                has_new_solver = True
            else:
                row_vals = gap_table.get(rk, {})
                need_eval = any(
                    (lbl not in row_vals) or (row_vals.get(lbl) is None)
                    for lbl in labels
                )
                if need_eval:
                    eval_row_keys.append(rk)
                    eval_solver_codes.append(code)

        if eval_row_keys:
            if has_new_solver:
                eval_labels = all_labels
                eval_datasets = all_datasets
                for rk in eval_row_keys:
                    if rk not in gap_table:
                        ensure_row(gap_table, rk, eval_labels)
                        ensure_row(cost_table, rk, eval_labels)
            else:
                eval_labels = labels
                eval_datasets = datasets
                for rk in eval_row_keys:
                    ensure_row(gap_table, rk, eval_labels)
                    ensure_row(cost_table, rk, eval_labels)
            
            total_instances = sum(len(pack.coordinates) for pack in eval_datasets.values())
            print(
                f" Evaluating {len(eval_row_keys)} population solvers "
                f"on {len(eval_labels)} datasets ({total_instances} total instances)..."
            )
            raw_results, instance_to_label = evaluate_solver_codes(
                eval_row_keys, eval_solver_codes, eval_datasets, eval_labels, config
            )
            print(" Evaluation complete, processing results...")
            gap_updates, cost_updates = aggregate_solver_results(
                eval_row_keys, eval_labels, raw_results, instance_to_label
            )
            for rk in eval_row_keys:
                gap_table.setdefault(rk, {}).update(gap_updates.get(rk, {}))
                cost_table.setdefault(rk, {}).update(cost_updates.get(rk, {}))
        else:
            print(" All population solvers already evaluated for all datasets, skipping evaluation.")
        basename = args.methods_basename or "trend_gap_table"

    elif args.mode == "optimal":
        row_key = "eoh_optimal"
        if row_key in gap_table:
            row_vals = gap_table.get(row_key, {})
            need_eval = any(
                (lbl not in row_vals) or (row_vals.get(lbl) is None)
                for lbl in labels
            )
            if not need_eval:
                print(" EoH optimal already evaluated for all datasets, skipping evaluation.")
                basename = args.methods_basename or "eoh_optimal_gap_table"
                if not os.path.isdir(out_dir):
                    os.makedirs(out_dir, exist_ok=True)
                save_table(out_dir, gap_table, basename)
                if args.save_cost_table:
                    save_table(out_dir, cost_table, basename.replace("gap", "cost"))
                print(f"\n Saved results to {out_dir} ({basename})")
                return

        eoh_code = get_eoh_optimal_code()
        print(f"Using EoH optimal heuristic code (length: {len(eoh_code)} chars)")
        row_keys = [row_key]
        solver_codes = [eoh_code]
        ensure_row(gap_table, row_key, labels)
        ensure_row(cost_table, row_key, labels)
        total_instances = sum(len(pack.coordinates) for pack in datasets.values())
        print(f" Evaluating EoH optimal on {len(labels)} datasets ({total_instances} total instances)...")
        raw_results, instance_to_label = evaluate_solver_codes(row_keys, solver_codes, datasets, labels, config)
        print(" Evaluation complete, processing results...")
        gap_updates, cost_updates = aggregate_solver_results(row_keys, labels, raw_results, instance_to_label)
        gap_table[row_key].update(gap_updates.get(row_key, {}))
        cost_table[row_key].update(cost_updates.get(row_key, {}))
        basename = args.methods_basename or "eoh_optimal_gap_table"

    else:  # heuristic mode
        specs = resolve_heuristic_specs(args)
        if not specs:
            raise RuntimeError("No heuristic methods selected")
        solver_specs = [spec for spec in specs if spec.requires_solver and spec.code_provider]
        python_specs = [spec for spec in specs if not spec.requires_solver and spec.handler]

        for spec in python_specs:
            if not spec.available:
                print(f"  Skip {spec.fullname} ({spec.name}): marked unavailable.")
                continue
            ensure_row(gap_table, spec.name, labels)
            ensure_row(cost_table, spec.name, labels)
            gap_row, cost_row = evaluate_python_heuristic(spec, datasets, labels)
            gap_table[spec.name].update(gap_row)
            cost_table[spec.name].update(cost_row)

        if solver_specs:
            solver_specs = [spec for spec in solver_specs if spec.available]
            row_keys = [spec.name for spec in solver_specs]
            solver_codes = [spec.code_provider() for spec in solver_specs]  # type: ignore[call-arg]
            total_instances = sum(len(pack.coordinates) for pack in datasets.values())
            print(f" Evaluating {len(row_keys)} heuristic solvers on {len(labels)} datasets ({total_instances} total instances)...")
            eval_row_keys = []
            eval_solver_codes = []
            for rk, code in zip(row_keys, solver_codes):
                if rk not in gap_table:
                    eval_row_keys.append(rk)
                    eval_solver_codes.append(code)
                else:
                    row_vals = gap_table.get(rk, {})
                    need_eval = any(
                        (lbl not in row_vals) or (row_vals.get(lbl) is None)
                        for lbl in labels
                    )
                    if need_eval:
                        eval_row_keys.append(rk)
                        eval_solver_codes.append(code)

            if eval_row_keys:
                total_instances = sum(len(pack.coordinates) for pack in datasets.values())
                print(
                    f" Evaluating {len(eval_row_keys)} heuristic solvers "
                    f"on {len(labels)} datasets ({total_instances} total instances)..."
                )
                raw_results, instance_to_label = evaluate_solver_codes(
                    eval_row_keys, eval_solver_codes, datasets, labels, config
                )
                print(" Evaluation complete, processing results...")
                gap_updates, cost_updates = aggregate_solver_results(
                    eval_row_keys, labels, raw_results, instance_to_label
                )
                for rk in eval_row_keys:
                    ensure_row(gap_table, rk, labels)
                    ensure_row(cost_table, rk, labels)
                    gap_table[rk].update(gap_updates.get(rk, {}))
                    cost_table[rk].update(cost_updates.get(rk, {}))
            else:
                print(" All heuristic solvers already evaluated for all datasets, skipping evaluation.")
        basename = args.methods_basename or "heuristic_gap_table"

    os.makedirs(out_dir, exist_ok=True)

    save_table(out_dir, gap_table, basename)
    if args.save_cost_table:
        save_table(out_dir, cost_table, basename.replace("gap", "cost"))

    print(f"\n Saved results to {out_dir} ({basename})")


if __name__ == "__main__":
    main()


