#!/usr/bin/env python3
"""
benchmark_overton_pipeline.py
=======================================

One-stop pipeline to:
1) Load per-question KMEANS (or any other) clusters and compute Overton scores + OLS.
2) Write a single combined CSV and Markdown summary containing raw and adjusted scores.

Args & Features
---------------
- Allows passing a custom data file with --data (CSV; must contain columns: user, question_id, model, representation_rating, and a clustering column).
- Allows specifying the cluster assignment column with --cluster_col (defaults to 'cluster_kmeans').

Outputs (in --outdir)
---------------------
- overton_scores_and_ols_tau{tau}.csv   (combined per-model table; default tau=4.0)
- overton_scores_and_ols_tau{tau}.md   (compact markdown summary by method)

CLI
---
# Default: load data from Hugging Face
python src/benchmark_overton_pipeline.py --weighted

# Use your own data file (or set DATASET in .env):
python src/benchmark_overton_pipeline.py --data path/to/your.csv --cluster_col cluster_kmeans

"""

import argparse
import os
from pathlib import Path
from typing import Dict, List, Tuple, Optional

import numpy as np
import pandas as pd
import statsmodels.formula.api as smf

from load_dataset import get_overtonbench_data


# -------------------------------
# Helpers
# -------------------------------

def cluster_sizes_by_question(df: pd.DataFrame, cluster_col: str) -> pd.DataFrame:
    """
    Return a DataFrame: [question_id, <cluster_col>, cluster_size]
    where cluster_size = # unique users in that cluster for that question (model-agnostic).
    """
    d = df.dropna(subset=[cluster_col]).copy()
    # unique users per (q, cluster)
    sizes = (d.drop_duplicates(["user", "question_id", cluster_col])
               .groupby(["question_id", cluster_col])["user"]
               .nunique()
               .reset_index(name="cluster_size"))
    return sizes

# -------------------------------
# Overton scoring + OLS (generic)
# -------------------------------

def compute_unadjusted_overton(
    df: pd.DataFrame,
    cluster_col: str,
    tau: float,
) -> Tuple[pd.DataFrame, pd.DataFrame]:
    """
    Compute unadjusted OvertonScore for the given cluster_col.
    Returns:
      - per_model_scores: model-level table with OvertonScore, total_clusters, avg_clusters_per_q
      - per_q_oc: per (question_id, model) OC values and cluster counts (for diagnostics/optional export)
    """
    # Drop rows without a cluster label
    d = df.dropna(subset=[cluster_col]).copy()
    # Group: (q, m, cluster) -> mean rating
    grp = (d.groupby(["question_id", "model", cluster_col])["representation_rating"]
             .mean()
             .reset_index(name="mean_rating"))
    # In case a group is empty (shouldn't be), drop NaN means
    grp = grp.dropna(subset=["mean_rating"])

    # Compute covered flags at cluster level
    grp["covered"] = (grp["mean_rating"] >= tau).astype(int)

    # Per (q, m): count clusters (K_qm), and covered count
    oc = (grp.groupby(["question_id", "model"])["covered"]
            .agg(["sum", "count"])
            .reset_index()
            .rename(columns={"sum": "covered_clusters", "count": "K"}))
    # OC(m, q) = covered / K
    oc["OC"] = oc["covered_clusters"] / oc["K"]

    # Per-model aggregate
    rows = []
    for m, g in oc.groupby("model"):
        rows.append({
            "model": m,
            "OvertonScore": float(g["OC"].mean()),                 # mean over questions with K>0
            "total_clusters": int(g["K"].sum()),
            "avg_clusters_per_q": float(g["K"].mean()),
        })
    per_model_scores = pd.DataFrame(rows)

    return per_model_scores, oc

