"""Results aggregation and summary statistics for optimization experiments.

This module provides utilities for aggregating results across multiple seeds
and computing bootstrap confidence intervals.
"""

from __future__ import annotations

from dataclasses import dataclass
from pathlib import Path
from typing import TYPE_CHECKING, Sequence

import numpy as np
import pandas as pd

from moltenflow.utils.logging import get_logger

from .logger import load_experiment_logs

if TYPE_CHECKING:
    import optuna

logger = get_logger(__name__)


@dataclass
class BootstrapCI:
    """Bootstrap confidence interval.

    Attributes:
        mean: Mean value
        std: Standard deviation
        ci_lower: Lower CI bound
        ci_upper: Upper CI bound
        confidence: Confidence level
        n_bootstrap: Number of bootstrap samples
    """

    mean: float
    std: float
    ci_lower: float
    ci_upper: float
    confidence: float = 0.95
    n_bootstrap: int = 1000


@dataclass
class MethodSummary:
    """Summary statistics for a single method.

    Attributes:
        method: Method name
        init: Initialization method
        n_seeds: Number of seeds
        final_hvi: Bootstrap CI for final HVI
        final_validity: Bootstrap CI for final validity
        hv_curve_mean: Mean HV at each step
        hv_curve_ci_lower: Lower CI for HV curve
        hv_curve_ci_upper: Upper CI for HV curve
    """

    method: str
    init: str
    n_seeds: int
    final_hvi: BootstrapCI
    final_validity: BootstrapCI
    hv_curve_mean: np.ndarray
    hv_curve_ci_lower: np.ndarray
    hv_curve_ci_upper: np.ndarray
    steps: np.ndarray


def bootstrap_ci(
    values: np.ndarray,
    n_bootstrap: int = 1000,
    confidence: float = 0.95,
    seed: int = 42,
) -> BootstrapCI:
    """Compute bootstrap confidence interval.

    Args:
        values: Array of values to bootstrap
        n_bootstrap: Number of bootstrap samples
        confidence: Confidence level
        seed: Random seed

    Returns:
        BootstrapCI with mean, std, and CI bounds
    """
    rng = np.random.default_rng(seed)
    n = len(values)

    if n == 0:
        return BootstrapCI(
            mean=np.nan,
            std=np.nan,
            ci_lower=np.nan,
            ci_upper=np.nan,
            confidence=confidence,
            n_bootstrap=n_bootstrap,
        )

    # Bootstrap resampling
    bootstrap_means = []
    for _ in range(n_bootstrap):
        idx = rng.choice(n, size=n, replace=True)
        bootstrap_means.append(np.mean(values[idx]))

    bootstrap_means = np.array(bootstrap_means)

    alpha = 1 - confidence
    ci_lower = float(np.percentile(bootstrap_means, 100 * alpha / 2))
    ci_upper = float(np.percentile(bootstrap_means, 100 * (1 - alpha / 2)))

    return BootstrapCI(
        mean=float(np.mean(values)),
        std=float(np.std(values)),
        ci_lower=ci_lower,
        ci_upper=ci_upper,
        confidence=confidence,
        n_bootstrap=n_bootstrap,
    )


def bootstrap_curve_ci(
    curves: np.ndarray,
    n_bootstrap: int = 1000,
    confidence: float = 0.95,
    seed: int = 42,
) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
    """Compute bootstrap CI for a set of curves.

    Args:
        curves: Array of shape (n_seeds, n_steps)
        n_bootstrap: Number of bootstrap samples
        confidence: Confidence level
        seed: Random seed

    Returns:
        Tuple of (mean_curve, ci_lower, ci_upper)
    """
    rng = np.random.default_rng(seed)
    n_seeds, n_steps = curves.shape

    if n_seeds == 0:
        empty = np.full(n_steps, np.nan)
        return empty, empty, empty

    # Bootstrap mean curves
    bootstrap_means = []
    for _ in range(n_bootstrap):
        idx = rng.choice(n_seeds, size=n_seeds, replace=True)
        bootstrap_means.append(np.mean(curves[idx], axis=0))

    bootstrap_means = np.array(bootstrap_means)

    alpha = 1 - confidence
    ci_lower = np.percentile(bootstrap_means, 100 * alpha / 2, axis=0)
    ci_upper = np.percentile(bootstrap_means, 100 * (1 - alpha / 2), axis=0)
    mean_curve = np.mean(curves, axis=0)

    return mean_curve, ci_lower, ci_upper


