#!/usr/bin/env python3
"""Run PSRO experiments with experiment management and monitoring."""

import argparse
import os
import time
import json
import shutil
from datetime import datetime
from typing import Optional

import sys
import os

# Add the project root to Python path
project_root = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
sys.path.insert(0, project_root)

# Limit BLAS intra-op threads to avoid thread explosion and resource contention
os.environ.setdefault("OPENBLAS_NUM_THREADS", "1")
os.environ.setdefault("OMP_NUM_THREADS", "1")
os.environ.setdefault("MKL_NUM_THREADS", "1")
os.environ.setdefault("NUMEXPR_NUM_THREADS", "1")

# Also disable dynamic threading and set other vendor libs to single-thread
os.environ.setdefault("OMP_DYNAMIC", "FALSE")
os.environ.setdefault("MKL_DYNAMIC", "FALSE")
os.environ.setdefault("VECLIB_MAXIMUM_THREADS", "1")
os.environ.setdefault("BLIS_NUM_THREADS", "1")

# Make joblib behavior explicit for better stability across platforms
os.environ.setdefault("JOBLIB_START_METHOD", "spawn")

# Use memory-backed tmp for faster IO if available
os.environ.setdefault("TMPDIR", "/dev/shm")

from heupsro.core.config import HeuPSROConfig
from heupsro.core.controller import HeuPSROController
from heupsro.core.resume import clean_files_after_round


def create_experiment_dir(experiment_name: str) -> str:
    """Create experiment directory with timestamp."""
    timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
    exp_dir = f"experiments/{experiment_name}_{timestamp}"
    os.makedirs(exp_dir, exist_ok=True)
    return exp_dir


def save_experiment_config(cfg: HeuPSROConfig, exp_dir: str) -> None:
    """Save experiment configuration."""
    config_path = os.path.join(exp_dir, "config.json")
    with open(config_path, "w") as f:
        # Dump all dataclass fields so changes in config are reflected
        json.dump(cfg.__dict__, f, indent=2)


