#!/usr/bin/env python3
"""
BP Online专用的find_h_best脚本

根据实验配置中的sigma_g（考虑min_simple_ratio调整），找到表现最好的solver。
使用与训练时相同的方法来求oracle（参照实验config中的oracle配置）。

用法:
  python heupsro/problems/bp_online/testing/find_h_best.py \
    --experiments /path/to/exp1 /path/to/exp2 ... \
    --eoh /path/to/eoh_exp1 /path/to/eoh_exp2 ... \
    --round 8  # 可选，默认使用最新的round
"""

import os
import sys
import json
import argparse
import numpy as np
import pandas as pd
from pathlib import Path
from typing import List, Optional, Tuple

# 添加项目路径
_script_dir = os.path.dirname(os.path.abspath(__file__))
_problem_dir = os.path.dirname(_script_dir)  # bp_online
_heupsro_dir = os.path.dirname(os.path.dirname(_problem_dir))  # heupsro
_project_root = os.path.dirname(_heupsro_dir)  # project root

if _project_root not in sys.path:
    sys.path.insert(0, _project_root)

from heupsro.experiments.resume_experiment import load_experiment_config
from heupsro.problems.bp_online.evolution.solver.prob import BPOnlineSolverProblem


def load_pools(experiment_dir: str) -> dict:
    """加载 pools.json"""
    pools_path = os.path.join(experiment_dir, "psro_results", "pools.json")
    if not os.path.exists(pools_path):
        raise FileNotFoundError(f"找不到 pools.json: {pools_path}")
    with open(pools_path, 'r', encoding='utf-8') as f:
        return json.load(f)


def load_sigma_g(experiment_dir: str, round_num: int = None) -> Tuple[np.ndarray, int]:
    """从 nash_mixture 文件加载 sigma_g"""
    psro_dir = os.path.join(experiment_dir, "psro_results")
    
    if round_num is not None:
        nash_file = os.path.join(psro_dir, f"nash_mixture_round_{round_num}.json")
        if not os.path.exists(nash_file):
            raise FileNotFoundError(f"找不到 nash_mixture_round_{round_num}.json: {nash_file}")
    else:
        # 找最新的 round
        nash_files = [f for f in os.listdir(psro_dir) if f.startswith("nash_mixture_round_") and f.endswith(".json")]
        if not nash_files:
            raise FileNotFoundError(f"找不到任何 nash_mixture_round_*.json 文件在: {psro_dir}")
        nash_files.sort(key=lambda x: int(x.split("_")[-1].split(".")[0]), reverse=True)
        nash_file = os.path.join(psro_dir, nash_files[0])
        round_num = int(nash_files[0].split("_")[-1].split(".")[0])
        print(f"  使用最新的 nash_mixture 文件: nash_mixture_round_{round_num}.json")
    
    with open(nash_file, 'r') as f:
        data = json.load(f)
    
    return np.array(data["generator_mixture"]["sigma_g"]), round_num


def adjust_sigma_g_with_min_ratio(
    sigma_g_full: np.ndarray,
    min_simple_ratio: float,
    psro_use_latest_only: bool
) -> np.ndarray:
    """
    调整 sigma_g，考虑 min_simple_ratio（参考 controller._evolve_solver_br）
    """
    if psro_use_latest_only:
        return sigma_g_full.copy()
    else:
        simple_weight = min_simple_ratio
        remaining_weight = 1.0 - simple_weight
        
        if len(sigma_g_full) == 1:
            return np.array([1.0])
        else:
            adjusted_sigma_g_full = np.zeros_like(sigma_g_full)
            adjusted_sigma_g_full[0] = simple_weight
            
            evolved_weights = sigma_g_full[1:]
            if evolved_weights.sum() > 0:
                normalized_evolved = evolved_weights / evolved_weights.sum()
            else:
                normalized_evolved = np.ones(len(evolved_weights)) / len(evolved_weights) if len(evolved_weights) > 0 else evolved_weights
            adjusted_sigma_g_full[1:] = normalized_evolved * remaining_weight
            
            return adjusted_sigma_g_full


