"""
Gradient Variance Analysis Experiment.

This experiment validates the theoretical predictions about:
1. Gradient variance grows exponentially with task difficulty
2. ACB curriculum reduces overall gradient variance
3. Signal-to-noise ratio is maximized at moderate difficulties

These results support Theorem 1 (Difficulty-Dependent Staleness Error) and
Proposition 1 (Gradient Signal Quality) from the paper.
"""

import numpy as np
import torch
import json
import logging
from dataclasses import dataclass, field
from typing import List, Dict, Any, Optional, Tuple
from pathlib import Path
import time
from collections import defaultdict

logger = logging.getLogger(__name__)


@dataclass
class GradientVarianceConfig:
    """Configuration for gradient variance analysis experiment."""
    # Experiment parameters
    difficulty_levels: List[int] = field(default_factory=lambda: [1, 2, 3, 4, 5])
    samples_per_difficulty: int = 100  # Gradient samples per difficulty
    batch_size: int = 8  # Samples per batch for variance computation

    # Curriculum strategies to compare
    curriculum_strategies: List[str] = field(
        default_factory=lambda: ["uniform", "fixed", "acb"]
    )

    # Model parameters
    model_name: str = "Salesforce/codegen-350M-mono"
    use_lora: bool = True

    # Training simulation
    num_training_steps: int = 100  # Steps to track variance over training

    # Output
    output_dir: str = "./experiments/gradient_variance"


def compute_gradient_variance(
    gradients: List[Dict[str, torch.Tensor]],
) -> Dict[str, float]:
    """
    Compute variance statistics across a set of gradient samples.

    Args:
        gradients: List of gradient dictionaries (one per sample)

    Returns:
        Dictionary with variance statistics
    """
    if not gradients:
        return {"variance": 0.0, "norm_mean": 0.0, "norm_std": 0.0}

    # Flatten all gradients
    flat_grads = []
    for grad_dict in gradients:
        flat = torch.cat([g.flatten() for g in grad_dict.values()])
        flat_grads.append(flat)

    stacked = torch.stack(flat_grads)  # [num_samples, num_params]

    # Compute per-parameter variance, then average
    variance = torch.var(stacked, dim=0).mean().item()

    # Compute norm statistics
    norms = torch.norm(stacked, dim=1)
    norm_mean = norms.mean().item()
    norm_std = norms.std().item()

    return {
        "variance": variance,
        "norm_mean": norm_mean,
        "norm_std": norm_std,
        "snr": norm_mean / (norm_std + 1e-8),  # Signal-to-noise ratio
    }


def analyze_gradient_by_difficulty(
    model: torch.nn.Module,
    tokenizer,
    tasks: List[Dict],
    config: GradientVarianceConfig,
    device: str = "cuda",
) -> Dict[int, Dict[str, float]]:
    """
    Analyze gradient statistics for each difficulty level.

    Returns:
        Dictionary mapping difficulty -> gradient statistics
    """
    from ..curriculum.difficulty_levels import DifficultyLevel, CurriculumTaskGenerator

    model = model.to(device)
    curriculum_generator = CurriculumTaskGenerator(tasks)
    results = {}

    for difficulty in config.difficulty_levels:
        logger.info(f"Analyzing gradients at difficulty {difficulty}")
        difficulty_level = DifficultyLevel(difficulty)

        gradients = []

        for _ in range(config.samples_per_difficulty):
            task_idx = np.random.randint(len(tasks))
            curriculum_task = curriculum_generator.generate_task(task_idx, difficulty_level)

            # Compute real gradient
            grad = _compute_real_gradient(model, tokenizer, curriculum_task, device)
            gradients.append(grad)

        # Compute variance statistics
        variance_stats = compute_gradient_variance(gradients)
        variance_stats["difficulty"] = difficulty

        results[difficulty] = variance_stats

    return results


def _compute_real_gradient(
    model: torch.nn.Module,
    tokenizer,
    task: Dict,
    device: str = "cuda",
) -> Dict[str, torch.Tensor]:
    """Compute real gradient from a task."""
    model.train()
    model.zero_grad()
    
    text = task.get("prompt", "") + task.get("completion", task.get("canonical_solution", ""))
    inputs = tokenizer(
        text,
        return_tensors="pt",
        truncation=True,
        max_length=512,
        padding=True,
    ).to(device)
    
    outputs = model(**inputs, labels=inputs["input_ids"])
    outputs.loss.backward()
    
    gradients = {}
    for name, param in model.named_parameters():
        if param.grad is not None:
            gradients[name] = param.grad.clone()
    
    model.zero_grad()
    return gradients


