#!/usr/bin/env python3
"""Unified entry-point to execute CVRP baseline suites.

Output format matches evaluate.py: per-instance gaps saved to CSV.
"""

from __future__ import annotations

import os
import sys
import argparse
import re
from typing import Dict, List, Tuple

# Add project root to path for imports
_script_dir = os.path.dirname(os.path.abspath(__file__))
# Go up: baselines -> testing -> cvrp -> problems -> heupsro -> project_root
_testing_dir = os.path.dirname(_script_dir)  # testing
_problem_dir = os.path.dirname(_testing_dir)  # cvrp
_problems_dir = os.path.dirname(_problem_dir)  # problems
_heupsro_dir = os.path.dirname(_problems_dir)  # heupsro
_project_root = os.path.dirname(_heupsro_dir)  # project root

# Insert paths in the correct order (most specific first)
if _heupsro_dir not in sys.path:
    sys.path.insert(0, _heupsro_dir)
if _project_root not in sys.path:
    sys.path.insert(0, _project_root)

import numpy as np

from heupsro.problems.cvrp.testing.load_test_data import load_all_test_data
from heupsro.problems.cvrp.testing.baselines import config_baselines, registry


def sanitize_column_name(dataset_label: str, instance_name: str) -> str:
    """Generate a safe column name for per-instance gap results."""
    safe_ds = re.sub(r'[^A-Za-z0-9_.-]+', '_', dataset_label)
    safe_inst = re.sub(r'[^A-Za-z0-9_.-]+', '_', instance_name)
    return f"{safe_ds}__{safe_inst}"


def build_instance_columns(
    datasets_dict: Dict[str, Tuple[List[Dict], List[float]]]
) -> Tuple[Dict[str, List[str]], List[str]]:
    """
    For each dataset, build a list of column names (one per instance) and the global column list.
    """
    inst_cols_by_dataset: Dict[str, List[str]] = {}
    all_cols: List[str] = []
    seen: set = set()
    for label, (instances, _) in datasets_dict.items():
        cols: List[str] = []
        for idx, inst in enumerate(instances):
            inst_name = inst.get('instance_name') or f"inst{idx}"
            col = sanitize_column_name(label, inst_name)
            if col in seen:
                col = f"{col}_{idx}"
            seen.add(col)
            cols.append(col)
            all_cols.append(col)
        inst_cols_by_dataset[label] = cols
    return inst_cols_by_dataset, all_cols


def save_table(out_dir: str, table: Dict[str, Dict[str, float]], basename: str) -> None:
    """Save table as CSV only (same format as evaluate.py)."""
    os.makedirs(out_dir, exist_ok=True)
    cp = os.path.join(out_dir, f"{basename}.csv")
    cols = sorted({k for row in table.values() for k in row.keys()})
    with open(cp, 'w') as f:
        f.write(",".join(["solver"] + cols) + "\n")
        for rk, row in table.items():
            vals = ["" if row.get(c) is None else f"{row.get(c):.6f}" for c in cols]
            f.write(",".join([rk] + vals) + "\n")


def parse_args() -> argparse.Namespace:
    base_dir = os.path.dirname(os.path.abspath(__file__))
    default_data = os.path.join(base_dir, "..", "TestingData")
    default_output = os.path.join(base_dir, "..", "baseline_results")
    parser = argparse.ArgumentParser(description="Run CVRP baseline methods.")
    parser.add_argument(
        "--test_data_dir",
        type=str,
        default=os.path.abspath(default_data),
        help="Directory containing CVRP benchmark files (.vrp and .sol).",
    )
    parser.add_argument(
        "--categories",
        type=str,
        default="heuristic",
        help="Comma-separated categories to include (heuristic,learning).",
    )
    parser.add_argument(
        "--methods",
        type=str,
        default=None,
        help="Comma-separated explicit method names (overrides categories).",
    )
    parser.add_argument(
        "--max_instances",
        type=int,
        default=None,
        help="Optional cap on instances per dataset.",
    )
    parser.add_argument(
        "--output_dir",
        type=str,
        default=default_output,
        help="Output directory for CSV results (default: baseline_results/).",
    )
    return parser.parse_args()


def resolve_methods(args: argparse.Namespace) -> List[registry.MethodSpec]:
    if args.methods:
        names = [m.strip() for m in args.methods.split(",") if m.strip()]
        return [registry.get(name) for name in names]

    selected = []
    categories = {c.strip() for c in args.categories.split(",") if c.strip()}
    cfg = config_baselines.default_config()
    family_map = {"heuristic": cfg.heuristics, "learning": cfg.learning}
    for cat in categories:
        for method_name in family_map.get(cat, []):
            selected.append(registry.get(method_name))
    return selected