def find_h_best(experiment_dir: str, round_num: int = None) -> Tuple[str, float]:
    """
    找出对于给定 sigma_g（考虑 min_simple_ratio 调整）表现最好的 solver ID
    """
    print(f"\n{'='*60}")
    print(f"查找 h_best for experiment: {os.path.basename(experiment_dir)}")
    print(f"{'='*60}")
    
    # 加载配置
    cfg = load_experiment_config(experiment_dir)
    
    # 直接从config.json读取oracle配置
    config_json_path = os.path.join(experiment_dir, "config.json")
    oracle_type = 'lb'  # BP默认值
    oracle_timeout = 0  # BP默认值
    if os.path.exists(config_json_path):
        with open(config_json_path, 'r') as f:
            config_data = json.load(f)
            oracle_type = config_data.get('oracle_type', oracle_type)
            oracle_timeout = config_data.get('oracle_timeout', oracle_timeout)
    
    if not hasattr(cfg, 'oracle_type'):
        cfg.oracle_type = oracle_type
    if not hasattr(cfg, 'oracle_timeout'):
        cfg.oracle_timeout = oracle_timeout
    
    print(f"  配置加载成功")
    print(f"    oracle_type: {oracle_type}")
    print(f"    oracle_timeout: {oracle_timeout}s")
    print(f"    bp_capacity: {getattr(cfg, 'bp_capacity', 100)}")
    print(f"    bp_num_items: {getattr(cfg, 'bp_num_items', 100)}")
    print(f"    eoh_eval_n_instances: {cfg.eoh_eval_n_instances}")
    
    # 加载 pools
    pools_data = load_pools(experiment_dir)
    solvers = pools_data.get("solvers", [])
    generators = pools_data.get("generators", [])
    print(f"  Pools加载成功: {len(solvers)} solvers, {len(generators)} generators")
    
    # 加载 sigma_g
    sigma_g_full, actual_round = load_sigma_g(experiment_dir, round_num)
    print(f"  使用 round {actual_round} 的 sigma_g")
    
    # 验证并调整 sigma_g 长度
    if len(sigma_g_full) != len(generators):
        if len(sigma_g_full) < len(generators):
            sigma_g_full = np.pad(sigma_g_full, (0, len(generators) - len(sigma_g_full)), 'constant')
        else:
            sigma_g_full = sigma_g_full[:len(generators)]
        sigma_g_full = sigma_g_full / sigma_g_full.sum()
    
    # 调整 sigma_g（考虑 min_simple_ratio）
    min_simple_ratio = getattr(cfg, 'min_simple_ratio', 0.3)
    psro_use_latest_only = getattr(cfg, 'psro_use_latest_only', False)
    adjusted_sigma_g = adjust_sigma_g_with_min_ratio(sigma_g_full, min_simple_ratio, psro_use_latest_only)
    
    # 过滤权重为0的 generator
    weight_threshold = 1e-3
    generator_codes = []
    generator_ids = []
    filtered_weights = []
    
    for idx, (gen, weight) in enumerate(zip(generators, adjusted_sigma_g)):
        if weight > weight_threshold and gen.get("code"):
            generator_codes.append(gen["code"])
            generator_ids.append(idx)
            filtered_weights.append(weight)
    
    if len(generator_codes) == 0:
        if len(generators) > 0 and generators[0].get("code"):
            generator_codes = [generators[0]["code"]]
            generator_ids = [0]
            filtered_weights = [1.0]
        else:
            raise ValueError("没有有效的 generator")
    
    # 归一化权重
    filtered_weights = np.array(filtered_weights)
    filtered_weights = filtered_weights / filtered_weights.sum()
    
    # 获取 solver codes
    solver_codes = [s["code"] for s in solvers]
    solver_ids = [s["program_id"] for s in solvers]
    
    print(f"\n  评估 {len(solver_codes)} 个solvers...")
    
    # 创建 BPOnlineSolverProblem
    # capacity 和 num_items 现在从 config 中获取，不需要作为参数传递
    problem = BPOnlineSolverProblem(
        config=cfg,
        generator_codes=generator_codes,
        generator_ids=generator_ids,
        generator_weights=filtered_weights,
        n_inst_eva=cfg.eoh_eval_n_instances
    )
    
    # 评估所有 solvers
    print(f"  开始批量评估（使用oracle_type={problem.gap_oracle}, timeout={problem.gap_oracle_timeout}s）...")
    scores = problem.evaluate_batch(solver_codes)
    
    # 找出最好的 solver（分数越低越好）
    best_idx = np.argmin(scores)
    best_solver_id = solver_ids[best_idx]
    best_score = scores[best_idx]
    
    print(f"\n  评估结果:")
    for i, (solver_id, score) in enumerate(zip(solver_ids, scores)):
        marker = " <-- BEST" if i == best_idx else ""
        print(f"    {solver_id}: {score:.6f}{marker}")
    
    print(f"\n  ✅ 找到 h_best: {best_solver_id} (score: {best_score:.6f})")
    
    return best_solver_id, best_score


def load_trend_gap_table(experiment_dir: str) -> pd.DataFrame:
    """加载 trend_gap_table.csv"""
    csv_path = os.path.join(experiment_dir, "test", "trend_gap_table.csv")
    if not os.path.exists(csv_path):
        raise FileNotFoundError(f"找不到 trend_gap_table.csv: {csv_path}")
    
    df = pd.read_csv(csv_path)
    if 'solver' not in df.columns:
        df = pd.read_csv(csv_path, index_col=0).reset_index()
        if 'index' in df.columns:
            df = df.rename(columns={'index': 'solver'})
    
    return df


