"""
Statistical Analysis Utilities for ACEAS Experiments.

This module provides statistical methods for:
1. Bootstrap confidence intervals for metric estimates
2. Welch's t-test for comparing methods
3. Effect size calculations (Cohen's d)
4. Multiple comparison corrections (Bonferroni, Holm)

These utilities address reviewer concerns about statistical methodology (W5)
by providing rigorous significance testing with proper test specification.
"""

import numpy as np
from typing import List, Tuple, Dict, Optional, Union
from dataclasses import dataclass
from scipy import stats


@dataclass
class ConfidenceInterval:
    """Confidence interval result."""
    point_estimate: float
    lower: float
    upper: float
    confidence_level: float

    def __str__(self) -> str:
        return f"{self.point_estimate:.3f} [{self.lower:.3f}, {self.upper:.3f}] ({int(self.confidence_level*100)}% CI)"


@dataclass
class HypothesisTestResult:
    """Result of a hypothesis test."""
    statistic: float
    p_value: float
    effect_size: float
    confidence_interval: Optional[ConfidenceInterval] = None
    test_name: str = ""
    significant_at_001: bool = False
    significant_at_005: bool = False

    def __str__(self) -> str:
        sig_str = "***" if self.significant_at_001 else ("**" if self.significant_at_005 else "")
        return (f"{self.test_name}: statistic={self.statistic:.3f}, p={self.p_value:.4f}{sig_str}, "
                f"effect_size={self.effect_size:.3f}")


def bootstrap_confidence_interval(
    data: Union[List[float], np.ndarray],
    confidence_level: float = 0.95,
    n_bootstrap: int = 10000,
    statistic: str = "mean",
    random_state: Optional[int] = None,
) -> ConfidenceInterval:
    """
    Compute bootstrap confidence interval for a statistic.

    Args:
        data: Sample data
        confidence_level: Confidence level (default 0.95 for 95% CI)
        n_bootstrap: Number of bootstrap resamples
        statistic: Statistic to compute ("mean", "median", "std")
        random_state: Random seed for reproducibility

    Returns:
        ConfidenceInterval with point estimate and bounds
    """
    data = np.asarray(data)
    rng = np.random.RandomState(random_state)

    # Compute statistic function
    if statistic == "mean":
        stat_func = np.mean
    elif statistic == "median":
        stat_func = np.median
    elif statistic == "std":
        stat_func = np.std
    else:
        raise ValueError(f"Unknown statistic: {statistic}")

    # Point estimate
    point_estimate = float(stat_func(data))

    # Bootstrap resampling
    bootstrap_stats = np.zeros(n_bootstrap)
    n = len(data)

    for i in range(n_bootstrap):
        resample = data[rng.randint(0, n, size=n)]
        bootstrap_stats[i] = stat_func(resample)

    # Percentile method for CI
    alpha = 1 - confidence_level
    lower_percentile = alpha / 2 * 100
    upper_percentile = (1 - alpha / 2) * 100

    lower = float(np.percentile(bootstrap_stats, lower_percentile))
    upper = float(np.percentile(bootstrap_stats, upper_percentile))

    return ConfidenceInterval(
        point_estimate=point_estimate,
        lower=lower,
        upper=upper,
        confidence_level=confidence_level,
    )


def welch_t_test(
    sample1: Union[List[float], np.ndarray],
    sample2: Union[List[float], np.ndarray],
    alternative: str = "two-sided",
) -> HypothesisTestResult:
    """
    Perform Welch's t-test for comparing two independent samples.

    Welch's t-test does not assume equal variances, making it more
    robust than Student's t-test for comparing methods.

    Args:
        sample1: First sample (e.g., baseline method results)
        sample2: Second sample (e.g., ACEAS results)
        alternative: "two-sided", "less", or "greater"

    Returns:
        HypothesisTestResult with test statistics
    """
    sample1 = np.asarray(sample1)
    sample2 = np.asarray(sample2)

    # Perform Welch's t-test
    statistic, p_value = stats.ttest_ind(
        sample1, sample2,
        equal_var=False,  # Welch's test
        alternative=alternative,
    )

    # Compute Cohen's d effect size
    pooled_std = np.sqrt((np.var(sample1, ddof=1) + np.var(sample2, ddof=1)) / 2)
    if pooled_std > 0:
        effect_size = (np.mean(sample2) - np.mean(sample1)) / pooled_std
    else:
        effect_size = 0.0

    # Confidence interval for difference in means
    mean_diff = np.mean(sample2) - np.mean(sample1)
    se_diff = np.sqrt(np.var(sample1, ddof=1) / len(sample1) +
                      np.var(sample2, ddof=1) / len(sample2))

    # Approximate df for Welch's t-test
    v1 = np.var(sample1, ddof=1) / len(sample1)
    v2 = np.var(sample2, ddof=1) / len(sample2)
    df = (v1 + v2)**2 / (v1**2 / (len(sample1) - 1) + v2**2 / (len(sample2) - 1))

    t_crit = stats.t.ppf(0.975, df)
    ci = ConfidenceInterval(
        point_estimate=mean_diff,
        lower=mean_diff - t_crit * se_diff,
        upper=mean_diff + t_crit * se_diff,
        confidence_level=0.95,
    )

    return HypothesisTestResult(
        statistic=float(statistic),
        p_value=float(p_value),
        effect_size=float(effect_size),
        confidence_interval=ci,
        test_name="Welch's t-test",
        significant_at_001=p_value < 0.01,
        significant_at_005=p_value < 0.05,
    )