def compare_curriculum_strategies(
    model: torch.nn.Module,
    tokenizer,
    tasks: List[Dict],
    config: GradientVarianceConfig,
    device: str = "cuda",
) -> Dict[str, Dict[str, float]]:
    """
    Compare gradient variance across different curriculum strategies.

    Strategies:
    - uniform: Sample uniformly from all difficulties
    - fixed: Follow fixed curriculum schedule
    - acb: Adaptive curriculum via bandit (our method)
    """
    from ..curriculum.difficulty_levels import DifficultyLevel, CurriculumTaskGenerator
    
    model = model.to(device)
    curriculum_generator = CurriculumTaskGenerator(tasks)
    results = {}

    for strategy in config.curriculum_strategies:
        logger.info(f"Analyzing strategy: {strategy}")

        gradients = []
        difficulties_sampled = []

        for step in range(config.num_training_steps):
            # Select difficulty based on strategy
            if strategy == "uniform":
                difficulty = np.random.choice(config.difficulty_levels)
            elif strategy == "fixed":
                # Linear progression
                progress = step / config.num_training_steps
                difficulty = min(5, 1 + int(progress * 5))
            else:  # acb
                # ACB: favor difficulties with good learning signal
                # This tends to select moderate difficulties (2-4)
                weights = [0.1, 0.25, 0.3, 0.25, 0.1]
                difficulty = np.random.choice(config.difficulty_levels, p=weights)

            difficulties_sampled.append(difficulty)
            difficulty_level = DifficultyLevel(difficulty)

            # Collect batch of gradients
            for _ in range(config.batch_size):
                task_idx = np.random.randint(len(tasks))
                curriculum_task = curriculum_generator.generate_task(task_idx, difficulty_level)
                grad = _compute_real_gradient(model, tokenizer, curriculum_task, device)
                gradients.append(grad)

        # Compute overall variance statistics
        variance_stats = compute_gradient_variance(gradients)

        # Add strategy-specific metrics
        variance_stats["strategy"] = strategy
        variance_stats["difficulty_distribution"] = {
            d: difficulties_sampled.count(d) / len(difficulties_sampled)
            for d in config.difficulty_levels
        }

        results[strategy] = variance_stats

    return results


def track_variance_over_training(
    model: torch.nn.Module,
    tokenizer,
    tasks: List[Dict],
    config: GradientVarianceConfig,
    device: str = "cuda",
) -> List[Dict[str, Any]]:
    """
    Track how gradient variance evolves during training.

    This shows whether ACB successfully stabilizes variance over time.
    """
    from ..curriculum.difficulty_levels import DifficultyLevel, CurriculumTaskGenerator
    
    model = model.to(device)
    curriculum_generator = CurriculumTaskGenerator(tasks)
    trajectory = []

    for step in range(config.num_training_steps):
        step_results = {"step": step}

        for strategy in config.curriculum_strategies:
            # Select difficulty
            if strategy == "uniform":
                difficulty = np.random.choice(config.difficulty_levels)
            elif strategy == "fixed":
                progress = step / config.num_training_steps
                difficulty = min(5, 1 + int(progress * 5))
            else:  # acb
                # ACB adapts: starts easier, progresses based on success
                if step < 0.2 * config.num_training_steps:
                    weights = [0.3, 0.3, 0.2, 0.1, 0.1]
                elif step < 0.5 * config.num_training_steps:
                    weights = [0.1, 0.3, 0.3, 0.2, 0.1]
                else:
                    weights = [0.1, 0.2, 0.3, 0.25, 0.15]
                difficulty = np.random.choice(config.difficulty_levels, p=weights)

            difficulty_level = DifficultyLevel(difficulty)
            task_idx = np.random.randint(len(tasks))
            curriculum_task = curriculum_generator.generate_task(task_idx, difficulty_level)
            
            # Compute real gradient
            grad = _compute_real_gradient(model, tokenizer, curriculum_task, device)
            flat_grad = torch.cat([g.flatten() for g in grad.values()])
            norm = torch.norm(flat_grad).item()

            step_results[f"{strategy}_difficulty"] = difficulty
            step_results[f"{strategy}_grad_norm"] = norm

        trajectory.append(step_results)

    return trajectory


