"""
Hessian Eigenvalue Analysis for Validating Theorem 1.

This experiment empirically measures the maximum Hessian eigenvalue λ_max(H_d)
at each difficulty level to validate the theoretical claim that:

    λ_max(H_d) = O(e^(αd))

We use power iteration to approximate the top eigenvalue efficiently, as
computing the full Hessian is computationally prohibitive for neural networks.

This addresses reviewer concern Q3/W1: "No empirical measurement of Hessian
eigenvalues to validate exponential growth claim."
"""

import json
import logging
import numpy as np
import torch
import torch.nn.functional as F
from dataclasses import dataclass, field
from pathlib import Path
from typing import Dict, List, Optional, Tuple, Any
from scipy import stats
import time

logger = logging.getLogger(__name__)


@dataclass
class HessianAnalysisConfig:
    """Configuration for Hessian eigenvalue analysis."""
    # Difficulty levels to analyze
    difficulty_levels: List[int] = field(default_factory=lambda: [1, 2, 3, 4, 5])

    # Power iteration parameters
    num_power_iterations: int = 50  # Iterations for power method
    num_samples_per_difficulty: int = 100  # Samples for gradient computation
    num_independent_estimates: int = 10  # Independent runs for error bars

    # Model parameters (will be overridden from trainer config)
    model_name: str = "Salesforce/codegen-350M-mono"
    use_lora: bool = True

    # Output configuration
    output_dir: str = "experiments/hessian_analysis"
    save_intermediate: bool = True


@dataclass
class HessianAnalysisResult:
    """Result of Hessian analysis for one difficulty level."""
    difficulty: int
    lambda_max_estimates: List[float]
    lambda_max_mean: float = 0.0
    lambda_max_std: float = 0.0
    log_lambda_max_mean: float = 0.0  # For exponential fit
    power_iteration_converged: bool = True

    def aggregate(self):
        """Compute aggregate statistics."""
        if self.lambda_max_estimates:
            self.lambda_max_mean = float(np.mean(self.lambda_max_estimates))
            self.lambda_max_std = float(np.std(self.lambda_max_estimates))
            self.log_lambda_max_mean = float(np.log(self.lambda_max_mean))


def hvp_finite_diff(
    model: torch.nn.Module,
    loss_fn,
    inputs: Dict[str, torch.Tensor],
    vector: torch.Tensor,
    epsilon: float = 1e-4,
) -> torch.Tensor:
    """
    Compute Hessian-vector product using finite differences.

    H @ v ≈ (∇L(θ + εv) - ∇L(θ - εv)) / (2ε)

    This is more stable than autograd for large models.
    """
    params = [p for p in model.parameters() if p.requires_grad]

    # Save original parameters
    original_params = [p.data.clone() for p in params]

    # Reshape vector to match parameters
    vector_list = _reshape_flat_to_params(vector, params)

    # Forward perturbation: θ + εv
    for p, v in zip(params, vector_list):
        p.data.add_(v, alpha=epsilon)
    loss_plus = loss_fn(model, inputs)
    grad_plus = torch.autograd.grad(loss_plus, params, create_graph=False)
    grad_plus_flat = _flatten_params(grad_plus)

    # Restore and backward perturbation: θ - εv
    for p, orig in zip(params, original_params):
        p.data.copy_(orig)
    for p, v in zip(params, vector_list):
        p.data.add_(v, alpha=-epsilon)
    loss_minus = loss_fn(model, inputs)
    grad_minus = torch.autograd.grad(loss_minus, params, create_graph=False)
    grad_minus_flat = _flatten_params(grad_minus)

    # Restore original parameters
    for p, orig in zip(params, original_params):
        p.data.copy_(orig)

    # Compute HVP
    hvp = (grad_plus_flat - grad_minus_flat) / (2 * epsilon)
    return hvp


def _flatten_params(params) -> torch.Tensor:
    """Flatten list of parameter tensors to single vector."""
    return torch.cat([p.flatten() for p in params])


def _reshape_flat_to_params(
    flat_vector: torch.Tensor,
    reference_params: List[torch.Tensor],
) -> List[torch.Tensor]:
    """Reshape flat vector back to list matching parameter shapes."""
    result = []
    offset = 0
    for p in reference_params:
        numel = p.numel()
        result.append(flat_vector[offset:offset + numel].reshape(p.shape))
        offset += numel
    return result


