#!/usr/bin/env python3
"""Evaluate pool solvers, population-bests, and EoH optimal heuristic across CVRP datasets.

Features:
- Source A (pool): evaluate all solvers from pools.json
- Source B (population): for each population_generation_*.json, pick best-by-objective individual as a solver
- Source C (optimal): evaluate EoH optimal heuristic code
- Datasets: all .vrp files with .sol files (with known optimal values)
- Robust parallel evaluation with fallback; resume and incremental saves after each row

Outputs (under --out_dir):
- trend_gap_table.csv (for pool/population modes)
- For optimal mode: eoh_optimal_gap_table.csv (saved to baseline_results/)
"""

import os
import sys
import json
import argparse
from typing import Dict, Any, List, Tuple, Optional
import numpy as np
import pandas as pd
import re
import threading

# Add project root to path for imports
_script_dir = os.path.dirname(os.path.abspath(__file__))
# Go up: testing -> cvrp -> problems -> heupsro -> project_root
_problem_dir = os.path.dirname(_script_dir)  # cvrp
_heupsro_dir = os.path.dirname(os.path.dirname(_problem_dir))  # heupsro
_project_root = os.path.dirname(_heupsro_dir)  # project root
_eoh_src_dir = os.path.join(_project_root, 'eoh', 'src')  # eoh/src for eoh module

# Insert paths in the correct order (most specific first)
if _eoh_src_dir not in sys.path:
    sys.path.insert(0, _eoh_src_dir)
if _heupsro_dir not in sys.path:
    sys.path.insert(0, _heupsro_dir)
if _project_root not in sys.path:
    sys.path.insert(0, _project_root)

from heupsro.core.config import HeuPSROConfig
from heupsro.problems.cvrp.evaluation import evaluate_solvers_on_instances_with_optcosts
from heupsro.problems.cvrp.initialization import get_nearest_neighbor_solver_code
from load_test_data import load_all_test_data

# Limit BLAS intra-op threads to avoid thread explosion and resource contention
os.environ.setdefault("OPENBLAS_NUM_THREADS", "1")
os.environ.setdefault("OMP_NUM_THREADS", "1")
os.environ.setdefault("MKL_NUM_THREADS", "1")
os.environ.setdefault("NUMEXPR_NUM_THREADS", "1")

# Also disable dynamic threading and set other vendor libs to single-thread
os.environ.setdefault("OMP_DYNAMIC", "FALSE")
os.environ.setdefault("MKL_DYNAMIC", "FALSE")
os.environ.setdefault("VECLIB_MAXIMUM_THREADS", "1")
os.environ.setdefault("BLIS_NUM_THREADS", "1")

# Make joblib behavior explicit for better stability across platforms
os.environ.setdefault("JOBLIB_START_METHOD", "spawn")

# Use memory-backed tmp for faster IO if available
os.environ.setdefault("TMPDIR", "/dev/shm")


def load_config_from_experiment(exp_dir: str) -> Optional[HeuPSROConfig]:
    """Load configuration from experiment directory if config.json exists."""
    config_path = os.path.join(exp_dir, "config.json")
    if os.path.exists(config_path):
        try:
            with open(config_path, "r") as f:
                config_data = json.load(f)
            # Create config from JSON data - load all fields
            cfg = HeuPSROConfig()
            for key, value in config_data.items():
                if hasattr(cfg, key):
                    setattr(cfg, key, value)
            return cfg
        except Exception as e:
            print(f"Warning: Failed to load config from {config_path}: {e}")
            return None
    return None


