#!/usr/bin/env python3
"""Run BP Online PSRO experiments."""

import argparse
import os
import time
import json
import shutil
from datetime import datetime
from typing import Optional

import sys

# Add the project root to Python path
project_root = os.path.dirname(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))))
sys.path.insert(0, project_root)

# Add eoh/src to path for eoh module imports
eoh_src_dir = os.path.join(project_root, 'eoh', 'src')
if eoh_src_dir not in sys.path:
    sys.path.insert(0, eoh_src_dir)

# Limit BLAS intra-op threads to avoid thread explosion and resource contention
os.environ.setdefault("OPENBLAS_NUM_THREADS", "1")
os.environ.setdefault("OMP_NUM_THREADS", "1")
os.environ.setdefault("MKL_NUM_THREADS", "1")
os.environ.setdefault("NUMEXPR_NUM_THREADS", "1")
os.environ.setdefault("OMP_DYNAMIC", "FALSE")
os.environ.setdefault("MKL_DYNAMIC", "FALSE")
os.environ.setdefault("VECLIB_MAXIMUM_THREADS", "1")
os.environ.setdefault("BLIS_NUM_THREADS", "1")
os.environ.setdefault("JOBLIB_START_METHOD", "spawn")
os.environ.setdefault("TMPDIR", "/dev/shm")

from heupsro.core.controller import HeuPSROController
from heupsro.core.resume import clean_files_after_round
from heupsro.problems.bp_online.config import BPOnlineConfig
from heupsro.problems.bp_online.adapter import BPOnlineProblemAdapter


def _str2bool(v):
    """Convert string to boolean."""
    if isinstance(v, bool):
        return v
    if isinstance(v, str):
        v = v.lower()
    if v in ("yes", "true", "t", "1", "y"):
        return True
    elif v in ("no", "false", "f", "0", "n"):
        return False
    else:
        raise argparse.ArgumentTypeError("Boolean value expected.")


def create_experiment_dir(experiment_name: str) -> str:
    """Create experiment directory with timestamp."""
    timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
    exp_dir = f"experiments/{experiment_name}_{timestamp}"
    os.makedirs(exp_dir, exist_ok=True)
    return exp_dir


def save_experiment_config(cfg: BPOnlineConfig, exp_dir: str) -> None:
    """Save experiment configuration."""
    config_path = os.path.join(exp_dir, "config.json")
    with open(config_path, "w") as f:
        json.dump(cfg.__dict__, f, indent=2)


