#!/usr/bin/env python3
"""
CVRP-specific find_h_best script

Find the best solver based on sigma_g (considering min_simple_ratio adjustment) in the experiment configuration.
Use the same method as during training to calculate the oracle (refer to the oracle configuration in the experiment config).
python heupsro/problems/cvrp/testing/find_h_best.py \
    --experiment /path/to/experiment \
    --round 8  # optional, default to use latest round
    --output h_best_results.csv  # optional, default to output to experiment/test/h_best_results.csv
"""

import os
import sys
import json
import argparse
import numpy as np
import pandas as pd
from pathlib import Path
from typing import List, Optional, Tuple

_script_dir = os.path.dirname(os.path.abspath(__file__))
_problem_dir = os.path.dirname(_script_dir)  # cvrp
_heupsro_dir = os.path.dirname(os.path.dirname(_problem_dir))  # heupsro
_project_root = os.path.dirname(_heupsro_dir)  # project root

if _project_root not in sys.path:
    sys.path.insert(0, _project_root)

from heupsro.experiments.resume_experiment import load_experiment_config
from heupsro.problems.cvrp.evolution.solver.prob import CVRPSolverProblem


def load_pools(experiment_dir: str) -> dict:
    pools_path = os.path.join(experiment_dir, "psro_results", "pools.json")
    if not os.path.exists(pools_path):
        raise FileNotFoundError(f"找不到 pools.json: {pools_path}")
    with open(pools_path, 'r', encoding='utf-8') as f:
        return json.load(f)


def load_sigma_g(experiment_dir: str, round_num: int = None) -> np.ndarray:
    psro_dir = os.path.join(experiment_dir, "psro_results")
    
    if round_num is not None:
        nash_file = os.path.join(psro_dir, f"nash_mixture_round_{round_num}.json")
        if not os.path.exists(nash_file):
            raise FileNotFoundError(f"nash_mixture_round_{round_num}.json not found: {nash_file}")
    else:
        # find latest round
        nash_files = [f for f in os.listdir(psro_dir) if f.startswith("nash_mixture_round_") and f.endswith(".json")]
        if not nash_files:
            raise FileNotFoundError(f"no nash_mixture_round_*.json files found in: {psro_dir}")
        nash_files.sort(key=lambda x: int(x.split("_")[-1].split(".")[0]), reverse=True)
        nash_file = os.path.join(psro_dir, nash_files[0])
        round_num = int(nash_files[0].split("_")[-1].split(".")[0])
        print(f"  using latest nash_mixture file: nash_mixture_round_{round_num}.json")
    
    with open(nash_file, 'r') as f:
        data = json.load(f)
    
    return np.array(data["generator_mixture"]["sigma_g"]), round_num


def adjust_sigma_g_with_min_ratio(
    sigma_g_full: np.ndarray,
    min_simple_ratio: float,
    psro_use_latest_only: bool
) -> np.ndarray:
    """
    adjust sigma_g, considering min_simple_ratio (refer to controller._evolve_solver_br)
    
    Args:
        sigma_g_full: full generator mixture (includes g0 at index 0)
        min_simple_ratio: minimum simple/baseline generator weight
        psro_use_latest_only: whether to only use latest strategy
    
    Returns:
        adjusted sigma_g
    """
    if psro_use_latest_only:
        # directly use the incoming distribution, skip simple/baseline special handling
        return sigma_g_full.copy()
    else:
        # adjust weights: force simple/baseline weight to min_simple_ratio, and allocate remaining weights proportionally
        simple_weight = min_simple_ratio
        remaining_weight = 1.0 - simple_weight
        
        # handle case with only simple/baseline generator
        if len(sigma_g_full) == 1:
            return np.array([1.0])
        else:
            adjusted_sigma_g_full = np.zeros_like(sigma_g_full)
            adjusted_sigma_g_full[0] = simple_weight
            
            evolved_weights = sigma_g_full[1:]
            if evolved_weights.sum() > 0:
                normalized_evolved = evolved_weights / evolved_weights.sum()
            else:
                normalized_evolved = np.ones(len(evolved_weights)) / len(evolved_weights) if len(evolved_weights) > 0 else evolved_weights
            adjusted_sigma_g_full[1:] = normalized_evolved * remaining_weight
            
            return adjusted_sigma_g_full