def paired_t_test(
    sample1: Union[List[float], np.ndarray],
    sample2: Union[List[float], np.ndarray],
    alternative: str = "two-sided",
) -> HypothesisTestResult:
    """
    Perform paired t-test for comparing matched samples.

    Use this when comparing methods on the same tasks (paired design).

    Args:
        sample1: First sample (baseline results per seed/task)
        sample2: Second sample (ACEAS results per seed/task)
        alternative: "two-sided", "less", or "greater"

    Returns:
        HypothesisTestResult with test statistics
    """
    sample1 = np.asarray(sample1)
    sample2 = np.asarray(sample2)

    if len(sample1) != len(sample2):
        raise ValueError("Paired samples must have same length")

    # Compute differences
    differences = sample2 - sample1

    # Perform paired t-test
    statistic, p_value = stats.ttest_rel(sample1, sample2, alternative=alternative)

    # Effect size (Cohen's d for paired samples)
    effect_size = np.mean(differences) / np.std(differences, ddof=1) if np.std(differences, ddof=1) > 0 else 0.0

    # CI for mean difference
    mean_diff = np.mean(differences)
    se_diff = np.std(differences, ddof=1) / np.sqrt(len(differences))
    t_crit = stats.t.ppf(0.975, len(differences) - 1)

    ci = ConfidenceInterval(
        point_estimate=mean_diff,
        lower=mean_diff - t_crit * se_diff,
        upper=mean_diff + t_crit * se_diff,
        confidence_level=0.95,
    )

    return HypothesisTestResult(
        statistic=float(statistic),
        p_value=float(p_value),
        effect_size=float(effect_size),
        confidence_interval=ci,
        test_name="Paired t-test",
        significant_at_001=p_value < 0.01,
        significant_at_005=p_value < 0.05,
    )


def bonferroni_correction(
    p_values: List[float],
    alpha: float = 0.05,
) -> Dict[str, Union[List[float], float, List[bool]]]:
    """
    Apply Bonferroni correction for multiple comparisons.

    Args:
        p_values: List of p-values from multiple tests
        alpha: Family-wise error rate (default 0.05)

    Returns:
        Dictionary with corrected threshold and significance indicators
    """
    n_tests = len(p_values)
    corrected_alpha = alpha / n_tests

    significant = [p < corrected_alpha for p in p_values]

    return {
        "original_p_values": p_values,
        "corrected_alpha": corrected_alpha,
        "n_tests": n_tests,
        "significant": significant,
        "n_significant": sum(significant),
    }


def holm_bonferroni_correction(
    p_values: List[float],
    alpha: float = 0.05,
) -> Dict[str, Union[List[float], List[bool]]]:
    """
    Apply Holm-Bonferroni step-down correction for multiple comparisons.

    This is less conservative than Bonferroni while still controlling FWER.

    Args:
        p_values: List of p-values from multiple tests
        alpha: Family-wise error rate (default 0.05)

    Returns:
        Dictionary with adjusted p-values and significance indicators
    """
    n_tests = len(p_values)

    # Sort p-values and track original indices
    sorted_indices = np.argsort(p_values)
    sorted_p_values = np.array(p_values)[sorted_indices]

    # Compute adjusted p-values
    adjusted_p_values = np.zeros(n_tests)
    for i, (idx, p) in enumerate(zip(sorted_indices, sorted_p_values)):
        # Holm's adjustment: multiply by (n - rank + 1)
        adjusted_p = min(1.0, p * (n_tests - i))
        # Ensure monotonicity
        if i > 0:
            adjusted_p = max(adjusted_p, adjusted_p_values[sorted_indices[i - 1]])
        adjusted_p_values[idx] = adjusted_p

    significant = [p < alpha for p in adjusted_p_values]

    return {
        "original_p_values": p_values,
        "adjusted_p_values": adjusted_p_values.tolist(),
        "significant": significant,
        "n_significant": sum(significant),
    }


def compare_methods_comprehensive(
    baseline_results: Dict[str, List[float]],
    aceas_results: List[float],
    method_names: Optional[List[str]] = None,
) -> Dict[str, any]:
    """
    Comprehensive statistical comparison of ACEAS against baselines.

    Args:
        baseline_results: Dictionary mapping baseline name to results list
        aceas_results: ACEAS results list
        method_names: Optional custom names for baselines

    Returns:
        Dictionary with all statistical comparisons
    """
    comparisons = {}

    p_values = []
    for baseline_name, baseline_data in baseline_results.items():
        test_result = welch_t_test(baseline_data, aceas_results)
        comparisons[baseline_name] = {
            "test": test_result,
            "aceas_mean": np.mean(aceas_results),
            "aceas_std": np.std(aceas_results),
            "baseline_mean": np.mean(baseline_data),
            "baseline_std": np.std(baseline_data),
            "improvement": np.mean(aceas_results) - np.mean(baseline_data),
        }
        p_values.append(test_result.p_value)

    # Multiple comparison correction
    holm_correction = holm_bonferroni_correction(p_values)

    # Update significance after correction
    for i, (baseline_name, comp) in enumerate(comparisons.items()):
        comp["significant_after_correction"] = holm_correction["significant"][i]
        comp["adjusted_p_value"] = holm_correction["adjusted_p_values"][i]

    return {
        "comparisons": comparisons,
        "holm_correction": holm_correction,
        "aceas_bootstrap_ci": bootstrap_confidence_interval(aceas_results),
    }


