#!/usr/bin/env python3
"""
Real Training Script for ACEAS Paper.

This script runs actual training experiments (not synthetic data) to generate
results for the paper figures. Uses local mode for simplicity and smaller
scale experiments that are feasible to run.

Usage:
    python run_real_training.py --method aceas --steps 2000
    python run_real_training.py --all --steps 1000
"""

import argparse
import json
import logging
import os
import sys
import time
from pathlib import Path
from typing import Dict, Any, List, Optional

# Add parent directory to path
sys.path.insert(0, str(Path(__file__).parent.parent))

import numpy as np
import torch

logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
)
logger = logging.getLogger(__name__)


def run_training_experiment(
    method: str,
    output_dir: Path,
    total_steps: int = 2000,
    batch_size: int = 8,
    seed: int = 42,
    max_tasks: int = 50,
    use_small_model: bool = False,
) -> Dict[str, Any]:
    """
    Run a single training experiment with real training.
    
    Args:
        method: Training method (sync, sync_curriculum, async_staleness, aceas)
        output_dir: Directory to save results
        total_steps: Total training steps
        batch_size: Batch size
        seed: Random seed
        max_tasks: Maximum number of tasks to use
        use_small_model: Use smaller model for CPU training
        
    Returns:
        Dictionary with training results
    """
    from src.training.aceas_trainer import ACEASTrainer, ACEASConfig
    from src.code_environment.code_env import create_synthetic_tasks, load_humaneval_tasks
    from src.curriculum.difficulty_levels import CodeTask
    
    # Set seeds
    np.random.seed(seed)
    torch.manual_seed(seed)
    
    logger.info(f"Running {method} with seed {seed}, steps {total_steps}")
    
    # Check if CUDA is available
    has_cuda = torch.cuda.is_available()
    if not has_cuda:
        logger.warning("No CUDA available, using CPU-optimized settings")
        use_small_model = True
    
    # Load tasks - use synthetic for CPU training (faster execution)
    if use_small_model:
        logger.info("Using synthetic tasks for CPU training")
        tasks = create_synthetic_tasks(max_tasks)
        from src.curriculum.difficulty_levels import CodeTask as DifficultyCodeTask
        tasks = [
            DifficultyCodeTask(
                task_id=t.task_id,
                prompt=t.prompt,
                canonical_solution=t.canonical_solution,
                test_cases=t.test_cases,
                entry_point=t.entry_point,
            )
            for t in tasks
        ]
        logger.info(f"Created {len(tasks)} synthetic tasks")
    else:
        # Load tasks - try HumanEval first, fall back to synthetic
        try:
            tasks = load_humaneval_tasks()
            if len(tasks) == 0:
                raise ValueError("No tasks loaded")
            logger.info(f"Loaded {len(tasks)} HumanEval tasks")
        except Exception as e:
            logger.warning(f"Failed to load HumanEval: {e}, using synthetic tasks")
            tasks = create_synthetic_tasks(max_tasks)
            from src.curriculum.difficulty_levels import CodeTask as DifficultyCodeTask
            tasks = [
                DifficultyCodeTask(
                    task_id=t.task_id,
                    prompt=t.prompt,
                    canonical_solution=t.canonical_solution,
                    test_cases=t.test_cases,
                    entry_point=t.entry_point,
                )
                for t in tasks
            ]
            logger.info(f"Created {len(tasks)} synthetic tasks")
    
    # Limit tasks if needed
    if max_tasks and len(tasks) > max_tasks:
        tasks = tasks[:max_tasks]
    
    # Configure method
    method_configs = {
        "sync": {
            "use_async": False,
            "curriculum_strategy": "uniform",
            "use_csc": False,
            "use_eaas": False,
        },
        "sync_curriculum": {
            "use_async": False,
            "curriculum_strategy": "fixed",
            "use_csc": False,
            "use_eaas": False,
        },
        "async": {
            "use_async": True,
            "curriculum_strategy": "uniform",
            "use_csc": False,
            "use_eaas": False,
        },
        "async_staleness": {
            "use_async": True,
            "curriculum_strategy": "uniform",
            "use_csc": True,
            "use_eaas": False,
        },
        "aceas": {
            "use_async": True,
            "curriculum_strategy": "adaptive",
            "use_csc": True,
            "use_eaas": True,
        },
    }
    
    method_config = method_configs.get(method, method_configs["sync"])
    
    # Select model based on hardware
    if use_small_model:
        model_name = "gpt2"  # Small 124M model, works well on CPU
        lora_r = 4
        lora_alpha = 8
        max_new_tokens = 64
    else:
        model_name = "Salesforce/codegen-350M-mono"
        lora_r = 8
        lora_alpha = 16
        max_new_tokens = 128
    
    # Create config
    # Use local mode for CPU/debugging, enable async for GPU training
    use_local = use_small_model or not has_cuda

    config = ACEASConfig(
        model_name=model_name,
        use_lora=True,
        lora_r=lora_r,
        lora_alpha=lora_alpha,
        total_timesteps=total_steps,
        batch_size=batch_size,
        eval_interval=max(50, total_steps // 10),
        save_interval=total_steps,  # Only save at end
        log_interval=max(5, total_steps // 50),
        num_workers=4 if has_cuda else 2,
        use_local_mode=use_local,  # Enable async for GPU training
        learning_rate=2e-5,
        max_new_tokens=max_new_tokens,
        temperature=0.8,
        **method_config,
    )

    # Log configuration for diagnostics
    logger.info(f"Config: use_local_mode={config.use_local_mode}, "
                f"use_async={config.use_async}, num_workers={config.num_workers}")
    logger.info(f"Curriculum: {config.curriculum_strategy}, "
                f"use_csc={config.use_csc}, use_eaas={config.use_eaas}")
    
    # Create output directory
    exp_output_dir = output_dir / f"{method}_seed{seed}"
    exp_output_dir.mkdir(parents=True, exist_ok=True)
    
    # Create trainer and run
    try:
        trainer = ACEASTrainer(
            tasks=tasks,
            config=config,
            output_dir=str(exp_output_dir),
        )
        
        results = trainer.train()
        
        # Add method info
        results["method"] = method
        results["seed"] = seed
        results["config"] = {k: str(v) for k, v in config.__dict__.items()}
        
        # Save results
        results_path = exp_output_dir / "results.json"
        with open(results_path, "w") as f:
            json.dump(results, f, indent=2, default=str)
        
        logger.info(f"Results saved to {results_path}")
        return results
        
    except Exception as e:
        logger.error(f"Training failed for {method}: {e}")
        import traceback
        traceback.print_exc()
        return {"error": str(e), "method": method, "seed": seed}


def aggregate_results(
    results_dir: Path,
    methods: List[str],
    seeds: List[int],
) -> Dict[str, Any]:
    """Aggregate results across seeds for each method."""
    aggregated = {}
    
    for method in methods:
        method_results = []
        
        for seed in seeds:
            results_path = results_dir / f"{method}_seed{seed}" / "results.json"
            if results_path.exists():
                with open(results_path) as f:
                    method_results.append(json.load(f))
        
        if not method_results:
            logger.warning(f"No results found for {method}")
            continue
        
        # Filter out failed runs
        valid_results = [r for r in method_results if "error" not in r]
        if not valid_results:
            logger.warning(f"All runs failed for {method}")
            continue
        
        # Aggregate metrics
        final_pass_rates = []
        throughputs = []
        
        for r in valid_results:
            eval_metrics = r.get("eval_metrics", [])
            if eval_metrics:
                final_pass_rates.append(eval_metrics[-1].get("pass_at_1", 0))
            throughputs.append(r.get("avg_throughput", 0))
        
        # Aggregate train_metrics
        all_train_metrics = []
        for r in valid_results:
            train_metrics = r.get("train_metrics", [])
            if train_metrics:
                all_train_metrics.append(train_metrics)
        
        # Use first seed's metrics as base, add std if multiple seeds
        base_train = valid_results[0].get("train_metrics", [])
        
        aggregated[method] = {
            "final_pass_at_1": float(np.mean(final_pass_rates)) if final_pass_rates else 0,
            "final_pass_at_1_std": float(np.std(final_pass_rates)) if len(final_pass_rates) > 1 else 0,
            "avg_throughput": float(np.mean(throughputs)) if throughputs else 0,
            "avg_throughput_std": float(np.std(throughputs)) if len(throughputs) > 1 else 0,
            "num_runs": len(valid_results),
            "train_metrics": base_train,
            "timing_metrics": valid_results[0].get("timing_metrics", []),
            "eval_metrics": valid_results[0].get("eval_metrics", []),
            "scheduler_stats": valid_results[0].get("scheduler_stats", {}),
        }
    
    return aggregated


def generate_figures(results_dir: Path) -> None:
    """Generate figures from results."""
    from src.utils.visualization import generate_all_figures
    
    figures_dir = results_dir / "figures"
    figures_dir.mkdir(parents=True, exist_ok=True)
    
    try:
        generate_all_figures(str(results_dir), str(figures_dir))
        logger.info(f"Figures generated in {figures_dir}")
    except Exception as e:
        logger.error(f"Figure generation failed: {e}")
        import traceback
        traceback.print_exc()


def main():
    parser = argparse.ArgumentParser(description="Run real ACEAS training experiments")
    parser.add_argument(
        "--method",
        choices=["sync", "sync_curriculum", "async", "async_staleness", "aceas"],
        help="Training method to run"
    )
    parser.add_argument(
        "--all",
        action="store_true",
        help="Run all methods"
    )
    parser.add_argument(
        "--steps",
        type=int,
        default=1000,
        help="Total training steps"
    )
    parser.add_argument(
        "--batch-size",
        type=int,
        default=8,
        help="Batch size"
    )
    parser.add_argument(
        "--seeds",
        type=int,
        nargs="+",
        default=[42],
        help="Random seeds"
    )
    parser.add_argument(
        "--output-dir",
        type=str,
        default="./experiments/real_results",
        help="Output directory"
    )
    parser.add_argument(
        "--max-tasks",
        type=int,
        default=50,
        help="Maximum number of tasks"
    )
    parser.add_argument(
        "--figures-only",
        action="store_true",
        help="Only generate figures from existing results"
    )
    parser.add_argument(
        "--small-model",
        action="store_true",
        help="Use smaller model for CPU training"
    )
    
    args = parser.parse_args()
    
    output_dir = Path(args.output_dir)
    output_dir.mkdir(parents=True, exist_ok=True)
    
    if args.figures_only:
        # Load existing results and generate figures
        all_results_path = output_dir / "all_results.json"
        if all_results_path.exists():
            generate_figures(output_dir)
        else:
            logger.error("No results found to generate figures from")
        return
    
    # Determine which methods to run
    if args.all:
        methods = ["sync", "sync_curriculum", "async_staleness", "aceas"]
    elif args.method:
        methods = [args.method]
    else:
        parser.error("Either --method or --all must be specified")
        return
    
    logger.info(f"Running methods: {methods}")
    logger.info(f"Steps: {args.steps}, Batch size: {args.batch_size}")
    logger.info(f"Seeds: {args.seeds}")
    logger.info(f"Output: {output_dir}")
    
    start_time = time.time()
    
    # Run experiments
    for method in methods:
        for seed in args.seeds:
            logger.info(f"\n{'='*60}")
            logger.info(f"Running {method} with seed {seed}")
            logger.info(f"{'='*60}")
            
            run_training_experiment(
                method=method,
                output_dir=output_dir,
                total_steps=args.steps,
                batch_size=args.batch_size,
                seed=seed,
                max_tasks=args.max_tasks,
                use_small_model=args.small_model,
            )
    
    # Aggregate results
    logger.info("\nAggregating results...")
    aggregated = aggregate_results(output_dir, methods, args.seeds)
    
    # Save aggregated results
    all_results_path = output_dir / "all_results.json"
    with open(all_results_path, "w") as f:
        json.dump(aggregated, f, indent=2, default=str)
    logger.info(f"Aggregated results saved to {all_results_path}")
    
    # Generate figures
    logger.info("\nGenerating figures...")
    generate_figures(output_dir)
    
    elapsed = time.time() - start_time
    logger.info(f"\n{'='*60}")
    logger.info(f"Total time: {elapsed/60:.1f} minutes")
    logger.info(f"Results saved to: {output_dir}")
    
    # Print summary
    print("\n" + "="*60)
    print("RESULTS SUMMARY")
    print("="*60)
    print(f"{'Method':<25} {'Pass@1':>10} {'Throughput':>12}")
    print("-"*60)
    for method, results in aggregated.items():
        pass_rate = results.get("final_pass_at_1", 0)
        throughput = results.get("avg_throughput", 0)
        print(f"{method:<25} {pass_rate:>10.2%} {throughput:>10.1f}/s")


if __name__ == "__main__":
    main()