def copy_experiment_state(source_exp_dir: str, target_exp_dir: str, up_to_round: int) -> None:
    """Copy experiment state from source to target, including all files up to specified round."""
    os.makedirs(target_exp_dir, exist_ok=True)
    
    # Copy config.json
    source_config = os.path.join(source_exp_dir, "config.json")
    if os.path.exists(source_config):
        shutil.copy2(source_config, os.path.join(target_exp_dir, "config.json"))
    
    # Copy progress.log
    source_progress = os.path.join(source_exp_dir, "progress.log")
    if os.path.exists(source_progress):
        shutil.copy2(source_progress, os.path.join(target_exp_dir, "progress.log"))
    
    # Copy console.log
    source_console = os.path.join(source_exp_dir, "console.log")
    target_console = os.path.join(target_exp_dir, "console.log")
    if os.path.exists(source_console):
        shutil.copy2(source_console, target_console)
    
    # Copy experiment_metadata.json
    source_meta = os.path.join(source_exp_dir, "experiment_metadata.json")
    if os.path.exists(source_meta):
        shutil.copy2(source_meta, os.path.join(target_exp_dir, "experiment_metadata.json"))
    
    # Copy diversity_report.json if exists
    source_diversity = os.path.join(source_exp_dir, "diversity_report.json")
    if os.path.exists(source_diversity):
        shutil.copy2(source_diversity, os.path.join(target_exp_dir, "diversity_report.json"))
    
    # Copy EOH directories
    checkpoint_path = os.path.join(source_exp_dir, "psro_results", f"checkpoint_round_{up_to_round}.json")
    max_solver_generation = None
    max_generator_generation = None
    
    if os.path.exists(checkpoint_path):
        try:
            with open(checkpoint_path, 'r') as f:
                checkpoint_data = json.load(f)
                eoh_state = checkpoint_data.get("eoh_state", {})
                max_solver_generation = eoh_state.get("solver_eoh_generation", None)
                max_generator_generation = eoh_state.get("generator_eoh_generation", None)
        except Exception as e:
            print(f"   Warning: Could not read checkpoint: {e}")
    
    for eoh_dir in ["solver_eoh", "generator_eoh"]:
        source_eoh = os.path.join(source_exp_dir, eoh_dir)
        target_eoh = os.path.join(target_exp_dir, eoh_dir)
        
        if os.path.exists(source_eoh):
            if os.path.exists(target_eoh):
                shutil.rmtree(target_eoh)
            shutil.copytree(source_eoh, target_eoh, dirs_exist_ok=False)
            
            max_generation = max_solver_generation if eoh_dir == "solver_eoh" else max_generator_generation
            pops_dir = os.path.join(target_eoh, "results", "pops")
            
            if max_generation is not None and os.path.exists(pops_dir):
                for pop_file in os.listdir(pops_dir):
                    if pop_file.startswith("population_generation_") and pop_file.endswith(".json"):
                        try:
                            gen_num = int(pop_file.split("_")[2].split(".")[0])
                            if gen_num > max_generation:
                                os.remove(os.path.join(pops_dir, pop_file))
                        except (ValueError, IndexError):
                            pass
    
    # Copy logs directory
    source_logs = os.path.join(source_exp_dir, "logs")
    target_logs = os.path.join(target_exp_dir, "logs")
    if os.path.exists(source_logs):
        if os.path.exists(target_logs):
            shutil.rmtree(target_logs)
        shutil.copytree(source_logs, target_logs)
    
    # Copy psro_results directory
    target_psro = os.path.join(target_exp_dir, "psro_results")
    os.makedirs(target_psro, exist_ok=True)
    
    source_psro = os.path.join(source_exp_dir, "psro_results")
    if os.path.exists(source_psro):
        if os.path.exists(target_psro) and os.listdir(target_psro):
            for item in os.listdir(target_psro):
                item_path = os.path.join(target_psro, item)
                if os.path.isdir(item_path):
                    shutil.rmtree(item_path)
                else:
                    os.remove(item_path)
        
        for item in os.listdir(source_psro):
            source_item = os.path.join(source_psro, item)
            target_item = os.path.join(target_psro, item)
            if os.path.isdir(source_item):
                shutil.copytree(source_item, target_item, dirs_exist_ok=True)
            else:
                shutil.copy2(source_item, target_item)
    
    clean_files_after_round(target_exp_dir, up_to_round)
    print(f" Copied experiment state up to round {up_to_round}")


def log_experiment_progress(exp_dir: str, round_num: int, controller: HeuPSROController, 
                          start_time: float, end_time: float) -> None:
    """Log experiment progress."""
    log_path = os.path.join(exp_dir, "progress.log")
    
    n_solvers = controller.pools.n_solvers
    n_generators = controller.pools.n_generators
    utility_matrix_shape = controller.meta.utilities.shape
    
    try:
        # Use the same meta-game solver as iterate_one_round
        meta_solver = getattr(controller.cfg, 'meta_game_solver', 'ne')
        
        if meta_solver == "alpha_rank":
            alpha = getattr(controller.cfg, 'alpha_rank_alpha', 15.0)
            num_iters = getattr(controller.cfg, 'alpha_rank_num_iters', 10_000)
            tol = getattr(controller.cfg, 'alpha_rank_tol', 1e-10)
            sigma_h, sigma_g = controller.meta.solve_alpha_rank(
                alpha=alpha,
                num_iters=num_iters,
                tol=tol
            )
            solver_info = "Alpha-Rank"
        else:
            sigma_h, sigma_g = controller.meta.solve_ne()
            solver_info = "NE"
        
        ne_info = f"σ_H={sigma_h.tolist()}, σ_G={sigma_g.tolist()}"
    except Exception as e:
        solver_info = meta_solver if 'meta_solver' in locals() else "NE"
        ne_info = f"{solver_info} not available: {e}"
    
    log_entry = {
        "round": round_num,
        "timestamp": datetime.now().isoformat(),
        "duration_seconds": end_time - start_time,
        "n_solvers": n_solvers,
        "n_generators": n_generators,
        "utility_matrix_shape": utility_matrix_shape,
        "ne_info": ne_info,
        "pool_info": {
            "solvers": [{"id": s.program_id, "source": s.metadata.get("source", "unknown")} 
                       for s in controller.pools.solver_pool],
            "generators": [{"id": g.program_id, "algorithm": g.algorithm, "has_code": bool(g.code)} 
                          for g in controller.pools.generator_pool]
        }
    }
    
    with open(log_path, "a") as f:
        f.write(json.dumps(log_entry) + "\n")
    
    print(f"Round {round_num} completed in {end_time - start_time:.2f}s")
    print(f"  {solver_info}: {ne_info}")