def extract_h_best_row(experiment_dir: str, round_num: int = None) -> Tuple[pd.Series, str]:
    """从实验中提取 h_best 对应的行"""
    h_best_id, _ = find_h_best(experiment_dir, round_num=round_num)
    df = load_trend_gap_table(experiment_dir)
    
    h_best_row = None
    if h_best_id in df['solver'].values:
        h_best_row = df[df['solver'] == h_best_id].iloc[0]
    else:
        if isinstance(h_best_id, str) and h_best_id.startswith('h'):
            try:
                h_num = int(h_best_id[1:])
                for idx, solver_name in enumerate(df['solver']):
                    if str(solver_name) == h_best_id or str(solver_name) == f'h{h_num}':
                        h_best_row = df.iloc[idx]
                        break
            except ValueError:
                pass
        
        if h_best_row is None:
            matching_rows = df[df['solver'].astype(str).str.contains(str(h_best_id), na=False)]
            if len(matching_rows) > 0:
                h_best_row = matching_rows.iloc[0]
    
    if h_best_row is None:
        raise ValueError(f"在 trend_gap_table.csv 中找不到 h_best={h_best_id} 对应的行")
    
    exp_name = os.path.basename(experiment_dir.rstrip('/'))
    return h_best_row, exp_name


def extract_from_trend_gap_table(experiment_dir: str) -> Tuple[pd.Series, str]:
    """从trend_gap_table.csv提取最后一行（用于eoh实验）"""
    df = load_trend_gap_table(experiment_dir)
    if len(df) == 0:
        raise ValueError(f"trend_gap_table.csv 为空: {experiment_dir}/test/")
    
    last_row = df.iloc[-1].copy()
    exp_name = os.path.basename(experiment_dir.rstrip('/'))
    return last_row, exp_name


def main():
    parser = argparse.ArgumentParser(
        description="BP Online专用的find_h_best脚本：根据sigma_g找到表现最好的solver"
    )
    parser.add_argument(
        '--experiments',
        type=str,
        nargs='+',
        default=[],
        help='PSRO实验目录路径（可以多个），对这些实验运行find_h_best'
    )
    parser.add_argument(
        '--eoh',
        type=str,
        nargs='+',
        default=[],
        help='EoH实验目录路径（可以多个），对这些实验取trend_gap_table最后一行'
    )
    parser.add_argument(
        '--round',
        type=int,
        default=None,
        help='从哪个round加载sigma_g（仅用于experiments，默认：最新的round）'
    )
    parser.add_argument(
        '--output',
        type=str,
        default=None,
        help='输出CSV文件路径（默认：第一个experiment目录下的test/h_best_results.csv）'
    )
    
    args = parser.parse_args()
    
    if not args.experiments and not args.eoh:
        parser.error("至少需要指定 --experiments 或 --eoh 之一")
    
    all_rows = []
    
    # 处理PSRO实验（使用find_h_best）
    for exp_dir in args.experiments:
        experiment_dir = os.path.abspath(exp_dir)
        if not os.path.exists(experiment_dir):
            print(f"⚠️  实验目录不存在: {experiment_dir}")
            continue
        
        exp_name = os.path.basename(experiment_dir.rstrip('/'))
        print(f"\n处理PSRO实验: {exp_name}")
        
        try:
            h_best_row, _ = extract_h_best_row(experiment_dir, round_num=args.round)
            h_best_row['experiment'] = exp_name
            all_rows.append(h_best_row)
            print(f"  ✅ 完成")
        except Exception as e:
            print(f"  ❌ 失败: {e}")
            import traceback
            traceback.print_exc()
    
    # 处理EoH实验（取trend_gap_table最后一行）
    for exp_dir in args.eoh:
        experiment_dir = os.path.abspath(exp_dir)
        if not os.path.exists(experiment_dir):
            print(f"⚠️  实验目录不存在: {experiment_dir}")
            continue
        
        exp_name = os.path.basename(experiment_dir.rstrip('/'))
        print(f"\n处理EoH实验: {exp_name}")
        
        try:
            h_best_row, _ = extract_from_trend_gap_table(experiment_dir)
            h_best_row['experiment'] = exp_name
            all_rows.append(h_best_row)
            print(f"  ✅ 完成，solver: {h_best_row['solver']}")
        except Exception as e:
            print(f"  ❌ 失败: {e}")
            import traceback
            traceback.print_exc()
    
    if not all_rows:
        print("\n❌ 没有成功处理任何实验")
        return
    
    # 合并所有结果
    result_df = pd.DataFrame(all_rows)
    
    # 确定输出路径
    if args.output:
        output_path = args.output
    elif args.experiments:
        output_path = os.path.join(os.path.abspath(args.experiments[0]), 'test', 'h_best_results.csv')
    else:
        output_path = os.path.join(os.path.abspath(args.eoh[0]), 'test', 'h_best_results.csv')
    
    os.makedirs(os.path.dirname(output_path), exist_ok=True)
    result_df.to_csv(output_path, index=False)
    print(f"\n✅ 结果已保存到: {output_path}")
    print(f"   共 {len(result_df)} 个实验")


if __name__ == '__main__':
    main()






