#!/usr/bin/env python3
"""
ICML 2026 Experiment Runner for ACEAS Paper.

This script runs all experiments needed for the ICML 2026 submission:
1. Main experiments with scaled model (Qwen3-0.6B)
2. Ablation studies
3. Staleness-difficulty grid search
4. Gradient variance analysis
5. Hyperparameter sensitivity sweeps

Usage:
    python run_icml2026_experiments.py --phase main
    python run_icml2026_experiments.py --phase all --output-dir ./icml_results
"""

import argparse
import json
import logging
import os
import sys
import time
from pathlib import Path
from typing import Dict, Any, List, Optional

# Add parent directory to path
sys.path.insert(0, str(Path(__file__).parent.parent))

import numpy as np

logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
)
logger = logging.getLogger(__name__)


# ==============================================================================
# Configuration for ICML 2026 Experiments
# ==============================================================================

ICML_CONFIG = {
    # Model configuration (scaled for ICML)
    "model": {
        "name": "Qwen/Qwen3-0.6B",
        "use_lora": True,  # Enable LoRA to reduce memory
        "lora_r": 32,
        "lora_alpha": 64,
        "gradient_checkpointing": True,  # Enable to reduce memory
        "torch_dtype": "float16",
    },

    # Training configuration (optimized for faster convergence)
    "training": {
        "total_timesteps": 3000,   # Reduced from 10000 (early stopping will kick in if needed)
        "batch_size": 16,          # Reduced to avoid OOM with async methods
        "learning_rate": 2e-5,     # Slightly higher for faster initial learning
        "eval_interval": 200,      # More frequent early evaluation
        "log_interval": 1,         # Log every step for visibility
        "early_stopping_patience": 5,  # Stop if no improvement for 5 evals
    },

    # Compute configuration (1x L40S)
    "compute": {
        "num_workers": 1,
        "gpu_per_worker": 1.0,  # Use full GPU
    },

    # Experiment seeds
    "seeds": [42],  # Single seed for faster experiments

    # Methods to compare
    "methods": ["aceas"],  # Only rerun aceas (others completed)

    # Datasets
    "datasets": ["humaneval", "mbpp_plus"],
}


def get_method_config(method: str, base_config: Dict[str, Any]) -> Dict[str, Any]:
    """Get configuration for a specific method."""
    config = base_config.copy()

    if method == "sync":
        config["use_async"] = False
        config["curriculum_strategy"] = "uniform"
        config["use_csc"] = False
        config["use_eaas"] = False

    elif method == "sync_curriculum":
        config["use_async"] = False
        config["curriculum_strategy"] = "fixed"
        config["use_csc"] = False
        config["use_eaas"] = False

    elif method == "async_staleness":
        config["use_async"] = True
        config["curriculum_strategy"] = "uniform"
        config["use_csc"] = True
        config["use_eaas"] = False

    elif method == "aceas":
        config["use_async"] = True
        config["curriculum_strategy"] = "adaptive"
        config["use_csc"] = True
        config["use_eaas"] = True

    return config


def reaggregate_results(output_dir: Path, methods: List[str], seeds: List[int]) -> Dict[str, Any]:
    """Re-aggregate results from existing seed directories."""
    results = {}
    
    for method in methods:
        method_results = []
        for seed in seeds:
            seed_path = output_dir / method / f"seed_{seed}" / "results.json"
            if seed_path.exists():
                with open(seed_path) as f:
                    method_results.append(json.load(f))
            else:
                # Try alternate path structure (flat)
                seed_path = output_dir / f"{method}_seed{seed}" / "results.json"
                if seed_path.exists():
                    with open(seed_path) as f:
                        method_results.append(json.load(f))
        
        if method_results:
            results[method] = aggregate_results(method_results)
            
    return results