def compute_weighted_overton(
    df: pd.DataFrame,
    cluster_col: str,
    tau: float,
) -> Tuple[pd.DataFrame, pd.DataFrame]:
    """
    Compute size-weighted Overton using cluster sizes based on unique users per (q, cluster),
    independent of model.
    Returns:
      - per_model_scores_w: model-level table with OvertonScore_w
      - per_q_oc_w: per (question_id, model) weighted OC values and the weight sum (Σ cluster_size)
    """
    # 1) Cluster-level means per (q, m, cluster)
    d = df.dropna(subset=[cluster_col]).copy()
    grp = (d.groupby(["question_id", "model", cluster_col])["representation_rating"]
             .mean().reset_index(name="mean_rating"))
    grp["covered"] = (grp["mean_rating"] >= tau).astype(int)

    # 2) Cluster sizes per (q, cluster), model-agnostic
    sizes = cluster_sizes_by_question(df, cluster_col=cluster_col)

    # 3) Merge sizes into covered table
    g = grp.merge(sizes, on=["question_id", cluster_col], how="left")
    # Safety: drop clusters with no size (shouldn't happen)
    g = g.dropna(subset=["cluster_size"]).copy()
    g["cluster_size"] = pd.to_numeric(g["cluster_size"], errors="coerce")
    g = g[g["cluster_size"] > 0]

    # 4) Weighted OC per (q, m) + number of contributing clusters K
    ocw = (g.assign(weighted_cov=lambda x: x["covered"] * x["cluster_size"])
             .groupby(["question_id", "model"])
             .agg(weighted_cov=("weighted_cov", "sum"),
                  weight_sum=("cluster_size", "sum"),
                  K=("cluster_size", "count"))   # number of (q,m,cluster) rows
             .reset_index())
    ocw["OC_w"] = ocw["weighted_cov"] / ocw["weight_sum"]

    # 5) Per-model weighted Overton
    rows = []
    for m, h in ocw.groupby("model"):
        rows.append({"model": m, "OvertonScore_w": float(h["OC_w"].mean())})
    per_model_scores_w = pd.DataFrame(rows)

    return per_model_scores_w, ocw

def compute_mixed_model_upper_bound(
    df: pd.DataFrame,
    cluster_col: str,
    tau: float,
) -> pd.DataFrame:
    """
    Mixed-model 'upper bound' per question for a given cluster partition:
    - For each (question_id, cluster), compute mean rating per model.
    - Mark the cluster as covered if ANY model's mean >= tau.
    - Compute per-question OC_unweighted (covered/K) and OC_weighted
      (weighted by cluster_size = unique users in cluster for that question).
    Returns a per-question DataFrame:
      [question_id, K, covered_clusters, weight_sum, OC_unweighted, OC_weighted]
    """
    d = df.dropna(subset=[cluster_col]).copy()

    # Cluster-level means per (q, m, cluster)
    cm = (d.groupby(["question_id", "model", cluster_col])["representation_rating"]
            .mean()
            .reset_index(name="mean_rating"))

    # ANY model clears the threshold → covered_any at (q, cluster)
    max_by_cluster = (cm.groupby(["question_id", cluster_col])["mean_rating"]
                        .max()
                        .reset_index(name="max_mean_rating"))
    max_by_cluster["covered_any"] = (max_by_cluster["max_mean_rating"] >= tau).astype(int)

    # Cluster sizes (unique users) per (q, cluster), model-agnostic
    sizes = cluster_sizes_by_question(df, cluster_col=cluster_col)

    cov = max_by_cluster.merge(sizes, on=["question_id", cluster_col], how="left")
    cov = cov.dropna(subset=["cluster_size"]).copy()
    cov["cluster_size"] = pd.to_numeric(cov["cluster_size"], errors="coerce")
    cov = cov[cov["cluster_size"] > 0].copy()

    # Per-question rollups
    per_q = (cov.groupby("question_id")
               .agg(
                   covered_clusters=("covered_any", "sum"),
                   K=("covered_any", "count"),
                   weight_sum=("cluster_size", "sum"),
                   weighted_cov=("covered_any", lambda x: np.nan),  # placeholder; compute below
               )
               .reset_index())

    # weighted coverage at question level
    wc = (cov.assign(wcov=lambda x: x["covered_any"] * x["cluster_size"])
             .groupby("question_id")["wcov"].sum().reset_index(name="weighted_cov"))
    per_q = per_q.drop(columns=["weighted_cov"]).merge(wc, on="question_id", how="left")

    per_q["OC_unweighted"] = per_q["covered_clusters"] / per_q["K"]
    per_q["OC_weighted"] = per_q["weighted_cov"] / per_q["weight_sum"]
    return per_q

