#!/usr/bin/env python3
"""Run TSP GLS PSRO experiments."""

import argparse
import os
import time
import json
from datetime import datetime
from typing import Optional

import sys

# Add the project root to Python path
project_root = os.path.dirname(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))))
sys.path.insert(0, project_root)

# Add eoh/src to path for eoh module imports
eoh_src_dir = os.path.join(project_root, 'eoh', 'src')
if eoh_src_dir not in sys.path:
    sys.path.insert(0, eoh_src_dir)

# Limit BLAS intra-op threads to avoid thread explosion and resource contention
os.environ.setdefault("OPENBLAS_NUM_THREADS", "1")
os.environ.setdefault("OMP_NUM_THREADS", "1")
os.environ.setdefault("MKL_NUM_THREADS", "1")
os.environ.setdefault("NUMEXPR_NUM_THREADS", "1")
os.environ.setdefault("OMP_DYNAMIC", "FALSE")
os.environ.setdefault("MKL_DYNAMIC", "FALSE")
os.environ.setdefault("VECLIB_MAXIMUM_THREADS", "1")
os.environ.setdefault("BLIS_NUM_THREADS", "1")
os.environ.setdefault("JOBLIB_START_METHOD", "spawn")
os.environ.setdefault("TMPDIR", "/dev/shm")

from asro.core.controller import HeuPSROController
from asro.core.resume import clean_files_after_round
from asro.problems.tsp_gls.config import TSPGLSConfig
from asro.problems.tsp_gls.adapter import TSPProblemAdapter


def _str2bool(v):
    """Convert string to boolean."""
    if isinstance(v, bool):
        return v
    if isinstance(v, str):
        v = v.lower()
    if v in ("yes", "true", "t", "1", "y"):
        return True
    elif v in ("no", "false", "f", "0", "n"):
        return False
    else:
        raise argparse.ArgumentTypeError("Boolean value expected.")


def create_experiment_dir(experiment_name: str) -> str:
    """Create experiment directory with timestamp."""
    timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
    exp_dir = f"experiments/{experiment_name}_{timestamp}"
    os.makedirs(exp_dir, exist_ok=True)
    return exp_dir


def save_experiment_config(cfg: TSPGLSConfig, exp_dir: str) -> None:
    """Save experiment configuration."""
    config_path = os.path.join(exp_dir, "config.json")
    with open(config_path, "w") as f:
        json.dump(cfg.__dict__, f, indent=2)


def log_experiment_progress(exp_dir: str, round_num: int, controller: HeuPSROController, 
                          start_time: float, end_time: float) -> None:
    """Log experiment progress."""
    log_path = os.path.join(exp_dir, "progress.log")
    
    n_solvers = controller.pools.n_solvers
    n_generators = controller.pools.n_generators
    utility_matrix_shape = controller.meta.utilities.shape
    
    try:
        # Use the same meta-game solver as iterate_one_round
        meta_solver = getattr(controller.cfg, 'meta_game_solver', 'ne')
        
        if meta_solver == "alpha_rank":
            alpha = getattr(controller.cfg, 'alpha_rank_alpha', 15.0)
            num_iters = getattr(controller.cfg, 'alpha_rank_num_iters', 10_000)
            tol = getattr(controller.cfg, 'alpha_rank_tol', 1e-10)
            sigma_h, sigma_g = controller.meta.solve_alpha_rank(
                alpha=alpha,
                num_iters=num_iters,
                tol=tol
            )
            solver_info = "Alpha-Rank"
        else:
            sigma_h, sigma_g = controller.meta.solve_ne()
            solver_info = "NE"
        
        ne_info = f"σ_H={sigma_h.tolist()}, σ_G={sigma_g.tolist()}"
    except Exception as e:
        solver_info = meta_solver if 'meta_solver' in locals() else "NE"
        ne_info = f"{solver_info} not available: {e}"
    
    log_entry = {
        "round": round_num,
        "timestamp": datetime.now().isoformat(),
        "duration_seconds": end_time - start_time,
        "n_solvers": n_solvers,
        "n_generators": n_generators,
        "utility_matrix_shape": utility_matrix_shape,
        "ne_info": ne_info,
        "pool_info": {
            "solvers": [{"id": s.program_id, "source": s.metadata.get("source", "unknown")} 
                       for s in controller.pools.solver_pool],
            "generators": [{"id": g.program_id, "algorithm": g.algorithm, "has_code": bool(g.code)} 
                          for g in controller.pools.generator_pool]
        }
    }
    
    with open(log_path, "a") as f:
        f.write(json.dumps(log_entry) + "\n")
    
    print(f"Round {round_num} completed in {end_time - start_time:.2f}s")
    solver_info = meta_solver if 'meta_solver' in locals() else "NE"
    print(f"  {solver_info}: {ne_info}")