def compute_gap(solution_cost: float, optimal_cost: float | None) -> float | None:
    """
    Compute gap percentage: (solution_cost - optimal_cost) / optimal_cost * 100
    
    Args:
        solution_cost: Solution cost from baseline
        optimal_cost: Optimal cost (from .sol file or oracle)
        
    Returns:
        Gap percentage, or None if optimal_cost is invalid
    """
    if optimal_cost is None or not np.isfinite(optimal_cost) or optimal_cost <= 0:
        return None
    
    if solution_cost < optimal_cost:
        # Solution better than optimal (shouldn't happen, but handle gracefully)
        return 0.0
    
    gap = (solution_cost - optimal_cost) / optimal_cost * 100.0
    return float(gap) if np.isfinite(gap) else None


def main() -> None:
    args = parse_args()
    
    # Load test data
    dataset = load_all_test_data(test_data_dir=args.test_data_dir)
    
    # Limit instances per dataset if requested
    if args.max_instances is not None:
        limited_dataset = {}
        for label, (instances, optimal_values) in dataset.items():
            limited_dataset[label] = (
                instances[:args.max_instances],
                optimal_values[:args.max_instances] if optimal_values else []
            )
        dataset = limited_dataset
    
    if not dataset:
        raise RuntimeError(f"No datasets found under {args.test_data_dir}")
    
    # Build column names (same format as evaluate.py)
    inst_cols_by_dataset, all_instance_cols = build_instance_columns(dataset)
    
    print(f"\n Loaded {len(dataset)} datasets:")
    for label in dataset.keys():
        instances, opt_values = dataset[label]
        num_with_opt = sum(1 for opt in opt_values if opt is not None and np.isfinite(opt))
        print(f"  - {label}: {len(instances)} instances ({num_with_opt} with optimal values)")
    
    # Resolve methods to run
    specs = resolve_methods(args)
    if not specs:
        raise RuntimeError("No baseline methods selected")
    
    # Initialize gap table (same format as evaluate.py)
    gap_table: Dict[str, Dict[str, float]] = {}
    
    # Run each baseline method
    for spec in specs:
        print(f"\n{'='*60}")
        print(f"Running baseline: {spec.name}")
        print(f"Description: {spec.description}")
        print(f"{'='*60}")
        
        handler = spec.handler
        if handler is None:
            print(f"❌ {spec.name} missing handler, skipping")
            continue
        
        # Initialize row for this baseline
        gap_table[spec.name] = {}
        for col in all_instance_cols:
            gap_table[spec.name][col] = None
        
        # Evaluate on each dataset
        for dataset_label, (instances, optimal_values) in dataset.items():
            inst_cols = inst_cols_by_dataset[dataset_label]
            
            print(f"  Dataset '{dataset_label}': {len(instances)} instances")
            
            for inst_idx, (inst, opt_cost) in enumerate(zip(instances, optimal_values)):
                col_name = inst_cols[inst_idx]
                
                try:
                    # Run baseline
                    res = handler(inst)
                    solution_cost = res.cost if hasattr(res, "cost") else float(res)
                    
                    # Compute gap
                    gap = compute_gap(solution_cost, opt_cost)
                    if gap is not None:
                        gap_table[spec.name][col_name] = gap
                    
                except Exception as e:
                    print(f"    Error on instance {inst.get('instance_name', f'inst{inst_idx}')}: {e}")
                    continue
        
        # Print summary for this baseline
        gaps = [g for g in gap_table[spec.name].values() if g is not None]
        if gaps:
            mean_gap = float(np.mean(gaps))
            print(f"   {spec.name}: mean gap = {mean_gap:.4f}% ({len(gaps)}/{len(all_instance_cols)} instances)")
        else:
            print(f"    {spec.name}: No valid gaps computed")
    
    # Save results to CSV (same format as evaluate.py)
    os.makedirs(args.output_dir, exist_ok=True)
    save_table(args.output_dir, gap_table, 'baseline_gap_table')
    
    print(f"\n Results saved to: {args.output_dir}/baseline_gap_table.csv")
    print(f"   Format matches evaluate.py: per-instance gaps for all baselines")


if __name__ == "__main__":
    main()

