#!/usr/bin/env python3
"""
Resume PSRO experiment from a previous checkpoint.

This script allows you to resume a PSRO experiment from any previous round,
continuing both the PSRO iteration count and EOH population states.
"""

import argparse
import json
import os
import sys
from typing import Optional
from datetime import datetime

# Add the project root to the path
# Add the project root to Python path
project_root = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
sys.path.insert(0, project_root)

from heupsro.core.config import HeuPSROConfig
from heupsro.core.controller import HeuPSROController


def detect_problem_adapter(cfg_data: dict):
    """Auto-detect problem type and return appropriate adapter."""
    param_compat = {
        "gap_oracle_timeout_s": "oracle_timeout",  
        "generator_lkh_timeout": "oracle_timeout",  
        "gap_oracle_timeout": "oracle_timeout", 
        "tsp_solver_time_limit": "instance_solver_time_limit", 
        "oracle_parallel_n_jobs": "optimal_parallel_n_jobs",
    }
    
    for old_key, new_key in param_compat.items():
        if old_key in cfg_data and new_key not in cfg_data:
            cfg_data[new_key] = cfg_data[old_key]
    
    # Check for BP Online problem (has bp_capacity)
    if 'bp_capacity' in cfg_data or 'bp_num_items' in cfg_data:
        from heupsro.problems.bp_online.config import BPOnlineConfig
        from heupsro.problems.bp_online.adapter import BPProblemAdapter
        cfg = BPOnlineConfig()
        # Load config fields
        config_fields = {f.name for f in cfg.__dataclass_fields__.values()}
        for key, value in cfg_data.items():
            if key in config_fields:
                setattr(cfg, key, value)
        return BPProblemAdapter(cfg), cfg
    
    # Check for TSP problem (has n_cities)
    if 'n_cities' in cfg_data:
        from heupsro.problems.tsp_gls.config import TSPGLSConfig
        from heupsro.problems.tsp_gls.adapter import TSPProblemAdapter
        cfg = TSPGLSConfig()
        config_fields = {f.name for f in cfg.__dataclass_fields__.values()}
        for key, value in cfg_data.items():
            if key in config_fields:
                setattr(cfg, key, value)
        return TSPProblemAdapter(cfg), cfg
    
    # Check for FSSP problem (has num_items and num_machines)
    if 'num_items' in cfg_data or 'num_machines' in cfg_data:
        from heupsro.problems.fssp_gls.config import FSSPGLSConfig
        from heupsro.problems.fssp_gls.adapter import FSSPProblemAdapter
        cfg = FSSPGLSConfig()
        config_fields = {f.name for f in cfg.__dataclass_fields__.values()}
        for key, value in cfg_data.items():
            if key in config_fields:
                setattr(cfg, key, value)
        return FSSPProblemAdapter(cfg), cfg
    
    # Check for CVRP problem (has num_customers and vehicle_capacity)
    if 'num_customers' in cfg_data or 'vehicle_capacity' in cfg_data:
        from heupsro.problems.cvrp.config import CVRPConfig
        from heupsro.problems.cvrp.adapter import CVRPProblemAdapter
        cfg = CVRPConfig()
        config_fields = {f.name for f in cfg.__dataclass_fields__.values()}
        for key, value in cfg_data.items():
            if key in config_fields:
                setattr(cfg, key, value)
        return CVRPProblemAdapter(cfg), cfg
    
    # Default: raise error
    raise ValueError(
        "Cannot auto-detect problem type. "
        "Expected 'bp_capacity' (BP Online), 'n_cities' (TSP), 'num_items'/'num_machines' (FSSP), "
        "or 'num_customers'/'vehicle_capacity' (CVRP) in config."
    )


def load_experiment_config(experiment_dir: str) -> HeuPSROConfig:
    """Load the original experiment configuration."""
    config_path = os.path.join(experiment_dir, "config.json")
    
    if not os.path.exists(config_path):
        raise FileNotFoundError(f"Config file not found: {config_path}")
    
    with open(config_path, "r") as f:
        config_data = json.load(f)
    
    # Create HeuPSROConfig from the loaded data
    cfg = HeuPSROConfig()
    
    param_compat = {
        "gap_oracle_timeout_s": "oracle_timeout", 
        "generator_lkh_timeout": "oracle_timeout", 
        "gap_oracle_timeout": "oracle_timeout",  
        "tsp_solver_time_limit": "instance_solver_time_limit",  
        "oracle_parallel_n_jobs": "optimal_parallel_n_jobs", 
    }
    
    for old_key, new_key in param_compat.items():
        if old_key in config_data and new_key not in config_data:
            config_data[new_key] = config_data[old_key]
    
    config_fields = {f.name for f in cfg.__dataclass_fields__.values()}
    for key, value in config_data.items():
        if key in config_fields:
            setattr(cfg, key, value)
    
    return cfg


def get_experiment_summary(experiment_dir: str) -> dict:
    """Get a summary of the experiment state."""
    summary_path = os.path.join(experiment_dir, "experiment_summary.json")
    
    if os.path.exists(summary_path):
        with open(summary_path, "r") as f:
            return json.load(f)
    
    # Fallback: try to get info from checkpoint files
    psro_results_dir = os.path.join(experiment_dir, "psro_results")
    if os.path.exists(psro_results_dir):
        checkpoint_files = []
        for f in os.listdir(psro_results_dir):
            if f.startswith("checkpoint_round_") and f.endswith(".json"):
                try:
                    round_num = int(f.split("_")[2].split(".")[0])
                    checkpoint_files.append(round_num)
                except (IndexError, ValueError):
                    continue
        
        if checkpoint_files:
            latest_round = max(checkpoint_files)
            return {
                "latest_round": latest_round,
                "total_rounds": len(checkpoint_files)
            }
    
    return {}


