"""
Lambda Sweep Experiment for Validating Theoretical Optimal Staleness Budget.

This experiment validates the theoretical prediction that the optimal coupling
parameter lambda = alpha/2 (half the Hessian growth rate) sits at the Pareto
frontier of throughput vs. performance.

Key Hypothesis:
- Theoretical derivation (Theorem 2) predicts lambda = alpha/2
- This experiment sweeps lambda values and measures:
  1. Pass@1 (final performance)
  2. Throughput (samples/second)
  3. Sample efficiency (samples to reach target)
- The optimal lambda should balance throughput and gradient quality

This addresses reviewer concern Q4: "No validation that λ=0.5 is actually optimal
as theory predicts."

Usage:
    # Run with training (requires GPU)
    python lambda_sweep.py --seeds 42 123 456 789 2024

    # Quick test with reduced steps
    python lambda_sweep.py --quick
"""

import json
import logging
import numpy as np
import time
from dataclasses import dataclass, field
from pathlib import Path
from typing import Dict, List, Optional, Any, Tuple
from collections import defaultdict

logger = logging.getLogger(__name__)


@dataclass
class LambdaSweepConfig:
    """Configuration for lambda sweep experiment."""
    # Lambda values to sweep
    lambda_values: List[float] = field(
        default_factory=lambda: [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0]
    )

    # Theoretical optimal: lambda = alpha/2 where alpha is the Hessian growth rate.
    # Empirically, alpha ~ 1.0 for code tasks (pass rates drop by ~e per difficulty level),
    # yielding theoretical_optimal = 0.5. This matches the lambda_coupling in ACEASConfig.
    theoretical_optimal: float = 0.5

    # Training configuration
    training_steps_per_lambda: int = 10000
    eval_interval: int = 500
    seeds: List[int] = field(default_factory=lambda: [42, 123, 456])

    # Early stopping
    early_stop_threshold: float = 0.01  # Stop if improvement < 1%
    early_stop_patience: int = 5  # Number of evals without improvement

    # Output configuration
    output_dir: str = "experiments/lambda_sweep"


@dataclass
class LambdaSweepResult:
    """Result for a single lambda value across seeds."""
    lambda_value: float
    seeds: List[int]

    # Per-seed metrics
    final_pass_at_1: List[float] = field(default_factory=list)
    avg_throughput: List[float] = field(default_factory=list)
    total_time: List[float] = field(default_factory=list)
    samples_to_30: List[Optional[int]] = field(default_factory=list)  # Samples to reach 30% pass@1

    # Aggregated metrics
    mean_pass_at_1: float = 0.0
    std_pass_at_1: float = 0.0
    mean_throughput: float = 0.0
    std_throughput: float = 0.0
    mean_time: float = 0.0

    # Per-difficulty discard rates
    discard_rates: Dict[int, float] = field(default_factory=dict)

    def aggregate(self):
        """Compute aggregate statistics."""
        if self.final_pass_at_1:
            self.mean_pass_at_1 = np.mean(self.final_pass_at_1)
            self.std_pass_at_1 = np.std(self.final_pass_at_1)
        if self.avg_throughput:
            self.mean_throughput = np.mean(self.avg_throughput)
            self.std_throughput = np.std(self.avg_throughput)
        if self.total_time:
            self.mean_time = np.mean(self.total_time)