def parse_args() -> argparse.Namespace:
    _script_dir = os.path.dirname(os.path.abspath(__file__))
    _default_test_data_dir = os.path.join(_script_dir, 'TestingData')
    
    p = argparse.ArgumentParser(description="Evaluate pool solvers, population-bests, and/or EoH optimal heuristic")
    p.add_argument('--mode', type=str, required=True, choices=['population', 'pool', 'optimal'],
                   help='Evaluation mode: population (from exp_without_pool), pool (from exp_with_pool), or optimal (EoH optimal heuristic)')
    p.add_argument('--exp_with_pool', type=str, default=None, help='Experiment dir that has psro_results/pools.json (required for pool mode)')
    p.add_argument('--exp_without_pool', type=str, default=None, help='Experiment dir that has solver_eoh/results/pops (required for population mode)')
    p.add_argument(
        '--test_data_dir',
        type=str,
        default=_default_test_data_dir,
        help=(
            'Directory containing test instances. '
            'Default: TestingData root (includes all .vrp files with .sol files).'
        ),
    )
    p.add_argument('--limit_test_instances', type=int, default=None, help='Limit number of instances per dataset (default: None)')
    p.add_argument('--n_jobs', type=int, default=-1)
    p.add_argument('--backend', type=str, default='loky', choices=['loky','threading'])
    p.add_argument('--prefer', type=str, default='processes', choices=['processes','threads'])
    p.add_argument('--time_limit', type=int, default=30)
    p.add_argument('--resume', action='store_true', default=True, help='Resume from existing results (default: True, automatically enabled)')
    p.add_argument('--no-resume', dest='resume', action='store_false', help='Disable resume, force re-evaluation of all solvers')
    p.add_argument('--out_dir', type=str, default=None, help='Output dir (default: exp_with_pool/test, exp_without_pool/test, or this testing dir for optimal)')
    p.add_argument('--select', type=str, default='min', choices=['min','max'], help='Population best by objective')
    p.add_argument('--debug_ident', action='store_true')
    p.add_argument('--verbose', action='store_true', help='Enable verbose output')
    # Quick-experiment limits
    p.add_argument('--limit_pool_solvers', type=int, default=None, help='Evaluate only first N pool solvers')
    p.add_argument('--limit_populations', type=int, default=None, help='Evaluate only first N populations')
    p.add_argument('--exp_dir', type=str, default=None, help='Experiment dir to load config from (optional, for optimal mode)')
    return p.parse_args()


def natural_key(s: str) -> List[Any]:
    return [int(t) if t.isdigit() else t.lower() for t in re.split(r'(\d+)', s)]


def list_population_files(pops_dir: str) -> List[str]:
    files = [os.path.join(pops_dir, f) for f in os.listdir(pops_dir) if f.startswith('population_generation_') and f.endswith('.json')]
    files.sort(key=natural_key)
    return files


def load_population_best_individual(pop_file: str, *, select: str = 'min') -> Optional[Dict[str, Any]]:
    try:
        with open(pop_file, 'r', encoding='utf-8') as f:
            population = json.load(f)
        if not isinstance(population, list) or not population:
            return None
        if select == 'min':
            return min(population, key=lambda ind: ind.get('objective', float('inf')))
        return max(population, key=lambda ind: ind.get('objective', float('-inf')))
    except Exception:
        return None


def load_pools(exp_dir: str) -> Dict[str, Any]:
    pools_path = os.path.join(exp_dir, 'psro_results', 'pools.json')
    with open(pools_path, 'r') as f:
        return json.load(f)


def sanitize_column_name(dataset_label: str, instance_name: str) -> str:
    """Generate a safe column name for per-instance gap results."""
    safe_ds = re.sub(r'[^A-Za-z0-9_.-]+', '_', dataset_label)
    safe_inst = re.sub(r'[^A-Za-z0-9_.-]+', '_', instance_name)
    return f"{safe_ds}__{safe_inst}"


def build_instance_columns(datasets_dict: Dict[str, Tuple[List[Dict], List[float]]]) -> Tuple[Dict[str, List[str]], List[str]]:
    """
    For each dataset, build a list of column names (one per instance) and the global column list.
    """
    inst_cols_by_dataset: Dict[str, List[str]] = {}
    all_cols: List[str] = []
    seen: set = set()
    for label, (instances, _) in datasets_dict.items():
        cols: List[str] = []
        for idx, inst in enumerate(instances):
            inst_name = inst.get('instance_name') or f"inst{idx}"
            col = sanitize_column_name(label, inst_name)
            if col in seen:
                col = f"{col}_{idx}"
            seen.add(col)
            cols.append(col)
            all_cols.append(col)
        inst_cols_by_dataset[label] = cols
    return inst_cols_by_dataset, all_cols