def run_main_experiments(
    output_dir: Path,
    seeds: List[int],
    methods: List[str],
    config: Dict[str, Any],
    quick_mode: bool = False,
) -> Dict[str, Any]:
    """
    Run main comparison experiments.

    Args:
        output_dir: Directory to save results
        seeds: List of random seeds
        methods: List of methods to compare
        config: Base configuration
        quick_mode: If True, run abbreviated experiments

    Returns:
        Dictionary with all results
    """
    logger.info("=" * 60)
    logger.info("Running MAIN EXPERIMENTS")
    logger.info("=" * 60)

    from src.training.aceas_trainer import ACEASTrainer, ACEASConfig
    from src.code_environment.code_env import load_humaneval_tasks, create_synthetic_tasks

    results = {}

    # Adjust for quick mode
    timesteps = 500 if quick_mode else config["training"]["total_timesteps"]
    batch_size = 8 if quick_mode else config["training"]["batch_size"]

    for method in methods:
        method_results = []

        for seed in seeds:
            logger.info(f"\n--- Running {method} with seed {seed} ---")

            # Set seed
            np.random.seed(seed)

            # Load tasks
            tasks = create_synthetic_tasks(20) if quick_mode else load_humaneval_tasks()

            # Create method-specific config
            method_config = get_method_config(method, {
                "model_name": config["model"]["name"],
                "use_lora": config["model"]["use_lora"],
                "lora_r": config["model"]["lora_r"],
                "lora_alpha": config["model"]["lora_alpha"],
                "gradient_checkpointing": config["model"]["gradient_checkpointing"],
                "total_timesteps": timesteps,
                "batch_size": batch_size,
                "learning_rate": config["training"].get("learning_rate", 2e-5),
                "eval_interval": config["training"].get("eval_interval", 200),
                "num_workers": config["compute"]["num_workers"],
                "gpu_per_worker": config["compute"]["gpu_per_worker"],
                "use_local_mode": quick_mode,  # Use local mode for quick tests
                "log_interval": config["training"].get("log_interval", 1),  # Log progress
                "early_stopping_patience": config["training"].get("early_stopping_patience", 5),
            })

            try:
                aceas_config = ACEASConfig(**method_config)
                trainer = ACEASTrainer(
                    tasks=tasks,
                    config=aceas_config,
                    output_dir=str(output_dir / method / f"seed_{seed}")
                )

                result = trainer.train()
                method_results.append(result)

                logger.info(f"Completed {method} seed {seed}: "
                           f"Pass@1={result.get('eval_metrics', [{}])[-1].get('pass_at_1', 0):.2%}")

            except Exception as e:
                logger.error(f"Failed {method} seed {seed}: {e}")
                method_results.append({"error": str(e)})

        # Aggregate results
        results[method] = aggregate_results(method_results)

    # Save aggregated results
    results_path = output_dir / "all_results.json"
    with open(results_path, "w") as f:
        json.dump(results, f, indent=2, default=str)

    logger.info(f"\nMain results saved to {results_path}")
    return results