def power_iteration(
    model: torch.nn.Module,
    loss_fn,
    data_loader,
    num_iterations: int = 50,
    tolerance: float = 1e-6,
    device: str = "cuda",
) -> Tuple[float, torch.Tensor, bool]:
    """
    Estimate the maximum Hessian eigenvalue using power iteration.

    The power method iteratively applies:
        v_{k+1} = H @ v_k / ||H @ v_k||

    and converges to the eigenvector corresponding to λ_max.

    Args:
        model: The neural network model
        loss_fn: Loss function that takes (model, inputs) and returns scalar loss
        data_loader: Iterator yielding input batches
        num_iterations: Maximum power iterations
        tolerance: Convergence tolerance
        device: Computing device

    Returns:
        (lambda_max, eigenvector, converged)
    """
    model.eval()
    params = [p for p in model.parameters() if p.requires_grad]
    total_params = sum(p.numel() for p in params)

    # Initialize random vector
    v = torch.randn(total_params, device=device)
    v = v / v.norm()

    lambda_prev = 0.0
    converged = False

    for iteration in range(num_iterations):
        # Average HVP over multiple batches for stability
        hvp_sum = torch.zeros(total_params, device=device)
        num_batches = 0

        for batch_idx, inputs in enumerate(data_loader):
            if batch_idx >= 5:  # Use 5 batches per iteration
                break

            inputs = {k: v.to(device) for k, v in inputs.items()}

            # Compute HVP
            hvp = hvp_finite_diff(model, loss_fn, inputs, v)
            hvp_sum += hvp
            num_batches += 1

        if num_batches == 0:
            break

        hvp_avg = hvp_sum / num_batches

        # Rayleigh quotient: λ ≈ v^T @ H @ v
        lambda_curr = float(torch.dot(v, hvp_avg))

        # Normalize to get new eigenvector estimate
        hvp_norm = hvp_avg.norm()
        if hvp_norm > 1e-10:
            v = hvp_avg / hvp_norm
        else:
            break

        # Check convergence
        if abs(lambda_curr - lambda_prev) < tolerance * abs(lambda_curr):
            converged = True
            logger.debug(f"Power iteration converged at iteration {iteration}")
            break

        lambda_prev = lambda_curr

    return abs(lambda_curr), v, converged


def compute_hessian_eigenvalue_for_difficulty(
    model: torch.nn.Module,
    tokenizer,
    tasks: List[Dict],
    difficulty: int,
    num_samples: int,
    num_estimates: int,
    device: str = "cuda",
) -> HessianAnalysisResult:
    """
    Compute Hessian eigenvalue estimates for a specific difficulty level.

    Args:
        model: Language model
        tokenizer: Tokenizer
        tasks: List of task dictionaries
        difficulty: Difficulty level (1-5)
        num_samples: Samples per estimate
        num_estimates: Number of independent estimates
        device: Computing device

    Returns:
        HessianAnalysisResult with eigenvalue statistics
    """
    logger.info(f"Computing Hessian eigenvalue for difficulty {difficulty}")

    # Filter tasks by difficulty (or create curriculum versions)
    difficulty_tasks = _prepare_tasks_for_difficulty(tasks, difficulty)

    result = HessianAnalysisResult(
        difficulty=difficulty,
        lambda_max_estimates=[],
    )

    def loss_fn(model, inputs):
        """Compute policy gradient loss for HVP."""
        outputs = model(**inputs, labels=inputs["input_ids"])
        return outputs.loss

    for estimate_idx in range(num_estimates):
        logger.debug(f"  Estimate {estimate_idx + 1}/{num_estimates}")

        # Create data loader for this estimate
        data_loader = _create_data_loader(
            difficulty_tasks, tokenizer, num_samples, device
        )

        # Run power iteration
        lambda_max, _, converged = power_iteration(
            model, loss_fn, data_loader,
            num_iterations=50,
            device=device,
        )

        result.lambda_max_estimates.append(lambda_max)
        if not converged:
            result.power_iteration_converged = False

    result.aggregate()
    logger.info(f"  Difficulty {difficulty}: λ_max = {result.lambda_max_mean:.4f} ± {result.lambda_max_std:.4f}")

    return result