def run_experiment(cfg: TSPGLSConfig, resume_from: Optional[str] = None,
                   resume_from_experiment: Optional[str] = None,
                   resume_from_round: Optional[int] = None) -> str:
    """Run TSP GLS PSRO experiment."""
    
    if resume_from_experiment:
        if resume_from_round is None:
            raise ValueError("--resume_from_round must be specified when using --resume_from_experiment")
        
        exp_dir = create_experiment_dir(cfg.experiment_name)
        print(f" Creating new experiment from {resume_from_experiment} in {exp_dir}")
        # Note: copy_experiment_state would need to be implemented similarly to TSP/BP
        # Load config from copied experiment directory if available
        config_path = os.path.join(exp_dir, "config.json")
        if os.path.exists(config_path):
            print(f" Loading configuration from {config_path}")
            with open(config_path, "r") as f:
                config_data = json.load(f)
            # Create new config from loaded data
            cfg = TSPGLSConfig()
            config_fields = {f.name for f in cfg.__dataclass_fields__.values()}
            for key, value in config_data.items():
                if key in config_fields:
                    setattr(cfg, key, value)
    elif resume_from:
        exp_dir = resume_from
        print(f"Resuming experiment from {exp_dir}")
        # Load config from experiment directory
        config_path = os.path.join(exp_dir, "config.json")
        if os.path.exists(config_path):
            print(f" Loading configuration from {config_path}")
            with open(config_path, "r") as f:
                config_data = json.load(f)
            # Create new config from loaded data
            cfg = TSPGLSConfig()
            config_fields = {f.name for f in cfg.__dataclass_fields__.values()}
            for key, value in config_data.items():
                if key in config_fields:
                    setattr(cfg, key, value)
        else:
            print(f" Warning: config.json not found in {exp_dir}, using provided config")
    else:
        exp_dir = create_experiment_dir(cfg.experiment_name)
        print(f" Starting new experiment in {exp_dir}")
    
    save_experiment_config(cfg, exp_dir)
    
    # Setup logging
    import warnings
    warnings.filterwarnings("ignore")

    class Tee:
        def __init__(self, stream, file_path):
            self.stream = stream
            self.file = open(file_path, "a", buffering=1)
        def write(self, data):
            self.stream.write(data)
            self.file.write(data)
        def flush(self):
            self.stream.flush()
            self.file.flush()
        def close(self):
            try:
                self.file.close()
            except Exception:
                pass

    log_path = os.path.join(exp_dir, "console.log")
    sys.stdout = Tee(sys.stdout, log_path)
    sys.stderr = Tee(sys.stderr, log_path)

    # Create problem adapter and initialize controller
    problem_adapter = TSPProblemAdapter(cfg)
    controller = HeuPSROController(cfg, out_dir=exp_dir, problem_adapter=problem_adapter)
    
    # Resume or initialize
    saved_pools_path = os.path.join(exp_dir, "psro_results", "pools.json")
    if (resume_from or resume_from_experiment) and os.path.exists(saved_pools_path):
        print("Resuming from saved state...")
        
        if resume_from_round is not None:
            print(f"   Cleaning files from rounds after {resume_from_round} before resume...")
            clean_files_after_round(exp_dir, resume_from_round)
            controller.resume(exp_dir, resume_from_round=resume_from_round)
            start_round = resume_from_round
            print(f"   Resumed from checkpoint at round {resume_from_round}, will continue from round {resume_from_round + 1}")
        else:
            controller.resume(exp_dir)
            if hasattr(controller, 'iteration'):
                detected_round = controller.iteration
                print(f"   Cleaning files from rounds after {detected_round} to ensure consistency...")
                clean_files_after_round(exp_dir, detected_round)
                start_round = detected_round
                print(f"   Resumed from checkpoint at round {detected_round}, will continue from round {detected_round + 1}")
            else:
                start_round = controller.pools.n_solvers + controller.pools.n_generators - 2
    else:
        print("Initializing from scratch...")
        controller.initialize()
        controller.save()
        start_round = 0
    
    print(f" Starting from round {start_round} (will execute rounds {start_round + 1} onwards)")
    print(f" Initial pool: {controller.pools.n_solvers} solvers, {controller.pools.n_generators} generators")
    
    # Run PSRO iterations
    total_start_time = time.time()
    
    for round_num in range(start_round, cfg.max_rounds):
        print(f"\n{'='*60}")
        print(f" Starting PSRO Round {round_num + 1}")
        print(f"{'='*60}")
        
        round_start_time = time.time()
        
        try:
            controller.iterate_one_round()
            round_end_time = time.time()
            current_iteration = controller.iteration
            
            # Save after each round (like FSSP)
            controller.save()
            
            # Log progress if needed
            if current_iteration % cfg.log_frequency == 0:
                log_experiment_progress(exp_dir, current_iteration, controller, 
                                      round_start_time, round_end_time)
            
            # Clean up old files after each round (like FSSP)
            clean_files_after_round(exp_dir, controller.iteration)
            
            print(f" Completed round {round_num + 1}")
            
        except KeyboardInterrupt:
            current_iteration = controller.iteration if hasattr(controller, 'iteration') else round_num + 1
            print(f"\nExperiment interrupted at round {current_iteration}")
            print("Saving current state...")
            controller.save()
            break
        except Exception as e:
            current_iteration = controller.iteration if hasattr(controller, 'iteration') else round_num + 1
            print(f"Error in round {current_iteration}: {e}")
            print("Saving current state...")
            controller.save()
            raise
    
    total_end_time = time.time()
    total_duration = total_end_time - total_start_time
    
    controller.save()
    
    # Save final summary
    final_iteration = controller.iteration if hasattr(controller, 'iteration') else round_num + 1
    summary = {
        "experiment_name": cfg.experiment_name,
        "total_rounds": final_iteration,
        "total_duration_seconds": total_duration,
        "final_pool_size": {
            "solvers": controller.pools.n_solvers,
            "generators": controller.pools.n_generators
        },
        "final_utility_matrix_shape": controller.meta.utilities.shape,
        "config": {
            "max_rounds": cfg.max_rounds,
            "solver_n_pop": cfg.solver_n_pop,
            "generator_n_pop": cfg.generator_n_pop,
            "n_cities": cfg.n_cities,
            "seed": cfg.seed
        }
    }
    
    summary_path = os.path.join(exp_dir, "experiment_summary.json")
    with open(summary_path, "w") as f:
        json.dump(summary, f, indent=2)
    
    print(f"\n Experiment completed in {total_duration:.2f} seconds")
    print(f"Final pool: {controller.pools.n_solvers} solvers, {controller.pools.n_generators} generators")
    print(f"Results saved to: {exp_dir}")
    
    return exp_dir