def run_ols_adjusted_coverage(
    df: pd.DataFrame,
    cluster_col: str,
    tau: float,
) -> pd.DataFrame:
    """
    Linear probability model with question fixed effects and SEs clustered by question.
    Spec: covered ~ 0 + C(model) + C(question_id)
    Returns a per-model table with adj_coverage and inferentials vs grand-mean of model effects.
    """
    d = df.dropna(subset=[cluster_col]).copy()
    # (q, m, cluster) -> mean rating
    gm = (d.groupby(["question_id", "model", cluster_col])["representation_rating"]
            .mean()
            .reset_index(name="mean_rating"))
    gm["covered"] = (gm["mean_rating"] >= tau).astype(int)

    # Fit OLS with cluster-robust SEs by question
    model = smf.ols("covered ~ 0 + C(model) + C(question_id)", data=gm)
    res = model.fit(cov_type="cluster", cov_kwds={"groups": gm["question_id"]})

    # Predicted adjusted coverage (averaged across questions)
    gm["pred"] = res.predict(gm)
    # First average predictions within each question (unweighted over clusters)
    q_model = (gm.groupby(["question_id", "model"], as_index=False)["pred"]
                 .mean()
                 .rename(columns={"pred": "pred_q"}))
    # Then average equally across questions (each question counts once)
    adj_cov = (q_model.groupby("model", as_index=False)["pred_q"]
                 .mean()
                 .rename(columns={"pred_q": "adj_coverage"}))

    # Build inferentials: test each model effect vs grand mean of model effects
    params = res.params.copy()
    idx = list(params.index)
    model_param_names = [nm for nm in idx if nm.startswith("C(model)")]
    if not model_param_names:
        raise RuntimeError("No model effects found in the fitted regression.")

    M = len(model_param_names)
    rows = []
    for nm in model_param_names:
        L = np.zeros((len(idx),), dtype=float)
        L[idx.index(nm)] = 1.0
        for jn in model_param_names:
            L[idx.index(jn)] -= 1.0 / M
        tt = res.t_test(L)
        # Parse model name
        m_name = nm.split("[T.", 1)[1][:-1] if "[T." in nm else nm.split("[", 1)[1][:-1]
        rows.append({
            "model": m_name,
            "coef_vs_grand_mean": float(tt.effect.item()),
            "se": float(tt.sd.item()),
            "t": float(tt.tvalue.item()),
            "p": float(tt.pvalue.item()),
            "ci_low": float(tt.conf_int().ravel()[0]),
            "ci_high": float(tt.conf_int().ravel()[1]),
            "n_rows": int((gm["model"] == m_name).sum()),
        })
    infer = pd.DataFrame(rows)

    out = adj_cov.merge(infer, on="model", how="outer")
    return out