def find_h_best(experiment_dir: str, round_num: int = None) -> Tuple[str, float]:
    """
    find the best solver ID for given sigma_g (considering min_simple_ratio adjustment)
    
    use the same method as during training to calculate the oracle (refer to the oracle configuration in the experiment config).
    
    Args:
        experiment_dir: experiment directory
        round_num: which round to load sigma_g (if None, use latest)
    
    Returns:
        (best_solver_id, best_score)
    """
    print(f"\n{'='*60}")
    print(f"find h_best for experiment: {os.path.basename(experiment_dir)}")
    print(f"{'='*60}")
    
    # load configuration
    cfg = load_experiment_config(experiment_dir)
    
    # directly read oracle configuration from config.json (because load_experiment_config may not correctly read)
    config_json_path = os.path.join(experiment_dir, "config.json")
    oracle_type = 'ortools'  # default value
    oracle_timeout = 60  # default value
    if os.path.exists(config_json_path):
        with open(config_json_path, 'r') as f:
            config_data = json.load(f)
            oracle_type = config_data.get('oracle_type', oracle_type)
            oracle_timeout = config_data.get('oracle_timeout', oracle_timeout)
    
    # ensure config object has oracle_type and oracle_timeout attributes
    if not hasattr(cfg, 'oracle_type'):
        cfg.oracle_type = oracle_type
    if not hasattr(cfg, 'oracle_timeout'):
        cfg.oracle_timeout = oracle_timeout
    
    print(f"  configuration loaded successfully")
    print(f"    oracle_type: {oracle_type}")
    print(f"    oracle_timeout: {oracle_timeout}s")
    print(f"    num_customers: {getattr(cfg, 'num_customers', 100)}")
    print(f"    vehicle_capacity: {getattr(cfg, 'vehicle_capacity', 100)}")
    print(f"    eoh_eval_n_instances: {cfg.eoh_eval_n_instances}")
    
    # load pools
    pools_data = load_pools(experiment_dir)
    solvers = pools_data.get("solvers", [])
    generators = pools_data.get("generators", [])
    print(f"  pools loaded successfully: {len(solvers)} solvers, {len(generators)} generators")
    
    # load sigma_g
    sigma_g_full, actual_round = load_sigma_g(experiment_dir, round_num)
    print(f"  using round {actual_round} sigma_g")
    print(f"    sigma_g length: {len(sigma_g_full)}")
    print(f"    sigma_g sum: {sigma_g_full.sum():.6f}")
    
    # validate and adjust sigma_g length
    if len(sigma_g_full) != len(generators):
        print(f"    sigma_g length({len(sigma_g_full)}) != generators number({len(generators)}), adjusting")
        if len(sigma_g_full) < len(generators):
            sigma_g_full = np.pad(sigma_g_full, (0, len(generators) - len(sigma_g_full)), 'constant')
        else:
            sigma_g_full = sigma_g_full[:len(generators)]
        sigma_g_full = sigma_g_full / sigma_g_full.sum()
    
    # adjust sigma_g (considering min_simple_ratio)
    min_simple_ratio = getattr(cfg, 'min_simple_ratio', 0.3)
    psro_use_latest_only = getattr(cfg, 'psro_use_latest_only', False)
    adjusted_sigma_g = adjust_sigma_g_with_min_ratio(sigma_g_full, min_simple_ratio, psro_use_latest_only)
    print(f"  adjusted sigma_g:")
    print(f"    min_simple_ratio: {min_simple_ratio}")
    print(f"    psro_use_latest_only: {psro_use_latest_only}")
    print(f"    adjusted sigma_g sum: {adjusted_sigma_g.sum():.6f}")
    
    # filter generators with weight 0 (using the same threshold as in controller)
    weight_threshold = 1e-3
    generator_codes = []
    generator_ids = []
    filtered_weights = []
    
    for idx, (gen, weight) in enumerate(zip(generators, adjusted_sigma_g)):
        if weight > weight_threshold and gen.get("code"):
            generator_codes.append(gen["code"])
            generator_ids.append(idx)
            filtered_weights.append(weight)
    
    # if no generators after filtering, keep the first generator
    if len(generator_codes) == 0:
        if len(generators) > 0 and generators[0].get("code"):
            generator_codes = [generators[0]["code"]]
            generator_ids = [0]
            filtered_weights = [1.0]
        else:
            raise ValueError("no valid generators after filtering")
    
    # normalize weights
    filtered_weights = np.array(filtered_weights)
    filtered_weights = filtered_weights / filtered_weights.sum()
    
    print(f"  using generators: {len(generator_codes)}")
    print(f"    generator_ids: {generator_ids}")
    print(f"    filtered_weights: {filtered_weights}")
    
    # get solver codes
    solver_codes = [s["code"] for s in solvers]
    solver_ids = [s["program_id"] for s in solvers]
    
    print(f"\n evaluating {len(solver_codes)} solvers...")
    print(f"    solver_ids: {solver_ids}")
    
    # create CVRPSolverProblem (using oracle configuration in experiment config)
    problem = CVRPSolverProblem(
        config=cfg,
        generator_codes=generator_codes,
        generator_ids=generator_ids,
        generator_weights=filtered_weights,
        num_customers=getattr(cfg, 'num_customers', 100),
        n_inst_eva=cfg.eoh_eval_n_instances
    )
    
    # evaluate all solvers (using oracle configuration in config)
    print(f"  evaluating all solvers (using oracle_type={problem.gap_oracle}, timeout={problem.gap_oracle_timeout}s)...")
    scores = problem.evaluate_batch(solver_codes)
    
    # find the best solver (score越高越好)
    best_idx = np.argmin(scores)
    best_solver_id = solver_ids[best_idx]
    best_score = scores[best_idx]
    
    print(f"\n evaluation results:")
    for i, (solver_id, score) in enumerate(zip(solver_ids, scores)):
        marker = " <-- BEST" if i == best_idx else ""
        print(f"    {solver_id}: {score:.6f}{marker}")
    
    print(f"\n  found h_best: {best_solver_id} (score: {best_score:.6f})")
    
    return best_solver_id, best_score