def main():
    """Main function for running TSP GLS experiments."""
    default_cfg = TSPGLSConfig()
    
    parser = argparse.ArgumentParser(description="Run TSP GLS PSRO experiments")
    parser.add_argument("--config", type=str, default=None,
                       help="Path to a JSON file with full TSPGLSConfig")
    parser.add_argument("--experiment_name", type=str, default=default_cfg.experiment_name)
    parser.add_argument("--max_rounds", type=int, default=default_cfg.max_rounds)
    parser.add_argument("--pop_size", type=int, default=default_cfg.pop_size)
    parser.add_argument("--solver_n_pop", type=int, default=None,
                       help="Number of evolution generations for solver EOH (overrides config file)")
    parser.add_argument("--generator_n_pop", type=int, default=None,
                       help="Number of evolution generations for generator EOH (overrides config file)")
    parser.add_argument("--ec_operators", type=str, nargs='+', default=default_cfg.ec_operators)
    
    # TSP-specific arguments
    parser.add_argument("--n_cities", type=int, default=default_cfg.n_cities)
    parser.add_argument("--instance_solver_time_limit", type=int, default=getattr(default_cfg, 'instance_solver_time_limit', 60))
    parser.add_argument("--tsp_solver_time_limit", type=int, default=None, help="Deprecated: use --instance_solver_time_limit")
    parser.add_argument("--tsp_solver_max_iterations", type=int, default=default_cfg.tsp_solver_max_iterations)
    parser.add_argument("--oracle_type", type=str, default=getattr(default_cfg, 'oracle_type', 'concorde'),
                       choices=["lkh3", "concorde", "none"])
    parser.add_argument("--gap_oracle", type=str, default=None, help="Deprecated: use --oracle_type")
    parser.add_argument("--oracle_timeout", type=int, default=getattr(default_cfg, 'oracle_timeout', 30))
    parser.add_argument("--gap_oracle_timeout", type=int, default=None, help="Deprecated: use --oracle_timeout")
    parser.add_argument("--lkh_runs", type=int, default=default_cfg.lkh_runs)
    
    # Common arguments
    parser.add_argument("--eoh_eval_n_instances", type=int, default=default_cfg.eoh_eval_n_instances)
    parser.add_argument("--seed", type=int, default=default_cfg.seed)
    parser.add_argument("--debug", type=_str2bool, default=default_cfg.debug_mode)
    parser.add_argument("--eoh_management_strategy", type=str, default=default_cfg.eoh_management_strategy,
                       choices=["pop_greedy", "pop_diverse"])
    parser.add_argument("--evolution_context_enabled", type=_str2bool, default=default_cfg.evolution_context_enabled)
    parser.add_argument("--disable_generator_evolution", type=_str2bool, default=default_cfg.disable_generator_evolution)
    parser.add_argument("--psro_use_latest_only", type=_str2bool, default=default_cfg.psro_use_latest_only)
    parser.add_argument("--min_simple_ratio", type=float, default=default_cfg.min_simple_ratio)
    parser.add_argument("--resume_from", type=str, default=None,
                       help="Resume from existing experiment directory")
    parser.add_argument("--resume_from_experiment", type=str, default=None,
                       help="Resume from another experiment (creates new experiment)")
    parser.add_argument("--resume_from_round", type=int, default=None,
                       help="Round to resume from (required with --resume_from_experiment)")
    parser.add_argument("--eval_n_instances", type=int, default=default_cfg.eval_n_instances)
    
    # Meta-game solver arguments
    parser.add_argument("--meta_game_solver", type=str, default=None,
                       choices=["ne", "alpha_rank"],
                       help=f"Meta-game solver method: 'ne' (Nash equilibrium) or 'alpha_rank' (default: from config file or {default_cfg.meta_game_solver})")
    parser.add_argument("--alpha_rank_alpha", type=float, default=None,
                       help=f"Alpha-Rank selection intensity (default: from config file or {default_cfg.alpha_rank_alpha})")
    parser.add_argument("--alpha_rank_num_iters", type=int, default=None,
                       help=f"Alpha-Rank maximum iterations (default: from config file or {default_cfg.alpha_rank_num_iters})")
    parser.add_argument("--alpha_rank_tol", type=float, default=None,
                       help=f"Alpha-Rank convergence tolerance (default: from config file or {default_cfg.alpha_rank_tol})")
    
    args = parser.parse_args()
    
    # Load configuration
    if args.config:
        with open(args.config, "r") as f:
            cfg_data = json.load(f)
        cfg = TSPGLSConfig(**{**default_cfg.__dict__, **cfg_data})
    else:
        cfg = TSPGLSConfig(**default_cfg.__dict__)
    
    # Apply CLI overrides
    cfg.experiment_name = args.experiment_name
    cfg.max_rounds = args.max_rounds
    cfg.pop_size = args.pop_size
    # Only override solver_n_pop and generator_n_pop if explicitly provided via CLI
    if args.solver_n_pop is not None:
        cfg.solver_n_pop = args.solver_n_pop
    if args.generator_n_pop is not None:
        cfg.generator_n_pop = args.generator_n_pop
    cfg.ec_operators = args.ec_operators
    cfg.n_cities = args.n_cities
    if args.tsp_solver_time_limit is not None:
        cfg.instance_solver_time_limit = args.tsp_solver_time_limit
    else:
        cfg.instance_solver_time_limit = args.instance_solver_time_limit
    cfg.tsp_solver_max_iterations = args.tsp_solver_max_iterations
    if args.gap_oracle is not None:
        cfg.oracle_type = args.gap_oracle
    else:
        cfg.oracle_type = args.oracle_type
    if args.gap_oracle_timeout is not None:
        cfg.oracle_timeout = args.gap_oracle_timeout
    else:
        cfg.oracle_timeout = args.oracle_timeout
    cfg.lkh_runs = args.lkh_runs
    cfg.eoh_eval_n_instances = args.eoh_eval_n_instances
    cfg.seed = args.seed
    cfg.debug_mode = args.debug
    cfg.eoh_management_strategy = args.eoh_management_strategy
    cfg.evolution_context_enabled = args.evolution_context_enabled
    cfg.disable_generator_evolution = args.disable_generator_evolution
    cfg.psro_use_latest_only = args.psro_use_latest_only
    cfg.min_simple_ratio = args.min_simple_ratio
    cfg.eval_n_instances = args.eval_n_instances
    # Only override meta_game_solver and alpha_rank parameters if explicitly provided via CLI
    if args.meta_game_solver is not None:
        cfg.meta_game_solver = args.meta_game_solver
    if args.alpha_rank_alpha is not None:
        cfg.alpha_rank_alpha = args.alpha_rank_alpha
    if args.alpha_rank_num_iters is not None:
        cfg.alpha_rank_num_iters = args.alpha_rank_num_iters
    if args.alpha_rank_tol is not None:
        cfg.alpha_rank_tol = args.alpha_rank_tol
    
    if not hasattr(cfg, 'save_frequency') or cfg.save_frequency is None:
        cfg.save_frequency = 1
    if not hasattr(cfg, 'log_frequency') or cfg.log_frequency is None:
        cfg.log_frequency = 1
    
    if args.resume_from_experiment and args.resume_from:
        raise ValueError("Cannot use both --resume_from and --resume_from_experiment")
    
    exp_dir = run_experiment(cfg, resume_from=args.resume_from,
                            resume_from_experiment=args.resume_from_experiment,
                            resume_from_round=args.resume_from_round)
    print(f"Experiment results saved to: {exp_dir}")


if __name__ == "__main__":
    main()