def run_ols_adjusted_coverage_weighted(
    df: pd.DataFrame,
    cluster_col: str,
    tau: float,
) -> pd.DataFrame:
    """
    Weighted linear probability model with question fixed effects and SEs clustered by question.
    Weights = cluster size (unique users in the cluster for that question).
    Adjusted coverage is computed as:
      - within each question, the weight-mean of predicted coverage over clusters,
      - then averaged equally across questions.
    Returns columns: model, adj_coverage_w, coef_vs_grand_mean_w, se_w, t_w, p_w, ci_low_w, ci_high_w, n_rows_w
    """
    # Prepare cluster-level means
    d = df.dropna(subset=[cluster_col]).copy()
    gm = (d.groupby(["question_id", "model", cluster_col])["representation_rating"]
            .mean().reset_index(name="mean_rating"))
    gm["covered"] = (gm["mean_rating"] >= tau).astype(int)

    # Cluster sizes (weights) per (q, cluster), model-agnostic
    sizes = (d.drop_duplicates(["user", "question_id", cluster_col])
               .groupby(["question_id", cluster_col])["user"]
               .nunique()
               .reset_index(name="cluster_size"))
    gm = gm.merge(sizes, on=["question_id", cluster_col], how="left")
    gm = gm.dropna(subset=["cluster_size"]).copy()
    gm["cluster_size"] = pd.to_numeric(gm["cluster_size"], errors="coerce")
    gm = gm[gm["cluster_size"] > 0].copy()

    # Weighted OLS with question FEs; cluster-robust SEs by question
    wls_model = smf.wls("covered ~ 0 + C(model) + C(question_id)", data=gm, weights=gm["cluster_size"])
    res = wls_model.fit(cov_type="cluster", cov_kwds={"groups": gm["question_id"]})

    # Predicted coverage
    gm["pred_w"] = res.predict(gm)

    # Within-question weighted mean of predictions, then average across questions
    q_model = (gm.assign(wpred=lambda x: x["pred_w"] * x["cluster_size"])
                 .groupby(["question_id", "model"])
                 .agg(wpred_sum=("wpred", "sum"), wsum=("cluster_size", "sum"))
                 .reset_index())
    q_model["pred_q_w"] = q_model["wpred_sum"] / q_model["wsum"]
    adj_cov_w = (q_model.groupby("model", as_index=False)["pred_q_w"].mean()
                   .rename(columns={"pred_q_w": "adj_coverage_w"}))

    # Inferentials vs. grand mean of model effects (same contrast as unweighted, different fit)
    params = res.params.copy()
    idx = list(params.index)
    model_param_names = [nm for nm in idx if nm.startswith("C(model)")]
    if not model_param_names:
        raise RuntimeError("No model effects found in the weighted regression.")

    M = len(model_param_names)
    rows = []
    for nm in model_param_names:
        L = np.zeros((len(idx),), dtype=float)
        L[idx.index(nm)] = 1.0
        for jn in model_param_names:
            L[idx.index(jn)] -= 1.0 / M
        tt = res.t_test(L)
        m_name = nm.split("[T.", 1)[1][:-1] if "[T." in nm else nm.split("[", 1)[1][:-1]
        rows.append({
            "model": m_name,
            "coef_vs_grand_mean_w": float(tt.effect.item()),
            "se_w": float(tt.sd.item()),
            "t_w": float(tt.tvalue.item()),
            "p_w": float(tt.pvalue.item()),
            "ci_low_w": float(tt.conf_int().ravel()[0]),
            "ci_high_w": float(tt.conf_int().ravel()[1]),
            "n_rows_w": int((gm["model"] == m_name).sum()),
        })
    infer_w = pd.DataFrame(rows)

    return adj_cov_w.merge(infer_w, on="model", how="outer")

# -------------------------------
# Pretty printing / diagnostics
# -------------------------------

def print_kmeans_diagnostics(
    merged_df: pd.DataFrame,
    cluster_col: str = "cluster_kmeans",
    user_col: str = "user",
):
    """Print one-line diagnostics per question for KMEANS join coverage."""
    # total users per q in ratings
    total_users = (merged_df.drop_duplicates([user_col, "question_id"])
                             .groupby("question_id")[user_col].count())
    # joined users per q (have non-null cluster)
    joined_users = (merged_df.dropna(subset=[cluster_col])
                              .drop_duplicates([user_col, "question_id"])
                              .groupby("question_id")[user_col].count())
    # cluster counts per q
    clusters_per_q = (merged_df.dropna(subset=[cluster_col])
                              .groupby("question_id")[cluster_col]
                              .nunique())

    qs = sorted(set(merged_df["question_id"]))
    for q in qs:
        tot = int(total_users.get(q, 0))
        jnd = int(joined_users.get(q, 0))
        K = int(clusters_per_q.get(q, 0))
        print(f"[kmeans] q={q}  clusters={K:>2}  joined_users={jnd}/{tot}  K={K}")