def copy_experiment_state(source_exp_dir: str, target_exp_dir: str, up_to_round: int) -> None:
    """
    Copy experiment state from source to target, including all files up to specified round.
    
    Args:
        source_exp_dir: Source experiment directory
        target_exp_dir: Target experiment directory (will be created if doesn't exist)
        up_to_round: Copy all checkpoints and state up to this round (inclusive)
    """
    os.makedirs(target_exp_dir, exist_ok=True)
    
    # Copy config.json
    source_config = os.path.join(source_exp_dir, "config.json")
    if os.path.exists(source_config):
        shutil.copy2(source_config, os.path.join(target_exp_dir, "config.json"))
    
    # Copy progress.log
    source_progress = os.path.join(source_exp_dir, "progress.log")
    if os.path.exists(source_progress):
        shutil.copy2(source_progress, os.path.join(target_exp_dir, "progress.log"))
    
    # Copy console.log (if exists, we'll append later)
    source_console = os.path.join(source_exp_dir, "console.log")
    target_console = os.path.join(target_exp_dir, "console.log")
    if os.path.exists(source_console):
        shutil.copy2(source_console, target_console)
    
    # Copy experiment_metadata.json
    source_meta = os.path.join(source_exp_dir, "experiment_metadata.json")
    if os.path.exists(source_meta):
        shutil.copy2(source_meta, os.path.join(target_exp_dir, "experiment_metadata.json"))
    
    # Copy diversity_report.json if exists
    source_diversity = os.path.join(source_exp_dir, "diversity_report.json")
    if os.path.exists(source_diversity):
        shutil.copy2(source_diversity, os.path.join(target_exp_dir, "diversity_report.json"))
    
    # Copy EOH directories (solver_eoh and generator_eoh)
    # But only copy population files up to the generation specified in the checkpoint
    checkpoint_path = os.path.join(source_exp_dir, "psro_results", f"checkpoint_round_{up_to_round}.json")
    max_solver_generation = None
    max_generator_generation = None
    
    if os.path.exists(checkpoint_path):
        try:
            with open(checkpoint_path, 'r') as f:
                checkpoint_data = json.load(f)
                eoh_state = checkpoint_data.get("eoh_state", {})
                max_solver_generation = eoh_state.get("solver_eoh_generation", None)
                max_generator_generation = eoh_state.get("generator_eoh_generation", None)
                print(f"  📊 Checkpoint round {up_to_round}: solver_gen={max_solver_generation}, generator_gen={max_generator_generation}")
        except Exception as e:
            print(f"  ⚠️ Warning: Could not read checkpoint to get generation limits: {e}")
    
    # Copy EOH directories with selective population file copying
    for eoh_dir in ["solver_eoh", "generator_eoh"]:
        source_eoh = os.path.join(source_exp_dir, eoh_dir)
        target_eoh = os.path.join(target_exp_dir, eoh_dir)
        
        if os.path.exists(source_eoh):
            if os.path.exists(target_eoh):
                shutil.rmtree(target_eoh)
            
            # Copy directory structure first
            shutil.copytree(source_eoh, target_eoh, dirs_exist_ok=False)
            
            # Then selectively filter population files based on generation limit
            max_generation = max_solver_generation if eoh_dir == "solver_eoh" else max_generator_generation
            pops_dir = os.path.join(target_eoh, "results", "pops")
            
            if max_generation is not None and os.path.exists(pops_dir):
                # Remove population files beyond the max generation
                for pop_file in os.listdir(pops_dir):
                    if pop_file.startswith("population_generation_") and pop_file.endswith(".json"):
                        try:
                            gen_num = int(pop_file.split("_")[2].split(".")[0])
                            if gen_num > max_generation:
                                pop_file_path = os.path.join(pops_dir, pop_file)
                                os.remove(pop_file_path)
                                print(f"  🗑️  Removed {eoh_dir}/results/pops/{pop_file} (gen {gen_num} > {max_generation})")
                        except (ValueError, IndexError):
                            # If we can't parse the generation number, keep the file
                            pass
    
    # Copy logs directory if exists
    source_logs = os.path.join(source_exp_dir, "logs")
    target_logs = os.path.join(target_exp_dir, "logs")
    if os.path.exists(source_logs):
        if os.path.exists(target_logs):
            shutil.rmtree(target_logs)
        shutil.copytree(source_logs, target_logs)
    
    # Copy psro_results directory
    target_psro = os.path.join(target_exp_dir, "psro_results")
    os.makedirs(target_psro, exist_ok=True)
    
    source_psro = os.path.join(source_exp_dir, "psro_results")
    if os.path.exists(source_psro):
        # Copy entire psro_results directory first (easier than selective copying)
        if os.path.exists(target_psro) and os.listdir(target_psro):
            # Clear target if it has content
            for item in os.listdir(target_psro):
                item_path = os.path.join(target_psro, item)
                if os.path.isdir(item_path):
                    shutil.rmtree(item_path)
                else:
                    os.remove(item_path)
        
        # Copy all files from source psro_results
        for item in os.listdir(source_psro):
            source_item = os.path.join(source_psro, item)
            target_item = os.path.join(target_psro, item)
            if os.path.isdir(source_item):
                shutil.copytree(source_item, target_item, dirs_exist_ok=True)
            else:
                shutil.copy2(source_item, target_item)
        
        print(f"Copied psro_results directory")
    
    # Now clean files after target round (using the resume cleanup function)
    print(f"Cleaning files from rounds after {up_to_round}...")
    clean_files_after_round(target_exp_dir, up_to_round)
    
    print(f"Copied experiment state up to round {up_to_round} from {source_exp_dir} to {target_exp_dir}")