def load_trend_gap_table(experiment_dir: str) -> pd.DataFrame:
    """load trend_gap_table.csv"""
    csv_path = os.path.join(experiment_dir, "test", "trend_gap_table.csv")
    if not os.path.exists(csv_path):
        raise FileNotFoundError(f"trend_gap_table.csv not found: {csv_path}")
    
    df = pd.read_csv(csv_path)
    if 'solver' not in df.columns:
        # try using the first column as solver
        df = pd.read_csv(csv_path, index_col=0).reset_index()
        if 'index' in df.columns:
            df = df.rename(columns={'index': 'solver'})
    
    return df


def extract_h_best_row(experiment_dir: str, round_num: int = None) -> Tuple[pd.Series, str]:
    """
    extract the row corresponding to h_best from the experiment
    
    Args:
        experiment_dir: experiment directory
        round_num: which round to load sigma_g (if None, use latest)
    
    Returns:
        (row corresponding to h_best, experiment name)
    """
    # find h_best
    h_best_id, _ = find_h_best(experiment_dir, round_num=round_num)
    
    # load trend_gap_table.csv
    df = load_trend_gap_table(experiment_dir)
    
    # find the row corresponding to h_best
    h_best_row = None
    
    # try to match program_id directly
    if h_best_id in df['solver'].values:
        h_best_row = df[df['solver'] == h_best_id].iloc[0]
    else:
        # try to match h{number} format
        if isinstance(h_best_id, str) and h_best_id.startswith('h'):
            try:
                h_num = int(h_best_id[1:])
                for idx, solver_name in enumerate(df['solver']):
                    if str(solver_name) == h_best_id or str(solver_name) == f'h{h_num}':
                        h_best_row = df.iloc[idx]
                        break
            except ValueError:
                pass
        
        # if still not found, try to find the row containing the ID
        if h_best_row is None:
            matching_rows = df[df['solver'].astype(str).str.contains(str(h_best_id), na=False)]
            if len(matching_rows) > 0:
                h_best_row = matching_rows.iloc[0]
    
    if h_best_row is None:
        raise ValueError(f"row corresponding to h_best={h_best_id} not found in trend_gap_table.csv")
    
    # extract experiment name
    exp_name = os.path.basename(experiment_dir.rstrip('/'))
    
    return h_best_row, exp_name