def run_experiment(cfg: BPOnlineConfig, resume_from: Optional[str] = None,
                   resume_from_experiment: Optional[str] = None,
                   resume_from_round: Optional[int] = None) -> str:
    """Run BP Online PSRO experiment."""
    
    if resume_from_experiment:
        if resume_from_round is None:
            raise ValueError("--resume_from_round must be specified when using --resume_from_experiment")
        
        exp_dir = create_experiment_dir(cfg.experiment_name)
        print(f" Creating new experiment from {resume_from_experiment} in {exp_dir}")
        copy_experiment_state(resume_from_experiment, exp_dir, resume_from_round)
        # Load config from copied experiment directory
        config_path = os.path.join(exp_dir, "config.json")
        if os.path.exists(config_path):
            print(f" Loading configuration from {config_path}")
            with open(config_path, "r") as f:
                config_data = json.load(f)
            # Create new config from loaded data
            cfg = BPOnlineConfig()
            config_fields = {f.name for f in cfg.__dataclass_fields__.values()}
            for key, value in config_data.items():
                if key in config_fields:
                    setattr(cfg, key, value)
    elif resume_from:
        exp_dir = resume_from
        print(f"Resuming experiment from {exp_dir}")
        # Load config from experiment directory
        config_path = os.path.join(exp_dir, "config.json")
        if os.path.exists(config_path):
            print(f" Loading configuration from {config_path}")
            with open(config_path, "r") as f:
                config_data = json.load(f)
            # Create new config from loaded data
            cfg = BPOnlineConfig()
            config_fields = {f.name for f in cfg.__dataclass_fields__.values()}
            for key, value in config_data.items():
                if key in config_fields:
                    setattr(cfg, key, value)
        else:
            print(f"  Warning: config.json not found in {exp_dir}, using provided config")
    else:
        exp_dir = create_experiment_dir(cfg.experiment_name)
        print(f" Starting new experiment in {exp_dir}")
    
    save_experiment_config(cfg, exp_dir)
    
    # 验证 meta_game_solver 配置
    meta_solver = getattr(cfg, 'meta_game_solver', 'ne')
    print(f" Meta-game solver configuration: {meta_solver}")
    
    # Setup logging
    import warnings
    warnings.filterwarnings("ignore")

    class Tee:
        def __init__(self, stream, file_path):
            self.stream = stream
            self.file = open(file_path, "a", buffering=1)
        def write(self, data):
            self.stream.write(data)
            self.file.write(data)
        def flush(self):
            self.stream.flush()
            self.file.flush()
        def close(self):
            try:
                self.file.close()
            except Exception:
                pass

    log_path = os.path.join(exp_dir, "console.log")
    sys.stdout = Tee(sys.stdout, log_path)
    sys.stderr = Tee(sys.stderr, log_path)

    # Create problem adapter and initialize controller
    problem_adapter = BPOnlineProblemAdapter(cfg)
    controller = HeuPSROController(cfg, out_dir=exp_dir, problem_adapter=problem_adapter)
    
    controller_meta_solver = getattr(controller.cfg, 'meta_game_solver', 'ne')
    print(f" Controller meta-game solver: {controller_meta_solver}")
    
    # Resume or initialize
    saved_pools_path = os.path.join(exp_dir, "psro_results", "pools.json")
    if (resume_from or resume_from_experiment) and os.path.exists(saved_pools_path):
        print("Resuming from saved state...")
        
        if resume_from_round is not None:
            print(f"   Cleaning files from rounds after {resume_from_round} before resume...")
            clean_files_after_round(exp_dir, resume_from_round)
            controller.resume(exp_dir, resume_from_round=resume_from_round)
            start_round = resume_from_round
            print(f"   Resumed from checkpoint at round {start_round}, will continue from round {start_round + 1}")
        else:
            controller.resume(exp_dir)
            if hasattr(controller, 'iteration'):
                detected_round = controller.iteration
                print(f"   Cleaning files from rounds after {detected_round} to ensure consistency...")
                clean_files_after_round(exp_dir, detected_round)
                start_round = detected_round
                print(f"   Resumed from checkpoint at round {start_round}, will continue from round {start_round + 1}")
            else:
                start_round = controller.pools.n_solvers + controller.pools.n_generators - 2
    else:
        print("Initializing from scratch...")
        controller.initialize()
        controller.save()
        start_round = 0
    
    print(f" Starting from round {start_round} (will execute rounds {start_round + 1} onwards)")
    print(f" Initial pool: {controller.pools.n_solvers} solvers, {controller.pools.n_generators} generators")
    
    # Run PSRO iterations
    total_start_time = time.time()
    
    for round_num in range(start_round, cfg.max_rounds):
        round_start_time = time.time()
        
        try:
            controller.iterate_one_round()
            round_end_time = time.time()
            current_iteration = controller.iteration
            
            if current_iteration % cfg.log_frequency == 0:
                log_experiment_progress(exp_dir, current_iteration, controller, 
                                      round_start_time, round_end_time)
            
            if current_iteration % cfg.save_frequency == 0:
                controller.save()
            
        except KeyboardInterrupt:
            current_iteration = controller.iteration if hasattr(controller, 'iteration') else round_num + 1
            print(f"\nExperiment interrupted at round {current_iteration}")
            print("Saving current state...")
            controller.save()
            break
        except Exception as e:
            current_iteration = controller.iteration if hasattr(controller, 'iteration') else round_num + 1
            print(f"Error in round {current_iteration}: {e}")
            print("Saving current state...")
            controller.save()
            raise
    
    total_end_time = time.time()
    total_duration = total_end_time - total_start_time
    
    controller.save()
    
    # Save final summary
    final_iteration = controller.iteration if hasattr(controller, 'iteration') else round_num + 1
    summary = {
        "experiment_name": cfg.experiment_name,
        "total_rounds": final_iteration,
        "total_duration_seconds": total_duration,
        "final_pool_size": {
            "solvers": controller.pools.n_solvers,
            "generators": controller.pools.n_generators
        },
        "final_utility_matrix_shape": controller.meta.utilities.shape,
        "config": {
            "max_rounds": cfg.max_rounds,
            "solver_n_pop": getattr(cfg, 'solver_n_pop', getattr(cfg, 'n_pop', 2)),
            "generator_n_pop": getattr(cfg, 'generator_n_pop', getattr(cfg, 'n_pop', 2)),
            "capacity": cfg.capacity,
            "num_items": cfg.num_items,
            "seed": cfg.seed
        }
    }
    
    summary_path = os.path.join(exp_dir, "experiment_summary.json")
    with open(summary_path, "w") as f:
        json.dump(summary, f, indent=2)
    
    print(f"\n{'='*50}")
    print(f"Experiment completed!")
    print(f"Total duration: {total_duration:.2f} seconds")
    print(f"Final pool: {controller.pools.n_solvers} solvers, {controller.pools.n_generators} generators")
    print(f"Results saved to: {exp_dir}")
    print(f"{'='*50}")
    
    return exp_dir