def run_single_lambda_experiment(
    lambda_value: float,
    seed: int,
    config: LambdaSweepConfig,
    trainer_class,
    trainer_config_class,
    tasks: List[Any],
) -> Dict[str, Any]:
    """
    Run a single training experiment with specified lambda.

    Args:
        lambda_value: The lambda coupling value to test
        seed: Random seed
        config: Sweep configuration
        trainer_class: ACEASTrainer class
        trainer_config_class: ACEASConfig class
        tasks: Task list

    Returns:
        Dictionary with experiment results
    """
    import torch
    import random

    # Set seeds
    np.random.seed(seed)
    random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)

    # Create trainer config with this lambda
    trainer_config = trainer_config_class(
        lambda_coupling=lambda_value,
        total_timesteps=config.training_steps_per_lambda,
        eval_interval=config.eval_interval,
        use_csc=True,  # Must enable CSC to test lambda
        use_async=True,
        curriculum_strategy="adaptive",
    )

    # Initialize trainer
    trainer = trainer_class(tasks, trainer_config)

    # Track metrics over training
    start_time = time.time()
    eval_history = []
    throughput_history = []
    discard_counts = defaultdict(int)
    total_counts = defaultdict(int)

    # Run training
    try:
        result = trainer.train()

        # Extract metrics
        final_pass_at_1 = result.get("eval_metrics", [{}])[-1].get("pass_at_1", 0.0)
        avg_throughput = np.mean([m.get("throughput", 0) for m in result.get("timing_metrics", [{}])])
        total_time = time.time() - start_time

        # Find samples to reach 30% pass@1
        samples_to_30 = None
        for m in result.get("eval_metrics", []):
            if m.get("pass_at_1", 0) >= 0.3:
                samples_to_30 = m.get("timestep", None)
                break

        # Get discard rates from scheduler stats
        scheduler_stats = result.get("scheduler_stats", {})
        discard_rates = scheduler_stats.get("discard_rates", {})

        return {
            "final_pass_at_1": final_pass_at_1,
            "avg_throughput": avg_throughput,
            "total_time": total_time,
            "samples_to_30": samples_to_30,
            "discard_rates": discard_rates,
            "eval_history": result.get("eval_metrics", []),
            "success": True,
        }

    except Exception as e:
        logger.error(f"Experiment failed for lambda={lambda_value}, seed={seed}: {e}")
        return {
            "success": False,
            "error": str(e),
        }


def run_lambda_sweep(
    config: LambdaSweepConfig,
    trainer_class,
    trainer_config_class,
    tasks: List[Any],
    progress_callback=None,
) -> Dict[float, LambdaSweepResult]:
    """
    Run full lambda sweep experiment.

    Args:
        config: Sweep configuration
        trainer_class: ACEASTrainer class
        trainer_config_class: ACEASConfig class
        tasks: Task list
        progress_callback: Optional callback(lambda_idx, seed_idx, result)

    Returns:
        Dictionary mapping lambda value to LambdaSweepResult
    """
    results = {}

    for lambda_idx, lambda_value in enumerate(config.lambda_values):
        logger.info(f"Testing lambda = {lambda_value}")

        result = LambdaSweepResult(
            lambda_value=lambda_value,
            seeds=config.seeds,
        )

        for seed_idx, seed in enumerate(config.seeds):
            logger.info(f"  Seed {seed} ({seed_idx + 1}/{len(config.seeds)})")

            exp_result = run_single_lambda_experiment(
                lambda_value=lambda_value,
                seed=seed,
                config=config,
                trainer_class=trainer_class,
                trainer_config_class=trainer_config_class,
                tasks=tasks,
            )

            if exp_result.get("success", False):
                result.final_pass_at_1.append(exp_result["final_pass_at_1"])
                result.avg_throughput.append(exp_result["avg_throughput"])
                result.total_time.append(exp_result["total_time"])
                result.samples_to_30.append(exp_result.get("samples_to_30"))

                # Merge discard rates
                for d, rate in exp_result.get("discard_rates", {}).items():
                    if d not in result.discard_rates:
                        result.discard_rates[d] = []
                    result.discard_rates[d].append(rate)

            if progress_callback:
                progress_callback(lambda_idx, seed_idx, exp_result)

        # Aggregate results for this lambda
        result.aggregate()

        # Average discard rates
        for d in result.discard_rates:
            result.discard_rates[d] = np.mean(result.discard_rates[d])

        results[lambda_value] = result
        logger.info(f"  Lambda {lambda_value}: Pass@1 = {result.mean_pass_at_1:.3f} +/- {result.std_pass_at_1:.3f}, "
                   f"Throughput = {result.mean_throughput:.1f}")

    return results