def generate_statistical_report(
    results: Dict[str, List[float]],
    metric_name: str = "Pass@1",
) -> str:
    """
    Generate a formatted statistical report for paper inclusion.

    Args:
        results: Dictionary mapping method name to results list
        metric_name: Name of the metric being compared

    Returns:
        Formatted string report
    """
    lines = [
        f"Statistical Analysis Report: {metric_name}",
        "=" * 50,
        "",
        "1. Descriptive Statistics",
        "-" * 30,
    ]

    for method, data in results.items():
        ci = bootstrap_confidence_interval(data)
        lines.append(f"  {method}:")
        lines.append(f"    Mean: {np.mean(data):.3f}")
        lines.append(f"    Std:  {np.std(data):.3f}")
        lines.append(f"    95% Bootstrap CI: [{ci.lower:.3f}, {ci.upper:.3f}]")
        lines.append("")

    # Pairwise comparisons if ACEAS is present
    if "ACEAS" in results or "aceas" in results:
        aceas_key = "ACEAS" if "ACEAS" in results else "aceas"
        aceas_data = results[aceas_key]
        baselines = {k: v for k, v in results.items() if k != aceas_key}

        if baselines:
            lines.extend([
                "2. Pairwise Comparisons (ACEAS vs. Baselines)",
                "-" * 30,
            ])

            p_values = []
            for baseline_name, baseline_data in baselines.items():
                test = welch_t_test(baseline_data, aceas_data)
                p_values.append(test.p_value)
                lines.append(f"  ACEAS vs. {baseline_name}:")
                lines.append(f"    Welch's t-statistic: {test.statistic:.3f}")
                lines.append(f"    p-value: {test.p_value:.4f}")
                lines.append(f"    Cohen's d: {test.effect_size:.3f}")
                lines.append(f"    Mean difference: {test.confidence_interval.point_estimate:.3f}")
                lines.append(f"    95% CI: [{test.confidence_interval.lower:.3f}, {test.confidence_interval.upper:.3f}]")
                sig_str = "Yes (p < 0.01)" if test.significant_at_001 else ("Yes (p < 0.05)" if test.significant_at_005 else "No")
                lines.append(f"    Significant: {sig_str}")
                lines.append("")

            # Multiple comparison correction
            holm = holm_bonferroni_correction(p_values)
            lines.extend([
                "3. Multiple Comparison Correction (Holm-Bonferroni)",
                "-" * 30,
            ])
            for i, (baseline_name, adj_p) in enumerate(zip(baselines.keys(), holm["adjusted_p_values"])):
                sig = "Yes" if holm["significant"][i] else "No"
                lines.append(f"  ACEAS vs. {baseline_name}: adjusted p = {adj_p:.4f}, significant: {sig}")

    return "\n".join(lines)


if __name__ == "__main__":
    # Test the statistical analysis utilities
    print("Testing Statistical Analysis Utilities")
    print("=" * 50)

    # Simulate results from paper
    np.random.seed(42)

    # Simulated results (3 seeds each, similar to paper)
    sync_grpo = [0.385, 0.402, 0.404]  # Mean ~39.7%
    sync_cccs = [0.502, 0.520, 0.523]  # Mean ~51.5%
    async_grpo = [0.305, 0.325, 0.324]  # Mean ~31.8%
    async_staleness = [0.392, 0.408, 0.409]  # Mean ~40.3%
    aceas = [0.590, 0.605, 0.608]  # Mean ~60.1%

    # Bootstrap CI for ACEAS
    ci = bootstrap_confidence_interval(aceas, random_state=42)
    print(f"\nACEAS Bootstrap CI: {ci}")

    # Compare ACEAS vs Sync-GRPO
    test = welch_t_test(sync_grpo, aceas)
    print(f"\nACEAS vs Sync-GRPO: {test}")

    # Compare ACEAS vs Async-GRPO
    test = welch_t_test(async_grpo, aceas)
    print(f"ACEAS vs Async-GRPO: {test}")

    # Generate full report
    results = {
        "Sync-GRPO": sync_grpo,
        "Sync-GRPO+CCCS": sync_cccs,
        "Async-GRPO": async_grpo,
        "Async-GRPO+Staleness": async_staleness,
        "ACEAS": aceas,
    }

    report = generate_statistical_report(results)
    print("\n" + report)

    print("\nAll tests passed!")