def _prepare_tasks_for_difficulty(
    tasks: List[Dict],
    difficulty: int,
) -> List[Dict]:
    """
    Prepare tasks at a specific difficulty level.

    For curriculum learning, this reveals a fraction of the solution:
    - Difficulty 1: 90% revealed, complete last 10%
    - Difficulty 5: 0% revealed, generate from scratch
    """
    reveal_fractions = {1: 0.9, 2: 0.7, 3: 0.5, 4: 0.3, 5: 0.0}
    reveal = reveal_fractions.get(difficulty, 0.5)

    difficulty_tasks = []
    for task in tasks:
        solution = task.get("canonical_solution", task.get("solution", ""))
        if not solution:
            continue

        # Reveal prefix
        reveal_chars = int(len(solution) * reveal)
        revealed = solution[:reveal_chars]
        to_complete = solution[reveal_chars:]

        difficulty_tasks.append({
            "prompt": task.get("prompt", "") + revealed,
            "completion": to_complete,
            "full_solution": solution,
            "difficulty": difficulty,
        })

    return difficulty_tasks


def _create_data_loader(
    tasks: List[Dict],
    tokenizer,
    num_samples: int,
    device: str,
):
    """Create a simple data loader yielding tokenized batches."""
    import random

    for _ in range(num_samples):
        task = random.choice(tasks)
        text = task["prompt"] + task["completion"]

        inputs = tokenizer(
            text,
            return_tensors="pt",
            truncation=True,
            max_length=512,
            padding="max_length",
        )

        yield inputs


def fit_exponential_growth(
    results: Dict[int, HessianAnalysisResult],
) -> Dict[str, float]:
    """
    Fit exponential model: log(λ_max) = α * d + const

    This validates Theorem 1's prediction that λ_max(H_d) = O(e^(αd)).

    Returns:
        Dictionary with fit parameters and statistics
    """
    difficulties = []
    log_lambdas = []
    log_lambda_stds = []

    for d, result in sorted(results.items()):
        if result.lambda_max_mean > 0:
            difficulties.append(d)
            log_lambdas.append(result.log_lambda_max_mean)
            # Propagate uncertainty: std(log(x)) ≈ std(x) / x
            log_lambda_stds.append(
                result.lambda_max_std / result.lambda_max_mean
            )

    difficulties = np.array(difficulties)
    log_lambdas = np.array(log_lambdas)

    # Linear regression: log(λ) = α * d + c
    slope, intercept, r_value, p_value, std_err = stats.linregress(
        difficulties, log_lambdas
    )

    # Predicted values for R² calculation
    predicted = slope * difficulties + intercept
    ss_res = np.sum((log_lambdas - predicted) ** 2)
    ss_tot = np.sum((log_lambdas - np.mean(log_lambdas)) ** 2)
    r_squared = 1 - (ss_res / ss_tot) if ss_tot > 0 else 0

    return {
        "alpha": float(slope),  # Hessian growth rate
        "intercept": float(intercept),
        "r_squared": float(r_squared),
        "p_value": float(p_value),
        "alpha_std_error": float(std_err),
        "alpha_95_ci": [float(slope - 1.96 * std_err), float(slope + 1.96 * std_err)],
        "theoretical_lambda": float(slope / 2),  # λ = α/2 from Theorem 2
    }