def save_table(out_dir: str, table: Dict[str, Dict[str, float]], basename: str) -> None:
    """Save table as CSV only (JSON format removed)."""
    os.makedirs(out_dir, exist_ok=True)
    cp = os.path.join(out_dir, f"{basename}.csv")
    cols = sorted({k for row in table.values() for k in row.keys()})
    with open(cp, 'w') as f:
        f.write(",".join(["solver"] + cols) + "\n")
        for rk, row in table.items():
            vals = ["" if row.get(c) is None else f"{row.get(c):.6f}" for c in cols]
            f.write(",".join([rk] + vals) + "\n")


def load_existing_table(out_dir: str, basename: str) -> Optional[Dict[str, Dict[str, float]]]:
    """Load table from CSV (JSON format removed)."""
    cp = os.path.join(out_dir, f"{basename}.csv")
    if not os.path.exists(cp):
        return None
    try:
        df = pd.read_csv(cp)
        table = {}
        for _, row in df.iterrows():
            solver = row['solver']
            table[solver] = {col: row[col] for col in df.columns if col != 'solver' and pd.notna(row[col])}
        return table
    except Exception:
        return None


def main() -> None:
    args = parse_args()

    # Validate mode-specific arguments
    if args.mode == 'pool' and not args.exp_with_pool:
        raise ValueError("--exp_with_pool is required for pool mode")
    if args.mode == 'population' and not args.exp_without_pool:
        raise ValueError("--exp_without_pool is required for population mode")
    
    exp_with_pool = os.path.abspath(args.exp_with_pool) if args.exp_with_pool else None
    exp_without_pool = os.path.abspath(args.exp_without_pool) if args.exp_without_pool else None
    
    # Set default output directory based on mode
    _script_dir = os.path.dirname(os.path.abspath(__file__))
    if args.out_dir:
        out_dir = os.path.abspath(args.out_dir)
    else:
        if args.mode == 'optimal':
            # Baseline mode saves to baseline_results/
            baseline_results_dir = os.path.join(_script_dir, 'baseline_results')
            os.makedirs(baseline_results_dir, exist_ok=True)
            out_dir = baseline_results_dir
        elif args.mode == 'pool' and exp_with_pool:
            out_dir = os.path.join(exp_with_pool, 'test')
        elif args.mode == 'population' and exp_without_pool:
            out_dir = os.path.join(exp_without_pool, 'test')
        else:
            out_dir = _script_dir
    os.makedirs(out_dir, exist_ok=True)

    # try to load configuration from experiment directory
    if args.exp_dir:
        config_exp_dir = args.exp_dir
    elif args.mode == 'pool':
        config_exp_dir = exp_with_pool
    elif args.mode == 'population':
        config_exp_dir = exp_without_pool
    else:  # optimal mode default not load from experiment directory
        config_exp_dir = None
    config = load_config_from_experiment(config_exp_dir) if config_exp_dir else None
    
    # use parameters from config, priority: command line parameter > config parameter > default value
    if config:
        backend = args.backend if args.backend != 'threading' else getattr(config, 'parallel_backend', 'loky')
        prefer = args.prefer if args.prefer != 'threads' else getattr(config, 'parallel_prefer', 'processes')
        if hasattr(args, 'n_jobs') and args.n_jobs is not None:
            n_jobs = args.n_jobs
        else:
            n_jobs = getattr(config, 'test_n_jobs', None) or getattr(config, 'eval_n_jobs', -1)
        
        time_limit = args.time_limit if args.time_limit != 30 else (
            getattr(config, 'test_time_limit', None) or 
            getattr(config, 'instance_solver_time_limit', 30)
        )
        
        config_source = config_exp_dir if config_exp_dir else "unknown"
        print(f" Loaded configuration from {config_source}/config.json")
        print(f"   Using: time_limit={time_limit}")
        print(f"          n_jobs={n_jobs}, backend={backend}, prefer={prefer}")
        config.eval_n_jobs = n_jobs
    else:
        backend = args.backend
        prefer = args.prefer
        n_jobs = args.n_jobs
        time_limit = args.time_limit
        
        config_source = config_exp_dir if config_exp_dir else "command-line"
        print(f"  No config.json found, using {config_source} arguments")
        
        # Create default config if not created
        if config is None:
            from heupsro.problems.cvrp.config import CVRPConfig
            config = CVRPConfig()
            config.instance_solver_time_limit = time_limit
            config.parallel_backend = backend
            config.parallel_prefer = prefer
            config.eval_n_jobs = n_jobs
            config.oracle_type = 'none'  # do not use oracle, directly use provided optimal_costs
    
    # Load test instances using load_test_data module
    datasets_dict = load_all_test_data(test_data_dir=args.test_data_dir)
    
    if not datasets_dict:
        raise RuntimeError(
            f"No test instances loaded. Please check:\n"
            f"  - test_data_dir: {args.test_data_dir}\n"
            f"  - Ensure .vrp files have corresponding .sol files"
        )
    
    # Limit instances per dataset if requested
    if args.limit_test_instances is not None and args.limit_test_instances > 0:
        for label in datasets_dict:
            instances, opt_values = datasets_dict[label]
            if len(instances) > args.limit_test_instances:
                datasets_dict[label] = (instances[:args.limit_test_instances], opt_values[:args.limit_test_instances])
    
    dataset_labels = list(datasets_dict.keys())
    # Build per-instance column names for every dataset
    inst_cols_by_dataset, all_instance_cols = build_instance_columns(datasets_dict)
    print(f"\n✅ Loaded {len(dataset_labels)} datasets:")
    for label in dataset_labels:
        instances, opt_values = datasets_dict[label]
        print(f"  - {label}: {len(instances)} instances -> columns: {len(inst_cols_by_dataset[label])}")

    # Load sources based on mode
    all_row_keys = []
    all_solver_codes = []
    
    # Build result tables (only gap_table)
    gap_table: Dict[str, Dict[str, float]] = {}
    
    if args.mode == 'optimal':
        existing_table = load_existing_table(out_dir, 'eoh_optimal_gap_table')
        if existing_table:
            gap_table = existing_table
            print(f"✅ Loaded existing results from {out_dir}/eoh_optimal_gap_table.csv")
            print(f"   Found {len(gap_table)} solvers with existing results")
    else:
        existing_table = load_existing_table(out_dir, 'trend_gap_table')
        if existing_table:
            gap_table = existing_table
            print(f"✅ Loaded existing results from {out_dir}/trend_gap_table.csv")
            print(f"   Found {len(gap_table)} solvers with existing results")

    # 清理旧格式的列（如果存在），仅保留 per-instance 列
    if gap_table:
        valid_cols = set(all_instance_cols)
        for rk, row in gap_table.items():
            invalid_cols = [c for c in row.keys() if c not in valid_cols]
            for c in invalid_cols:
                row.pop(c, None)
    
    # A: pool solvers
    if args.mode == 'pool':
        pools = load_pools(exp_with_pool)
        pool_solvers = pools.get('solvers', [])
        if args.limit_pool_solvers is not None and args.limit_pool_solvers > 0:
            pool_solvers = pool_solvers[:args.limit_pool_solvers]
        
        for idx, solver in enumerate(pool_solvers):
            code = solver.get('code', '')
            if not code:
                continue
            row_key = f"h{idx}"
            all_row_keys.append(row_key)
            all_solver_codes.append(code)
            # ensure row exists, and initialize all dataset labels
            if row_key not in gap_table:
                gap_table[row_key] = {}
            for col in all_instance_cols:
                gap_table[row_key].setdefault(col, None)
    
    # B: population bests
    elif args.mode == 'population':
        pop_best = []
        if exp_without_pool:
            pops_dir = os.path.join(exp_without_pool, 'solver_eoh', 'results', 'pops')
            if os.path.isdir(pops_dir):
                pop_files = list_population_files(pops_dir)
                if args.limit_populations is not None and args.limit_populations > 0:
                    pop_files = pop_files[:args.limit_populations]
                for pf in pop_files:
                    ind = load_population_best_individual(pf, select=args.select)
                    if ind is not None:
                        # extract generation number from file name: population_generation_{i}.json -> i
                        filename = os.path.splitext(os.path.basename(pf))[0]
                        # extract generation number
                        match = re.search(r'population_generation_(\d+)', filename)
                        if match:
                            gen_num = int(match.group(1))
                            pop_best.append((gen_num, ind))
                        else:
                            # if cannot extract, use file name as fallback
                            pop_best.append((filename, ind))
        
        for gen_or_name, ind in pop_best:
            code = ind.get('code', '')
            if not code:
                continue
            # use h{i} format, consistent with pool mode
            if isinstance(gen_or_name, int):
                row_key = f"h{gen_or_name}"
            else:
                # if cannot extract generation number, use original format as fallback
                row_key = f"popbest_{gen_or_name}"
            all_row_keys.append(row_key)
            all_solver_codes.append(code)
            # ensure row exists, and initialize all dataset labels
            if row_key not in gap_table:
                gap_table[row_key] = {}
            for col in all_instance_cols:
                gap_table[row_key].setdefault(col, None)
    
    # C: EoH optimal heuristic
    elif args.mode == 'optimal':
        eoh_code = get_nearest_neighbor_solver_code()  # use nearest neighbor as baseline
        print(f"Using EoH optimal heuristic code (length: {len(eoh_code)} chars)")
        row_key = "eoh_optimal"
        all_row_keys.append(row_key)
        all_solver_codes.append(eoh_code)
        # ensure row exists, and initialize all dataset labels
        if row_key not in gap_table:
            gap_table[row_key] = {}
        for col in all_instance_cols:
            gap_table[row_key].setdefault(col, None)

    # Protect table updates/saves when using outer threading parallel
    tables_lock = threading.Lock()
    completed_counter = 0

    def ensure_row(row_key: str):
        if row_key not in gap_table:
            gap_table[row_key] = {col: None for col in all_instance_cols}
        else:
            for col in all_instance_cols:
                gap_table[row_key].setdefault(col, None)

    # count what needs to be evaluated
    total_combinations = len(all_row_keys) * len(all_instance_cols)
    already_computed = 0
    need_compute = 0
    
    for row_key in all_row_keys:
        for col in all_instance_cols:
            if row_key in gap_table and col in gap_table[row_key] and gap_table[row_key][col] is not None:
                already_computed += 1
            else:
                need_compute += 1
    
    print(f"\n Evaluation Summary:")
    print(f"   Total combinations: {total_combinations} (solvers × instances)")
    print(f"   Already computed: {already_computed}")
    print(f"   Need to compute: {need_compute}")
    
    if already_computed > 0 and args.resume:
        print(f"    Resume mode: Will skip {already_computed} already computed results")
    elif not args.resume:
        print(f"     Resume disabled: Will re-evaluate all {total_combinations} combinations")
    
    # Evaluate each dataset separately (per-instance gaps)
    for dataset_label in dataset_labels:
        test_instances, test_opt_values = datasets_dict[dataset_label]
        inst_cols = inst_cols_by_dataset[dataset_label]
        
        # Check which solvers need evaluation
        eval_row_keys = []
        eval_solver_codes = []
        skipped_solvers = []
        
        for solver_idx, (row_key, code) in enumerate(zip(all_row_keys, all_solver_codes)):
            if not args.resume:
                # do not enable resume, evaluate all
                eval_row_keys.append(row_key)
                eval_solver_codes.append(code)
            else:
                # enable resume, re-evaluate if any instance is missing
                missing_any = False
                for col in inst_cols:
                    if row_key not in gap_table or col not in gap_table[row_key] or gap_table[row_key][col] is None:
                        missing_any = True
                        break
                if missing_any:
                    eval_row_keys.append(row_key)
                    eval_solver_codes.append(code)
                else:
                    skipped_solvers.append(row_key)
        
        # Initialize table rows
        for row_key in all_row_keys:
            ensure_row(row_key)
        
        # print skipped information
        if skipped_solvers and args.resume:
            print(f"\n📋 Dataset '{dataset_label}':")
            print(f"   Skipping {len(skipped_solvers)} already computed solvers: {', '.join(skipped_solvers[:5])}{'...' if len(skipped_solvers) > 5 else ''}")

        # Batch evaluate
        if eval_row_keys and test_instances:
            print(f"\nEvaluating {len(eval_row_keys)} solvers on {len(test_instances)} instances from dataset '{dataset_label}' (per-instance gaps will be saved)")
            
            # first priority: check if optimal values are provided directly
            provided_optimal_count = sum(1 for opt in test_opt_values if opt is not None and np.isfinite(opt))
            total_instances = len(test_instances)
            
            if provided_optimal_count > 0:
                print(f"    Using provided optimal_costs: {provided_optimal_count}/{total_instances} instances have optimal values")
                print(f"    No oracle computation - directly using provided optimal_costs")
            else:
                print(f"     Warning: No valid optimal_costs provided for any instance")
            
            oracle_type = getattr(config, 'oracle_type', 'none')
            print(f"   Using batch evaluation with n_jobs={getattr(config, 'eval_n_jobs', -1)}, backend={getattr(config, 'parallel_backend', 'loky')}, oracle_type={oracle_type} (for gap calculation only)")
            
            # call evaluation function: directly use provided optimal_costs, no oracle computation
            raw_results = evaluate_solvers_on_instances_with_optcosts(
                solver_codes=eval_solver_codes,
                instances=test_instances,
                optimal_costs=test_opt_values,  # directly provided optimal values, use first
                config=config,
                return_format="raw",
            )
            
            print(f"    Batch evaluation completed, got {len(raw_results)} results")
            
            # Process results (per-instance)
            with tables_lock:
                for solver_idx, row_key in enumerate(eval_row_keys):
                    for inst_idx, col_name in enumerate(inst_cols):
                        key = (solver_idx, inst_idx)
                        if key in raw_results:
                            gap = raw_results[key]  # raw_results[key] directly is gap value
                            if gap is not None:
                                gap_table[row_key][col_name] = float(gap)
                    # print brief results of first few instances for observation
                    sample_cols = inst_cols[:3]
                    sample_vals = []
                    for col_name in sample_cols:
                        val = gap_table[row_key].get(col_name)
                        if isinstance(val, (int, float)):
                            sample_vals.append(f"{col_name}={val:.4f}%")
                    if sample_vals:
                        print(f"  {row_key}: " + ", ".join(sample_vals) + ("..." if len(inst_cols) > len(sample_cols) else ""))
            
            # save intermediate results (after each evaluation)
            if args.mode == 'optimal':
                save_table(out_dir, gap_table, 'eoh_optimal_gap_table')
            else:
                save_table(out_dir, gap_table, 'trend_gap_table')
        else:
            print(f"\n📋 Dataset '{dataset_label}': All {len(all_row_keys)} solvers already computed, skipping evaluation")

    # Final save (if there are any unsaved changes)
    if args.mode == 'optimal':
        # For optimal mode, save to baseline_results/
        save_table(out_dir, gap_table, 'eoh_optimal_gap_table')
        print(f"\n Final results saved to: {out_dir}")
        print(f"   - eoh_optimal_gap_table.csv")
    else:
        # For pool and population modes, save regular trend tables
        save_table(out_dir, gap_table, 'trend_gap_table')
        print(f"\n Final results saved to: {out_dir}")
        print(f"   - trend_gap_table.csv")


if __name__ == '__main__':
    main()