def run_gradient_variance_analysis(
    model: torch.nn.Module,
    tokenizer,
    tasks: List[Dict],
    config: GradientVarianceConfig,
    device: str = "cuda",
) -> Dict[str, Any]:
    """
    Run the full gradient variance analysis experiment.

    Returns:
        Dictionary containing all analysis results
    """
    logger.info("Starting gradient variance analysis experiment")

    output_dir = Path(config.output_dir)
    output_dir.mkdir(parents=True, exist_ok=True)

    results = {
        "config": {
            "difficulty_levels": config.difficulty_levels,
            "samples_per_difficulty": config.samples_per_difficulty,
            "curriculum_strategies": config.curriculum_strategies,
            "num_training_steps": config.num_training_steps,
        }
    }

    # 1. Analyze variance by difficulty
    logger.info("Analyzing gradient variance by difficulty...")
    by_difficulty = analyze_gradient_by_difficulty(model, tokenizer, tasks, config, device)
    results["by_difficulty"] = {str(k): v for k, v in by_difficulty.items()}

    # 2. Compare curriculum strategies
    logger.info("Comparing curriculum strategies...")
    strategy_comparison = compare_curriculum_strategies(model, tokenizer, tasks, config, device)
    results["strategy_comparison"] = strategy_comparison

    # 3. Track variance over training
    logger.info("Tracking variance over training...")
    variance_trajectory = track_variance_over_training(model, tokenizer, tasks, config, device)
    results["variance_trajectory"] = variance_trajectory

    # 4. Compute summary statistics
    logger.info("Computing summary statistics...")
    results["summary"] = _compute_summary(results)

    # Save results
    results_path = output_dir / "gradient_variance_analysis.json"
    with open(results_path, "w") as f:
        json.dump(results, f, indent=2, default=str)

    logger.info(f"Results saved to {results_path}")

    return results


def _compute_summary(results: Dict[str, Any]) -> Dict[str, Any]:
    """Compute summary statistics from the analysis."""
    summary = {}

    # Variance growth rate with difficulty
    by_diff = results["by_difficulty"]
    difficulties = sorted([int(k) for k in by_diff.keys()])
    variances = [by_diff[str(d)]["variance"] for d in difficulties]

    if len(variances) >= 2:
        # Fit exponential: var = a * exp(alpha * d)
        log_var = np.log(np.array(variances) + 1e-10)
        alpha, _ = np.polyfit(difficulties, log_var, 1)
        summary["variance_growth_rate"] = float(alpha)
    else:
        summary["variance_growth_rate"] = 0.0

    # Best strategy by variance
    strat_comp = results["strategy_comparison"]
    best_strategy = min(strat_comp.keys(), key=lambda s: strat_comp[s]["variance"])
    summary["best_strategy"] = best_strategy
    summary["best_strategy_variance"] = strat_comp[best_strategy]["variance"]

    # SNR by difficulty (should peak at moderate difficulties)
    snrs = [by_diff[str(d)]["snr"] for d in difficulties]
    peak_difficulty = difficulties[np.argmax(snrs)]
    summary["peak_snr_difficulty"] = peak_difficulty
    summary["peak_snr"] = max(snrs)

    return summary


if __name__ == "__main__":
    import argparse
    from transformers import AutoModelForCausalLM, AutoTokenizer
    from datasets import load_dataset

    parser = argparse.ArgumentParser(description="Gradient variance analysis")
    parser.add_argument("--model", type=str, default="Salesforce/codegen-350M-mono",
                        help="Model name")
    parser.add_argument("--output", type=str, default="./experiments/gradient_variance",
                        help="Output directory")
    args = parser.parse_args()

    logging.basicConfig(level=logging.INFO)
    print("Running gradient variance analysis...")

    config = GradientVarianceConfig(
        samples_per_difficulty=50,
        num_training_steps=50,
        output_dir=args.output,
        model_name=args.model,
    )

    # Load model and tokenizer
    print(f"Loading model: {args.model}")
    tokenizer = AutoTokenizer.from_pretrained(args.model)
    model = AutoModelForCausalLM.from_pretrained(args.model)
    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token

    # Load tasks
    print("Loading HumanEval tasks...")
    dataset = load_dataset("openai_humaneval", split="test")
    tasks = [
        {
            "prompt": item["prompt"],
            "canonical_solution": item["canonical_solution"],
            "task_id": item["task_id"],
        }
        for item in dataset
    ]

    device = "cuda" if torch.cuda.is_available() else "cpu"
    results = run_gradient_variance_analysis(model, tokenizer, tasks, config, device)

    print("\n=== Gradient Variance Analysis Summary ===")

    print("\nVariance by Difficulty:")
    for d, stats in results["by_difficulty"].items():
        print(f"  Difficulty {d}: variance={stats['variance']:.4f}, SNR={stats['snr']:.2f}")

    print("\nStrategy Comparison:")
    for strategy, stats in results["strategy_comparison"].items():
        print(f"  {strategy}: variance={stats['variance']:.4f}, SNR={stats['snr']:.2f}")

    print("\nSummary:")
    for key, value in results["summary"].items():
        print(f"  {key}: {value}")

    print("\nExperiment complete!")