# -------------------------------
# Main
# -------------------------------

def main():
    ap = argparse.ArgumentParser(description="Benchmark Overton with unadjusted/weighted scores and OLS.")
    ap.add_argument("--tau", type=float, default=4.0)
    ap.add_argument("--weighted", action="store_true",
                    help="Also compute cluster-size–weighted KMEANS Overton and weighted adjusted coverage (separate MD section).")
    ap.add_argument("--outdir", default="outputs")
    ap.add_argument("--data", default=None,
                    help="Path to your data CSV (same schema as OvertonBench). Default: load from Hugging Face. Overridden by DATASET in .env if set.")
    ap.add_argument("--cluster_col", default="cluster_kmeans",
                    help="Column name for the cluster labels in --data.")
    ap.add_argument("--emit-oc-per-question", action="store_true",
                    help="If set, also emit per-question OC CSVs for each method.")
    ap.add_argument("--upper-bound", dest="upper_bound", action="store_true",
                    help="Compute mixed-model upper bound per question (covered if ANY model meets τ for that cluster) and save summary.")
    ap.add_argument("--source", default=None,
                    help="Question source split when loading from HF: full (default), modelslant, or prism. Output filenames get _modelslant/_prism suffix when set.")
    args = ap.parse_args()

    source = (args.source or "full").strip().lower() if args.source else "full"
    def dataset_suffix(filename: str) -> str:
        if source in ("modelslant", "prism"):
            base, ext = filename.rsplit(".", 1) if "." in filename else (filename, "")
            return f"{base}_{source}.{ext}" if ext else f"{filename}_{source}"
        return filename

    outdir = Path(args.outdir)
    outdir.mkdir(parents=True, exist_ok=True)

    combined = None
    md_sections: List[str] = [f"# Overton Scores & OLS (τ = {args.tau:.2f})", ""]

    # Helper for CI formatting
    def fmt_ci(est, lo, hi):
        if pd.isna(est) or pd.isna(lo) or pd.isna(hi):
            return "NA"
        return f"{est:.3f} [{lo:.3f}, {hi:.3f}]"

    # -----------------------
    # KMEANS method
    # -----------------------
    df_k = get_overtonbench_data(args.data, source_split=args.source or "full")

    # Diagnostics
    print_kmeans_diagnostics(df_k, cluster_col=args.cluster_col, user_col="user")

    # Unadjusted Overton
    per_model_kmeans, oc_per_q_kmeans = compute_unadjusted_overton(
        df_k, cluster_col=args.cluster_col, tau=args.tau
    )
    per_model_kmeans = per_model_kmeans.rename(columns={
        "OvertonScore": "OvertonScore_kmeans",
        "total_clusters": "total_clusters_kmeans",
        "avg_clusters_per_q": "avg_clusters_per_q_kmeans",
    })

    # OLS adjusted coverage (unweighted)
    ols_kmeans = run_ols_adjusted_coverage(
        df_k, cluster_col=args.cluster_col, tau=args.tau
    ).rename(columns={
        "adj_coverage": "adj_coverage_kmeans",
        "coef_vs_grand_mean": "coef_vs_grand_mean_kmeans",
        "se": "se_kmeans",
        "t": "t_kmeans",
        "p": "p_kmeans",
        "ci_low": "ci_low_kmeans",
        "ci_high": "ci_high_kmeans",
        "n_rows": "n_rows_kmeans",
    })

    # Base per-model table (unweighted + OLS)
    kmeans_out = per_model_kmeans.merge(ols_kmeans, on="model", how="outer")

    # Add to combined
    combined = kmeans_out if combined is None else combined.merge(kmeans_out, on="model", how="outer")

    # Markdown: KMEANS (unweighted) — with CI column
    top = kmeans_out.copy()
    top["__adj_ci__"] = top.apply(
        lambda r: fmt_ci(r.get("adj_coverage_kmeans"), r.get("ci_low_kmeans"), r.get("ci_high_kmeans")),
        axis=1
    )
    # Keep a numeric key for sorting even though display is a string
    top["__sort_key__"] = top["adj_coverage_kmeans"]

    top = top.loc[:, ["model", "OvertonScore_kmeans", "__adj_ci__", "p_kmeans"]].rename(columns={
        "OvertonScore_kmeans": "OvertonScore (raw)",
        "__adj_ci__": "adj. coverage (95% CI)",
        "p_kmeans": "p (vs. grand mean)",
    })

    # Format numeric columns
    if "OvertonScore (raw)" in top.columns:
        top["OvertonScore (raw)"] = top["OvertonScore (raw)"].map(lambda x: f"{x:.3f}" if pd.notna(x) else "NA")
    if "p (vs. grand mean)" in top.columns:
        top["p (vs. grand mean)"] = top["p (vs. grand mean)"].map(lambda x: f"{x:.3g}" if pd.notna(x) else "NA")

    # Attach sort key for ordering, then drop before printing
    top["_sort"] = kmeans_out["adj_coverage_kmeans"].values
    top = top.sort_values("_sort", ascending=False, na_position="last").drop(columns=["_sort"])

    md_sections += [
        "## KMEANS",
        "",
        top.to_markdown(index=False),
        ""
    ]

    # -----------------------
    # KMEANS weighted (separate section + merged into CSV)
    # -----------------------
    if args.weighted:
        # Weighted Overton (cluster-size–weighted coverage per question)
        per_model_kmeans_w, oc_per_q_kmeans_w = compute_weighted_overton(
            df_k, cluster_col=args.cluster_col, tau=args.tau
        )
        per_model_kmeans_w = per_model_kmeans_w.rename(columns={
            "OvertonScore_w": "OvertonScore_kmeans_w",
        })

        # Weighted OLS (cluster-size weights within question)
        ols_kmeans_w = run_ols_adjusted_coverage_weighted(
            df_k, cluster_col=args.cluster_col, tau=args.tau
        ).rename(columns={
            "adj_coverage_w": "adj_coverage_kmeans_w",
            "coef_vs_grand_mean_w": "coef_vs_grand_mean_kmeans_w",
            "se_w": "se_kmeans_w",
            "t_w": "t_kmeans_w",
            "p_w": "p_kmeans_w",
            "ci_low_w": "ci_low_kmeans_w",
            "ci_high_w": "ci_high_kmeans_w",
            "n_rows_w": "n_rows_kmeans_w",
        })

        # Per-model weighted table
        kmeans_out_w = per_model_kmeans_w.merge(ols_kmeans_w, on="model", how="outer")

        # Merge weighted columns into combined CSV (so CSV has both)
        combined = combined.merge(kmeans_out_w, on="model", how="outer")

        # Markdown: KMEANS weighted (separate section) — with CI column
        top_w = kmeans_out_w.copy()
        top_w["__adj_ci__"] = top_w.apply(
            lambda r: fmt_ci(r.get("adj_coverage_kmeans_w"), r.get("ci_low_kmeans_w"), r.get("ci_high_kmeans_w")),
            axis=1
        )
        top_w = top_w.loc[:, ["model", "OvertonScore_kmeans_w", "__adj_ci__", "p_kmeans_w"]].rename(columns={
            "OvertonScore_kmeans_w": "OvertonScore (raw)",
            "__adj_ci__": "adj. coverage (95% CI)",
            "p_kmeans_w": "p (vs. grand mean)",
        })

        if "OvertonScore (raw)" in top_w.columns:
            top_w["OvertonScore (raw)"] = top_w["OvertonScore (raw)"].map(lambda x: f"{x:.3f}" if pd.notna(x) else "NA")
        if "p (vs. grand mean)" in top_w.columns:
            top_w["p (vs. grand mean)"] = top_w["p (vs. grand mean)"].map(lambda x: f"{x:.3g}" if pd.notna(x) else "NA")

        # Sort by numeric adjusted coverage (weighted) for display
        top_w["_sort"] = kmeans_out_w["adj_coverage_kmeans_w"].values
        top_w = top_w.sort_values("_sort", ascending=False, na_position="last").drop(columns=["_sort"])

        md_sections += [
            "## KMEANS weighted",
            "",
            top_w.to_markdown(index=False),
            ""
        ]

    # Mixed-model upper bound (per-question + summary)
    if args.upper_bound:
        ub_per_q_kmeans = compute_mixed_model_upper_bound(
            df_k, cluster_col=args.cluster_col, tau=args.tau
        ).sort_values("question_id").reset_index(drop=True)

        ub_path_q = str(outdir / dataset_suffix(f"mixed_model_upper_bound_kmeans_per_question_tau{args.tau}.csv"))
        ub_per_q_kmeans.to_csv(ub_path_q, index=False)

        # Overall UB across questions (mean of per-question OCs)
        ub_summary = pd.DataFrame({
            "UB_unweighted": [ub_per_q_kmeans["OC_unweighted"].mean()],
            "UB_weighted":   [ub_per_q_kmeans["OC_weighted"].mean()],
            "questions_count":[len(ub_per_q_kmeans)],
            "tau":           [args.tau],
        })
        ub_path_sum = str(outdir / dataset_suffix(f"mixed_model_upper_bound_kmeans_summary_tau{args.tau}.csv"))
        ub_summary.to_csv(ub_path_sum, index=False)

        print(f"[kmeans] wrote mixed-model UB per-question → {ub_path_q}")
        print(f"[kmeans] wrote mixed-model UB summary      → {ub_path_sum}")

    # Optional per-question OC emits (unweighted + weighted if requested)
    if args.emit_oc_per_question:
        # Unweighted per-question OC
        oc_path_kmeans = str(outdir / dataset_suffix(f"oc_per_question_kmeans_tau{args.tau}.csv"))
        oc_per_q_kmeans.to_csv(oc_path_kmeans, index=False)
        print(f"[kmeans] wrote per-question OC → {oc_path_kmeans}")

        # Weighted per-question OC (only if --weighted)
        if args.weighted:
            ocw_path = str(outdir / dataset_suffix(f"oc_per_question_kmeans_weighted_tau{args.tau}.csv"))
            oc_per_q_kmeans_w[["question_id", "model", "K", "weighted_cov", "weight_sum", "OC_w"]].to_csv(ocw_path, index=False)
            print(f"[kmeans] wrote per-question weighted OC → {ocw_path}")

    # -----------------------
    # Write outputs
    # -----------------------
    sort_keys = [c for c in combined.columns if c.startswith("adj_coverage")]
    sort_key = sort_keys[0] if sort_keys else "model"
    combined = combined.sort_values(by=sort_key, ascending=False, ignore_index=True)

    out_csv = str(outdir / dataset_suffix(f"overton_scores_and_ols_tau{args.tau}.csv"))
    out_md  = str(outdir / dataset_suffix(f"overton_scores_and_ols_tau{args.tau}.md"))
    combined.to_csv(out_csv, index=False)
    with open(out_md, "w") as f:
        f.write("\n".join(md_sections))

    print(f"[OK] wrote combined CSV  → {out_csv}")
    print(f"[OK] wrote markdown     → {out_md}")



if __name__ == "__main__":
    main()
