"""Correlation and regression analysis.

Relates exploration metrics (primitives, diversity, novelty) to
math pass@k gains across checkpoints.
"""
from __future__ import annotations

from typing import Optional

import numpy as np
import pandas as pd
from scipy import stats


def spearman_correlation_table(
    df: pd.DataFrame,
    metric_cols: list[str],
    outcome_col: str,
) -> pd.DataFrame:
    """Compute Spearman rank correlation between each metric and outcome.

    Returns DataFrame with columns: metric, rho, p_value, n.
    """
    rows = []
    for col in metric_cols:
        mask = df[col].notna() & df[outcome_col].notna()
        x = df.loc[mask, col].values
        y = df.loc[mask, outcome_col].values
        n = len(x)
        if n < 3:
            rows.append({"metric": col, "rho": np.nan, "p_value": np.nan, "n": n})
            continue
        rho, p = stats.spearmanr(x, y)
        rows.append({"metric": col, "rho": rho, "p_value": p, "n": n})
    return pd.DataFrame(rows)


def pearson_correlation_table(
    df: pd.DataFrame,
    metric_cols: list[str],
    outcome_col: str,
) -> pd.DataFrame:
    """Compute Pearson correlation between each metric and outcome."""
    rows = []
    for col in metric_cols:
        mask = df[col].notna() & df[outcome_col].notna()
        x = df.loc[mask, col].values
        y = df.loc[mask, outcome_col].values
        n = len(x)
        if n < 3:
            rows.append({"metric": col, "r": np.nan, "p_value": np.nan, "n": n})
            continue
        r, p = stats.pearsonr(x, y)
        rows.append({"metric": col, "r": r, "p_value": p, "n": n})
    return pd.DataFrame(rows)


def ols_regression(
    df: pd.DataFrame,
    predictor: str,
    outcome: str,
) -> dict:
    """Simple OLS regression: outcome ~ predictor.

    Returns dict with slope, intercept, r_squared, std_err, p_value.
    """
    mask = df[predictor].notna() & df[outcome].notna()
    x = df.loc[mask, predictor].values
    y = df.loc[mask, outcome].values

    if len(x) < 3:
        return {
            "predictor": predictor,
            "slope": np.nan,
            "intercept": np.nan,
            "r_squared": np.nan,
            "std_err": np.nan,
            "p_value": np.nan,
            "n": len(x),
        }

    result = stats.linregress(x, y)
    return {
        "predictor": predictor,
        "slope": result.slope,
        "intercept": result.intercept,
        "r_squared": result.rvalue ** 2,
        "std_err": result.stderr,
        "p_value": result.pvalue,
        "n": len(x),
    }


def multi_predictor_ols(
    df: pd.DataFrame,
    predictors: list[str],
    outcome: str,
) -> dict:
    """Multi-predictor OLS regression using sklearn.

    Returns dict with coefficients, r_squared, and per-predictor info.
    """
    from sklearn.linear_model import LinearRegression

    # Drop rows with any NaN in predictors or outcome
    cols = predictors + [outcome]
    clean = df[cols].dropna()

    if len(clean) < len(predictors) + 2:
        return {
            "predictors": predictors,
            "outcome": outcome,
            "n": len(clean),
            "r_squared": np.nan,
            "coefficients": {},
        }

    X = clean[predictors].values
    y = clean[outcome].values

    model = LinearRegression()
    model.fit(X, y)

    return {
        "predictors": predictors,
        "outcome": outcome,
        "n": len(clean),
        "r_squared": model.score(X, y),
        "intercept": model.intercept_,
        "coefficients": {p: c for p, c in zip(predictors, model.coef_)},
    }


def build_checkpoint_table(
    checkpoint_metrics: dict[str, dict],
) -> pd.DataFrame:
    """Build a DataFrame with one row per checkpoint, all metrics as columns.

    Args:
        checkpoint_metrics: Dict keyed by checkpoint_id, values are dicts
            containing all computed metrics (pass@k, primitive, diversity, novelty).

    Returns DataFrame sorted by checkpoint_id.
    """
    rows = []
    for ckpt_id, metrics in checkpoint_metrics.items():
        row = {"checkpoint_id": ckpt_id}
        row.update(metrics)
        rows.append(row)

    df = pd.DataFrame(rows)
    if "checkpoint_id" in df.columns:
        df = df.sort_values("checkpoint_id").reset_index(drop=True)
    return df


def length_correlation_check(
    df: pd.DataFrame,
    metric_cols: list[str],
    length_col: str = "avg_trace_length_mean",
    threshold: float = 0.8,
) -> list[str]:
    """Check which metrics are highly correlated with trace length.

    Returns list of metric names where |Pearson r| > threshold.
    Flags these as potentially length-confounded.
    """
    flagged = []
    for col in metric_cols:
        mask = df[col].notna() & df[length_col].notna()
        x = df.loc[mask, col].values
        y = df.loc[mask, length_col].values
        if len(x) < 3:
            continue
        r, _ = stats.pearsonr(x, y)
        if abs(r) > threshold:
            flagged.append(col)
    return flagged