def run_hessian_analysis(
    config: HessianAnalysisConfig,
    model=None,
    tokenizer=None,
    tasks=None,
) -> Dict[str, Any]:
    """
    Run the full Hessian eigenvalue analysis experiment.

    Args:
        config: Experiment configuration
        model: Optional pre-loaded model
        tokenizer: Optional pre-loaded tokenizer
        tasks: Optional task list

    Returns:
        Dictionary containing results and exponential fit
    """
    logger.info("Starting Hessian eigenvalue analysis")
    output_dir = Path(config.output_dir)
    output_dir.mkdir(parents=True, exist_ok=True)

    device = "cuda" if torch.cuda.is_available() else "cpu"

    # Load model if not provided
    if model is None or tokenizer is None:
        logger.info(f"Loading model: {config.model_name}")
        from transformers import AutoModelForCausalLM, AutoTokenizer

        tokenizer = AutoTokenizer.from_pretrained(config.model_name)
        model = AutoModelForCausalLM.from_pretrained(config.model_name)

        if tokenizer.pad_token is None:
            tokenizer.pad_token = tokenizer.eos_token

    model = model.to(device)

    # Load tasks if not provided
    if tasks is None:
        tasks = _load_default_tasks()

    # Run analysis for each difficulty
    results = {}
    for difficulty in config.difficulty_levels:
        result = compute_hessian_eigenvalue_for_difficulty(
            model=model,
            tokenizer=tokenizer,
            tasks=tasks,
            difficulty=difficulty,
            num_samples=config.num_samples_per_difficulty,
            num_estimates=config.num_independent_estimates,
            device=device,
        )
        results[difficulty] = result

        # Save intermediate results
        if config.save_intermediate:
            _save_intermediate(results, output_dir / "intermediate_results.json")

    # Fit exponential model
    fit_results = fit_exponential_growth(results)

    # Compile final results
    final_results = {
        "config": {
            "difficulty_levels": config.difficulty_levels,
            "num_samples_per_difficulty": config.num_samples_per_difficulty,
            "num_independent_estimates": config.num_independent_estimates,
            "model_name": config.model_name,
        },
        "per_difficulty": {
            d: {
                "lambda_max_mean": r.lambda_max_mean,
                "lambda_max_std": r.lambda_max_std,
                "log_lambda_max_mean": r.log_lambda_max_mean,
                "all_estimates": r.lambda_max_estimates,
                "converged": r.power_iteration_converged,
            }
            for d, r in results.items()
        },
        "exponential_fit": fit_results,
        "validation": {
            "exponential_hypothesis_confirmed": fit_results["r_squared"] > 0.9,
            "alpha_estimate": fit_results["alpha"],
            "theoretical_lambda": fit_results["theoretical_lambda"],
        },
    }

    # Save final results
    results_path = output_dir / "hessian_eigenvalue_results.json"
    with open(results_path, "w") as f:
        json.dump(final_results, f, indent=2)

    logger.info(f"Results saved to {results_path}")
    logger.info(f"Exponential fit: α = {fit_results['alpha']:.3f} (R² = {fit_results['r_squared']:.3f})")

    return final_results


def _save_intermediate(results: Dict, path: Path):
    """Save intermediate results."""
    data = {
        d: {
            "lambda_max_mean": r.lambda_max_mean,
            "lambda_max_std": r.lambda_max_std,
            "all_estimates": r.lambda_max_estimates,
        }
        for d, r in results.items()
    }
    with open(path, "w") as f:
        json.dump(data, f, indent=2)


def _load_default_tasks() -> List[Dict]:
    """Load default HumanEval tasks."""
    from datasets import load_dataset
    dataset = load_dataset("openai_humaneval", split="test")
    return [
        {
            "prompt": item["prompt"],
            "canonical_solution": item["canonical_solution"],
            "task_id": item["task_id"],
        }
        for item in dataset
    ]


if __name__ == "__main__":
    logging.basicConfig(level=logging.INFO)

    config = HessianAnalysisConfig(
        num_samples_per_difficulty=20,
        num_independent_estimates=5,
    )

    results = run_hessian_analysis(config)

    print("\nHessian Eigenvalue Analysis Results")
    print("=" * 50)

    print("\nPer-difficulty λ_max:")
    for d in config.difficulty_levels:
        data = results["per_difficulty"][d]
        print(f"  d={d}: λ_max = {data['lambda_max_mean']:.4f} ± {data['lambda_max_std']:.4f}")

    fit = results["exponential_fit"]
    print(f"\nExponential Fit: log(λ_max) = {fit['alpha']:.3f} * d + {fit['intercept']:.3f}")
    print(f"  R² = {fit['r_squared']:.4f}")
    print(f"  α (Hessian growth rate) = {fit['alpha']:.3f} ± {fit['alpha_std_error']:.3f}")

    validated = results["validation"]["exponential_hypothesis_confirmed"]
    print(f"\nExponential hypothesis validated: {'Yes' if validated else 'No'} (R² > 0.9)")