def aggregate_by_method(
    logs: dict[str, list[dict]],
) -> dict[str, pd.DataFrame]:
    """Aggregate logs by method.

    Args:
        logs: Dictionary mapping run_id to list of records

    Returns:
        Dictionary mapping method name to DataFrame with all records
    """
    method_records = {}

    for run_id, records in logs.items():
        if not records:
            continue

        method = records[0]["method"]
        if method not in method_records:
            method_records[method] = []

        method_records[method].extend(records)

    return {m: pd.DataFrame(records) for m, records in method_records.items()}


def compute_method_summary(
    df: pd.DataFrame,
    n_bootstrap: int = 1000,
    confidence: float = 0.95,
    seed: int = 42,
) -> MethodSummary:
    """Compute summary statistics for a method.

    Args:
        df: DataFrame with records for one method
        n_bootstrap: Bootstrap samples
        confidence: Confidence level
        seed: Random seed

    Returns:
        MethodSummary with aggregated statistics
    """
    method = df["method"].iloc[0]
    init = df["init"].iloc[0]
    seeds = df["seed"].unique()
    n_seeds = len(seeds)

    # Get final values per seed
    final_hvis = []
    final_validities = []

    for s in seeds:
        seed_df = df[df["seed"] == s]
        if len(seed_df) > 0:
            final_hvis.append(seed_df["hvi"].iloc[-1])
            final_validities.append(seed_df["cumulative_validity"].iloc[-1])

    final_hvis = np.array(final_hvis)
    final_validities = np.array(final_validities)

    # Bootstrap CI for final metrics
    final_hvi_ci = bootstrap_ci(final_hvis, n_bootstrap, confidence, seed)
    final_validity_ci = bootstrap_ci(final_validities, n_bootstrap, confidence, seed)

    # Compute HV curves per seed
    max_steps = df["step"].max() + 1
    hv_curves = np.full((n_seeds, max_steps), np.nan)

    for i, s in enumerate(seeds):
        seed_df = df[df["seed"] == s].sort_values("step")
        for _, row in seed_df.iterrows():
            step = int(row["step"])
            if step < max_steps:
                hv_curves[i, step] = row["hv"]

    # Forward fill NaN values
    for i in range(n_seeds):
        for j in range(1, max_steps):
            if np.isnan(hv_curves[i, j]):
                hv_curves[i, j] = hv_curves[i, j - 1]

    # Bootstrap CI for curves
    mean_curve, ci_lower, ci_upper = bootstrap_curve_ci(hv_curves, n_bootstrap, confidence, seed)

    return MethodSummary(
        method=method,
        init=init,
        n_seeds=n_seeds,
        final_hvi=final_hvi_ci,
        final_validity=final_validity_ci,
        hv_curve_mean=mean_curve,
        hv_curve_ci_lower=ci_lower,
        hv_curve_ci_upper=ci_upper,
        steps=np.arange(max_steps),
    )


def load_runtime_data(log_dir: str | Path) -> dict[str, list[float]]:
    """Load runtime data from timing.json files in run directories.

    Args:
        log_dir: Directory containing run subdirectories

    Returns:
        Dictionary mapping method names to lists of runtime values (in seconds)
    """
    import json

    log_dir = Path(log_dir)
    runtime_data: dict[str, list[float]] = {}

    # Scan for timing.json files in subdirectories
    for subdir in log_dir.iterdir():
        if not subdir.is_dir():
            continue

        timing_path = subdir / "timing.json"
        if timing_path.exists():
            try:
                with open(timing_path) as f:
                    timing = json.load(f)
                    method = timing.get("method", "unknown")
                    runtime = timing.get("runtime_seconds", 0.0)

                    if method not in runtime_data:
                        runtime_data[method] = []
                    runtime_data[method].append(runtime)
            except (json.JSONDecodeError, KeyError) as e:
                logger.warning(f"Failed to load timing from {timing_path}: {e}")

    return runtime_data