def log_experiment_progress(exp_dir: str, round_num: int, controller: HeuPSROController, 
                          start_time: float, end_time: float) -> None:
    """Log experiment progress."""
    log_path = os.path.join(exp_dir, "progress.log")
    
    # Calculate metrics
    n_solvers = controller.pools.n_solvers
    n_generators = controller.pools.n_generators
    utility_matrix_shape = controller.meta.utilities.shape
    
    # Get meta-game solution using configured solver
    try:
        # Use the same meta-game solver as iterate_one_round
        meta_solver = getattr(controller.cfg, 'meta_game_solver', 'ne')
        
        if meta_solver == "alpha_rank":
            alpha = getattr(controller.cfg, 'alpha_rank_alpha', 15.0)
            num_iters = getattr(controller.cfg, 'alpha_rank_num_iters', 10_000)
            tol = getattr(controller.cfg, 'alpha_rank_tol', 1e-10)
            sigma_h, sigma_g = controller.meta.solve_alpha_rank(
                alpha=alpha,
                num_iters=num_iters,
                tol=tol
            )
            solver_info = "Alpha-Rank"
        else:
            sigma_h, sigma_g = controller.meta.solve_ne()
            solver_info = "NE"
        
        ne_info = f"σ_H={sigma_h.tolist()}, σ_G={sigma_g.tolist()}"
    except Exception as e:
        solver_info = meta_solver if 'meta_solver' in locals() else "NE"
        ne_info = f"{solver_info} not available: {e}"
    
    log_entry = {
        "round": round_num,
        "timestamp": datetime.now().isoformat(),
        "duration_seconds": end_time - start_time,
        "n_solvers": n_solvers,
        "n_generators": n_generators,
        "utility_matrix_shape": utility_matrix_shape,
        "ne_info": ne_info,
        "pool_info": {
            "solvers": [{"id": s.program_id, "source": s.metadata.get("source", "unknown")} 
                       for s in controller.pools.solver_pool],
            "generators": [{"id": g.program_id, "algorithm": g.algorithm, "has_code": bool(g.code)} 
                          for g in controller.pools.generator_pool]
        }
    }
    
    # Append to log file
    with open(log_path, "a") as f:
        f.write(json.dumps(log_entry) + "\n")
    
    print(f"Round {round_num} completed in {end_time - start_time:.2f}s")
    # print(f"  Pool: {n_solvers} solvers, {n_generators} generators")
    # print(f"  Utility matrix: {utility_matrix_shape}")
    solver_info = meta_solver if 'meta_solver' in locals() else "NE"
    print(f"  {solver_info}: {ne_info}")