def main():
    """Main function for running BP Online experiments."""
    default_cfg = BPOnlineConfig()
    
    parser = argparse.ArgumentParser(description="Run BP Online PSRO experiments")
    parser.add_argument("--config", type=str, default=None,
                       help="Path to a JSON file with full BPOnlineConfig")
    parser.add_argument("--experiment_name", type=str, default=default_cfg.experiment_name)
    parser.add_argument("--max_rounds", type=int, default=default_cfg.max_rounds)
    parser.add_argument("--pop_size", type=int, default=default_cfg.pop_size)
    parser.add_argument("--solver_n_pop", type=int, default=getattr(default_cfg, 'solver_n_pop', getattr(default_cfg, 'n_pop', 2)))
    parser.add_argument("--generator_n_pop", type=int, default=getattr(default_cfg, 'generator_n_pop', getattr(default_cfg, 'n_pop', 2)))
    parser.add_argument("--ec_operators", type=str, nargs='+', default=default_cfg.ec_operators)
    
    # BP Online-specific arguments
    parser.add_argument("--capacity", type=int, default=default_cfg.capacity)
    parser.add_argument("--num_items", type=int, default=default_cfg.num_items)
    parser.add_argument("--instance_solver_time_limit", type=int, default=getattr(default_cfg, 'instance_solver_time_limit', 5))
    parser.add_argument("--oracle_type", type=str, default=getattr(default_cfg, 'oracle_type', 'lb'),
                       choices=["lb", "none"])
    parser.add_argument("--oracle_timeout", type=int, default=getattr(default_cfg, 'oracle_timeout', 0))
    
    # Common arguments
    parser.add_argument("--eoh_eval_n_instances", type=int, default=default_cfg.eoh_eval_n_instances)
    parser.add_argument("--seed", type=int, default=default_cfg.seed)
    parser.add_argument("--debug", type=_str2bool, default=default_cfg.debug_mode)
    parser.add_argument("--eoh_management_strategy", type=str, default=default_cfg.eoh_management_strategy,
                       choices=["pop_greedy", "pop_diverse"])
    parser.add_argument("--evolution_context_enabled", type=_str2bool, default=default_cfg.evolution_context_enabled)
    parser.add_argument("--disable_generator_evolution", type=_str2bool, default=default_cfg.disable_generator_evolution)
    parser.add_argument("--psro_use_latest_only", type=_str2bool, default=default_cfg.psro_use_latest_only)
    parser.add_argument("--min_simple_ratio", type=float, default=default_cfg.min_simple_ratio)
    parser.add_argument("--resume", type=str, default=None)
    parser.add_argument("--resume_from_experiment", type=str, default=None)
    parser.add_argument("--resume_from_round", type=int, default=None)
    parser.add_argument("--eval_n_instances", type=int, default=default_cfg.eval_n_instances)
    
    # Meta-game solver arguments
    parser.add_argument("--meta_game_solver", type=str, default=default_cfg.meta_game_solver,
                       choices=["ne", "alpha_rank"],
                       help=f"Meta-game solver method: 'ne' (Nash equilibrium) or 'alpha_rank' (default: {default_cfg.meta_game_solver})")
    parser.add_argument("--alpha_rank_alpha", type=float, default=default_cfg.alpha_rank_alpha,
                       help=f"Alpha-Rank selection intensity (default: {default_cfg.alpha_rank_alpha})")
    parser.add_argument("--alpha_rank_num_iters", type=int, default=default_cfg.alpha_rank_num_iters,
                       help=f"Alpha-Rank maximum iterations (default: {default_cfg.alpha_rank_num_iters})")
    parser.add_argument("--alpha_rank_tol", type=float, default=default_cfg.alpha_rank_tol,
                       help=f"Alpha-Rank convergence tolerance (default: {default_cfg.alpha_rank_tol})")
    
    args = parser.parse_args()
    
    # Load configuration
    if args.config:
        with open(args.config, "r") as f:
            cfg_data = json.load(f)
        cfg = BPOnlineConfig(**{**default_cfg.__dict__, **cfg_data})
    else:
        cfg = BPOnlineConfig(**default_cfg.__dict__)
    
    # Apply CLI overrides
    cfg.experiment_name = args.experiment_name
    cfg.max_rounds = args.max_rounds
    cfg.pop_size = args.pop_size
    cfg.solver_n_pop = args.solver_n_pop
    cfg.generator_n_pop = args.generator_n_pop
    cfg.ec_operators = args.ec_operators
    cfg.capacity = args.capacity
    cfg.num_items = args.num_items
    cfg.instance_solver_time_limit = args.instance_solver_time_limit
    cfg.oracle_type = args.oracle_type
    cfg.oracle_timeout = args.oracle_timeout
    cfg.eoh_eval_n_instances = args.eoh_eval_n_instances
    cfg.seed = args.seed
    cfg.debug_mode = args.debug
    cfg.eoh_management_strategy = args.eoh_management_strategy
    cfg.evolution_context_enabled = args.evolution_context_enabled
    cfg.disable_generator_evolution = args.disable_generator_evolution
    cfg.psro_use_latest_only = args.psro_use_latest_only
    cfg.min_simple_ratio = args.min_simple_ratio
    cfg.eval_n_instances = args.eval_n_instances
    cfg.meta_game_solver = args.meta_game_solver
    cfg.alpha_rank_alpha = args.alpha_rank_alpha
    cfg.alpha_rank_num_iters = args.alpha_rank_num_iters
    cfg.alpha_rank_tol = args.alpha_rank_tol
    
    if not hasattr(cfg, 'save_frequency') or cfg.save_frequency is None:
        cfg.save_frequency = 1
    if not hasattr(cfg, 'log_frequency') or cfg.log_frequency is None:
        cfg.log_frequency = 1
    
    if args.resume_from_experiment and args.resume:
        raise ValueError("Cannot use both --resume and --resume_from_experiment")
    
    exp_dir = run_experiment(cfg, resume_from=args.resume,
                            resume_from_experiment=args.resume_from_experiment,
                            resume_from_round=args.resume_from_round)
    print(f"Experiment results saved to: {exp_dir}")


if __name__ == "__main__":
    main()