def generate_summary_table(
    log_dir: str | Path,
    methods: Sequence[str] | None = None,
    n_bootstrap: int = 1000,
    confidence: float = 0.95,
    seed: int = 42,
    include_runtime: bool = True,
) -> pd.DataFrame:
    """Generate summary table comparing methods.

    Args:
        log_dir: Directory containing optimization logs
        methods: Methods to include (default: all)
        n_bootstrap: Bootstrap samples
        confidence: Confidence level
        seed: Random seed
        include_runtime: Whether to include runtime statistics

    Returns:
        DataFrame with summary statistics per method
    """
    # Load logs
    logs = load_experiment_logs(log_dir, methods=methods)

    if not logs:
        logger.warning(f"No logs found in {log_dir}")
        return pd.DataFrame()

    # Aggregate by method
    method_dfs = aggregate_by_method(logs)

    # Load runtime data
    runtime_data = load_runtime_data(log_dir) if include_runtime else {}

    # Compute summaries
    rows = []
    for method, df in method_dfs.items():
        summary = compute_method_summary(df, n_bootstrap, confidence, seed)

        row = {
            "method": summary.method,
            "init": summary.init,
            "n_seeds": summary.n_seeds,
            "final_hvi_mean": summary.final_hvi.mean,
            "final_hvi_std": summary.final_hvi.std,
            "final_hvi_ci_lower": summary.final_hvi.ci_lower,
            "final_hvi_ci_upper": summary.final_hvi.ci_upper,
            "final_validity_mean": summary.final_validity.mean,
            "final_validity_ci_lower": summary.final_validity.ci_lower,
            "final_validity_ci_upper": summary.final_validity.ci_upper,
        }

        # Add runtime statistics if available
        if method in runtime_data and runtime_data[method]:
            runtimes = np.array(runtime_data[method])
            runtime_ci = bootstrap_ci(runtimes, n_bootstrap, confidence, seed)
            row["runtime_mean_seconds"] = runtime_ci.mean
            row["runtime_std_seconds"] = runtime_ci.std
            row["runtime_ci_lower_seconds"] = runtime_ci.ci_lower
            row["runtime_ci_upper_seconds"] = runtime_ci.ci_upper
        else:
            row["runtime_mean_seconds"] = np.nan
            row["runtime_std_seconds"] = np.nan
            row["runtime_ci_lower_seconds"] = np.nan
            row["runtime_ci_upper_seconds"] = np.nan

        rows.append(row)

    result_df = pd.DataFrame(rows)

    # Sort by HVI
    if len(result_df) > 0:
        result_df = result_df.sort_values("final_hvi_mean", ascending=False)

    return result_df


def format_summary_table(df: pd.DataFrame, confidence: float = 0.95) -> str:
    """Format summary table as string for display.

    Args:
        df: Summary DataFrame
        confidence: Confidence level used

    Returns:
        Formatted table string
    """
    if len(df) == 0:
        return "No results to display."

    # Check if runtime data is available
    has_runtime = (
        "runtime_mean_seconds" in df.columns and not df["runtime_mean_seconds"].isna().all()
    )

    if has_runtime:
        lines = [
            f"Method Comparison ({confidence * 100:.0f}% CI)",
            "=" * 100,
            f"{'Method':<15} {'Init':<8} {'N':>3} {'HVI Mean':>10} {'HVI CI':>20} {'Validity':>10} {'Runtime (s)':>12}",
            "-" * 100,
        ]
    else:
        lines = [
            f"Method Comparison ({confidence * 100:.0f}% CI)",
            "=" * 80,
            f"{'Method':<15} {'Init':<12} {'N':>3} {'HVI Mean':>10} {'HVI CI':>20} {'Validity':>10}",
            "-" * 80,
        ]

    for _, row in df.iterrows():
        hvi_ci = f"[{row['final_hvi_ci_lower']:.4f}, {row['final_hvi_ci_upper']:.4f}]"

        if has_runtime and not np.isnan(row.get("runtime_mean_seconds", np.nan)):
            runtime_str = f"{row['runtime_mean_seconds']:.1f}"
            lines.append(
                f"{row['method']:<15} {row['init']:<8} {row['n_seeds']:>3} "
                f"{row['final_hvi_mean']:>10.4f} {hvi_ci:>20} "
                f"{row['final_validity_mean']:>10.2%} {runtime_str:>12}"
            )
        else:
            lines.append(
                f"{row['method']:<15} {row['init']:<12} {row['n_seeds']:>3} "
                f"{row['final_hvi_mean']:>10.4f} {hvi_ci:>20} "
                f"{row['final_validity_mean']:>10.2%}"
            )

    lines.append("=" * (100 if has_runtime else 80))

    return "\n".join(lines)