def run_ablation_study(
    output_dir: Path,
    seed: int,
    config: Dict[str, Any],
    quick_mode: bool = False,
) -> Dict[str, Any]:
    """
    Run ablation study removing each component.

    Args:
        output_dir: Directory to save results
        seed: Random seed
        config: Base configuration
        quick_mode: If True, run abbreviated experiments

    Returns:
        Dictionary with ablation results
    """
    logger.info("=" * 60)
    logger.info("Running ABLATION STUDY")
    logger.info("=" * 60)

    from src.training.aceas_trainer import ACEASTrainer, ACEASConfig
    from src.code_environment.code_env import load_humaneval_tasks, create_synthetic_tasks

    ablation_configs = {
        "full": {"use_csc": True, "use_eaas": True, "curriculum_strategy": "adaptive"},
        "no_csc": {"use_csc": False, "use_eaas": True, "curriculum_strategy": "adaptive"},
        "no_eaas": {"use_csc": True, "use_eaas": False, "curriculum_strategy": "adaptive"},
        "no_acb": {"use_csc": True, "use_eaas": True, "curriculum_strategy": "fixed"},
    }

    results = {}
    timesteps = 500 if quick_mode else config["training"]["total_timesteps"]

    tasks = create_synthetic_tasks(20) if quick_mode else load_humaneval_tasks()

    for ablation_name, ablation_overrides in ablation_configs.items():
        logger.info(f"\n--- Running ablation: {ablation_name} ---")

        np.random.seed(seed)

        method_config = {
            "model_name": config["model"]["name"],
            "use_lora": config["model"]["use_lora"],
            "gradient_checkpointing": config["model"]["gradient_checkpointing"],
            "total_timesteps": timesteps,
            "batch_size": 8 if quick_mode else config["training"]["batch_size"],
            "learning_rate": config["training"].get("learning_rate", 2e-5),
            "eval_interval": config["training"].get("eval_interval", 200),
            "num_workers": config["compute"]["num_workers"],
            "use_async": True,
            "use_local_mode": quick_mode,
            "early_stopping_patience": config["training"].get("early_stopping_patience", 5),
            **ablation_overrides,
        }

        try:
            aceas_config = ACEASConfig(**method_config)
            trainer = ACEASTrainer(
                tasks=tasks,
                config=aceas_config,
                output_dir=str(output_dir / "ablations" / ablation_name)
            )

            result = trainer.train()
            results[ablation_name] = {
                "final_pass_at_1": result.get("eval_metrics", [{}])[-1].get("pass_at_1", 0),
                "avg_throughput": result.get("avg_throughput", 0),
                "train_metrics": result.get("train_metrics", []),
            }

        except Exception as e:
            logger.error(f"Failed ablation {ablation_name}: {e}")
            results[ablation_name] = {"error": str(e)}

    # Save results
    results_path = output_dir / "ablations" / "ablation_results.json"
    results_path.parent.mkdir(parents=True, exist_ok=True)
    with open(results_path, "w") as f:
        json.dump(results, f, indent=2, default=str)

    logger.info(f"\nAblation results saved to {results_path}")
    return results


def run_grid_search(
    output_dir: Path,
    quick_mode: bool = False,
) -> Dict[str, Any]:
    """
    Run staleness-difficulty grid search experiment.

    This validates the theoretical prediction that gradient bias
    grows exponentially with difficulty under staleness.
    """
    logger.info("=" * 60)
    logger.info("Running STALENESS-DIFFICULTY GRID SEARCH")
    logger.info("=" * 60)

    from src.experiments.staleness_difficulty_grid import (
        run_staleness_difficulty_grid,
        StalenessGridConfig,
    )

    config = StalenessGridConfig(
        samples_per_cell=10 if quick_mode else 50,
        output_dir=str(output_dir / "grid_search"),
    )

    results = run_staleness_difficulty_grid(config)

    logger.info(f"\nGrid search results saved to {config.output_dir}")
    return results


def run_variance_analysis(
    output_dir: Path,
    quick_mode: bool = False,
) -> Dict[str, Any]:
    """
    Run gradient variance analysis experiment.

    This validates that ACB reduces gradient variance and that
    SNR peaks at moderate difficulties.
    """
    logger.info("=" * 60)
    logger.info("Running GRADIENT VARIANCE ANALYSIS")
    logger.info("=" * 60)

    from src.experiments.gradient_variance import (
        run_gradient_variance_analysis,
        GradientVarianceConfig,
    )

    config = GradientVarianceConfig(
        samples_per_difficulty=20 if quick_mode else 100,
        num_training_steps=20 if quick_mode else 100,
        output_dir=str(output_dir / "variance_analysis"),
    )

    results = run_gradient_variance_analysis(config)

    logger.info(f"\nVariance analysis results saved to {config.output_dir}")
    return results