def run_experiment(cfg: HeuPSROConfig, resume_from: Optional[str] = None,
                   resume_from_experiment: Optional[str] = None,
                   resume_from_round: Optional[int] = None) -> str:
    """
    Run PSRO experiment with monitoring and checkpointing.
    
    Args:
        cfg: Experiment configuration
        resume_from: Resume from existing experiment directory (same directory)
        resume_from_experiment: Source experiment directory to copy state from
        resume_from_round: Round number to resume from (when using resume_from_experiment)
    """
    
    # Handle resume_from_experiment: create new experiment with copied state
    if resume_from_experiment:
        if resume_from_round is None:
            raise ValueError("--resume_from_round must be specified when using --resume_from_experiment")
        
        # Create new experiment directory
        exp_dir = create_experiment_dir(cfg.experiment_name)
        print(f"Creating new experiment from {resume_from_experiment} in {exp_dir}")
        print(f"Copying state up to round {resume_from_round}")
        
        # Copy state from source experiment
        copy_experiment_state(resume_from_experiment, exp_dir, resume_from_round)
    # Create or use existing experiment directory
    elif resume_from:
        exp_dir = resume_from
        print(f"Resuming experiment from {exp_dir}")
    else:
        exp_dir = create_experiment_dir(cfg.experiment_name)
        print(f"Starting new experiment in {exp_dir}")
    
    # Save configuration
    save_experiment_config(cfg, exp_dir)
    
    # Setup logging: capture all stdout/stderr to file, suppress warnings
    import warnings
    warnings.filterwarnings("ignore")

    class Tee:
        def __init__(self, stream, file_path):
            self.stream = stream
            self.file = open(file_path, "a", buffering=1)
        def write(self, data):
            self.stream.write(data)
            self.file.write(data)
        def flush(self):
            self.stream.flush()
            self.file.flush()
        def close(self):
            try:
                self.file.close()
            except Exception:
                pass

    log_path = os.path.join(exp_dir, "console.log")
    sys.stdout = Tee(sys.stdout, log_path)
    sys.stderr = Tee(sys.stderr, log_path)

    # Initialize controller
    controller = HeuPSROController(cfg, out_dir=exp_dir)
    
    # Resume or initialize
    # Detect saved state under psro_results/ to properly resume
    saved_pools_path = os.path.join(exp_dir, "psro_results", "pools.json")
    if (resume_from or resume_from_experiment) and os.path.exists(saved_pools_path):
        print("Resuming from saved state...")
        
        # Clean files after target round BEFORE resuming (important!)
        # This ensures we have a clean state matching the target round
        if resume_from_round is not None:
            print(f"Cleaning files from rounds after {resume_from_round} before resume...")
            clean_files_after_round(exp_dir, resume_from_round)
            controller.resume(exp_dir, resume_from_round=resume_from_round)
            # controller.iteration is now set to resume_from_round
            # We want to start from the next round (resume_from_round + 1)
            # But iterate_one_round() will increment iteration first, so we set start_round = resume_from_round
            start_round = resume_from_round
            print(f"  Resumed from checkpoint at round {resume_from_round}, will continue from round {resume_from_round + 1}")
        else:
            # Auto-detect from latest checkpoint - still clean if there are inconsistent files
            controller.resume(exp_dir)
            # Try to infer start_round from controller.iteration
            if hasattr(controller, 'iteration'):
                detected_round = controller.iteration
                # Clean files after detected round to ensure consistency
                print(f"  🧹 Cleaning files from rounds after {detected_round} to ensure consistency...")
                clean_files_after_round(exp_dir, detected_round)
                start_round = detected_round
                print(f"  Resumed from checkpoint at round {detected_round}, will continue from round {detected_round + 1}")
            else:
                start_round = controller.pools.n_solvers + controller.pools.n_generators - 2
    else:
        print("Initializing from scratch...")
        controller.initialize()
        controller.save()  # Save initial state
        start_round = 0
    
    # Note: iterate_one_round() increments iteration first, so if we start from round N,
    # it will print "PSRO Iteration N+1" on the first call
    print(f"Starting from round {start_round} (will execute rounds {start_round + 1} onwards)")
    print(f"Initial pool: {controller.pools.n_solvers} solvers, {controller.pools.n_generators} generators")
    
    # Run PSRO iterations
    total_start_time = time.time()
    
    # Note: start_round is the last completed round, so we iterate from start_round
    # iterate_one_round() will increment iteration internally, so iteration numbers match correctly
    for round_num in range(start_round, cfg.max_rounds):
        # print(f"\n{'='*50}")
        # print(f"PSRO Round {round_num + 1}/{cfg.max_rounds}")
        # print(f"{'='*50}")
        
        round_start_time = time.time()
        
        try:
            # Run one PSRO iteration
            controller.iterate_one_round()
            
            round_end_time = time.time()
            
            # Get current iteration number (iterate_one_round increments it first)
            current_iteration = controller.iteration
            
            # Log progress
            if current_iteration % cfg.log_frequency == 0:
                log_experiment_progress(exp_dir, current_iteration, controller, 
                                      round_start_time, round_end_time)
            
            # Save checkpoint
            if current_iteration % cfg.save_frequency == 0:
                # print(f"Saving checkpoint at round {current_iteration}...")
                controller.save()
            
        except KeyboardInterrupt:
            current_iteration = controller.iteration if hasattr(controller, 'iteration') else round_num + 1
            print(f"\nExperiment interrupted at round {current_iteration}")
            print("Saving current state...")
            controller.save()
            break
        except Exception as e:
            current_iteration = controller.iteration if hasattr(controller, 'iteration') else round_num + 1
            print(f"Error in round {current_iteration}: {e}")
            print("Saving current state...")
            controller.save()
            raise
    
    total_end_time = time.time()
    total_duration = total_end_time - total_start_time
    
    # Final save
    controller.save()
    
    # Save final summary
    final_iteration = controller.iteration if hasattr(controller, 'iteration') else round_num + 1
    summary = {
        "experiment_name": cfg.experiment_name,
        "total_rounds": final_iteration,
        "total_duration_seconds": total_duration,
        "final_pool_size": {
            "solvers": controller.pools.n_solvers,
            "generators": controller.pools.n_generators
        },
        "final_utility_matrix_shape": controller.meta.utilities.shape,
        "config": {
            "max_rounds": cfg.max_rounds,
            "n_pop": cfg.n_pop,
            "n_cities": cfg.n_cities,
            "seed": cfg.seed
        }
    }
    
    summary_path = os.path.join(exp_dir, "experiment_summary.json")
    with open(summary_path, "w") as f:
        json.dump(summary, f, indent=2)
    
    print(f"\n{'='*50}")
    print(f"Experiment completed!")
    print(f"Total duration: {total_duration:.2f} seconds")
    print(f"Final pool: {controller.pools.n_solvers} solvers, {controller.pools.n_generators} generators")
    print(f"Results saved to: {exp_dir}")
    print(f"{'='*50}")
    
    return exp_dir