def main():
    parser = argparse.ArgumentParser(description="Resume PSRO experiment from checkpoint")
    parser.add_argument("--experiment_dir", required=True,
                       help="Path to the experiment directory to resume from")
    parser.add_argument("--resume_from_round", type=int, default=None,
                       help="Specific round to resume from (default: latest checkpoint)")
    parser.add_argument("--output_dir", default=None,
                       help="Output directory for resumed experiment (default: same as experiment_dir)")
    parser.add_argument("--max_rounds", type=int, default=None,
                       help="Maximum rounds to run (default: use original config)")
    parser.add_argument("--dry_run", action="store_true",
                       help="Show what would be resumed without actually running")
    
    args = parser.parse_args()
    
    # Validate experiment directory
    if not os.path.exists(args.experiment_dir):
        print(f"❌ Error: Experiment directory not found: {args.experiment_dir}")
        return 1
    
    # Load original configuration
    try:
        config_path = os.path.join(args.experiment_dir, "config.json")
        if not os.path.exists(config_path):
            raise FileNotFoundError(f"Config file not found: {config_path}")
        
        with open(config_path, "r") as f:
            config_data = json.load(f)
        
        # Auto-detect problem type and create adapter
        problem_adapter, cfg = detect_problem_adapter(config_data)
        print(f"✅ Loaded configuration from {args.experiment_dir}")
        print(f"✅ Detected problem type: {type(problem_adapter).__name__}")
    except Exception as e:
        print(f"❌ Error loading configuration: {e}")
        return 1
    
    # Override max_rounds if specified
    if args.max_rounds is not None:
        cfg.max_rounds = args.max_rounds
        print(f"📊 Updated max_rounds to {args.max_rounds}")
    
    # Get experiment summary
    summary = get_experiment_summary(args.experiment_dir)
    print(f"📊 Experiment summary: {summary}")
    
    # Determine output directory
    output_dir = args.output_dir if args.output_dir else args.experiment_dir
    
    # Setup logging: capture all stdout/stderr to file in append mode, suppress warnings
    import warnings
    warnings.filterwarnings("ignore")
    
    class Tee:
        def __init__(self, stream, file_path):
            self.stream = stream
            # Use append mode ("a") to preserve existing log data
            self.file = open(file_path, "a", buffering=1)
        def write(self, data):
            self.stream.write(data)
            self.file.write(data)
        def flush(self):
            self.stream.flush()
            self.file.flush()
        def close(self):
            try:
                self.file.close()
            except Exception:
                pass
    
    log_path = os.path.join(output_dir, "console.log")
    # Add a separator line to mark resume session
    with open(log_path, "a", buffering=1) as f:
        f.write(f"\n{'='*60}\n")
        f.write(f"🔄 RESUMING EXPERIMENT at {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}\n")
        f.write(f"{'='*60}\n")
    
    # Redirect stdout and stderr to console.log (append mode)
    sys.stdout = Tee(sys.stdout, log_path)
    sys.stderr = Tee(sys.stderr, log_path)
    
    # Create controller and resume
    print(f"🔄 Creating controller and resuming from {args.experiment_dir}")
    controller = HeuPSROController(cfg, output_dir, problem_adapter=problem_adapter)
    
    try:
        controller.resume(args.experiment_dir, args.resume_from_round)
        print(f"✅ Successfully resumed experiment")
        print(f"📊 Current state: iteration={controller.iteration}")
        print(f"📊 Pools: {controller.pools.n_solvers} solvers, {controller.pools.n_generators} generators")
        print(f"📊 EOH generations: solver={controller._solver_eoh_generation}, generator={controller._generator_eoh_generation}")
        
        if args.dry_run:
            print(f"🔍 Dry run mode - would resume from iteration {controller.iteration} to {cfg.max_rounds}")
            return 0
        
        # Continue running from current iteration
        remaining_rounds = cfg.max_rounds - controller.iteration
        if remaining_rounds < 0:
            print(f"✅ Experiment already completed (iteration {controller.iteration} >= max_rounds {cfg.max_rounds})")
            return 0
        
        print(f"🚀 Continuing experiment for {remaining_rounds} more rounds...")
        print(f"📊 Target: iterations {controller.iteration + 1} to {cfg.max_rounds}")
        
        # Run remaining iterations
        # Note: iterate_one_round() increments iteration first, so if controller.iteration = N,
        # the first call will execute round N+1. We iterate from controller.iteration to cfg.max_rounds
        # (exclusive) so that the last executed round is cfg.max_rounds.
        for round_num in range(controller.iteration, cfg.max_rounds):
            print(f"\n{'='*60}")
            print(f"🔄 Starting PSRO Round {round_num + 1}")
            print(f"{'='*60}")
            
            controller.iterate_one_round()
            controller.save()
            
            print(f"✅ Completed round {controller.iteration}")
        
        print(f"\n🎉 Experiment completed successfully!")
        print(f"📊 Final state: {controller.pools.n_solvers} solvers, {controller.pools.n_generators} generators")
        print(f"📊 Results saved to: {output_dir}")
        
        return 0
        
    except Exception as e:
        print(f"❌ Error during experiment: {e}")
        import traceback
        traceback.print_exc()
        return 1


if __name__ == "__main__":
    exit(main())