def run_hyperparam_sweep(
    output_dir: Path,
    config: Dict[str, Any],
    quick_mode: bool = False,
) -> Dict[str, Any]:
    """
    Run hyperparameter sensitivity sweep for lambda (coupling strength).
    """
    logger.info("=" * 60)
    logger.info("Running HYPERPARAMETER SENSITIVITY SWEEP")
    logger.info("=" * 60)

    from src.training.aceas_trainer import ACEASTrainer, ACEASConfig
    from src.code_environment.code_env import load_humaneval_tasks, create_synthetic_tasks

    lambda_values = [0.3, 0.5, 0.7] if quick_mode else [0.1, 0.3, 0.5, 0.7, 0.9]
    results = {}

    tasks = create_synthetic_tasks(20) if quick_mode else load_humaneval_tasks()

    for lambda_val in lambda_values:
        logger.info(f"\n--- Running lambda = {lambda_val} ---")

        np.random.seed(42)

        method_config = {
            "model_name": config["model"]["name"],
            "use_lora": config["model"]["use_lora"],
            "gradient_checkpointing": config["model"]["gradient_checkpointing"],
            "total_timesteps": 500 if quick_mode else config["training"]["total_timesteps"],
            "batch_size": 8 if quick_mode else config["training"]["batch_size"],
            "learning_rate": config["training"].get("learning_rate", 2e-5),
            "eval_interval": config["training"].get("eval_interval", 200),
            "num_workers": config["compute"]["num_workers"],
            "use_async": True,
            "use_csc": True,
            "use_eaas": True,
            "curriculum_strategy": "adaptive",
            "lambda_coupling": lambda_val,
            "use_local_mode": quick_mode,
            "early_stopping_patience": config["training"].get("early_stopping_patience", 5),
        }

        try:
            aceas_config = ACEASConfig(**method_config)
            trainer = ACEASTrainer(
                tasks=tasks,
                config=aceas_config,
                output_dir=str(output_dir / "hyperparam" / f"lambda_{lambda_val}")
            )

            result = trainer.train()
            results[str(lambda_val)] = {
                "final_pass_at_1": result.get("eval_metrics", [{}])[-1].get("pass_at_1", 0),
                "avg_throughput": result.get("avg_throughput", 0),
            }

        except Exception as e:
            logger.error(f"Failed lambda={lambda_val}: {e}")
            results[str(lambda_val)] = {"error": str(e)}

    # Save results
    results_path = output_dir / "hyperparam" / "lambda_sweep.json"
    results_path.parent.mkdir(parents=True, exist_ok=True)
    with open(results_path, "w") as f:
        json.dump(results, f, indent=2, default=str)

    logger.info(f"\nHyperparameter sweep results saved to {results_path}")
    return results


def generate_figures(output_dir: Path) -> None:
    """Generate all ICML figures from results."""
    logger.info("=" * 60)
    logger.info("GENERATING FIGURES")
    logger.info("=" * 60)

    from src.utils.visualization import generate_icml_figures, generate_all_figures

    figures_dir = output_dir / "figures"
    figures_dir.mkdir(parents=True, exist_ok=True)

    try:
        generate_all_figures(str(output_dir), str(figures_dir))
    except Exception as e:
        logger.error(f"Standard figure generation failed: {e}")

    try:
        generate_icml_figures(
            results_dir=str(output_dir),
            grid_results_path=str(output_dir / "grid_search" / "staleness_difficulty_grid.json"),
            variance_results_path=str(output_dir / "variance_analysis" / "gradient_variance_analysis.json"),
            output_dir=str(figures_dir),
        )
    except Exception as e:
        logger.error(f"Figure generation failed: {e}")