def find_pareto_frontier(
    results: Dict[float, LambdaSweepResult]
) -> List[Tuple[float, float, float]]:
    """
    Find the Pareto frontier of throughput vs pass@1.

    A point is on the Pareto frontier if no other point dominates it
    (i.e., better on both throughput AND pass@1).

    Args:
        results: Lambda sweep results

    Returns:
        List of (lambda, throughput, pass@1) tuples on the frontier
    """
    points = [
        (lam, res.mean_throughput, res.mean_pass_at_1)
        for lam, res in results.items()
    ]

    pareto = []
    for i, (lam, throughput, pass_at_1) in enumerate(points):
        dominated = False
        for j, (_, other_throughput, other_pass) in enumerate(points):
            if i != j:
                # Check if point j dominates point i
                if other_throughput >= throughput and other_pass >= pass_at_1:
                    if other_throughput > throughput or other_pass > pass_at_1:
                        dominated = True
                        break
        if not dominated:
            pareto.append((lam, throughput, pass_at_1))

    # Sort by lambda
    pareto.sort(key=lambda x: x[0])
    return pareto


def save_lambda_sweep_results(
    results: Dict[float, LambdaSweepResult],
    config: LambdaSweepConfig,
    output_path: str,
):
    """Save lambda sweep results to JSON."""
    data = {
        "config": {
            "lambda_values": config.lambda_values,
            "theoretical_optimal": config.theoretical_optimal,
            "training_steps": config.training_steps_per_lambda,
            "seeds": config.seeds,
        },
        "results": {},
        "pareto_frontier": [],
    }

    for lam, result in results.items():
        data["results"][str(lam)] = {
            "lambda": result.lambda_value,
            "mean_pass_at_1": result.mean_pass_at_1,
            "std_pass_at_1": result.std_pass_at_1,
            "mean_throughput": result.mean_throughput,
            "std_throughput": result.std_throughput,
            "mean_time": result.mean_time,
            "discard_rates": result.discard_rates,
            "raw_pass_at_1": result.final_pass_at_1,
            "raw_throughput": result.avg_throughput,
        }

    # Add Pareto frontier
    pareto = find_pareto_frontier(results)
    data["pareto_frontier"] = [
        {"lambda": lam, "throughput": t, "pass_at_1": p}
        for lam, t, p in pareto
    ]

    # Check if theoretical optimal is near Pareto frontier
    theoretical_lam = config.theoretical_optimal
    if theoretical_lam in results:
        theo_result = results[theoretical_lam]
        pareto_lams = [p[0] for p in pareto]
        data["theoretical_on_pareto"] = theoretical_lam in pareto_lams
        data["theoretical_result"] = {
            "lambda": theoretical_lam,
            "pass_at_1": theo_result.mean_pass_at_1,
            "throughput": theo_result.mean_throughput,
        }

    Path(output_path).parent.mkdir(parents=True, exist_ok=True)
    with open(output_path, "w") as f:
        json.dump(data, f, indent=2)

    logger.info(f"Saved lambda sweep results to {output_path}")


def run_lambda_sweep_with_training(
    config: LambdaSweepConfig,
    model_name: str = "Salesforce/codegen-350M-mono",
    use_lora: bool = True,
    quick_mode: bool = False,
) -> Dict[float, LambdaSweepResult]:
    """
    Run lambda sweep with actual training.

    This function integrates with the ACEAS trainer to run real experiments
    across different lambda values.

    Args:
        config: Sweep configuration
        model_name: Base model name
        use_lora: Whether to use LoRA
        quick_mode: Use reduced steps for testing

    Returns:
        Dictionary mapping lambda to results
    """
    import torch

    # Adjust config for quick mode
    if quick_mode:
        config = LambdaSweepConfig(
            lambda_values=[0.25, 0.5, 0.75, 1.0],  # Fewer lambdas
            training_steps_per_lambda=1000,  # Fewer steps
            eval_interval=200,
            seeds=[42],  # Single seed
        )

    logger.info(f"Starting lambda sweep experiment")
    logger.info(f"  Lambda values: {config.lambda_values}")
    logger.info(f"  Seeds: {config.seeds}")
    logger.info(f"  Steps per lambda: {config.training_steps_per_lambda}")

    # Load tasks
    tasks = _load_tasks()
    logger.info(f"Loaded {len(tasks)} tasks")

    # Import trainer (lazy import to avoid circular deps)
    from ..training.aceas_trainer import ACEASTrainer, ACEASConfig
    trainer_class = ACEASTrainer
    trainer_config_class = ACEASConfig

    # Run sweep
    results = run_lambda_sweep(
        config=config,
        trainer_class=trainer_class,
        trainer_config_class=trainer_config_class,
        tasks=tasks,
        progress_callback=lambda li, si, r: logger.info(
            f"Completed lambda={config.lambda_values[li]}, seed={config.seeds[si]}"
        ),
    )

    return results