def compute_pairwise_tests(
    log_dir: str | Path,
    methods: Sequence[str] | None = None,
    metric: str = "hvi",
) -> dict[str, dict]:
    """Compute pairwise statistical tests between methods.

    Uses Mann-Whitney U test (non-parametric) to compare final metrics
    across seeds for different methods.

    Args:
        log_dir: Directory containing optimization logs
        methods: Methods to compare (default: all)
        metric: Metric to compare ("hvi" or "validity")

    Returns:
        Dictionary mapping comparison pairs to test results
    """
    try:
        from scipy.stats import mannwhitneyu
    except ImportError:
        logger.error("scipy required for statistical tests. Install with: pip install scipy")
        return {}

    # Load logs
    logs = load_experiment_logs(log_dir, methods=methods)

    if not logs:
        logger.warning(f"No logs found in {log_dir}")
        return {}

    # Aggregate by method
    method_dfs = aggregate_by_method(logs)

    if len(method_dfs) < 2:
        logger.warning("Need at least 2 methods for comparison")
        return {}

    # Extract final metric values per seed for each method
    method_values = {}
    for method, df in method_dfs.items():
        seeds = df["seed"].unique()
        values = []

        for s in seeds:
            seed_df = df[df["seed"] == s]
            if len(seed_df) > 0:
                if metric == "hvi":
                    values.append(seed_df["hvi"].iloc[-1])
                elif metric == "validity":
                    values.append(seed_df["cumulative_validity"].iloc[-1])

        method_values[method] = np.array(values)

    # Perform pairwise tests
    results = {}
    method_list = list(method_values.keys())

    for i in range(len(method_list)):
        for j in range(i + 1, len(method_list)):
            method_a = method_list[i]
            method_b = method_list[j]

            values_a = method_values[method_a]
            values_b = method_values[method_b]

            # Mann-Whitney U test
            statistic, p_value = mannwhitneyu(values_a, values_b, alternative="two-sided")

            comparison_key = f"{method_a}_vs_{method_b}"
            results[comparison_key] = {
                "method_a": method_a,
                "method_b": method_b,
                "statistic": float(statistic),
                "p_value": float(p_value),
                "significant": p_value < 0.05,
                "mean_a": float(np.mean(values_a)),
                "mean_b": float(np.mean(values_b)),
                "n_a": len(values_a),
                "n_b": len(values_b),
            }

    return results


# =============================================================================
# Hyperparameter Optimization Summary Functions
# =============================================================================


def generate_hpo_summary(study: "optuna.Study", method: str) -> dict:
    """Generate summary statistics from Optuna study.

    Args:
        study: Optuna study object
        method: Method name

    Returns:
        Dictionary with summary statistics
    """
    trials = [t for t in study.trials if t.value is not None]

    if not trials:
        return {
            "method": method,
            "n_trials": 0,
            "best_hvi": None,
            "best_params": {},
            "mean_hvi": None,
            "std_hvi": None,
        }

    values = [t.value for t in trials]
    best_trial = max(trials, key=lambda t: t.value)

    return {
        "method": method,
        "n_trials": len(trials),
        "best_hvi": best_trial.value,
        "best_params": best_trial.params,
        "mean_hvi": float(np.mean(values)),
        "std_hvi": float(np.std(values)),
        "median_hvi": float(np.median(values)),
        "min_hvi": float(np.min(values)),
        "max_hvi": float(np.max(values)),
        "best_trial_number": best_trial.number,
    }


def format_hpo_table(summary: dict) -> str:
    """Format HPO summary as string for display.

    Args:
        summary: HPO summary dictionary

    Returns:
        Formatted table string
    """
    method = summary.get("method", "Unknown")
    n_trials = summary.get("n_trials", 0)

    lines = [
        f"HPO Summary - {method}",
        "=" * 60,
        f"Total trials: {n_trials}",
    ]

    if n_trials > 0:
        lines.extend(
            [
                f"Best HVI: {summary.get('best_hvi', 'N/A'):.4f}",
                f"Mean HVI: {summary.get('mean_hvi', 'N/A'):.4f} +/- {summary.get('std_hvi', 'N/A'):.4f}",
                f"Best trial: #{summary.get('best_trial_number', 'N/A')}",
                "",
                "Best Parameters:",
                "-" * 40,
            ]
        )

        best_params = summary.get("best_params", {})
        for name, value in best_params.items():
            if isinstance(value, float):
                lines.append(f"  {name}: {value:.6g}")
            else:
                lines.append(f"  {name}: {value}")

    lines.append("=" * 60)

    return "\n".join(lines)


def get_top_n_trials(study: "optuna.Study", n: int = 5) -> list[dict]:
    """Get top N trials from study.

    Args:
        study: Optuna study object
        n: Number of top trials to return

    Returns:
        List of trial dictionaries
    """
    trials = [t for t in study.trials if t.value is not None]

    if not trials:
        return []

    # Sort by value (descending for maximization)
    sorted_trials = sorted(trials, key=lambda t: t.value, reverse=True)

    top_trials = []
    for trial in sorted_trials[:n]:
        top_trials.append(
            {
                "trial_number": trial.number,
                "value": trial.value,
                "params": trial.params,
            }
        )

    return top_trials