def main():
    """Main function for running experiments."""
    # Create default config first
    default_cfg = HeuPSROConfig()
    
    parser = argparse.ArgumentParser(description="Run PSRO experiments")
    parser.add_argument("--config", type=str, default=None,
                       help="Path to a JSON file with full HeuPSROConfig; CLI flags override fields if provided")
    parser.add_argument("--experiment_name", type=str, default=default_cfg.experiment_name,
                       help="Name of the experiment")
    parser.add_argument("--max_rounds", type=int, default=default_cfg.max_rounds,
                       help="Maximum number of PSRO rounds")
    # parser.add_argument("--eoh_iters", type=int, default=default_cfg.eoh_iters,
    #                    help="Number of EoH iterations for each BR approximation")
    parser.add_argument("--pop_size", type=int, default=default_cfg.pop_size,
                       help="Population size for each generation")
    parser.add_argument("--n_pop", type=int, default=default_cfg.n_pop,
                       help="Number of EoH iterations for each BR approximation")
    parser.add_argument("--ec_operators", type=str, nargs='+', default=default_cfg.ec_operators,
                       help="EOH operator set, e.g., --ec_operators e1 e2 m1 m2 m3")
    parser.add_argument("--n_cities", type=int, default=default_cfg.n_cities,
                       help="Number of cities in TSP instances")
    parser.add_argument("--instance_solver_time_limit", type=int, default=getattr(default_cfg, 'instance_solver_time_limit', 60),
                       help="Time limit for TSP solver in seconds")
    parser.add_argument("--tsp_solver_max_iterations", type=int, default=default_cfg.tsp_solver_max_iterations,
                       help="Maximum iterations for local search")
    parser.add_argument("--eoh_eval_n_instances", type=int, default=default_cfg.eoh_eval_n_instances,
                       help="Number of instances for EOH individual evaluation")
    parser.add_argument("--seed", type=int, default=default_cfg.seed,
                       help="Random seed")
    parser.add_argument("--debug", action="store_true", default=default_cfg.debug_mode,
                       help="Enable debug mode")
    parser.add_argument("--resume", type=str, default=None,
                       help="Resume from existing experiment directory (same directory)")
    parser.add_argument("--resume_from_experiment", type=str, default=None,
                       help="Source experiment directory to copy state from (creates new experiment)")
    parser.add_argument("--resume_from_round", type=int, default=None,
                       help="Round number to resume from (required when using --resume_from_experiment)")
    parser.add_argument("--eval_n_instances", type=int, default=default_cfg.eval_n_instances,
                       help="--eval_n_instances")
    # New CLI flags
    parser.add_argument("--eoh_management_strategy", type=str, default=default_cfg.eoh_management_strategy,
                       help='EOH management strategy: "pop_greedy" | "pop_diverse"')
    def _str2bool(v):
        if isinstance(v, bool):
            return v
        v = str(v).lower()
        if v in ("yes", "true", "t", "1", "y"): return True
        if v in ("no", "false", "f", "0", "n"): return False
        raise argparse.ArgumentTypeError("Boolean value expected.")
    parser.add_argument("--evolution_context_enabled", type=_str2bool, default=default_cfg.evolution_context_enabled,
                       help="Enable evolution context for prompts (true/false)")
    parser.add_argument("--disable_generator_evolution", type=_str2bool, default=default_cfg.disable_generator_evolution,
                       help="If true, skip generator BR and keep uniform generator only")
    parser.add_argument("--psro_use_latest_only", type=_str2bool, default=default_cfg.psro_use_latest_only,
                       help="If true, only use latest strategies in pool for BR (instead of Nash equilibrium mixture)")
    parser.add_argument("--min_simple_ratio", type=float, default=default_cfg.min_simple_ratio,
                       help="Lower bound for simple/baseline generator sampling share in PSRO")
    parser.add_argument("--meta_game_solver", type=str, default=default_cfg.meta_game_solver,
                       choices=["ne", "alpha_rank"],
                       help=f"Meta-game solver method: 'ne' (Nash equilibrium) or 'alpha_rank' (default: {default_cfg.meta_game_solver})")
    parser.add_argument("--alpha_rank_alpha", type=float, default=default_cfg.alpha_rank_alpha,
                       help=f"Alpha-Rank selection intensity (default: {default_cfg.alpha_rank_alpha})")
    parser.add_argument("--alpha_rank_num_iters", type=int, default=default_cfg.alpha_rank_num_iters,
                       help=f"Alpha-Rank maximum iterations (default: {default_cfg.alpha_rank_num_iters})")
    parser.add_argument("--alpha_rank_tol", type=float, default=default_cfg.alpha_rank_tol,
                       help=f"Alpha-Rank convergence tolerance (default: {default_cfg.alpha_rank_tol})")
    
    args = parser.parse_args()
    
    # Load base configuration from JSON if provided; else start from defaults
    if args.config:
        with open(args.config, "r") as f:
            cfg_data = json.load(f)
        cfg = HeuPSROConfig(**{**default_cfg.__dict__, **cfg_data})
    else:
        cfg = HeuPSROConfig(**default_cfg.__dict__)
    
    # Apply CLI overrides if provided (only when not None or flags present)
    cfg.experiment_name = args.experiment_name
    cfg.max_rounds = args.max_rounds
    cfg.pop_size = args.pop_size
    cfg.n_pop = args.n_pop
    cfg.ec_operators = args.ec_operators
    cfg.n_cities = args.n_cities
    cfg.instance_solver_time_limit = args.instance_solver_time_limit
    cfg.tsp_solver_max_iterations = args.tsp_solver_max_iterations
    cfg.eoh_eval_n_instances = args.eoh_eval_n_instances
    cfg.seed = args.seed
    cfg.debug_mode = args.debug
    # New CLI overrides
    cfg.eoh_management_strategy = args.eoh_management_strategy
    cfg.evolution_context_enabled = args.evolution_context_enabled
    cfg.disable_generator_evolution = args.disable_generator_evolution
    cfg.psro_use_latest_only = args.psro_use_latest_only
    cfg.min_simple_ratio = args.min_simple_ratio
    cfg.meta_game_solver = args.meta_game_solver
    cfg.alpha_rank_alpha = args.alpha_rank_alpha
    cfg.alpha_rank_num_iters = args.alpha_rank_num_iters
    cfg.alpha_rank_tol = args.alpha_rank_tol
    # Ensure frequent saving/logging unless overridden by config JSON
    if not hasattr(cfg, 'save_frequency') or cfg.save_frequency is None:
        cfg.save_frequency = 1
    if not hasattr(cfg, 'log_frequency') or cfg.log_frequency is None:
        cfg.log_frequency = 1
    
    # Validate arguments
    if args.resume_from_experiment and args.resume:
        raise ValueError("Cannot use both --resume and --resume_from_experiment. Use --resume_from_experiment to create a new experiment with copied state.")
    
    # Run experiment
    exp_dir = run_experiment(cfg, resume_from=args.resume,
                            resume_from_experiment=args.resume_from_experiment,
                            resume_from_round=args.resume_from_round)
    print(f"Experiment results saved to: {exp_dir}")


if __name__ == "__main__":
    main()