def aggregate_results(results_list: List[Dict[str, Any]]) -> Dict[str, Any]:
    """Aggregate results across seeds."""
    valid_results = [r for r in results_list if "error" not in r]

    if not valid_results:
        return {"error": "All runs failed"}

    # Get final pass@1 from eval metrics
    final_pass_rates = []
    throughputs = []

    for r in valid_results:
        eval_metrics = r.get("eval_metrics", [])
        if eval_metrics:
            final_pass_rates.append(eval_metrics[-1].get("pass_at_1", 0))
        throughputs.append(r.get("avg_throughput", 0))

    # Aggregate train_metrics across seeds
    aggregated_train = {} # timestep -> {key -> [values]}
    
    for r in valid_results:
        for m in r.get("train_metrics", []):
            ts = m.get("timestep")
            if ts is None: continue
            
            if ts not in aggregated_train:
                aggregated_train[ts] = {}
            
            for k, v in m.items():
                if isinstance(v, (int, float)):
                    if k not in aggregated_train[ts]:
                        aggregated_train[ts][k] = []
                    aggregated_train[ts][k].append(v)

    # Calculate mean and std
    final_train_metrics = []
    for ts in sorted(aggregated_train.keys()):
        metrics = {"timestep": ts}
        for k, values in aggregated_train[ts].items():
            if k == "timestep": continue
            metrics[k] = np.mean(values)
            metrics[f"{k}_std"] = np.std(values)
        final_train_metrics.append(metrics)

    return {
        "final_pass_at_1": np.mean(final_pass_rates) if final_pass_rates else 0,
        "final_pass_at_1_std": np.std(final_pass_rates) if final_pass_rates else 0,
        "avg_throughput": np.mean(throughputs) if throughputs else 0,
        "avg_throughput_std": np.std(throughputs) if throughputs else 0,
        "num_runs": len(valid_results),
        "train_metrics": final_train_metrics,
        "timing_metrics": valid_results[0].get("timing_metrics", []) if valid_results else [],
        "scheduler_stats": valid_results[0].get("scheduler_stats", {}) if valid_results else {},
    }


def main():
    parser = argparse.ArgumentParser(description="ICML 2026 ACEAS Experiments")
    parser.add_argument(
        "--phase",
        choices=["main", "ablation", "grid", "variance", "hyperparam", "figures", "all"],
        default="all",
        help="Which experiment phase to run"
    )
    parser.add_argument(
        "--output-dir",
        type=str,
        default="./experiments/icml2026_results",
        help="Output directory for results"
    )
    parser.add_argument(
        "--quick",
        action="store_true",
        help="Run quick test mode with reduced iterations"
    )
    parser.add_argument(
        "--seeds",
        type=int,
        nargs="+",
        default=[42],
        help="Random seeds to use"
    )

    args = parser.parse_args()

    output_dir = Path(args.output_dir)
    output_dir.mkdir(parents=True, exist_ok=True)

    logger.info(f"ICML 2026 ACEAS Experiments")
    logger.info(f"Output directory: {output_dir}")
    logger.info(f"Quick mode: {args.quick}")
    logger.info(f"Seeds: {args.seeds}")

    # Update config with command line seeds
    config = ICML_CONFIG.copy()
    config["seeds"] = args.seeds

    start_time = time.time()

    try:
        if args.phase in ["main", "all"]:
            run_main_experiments(
                output_dir, config["seeds"], config["methods"],
                config, quick_mode=args.quick
            )

        if args.phase in ["ablation", "all"]:
            run_ablation_study(
                output_dir, config["seeds"][0],
                config, quick_mode=args.quick
            )

        if args.phase in ["grid", "all"]:
            run_grid_search(output_dir, quick_mode=args.quick)

        if args.phase in ["variance", "all"]:
            run_variance_analysis(output_dir, quick_mode=args.quick)

        if args.phase in ["hyperparam", "all"]:
            run_hyperparam_sweep(output_dir, config, quick_mode=args.quick)

        if args.phase in ["figures", "all"]:
            # Re-aggregate first to ensure we have latest data/stats
            logger.info("Re-aggregating results...")
            results = reaggregate_results(output_dir, config["methods"], config["seeds"])
            
            # Save aggregated results
            results_path = output_dir / "all_results.json"
            with open(results_path, "w") as f:
                json.dump(results, f, indent=2, default=str)
                
            generate_figures(output_dir)

    except KeyboardInterrupt:
        logger.info("Interrupted by user")

    elapsed = time.time() - start_time
    logger.info(f"\n{'='*60}")
    logger.info(f"Total time: {elapsed/60:.1f} minutes")
    logger.info(f"Results saved to: {output_dir}")


if __name__ == "__main__":
    main()
