"""Statistical utilities: bootstrap CIs, permutation tests, effect sizes, ROPE."""

import numpy as np
from scipy.stats import bootstrap, permutation_test


def bca_bootstrap_ci(x, confidence_level=0.95, n_resamples=9999, seed=42):
    """BCa confidence interval for the mean; falls back to percentile on failure."""
    x = np.asarray(x, dtype=float)
    if x.size < 2:
        return (float(x[0]), float(x[0]))
    try:
        res = bootstrap(
            (x,), np.mean, n_resamples=n_resamples,
            method="BCa", confidence_level=confidence_level,
            random_state=np.random.default_rng(seed),
        )
        return (float(res.confidence_interval.low),
                float(res.confidence_interval.high))
    except Exception:
        res = bootstrap(
            (x,), np.mean, n_resamples=n_resamples,
            method="percentile", confidence_level=confidence_level,
            random_state=np.random.default_rng(seed),
        )
        return (float(res.confidence_interval.low),
                float(res.confidence_interval.high))


def bca_bootstrap_ci_paired(x, y, confidence_level=0.95, n_resamples=9999, seed=42):
    diff = np.asarray(x, dtype=float) - np.asarray(y, dtype=float)
    return bca_bootstrap_ci(diff, confidence_level=confidence_level,
                            n_resamples=n_resamples, seed=seed)


def sign_flip_test(x, y, n_resamples=9999, seed=42):
    """Two-sided sign-flip permutation test for H0: mean(x - y) = 0."""
    x = np.asarray(x, dtype=float)
    y = np.asarray(y, dtype=float)

    def stat(x_perm, y_perm, axis):
        return np.mean(x_perm - y_perm, axis=axis)

    res = permutation_test(
        (x, y), stat,
        permutation_type="samples",
        n_resamples=n_resamples,
        alternative="two-sided",
        random_state=np.random.default_rng(seed),
    )
    return float(res.pvalue)


def rope_analysis(samples, lo, hi):
    samples = np.asarray(samples, dtype=float)
    n = len(samples)
    pct_below = float(np.sum(samples < lo) / n * 100)
    pct_in = float(np.sum((samples >= lo) & (samples <= hi)) / n * 100)
    pct_above = float(np.sum(samples > hi) / n * 100)

    if pct_in >= 95:
        decision = "equivalent"
    elif (pct_below + pct_above) >= 95:
        decision = "different"
    else:
        decision = "undecided"

    return {"decision": decision, "pct_below": pct_below,
            "pct_in": pct_in, "pct_above": pct_above}


def rope_from_bootstrap(x, y, lo, hi, n_resamples=9999, seed=42):
    diff = np.asarray(x, dtype=float) - np.asarray(y, dtype=float)
    rng = np.random.default_rng(seed)
    boot_means = np.array([
        np.mean(rng.choice(diff, size=len(diff), replace=True))
        for _ in range(n_resamples)
    ])
    ci = bca_bootstrap_ci_paired(x, y, n_resamples=n_resamples, seed=seed)
    result = rope_analysis(boot_means, lo, hi)
    result["ci"] = ci
    return result


def cohens_d_paired(x, y):
    """Cohen's d for paired samples: mean(x-y) / std(x-y)."""
    diff = np.asarray(x, dtype=float) - np.asarray(y, dtype=float)
    return float(np.mean(diff) / np.std(diff, ddof=1)) if np.std(diff, ddof=1) > 0 else 0.0


def hedges_g_paired(x, y):
    d = cohens_d_paired(x, y)
    n = len(np.asarray(x))
    correction = 1 - 3 / (4 * (n - 1) - 1) if n > 1 else 1.0
    return float(d * correction)


def interpret_effect_size(d):
    """Classify |d| as negligible/small/medium/large (Cohen's benchmarks)."""
    ad = abs(d)
    if ad < 0.2:
        return "negligible"
    elif ad < 0.5:
        return "small"
    elif ad < 0.8:
        return "medium"
    else:
        return "large"


def full_comparison_report(x, y, rope_lo=None, rope_hi=None, label="",
                           n_resamples=9999, seed=42):
    """Run BCa CI, sign-flip test, Hedges' g, and optional ROPE on a paired comparison."""
    x = np.asarray(x, dtype=float)
    y = np.asarray(y, dtype=float)
    diff = x - y

    report = {
        "label": label,
        "n": int(len(x)),
        "mean_x": float(np.mean(x)),
        "mean_y": float(np.mean(y)),
        "mean_diff": float(np.mean(diff)),
        "se_diff": float(np.std(diff, ddof=1) / np.sqrt(len(diff))) if len(diff) > 1 else 0.0,
        "ci": bca_bootstrap_ci_paired(x, y, n_resamples=n_resamples, seed=seed),
        "p_value": sign_flip_test(x, y, n_resamples=n_resamples, seed=seed),
        "cohens_d": cohens_d_paired(x, y),
        "hedges_g": hedges_g_paired(x, y),
    }
    report["effect_interp"] = interpret_effect_size(report["hedges_g"])

    if rope_lo is not None and rope_hi is not None:
        report["rope"] = rope_from_bootstrap(
            x, y, rope_lo, rope_hi, n_resamples=n_resamples, seed=seed)

    return report


def format_report(report):
    lines = []
    lines.append(f"  {report['label']}")
    lines.append(f"    n = {report['n']}")
    lines.append(f"    mean_x = {report['mean_x']:.4f},  mean_y = {report['mean_y']:.4f}")
    lines.append(f"    diff   = {report['mean_diff']:+.4f}  (SE = {report['se_diff']:.4f})")
    ci = report["ci"]
    lines.append(f"    BCa 95% CI: [{ci[0]:+.4f}, {ci[1]:+.4f}]")
    lines.append(f"    permutation p = {report['p_value']:.4f}")
    lines.append(f"    Hedges' g = {report['hedges_g']:+.3f}  ({report['effect_interp']})")

    if "rope" in report:
        r = report["rope"]
        lines.append(f"    ROPE: {r['decision']}  "
                     f"(below {r['pct_below']:.1f}% | "
                     f"in {r['pct_in']:.1f}% | "
                     f"above {r['pct_above']:.1f}%)")
    return "\n".join(lines)