def _load_tasks() -> List[Any]:
    """Load HumanEval tasks."""
    from datasets import load_dataset
    
    tasks = []
    humaneval = load_dataset("openai_humaneval", split="test")
    for item in humaneval:
        tasks.append({
            "prompt": item["prompt"],
            "canonical_solution": item["canonical_solution"],
            "test": item["test"],
            "entry_point": item["entry_point"],
            "task_id": item["task_id"],
            "source": "humaneval",
        })
    logger.info(f"Loaded {len(tasks)} HumanEval tasks")
    return tasks


def analyze_lambda_sweep_results(
    results: Dict[float, LambdaSweepResult],
    theoretical_optimal: float = 0.5,
) -> Dict[str, Any]:
    """
    Analyze lambda sweep results to validate theoretical predictions.

    Args:
        results: Sweep results
        theoretical_optimal: Theoretically predicted optimal lambda

    Returns:
        Analysis summary
    """
    pareto = find_pareto_frontier(results)
    pareto_lambdas = [p[0] for p in pareto]

    # Check if theoretical optimal is on or near Pareto frontier
    is_on_pareto = theoretical_optimal in pareto_lambdas

    # Find lambda with best Pass@1
    best_lambda = max(results.keys(), key=lambda l: results[l].mean_pass_at_1)
    best_pass_at_1 = results[best_lambda].mean_pass_at_1

    # Compute distance from theoretical optimal to best
    theoretical_pass_at_1 = results.get(theoretical_optimal, LambdaSweepResult(0, [])).mean_pass_at_1
    gap = best_pass_at_1 - theoretical_pass_at_1

    # Analyze discard rates
    avg_discard_by_lambda = {}
    for lam, res in results.items():
        if res.discard_rates:
            avg_discard_by_lambda[lam] = np.mean(list(res.discard_rates.values()))

    return {
        "pareto_frontier": pareto,
        "theoretical_optimal": theoretical_optimal,
        "is_theoretical_on_pareto": is_on_pareto,
        "best_lambda_for_pass_at_1": best_lambda,
        "best_pass_at_1": best_pass_at_1,
        "theoretical_pass_at_1": theoretical_pass_at_1,
        "gap_from_best": gap,
        "validates_theory": gap < 0.02,  # Within 2% of best
        "avg_discard_rates": avg_discard_by_lambda,
    }


if __name__ == "__main__":
    import argparse

    parser = argparse.ArgumentParser(description="Lambda sweep experiment")
    parser.add_argument("--quick", action="store_true", help="Quick test mode")
    parser.add_argument("--seeds", nargs="+", type=int, default=[42, 123, 456],
                        help="Random seeds")
    parser.add_argument("--output", type=str, default="experiments/lambda_sweep",
                        help="Output directory")
    args = parser.parse_args()

    logging.basicConfig(level=logging.INFO)

    config = LambdaSweepConfig(
        seeds=args.seeds,
        output_dir=args.output,
    )

    logger.info("Running lambda sweep with training")
    results = run_lambda_sweep_with_training(
        config,
        quick_mode=args.quick,
    )

    # Find Pareto frontier
    pareto = find_pareto_frontier(results)
    print("\nPareto frontier points:")
    for lam, throughput, pass_at_1 in pareto:
        print(f"  lambda={lam:.1f}: throughput={throughput:.1f}, pass@1={pass_at_1:.3f}")

    # Analyze results
    analysis = analyze_lambda_sweep_results(results)
    print(f"\nTheoretical optimal (λ=0.5) on Pareto frontier: {analysis['is_theoretical_on_pareto']}")
    print(f"Best λ for Pass@1: {analysis['best_lambda_for_pass_at_1']}")
    print(f"Gap from theoretical: {analysis['gap_from_best']:.3f}")
    print(f"Theory validated: {analysis['validates_theory']}")

    # Save results
    Path(args.output).mkdir(parents=True, exist_ok=True)
    output_file = f"{args.output}/lambda_sweep_results.json"
    save_lambda_sweep_results(results, config, output_file)
    print(f"\nResults saved to {output_file}")