def extract_from_trend_gap_table(experiment_dir: str) -> Tuple[pd.Series, str]:
    """
    extract the last row from trend_gap_table.csv (for eoh experiment)
    
    Args:
        experiment_dir: experiment directory
    
    Returns:
        (last row data, experiment name)
    """
    df = load_trend_gap_table(experiment_dir)
    if len(df) == 0:
        raise ValueError(f"trend_gap_table.csv is empty: {experiment_dir}/test/")
    
    last_row = df.iloc[-1].copy()
    exp_name = os.path.basename(experiment_dir.rstrip('/'))
    
    return last_row, exp_name


def main():
    parser = argparse.ArgumentParser(
        description="CVRP specific find_h_best script: find the best solver based on sigma_g"
    )
    parser.add_argument(
        '--experiments',
        type=str,
        nargs='+',
        default=[],
        help='PSRO experiment directory paths (can be multiple), run find_h_best for these experiments'
    )
    parser.add_argument(
        '--eoh',
        type=str,
        nargs='+',
        default=[],
        help='EoH experiment directory paths (can be multiple), extract the last row from trend_gap_table for these experiments'
    )
    parser.add_argument(
        '--round',
        type=int,
        default=None,
        help='which round to load sigma_g (only used for experiments, default: latest round)'
    )
    parser.add_argument(
        '--output',
        type=str,
        default=None,
        help='output CSV file path (default: test/h_best_results.csv in the first experiment directory)'
    )
    
    args = parser.parse_args()
    
    if not args.experiments and not args.eoh:
        parser.error("at least one of --experiments or --eoh is required")
    
    all_rows = []
    
    # process PSRO experiments (using find_h_best)
    for exp_dir in args.experiments:
        experiment_dir = os.path.abspath(exp_dir)
        if not os.path.exists(experiment_dir):
            print(f"  experiment directory not found: {experiment_dir}")
            continue
        
        exp_name = os.path.basename(experiment_dir.rstrip('/'))
        print(f"\n processing PSRO experiment: {exp_name}")
        
        try:
            h_best_row, _ = extract_h_best_row(experiment_dir, round_num=args.round)
            h_best_row['experiment'] = exp_name
            all_rows.append(h_best_row)
            print(f"   completed")
        except Exception as e:
            print(f"   failed: {e}")
            import traceback
            traceback.print_exc()
    
    # process EoH experiments (extract the last row from trend_gap_table)
    for exp_dir in args.eoh:
        experiment_dir = os.path.abspath(exp_dir)
        if not os.path.exists(experiment_dir):
            print(f"  experiment directory not found: {experiment_dir}")
            continue
        
        exp_name = os.path.basename(experiment_dir.rstrip('/'))
        print(f"\n processing EoH experiment: {exp_name}")
        
        try:
            h_best_row, _ = extract_from_trend_gap_table(experiment_dir)
            h_best_row['experiment'] = exp_name
            all_rows.append(h_best_row)
            print(f"   completed, solver: {h_best_row['solver']}")
        except Exception as e:
            print(f"   failed: {e}")
            import traceback
            traceback.print_exc()
    
    if not all_rows:
        print("\n no successful experiments processed")
        return
    
    # merge all results
    result_df = pd.DataFrame(all_rows)
    
    # determine output path
    if args.output:
        output_path = args.output
    elif args.experiments:
        output_path = os.path.join(os.path.abspath(args.experiments[0]), 'test', 'h_best_results.csv')
    else:
        output_path = os.path.join(os.path.abspath(args.eoh[0]), 'test', 'h_best_results.csv')
    
    os.makedirs(os.path.dirname(output_path), exist_ok=True)
    result_df.to_csv(output_path, index=False)
    print(f"\n results saved to: {output_path}")
    print(f"    contains {len(result_df)} experiments")


if __name__ == '__main__':
    main()

