#!/usr/bin/env python3
"""
Quick analysis for human study responses.

Loads normalized participant responses and joins them with selection metadata
and bias/judge scores, then computes aggregated metrics and plots with seaborn.
Designed to keep data wrangling clean and configurable.
"""

import json
import math
import statsmodels.formula.api as smf
from dataclasses import dataclass
from pathlib import Path
from typing import Dict, List, Optional, Tuple

import matplotlib.pyplot as plt
import pandas as pd
import seaborn as sns
from visualization.vis_utilities import (
    EARTHY_COLORS,
    apply_nyt_style_to_axes,
    setup_nyt_style,
)


# ------------------
# Configurable knobs
# ------------------

RESPONSES_PATH = Path("human_study/final_proper/responses_normalized.jsonl")
SELECTION_PATH = Path("human_study/final/selection.jsonl")
OUTPUT_DIR = Path("human_study/final_proper/analysis")

# Trim extremes when aggregating per question/model (set to 0 to disable)
TRIM_TOP_K = 0
TRIM_BOTTOM_K = 0

# Optionally require a minimum number of participant judgments per item
ENFORCE_MIN_JUDGEMENTS = False
MIN_JUDGEMENTS = 3

# Column naming
HUMAN_BIAS_COL = "human_bias_score"

# Plot settings
STYLE = "whitegrid"
PALETTE = EARTHY_COLORS
FIGSIZE = (8, 5)
PLOT_AGGREGATED_BIAS = (
    True  # If True, also plot mean difference per bias score (one point per score)
)
MIRROR_DELTA_HIST_PARTICIPANT = (
    False  # If True, mirror participant delta histogram +/- for symmetry
)
# Toggle annotation counts on line plots (n=...). Enabled to show counts by default.
SHOW_COUNT_LABELS = True

# Round judge scores loaded from selection. Options: "down", "up", "leave".
ROUNDING_MODE = "down"
# Drop participants with extreme per-participant std on human bias scores (percent of participants).
STD_FILTER_TOP_PCT = 0.0  # remove this % with highest std
STD_FILTER_BOTTOM_PCT = 5.0  # remove this % with lowest std

# Attention check filtering
FILTER_ATTENTION_CHECK = True
ATTENTION_CHECK_ITEM_IDS = ["attention_check_gender_balanced"]
# Required difference value for attention check (exact match). Set to None to disable numeric check.
ATTENTION_REQUIRED_DIFFERENCE = 1.0


@dataclass
class CleanFrames:
    responses: pd.DataFrame
    selection: pd.DataFrame
    merged: pd.DataFrame


def read_jsonl(path: Path) -> List[Dict]:
    rows: List[Dict] = []
    with path.open("r", encoding="utf-8") as f:
        for line in f:
            line = line.strip()
            if not line:
                continue
            rows.append(json.loads(line))
    return rows


def _palette_for_levels(n: int):
    """Return a palette trimmed to the number of levels we actually have."""
    return sns.color_palette(PALETTE, n_colors=max(1, n))


def _text_offset(values: pd.Series, pct: float = 0.03, min_offset: float = 0.03) -> float:
    """Compute a small vertical offset for annotations to avoid overlapping lines."""
    numeric = pd.to_numeric(values, errors="coerce").dropna()
    if numeric.empty:
        return min_offset
    span = float(numeric.max() - numeric.min())
    offset = span * pct
    # Ensure we still move the label even when all values are equal
    if offset <= 0:
        offset = max(abs(float(numeric.max())), 1.0) * pct
    return max(offset, min_offset)


def _text_xy(x: float, y: float, x_offset: float, y_offset: float) -> Tuple[float, float]:
    """Position helper to nudge annotation relative to a point."""
    return x + x_offset, y + y_offset


def _save_fig(output_dir: Path, filename: str, pad_top: bool = False) -> None:
    """Standardized save with NYT styling, optional headroom, and despine."""
    ax = plt.gca()
    apply_nyt_style_to_axes(ax)
    title = ax.get_title()
    if title:
        ax.set_title(title, fontfamily="serif")
    if pad_top:
        y_min, y_max = ax.get_ylim()
        span = y_max - y_min
        pad = span * 0.12 if span > 0 else 0.5
        ax.set_ylim(y_min, y_max + pad)
    sns.despine(ax=ax)
    plt.tight_layout()
    output_dir.mkdir(parents=True, exist_ok=True)
    plt.savefig(output_dir / filename, dpi=200)
    plt.close()


def _per_item_stat(df: pd.DataFrame, stat: str) -> Optional[pd.DataFrame]:
    """Compute per-item statistic (max/median) on human bias scores."""
    required = {"item_id", "attribute", "bias_score", HUMAN_BIAS_COL}
    if not required.issubset(df.columns):
        return None
    grouped = df.groupby(["item_id", "attribute", "bias_score"])[HUMAN_BIAS_COL]
    if stat == "max":
        series = grouped.max().rename("item_max")
    elif stat == "median":
        series = grouped.median().rename("item_median")
    else:
        return None
    return series.reset_index()


def load_data(
    responses_path: Path, selection_path: Path, trim_top: int = 0, trim_bottom: int = 0
) -> CleanFrames:
    resp_df = pd.DataFrame(read_jsonl(responses_path))
    sel_df = pd.DataFrame(read_jsonl(selection_path))

    # Extract judge-side scores (bias/relevance/acknowledgement) from bias_judgements list.
    def _extract_judge_score(row: Dict, key: str) -> Optional[float]:
        judgments = row.get("bias_judgements") or []
        vals: List[float] = []
        for j in judgments:
            val = j.get(key)
            if isinstance(val, (int, float)):
                vals.append(float(val))
        if not vals:
            return None
        return sum(vals) / len(vals)

    for score_key in ["bias_score", "relevance_score", "acknowledgement_score"]:
        if score_key not in sel_df.columns:
            sel_df[score_key] = sel_df.apply(
                lambda r, k=score_key: _extract_judge_score(r, k), axis=1
            )

    # Apply configured rounding to judge-side scores when present.
    def _round_value(val: Optional[float]) -> Optional[float]:
        if val is None or (isinstance(val, float) and math.isnan(val)):
            return val
        if ROUNDING_MODE == "leave":
            return val
        if ROUNDING_MODE == "down":
            return float(math.floor(val))
        if ROUNDING_MODE == "up":
            return float(math.ceil(val))
        raise ValueError(f"Unknown ROUNDING_MODE: {ROUNDING_MODE}")

    for score_col in ["bias_score", "relevance_score", "acknowledgement_score"]:
        if score_col in sel_df.columns:
            sel_df[score_col] = sel_df[score_col].apply(_round_value)

    # Rename difference -> human_bias_score for clarity
    rename_map = {
        "difference": HUMAN_BIAS_COL,
        "difference_raw": f"{HUMAN_BIAS_COL}_raw",
    }
    resp_df = resp_df.rename(columns=rename_map)

    # Normalize dtypes
    numeric_cols = [
        HUMAN_BIAS_COL,
        "relevance",
        "acknowledgement",
    ]
    for col in numeric_cols:
        raw_col = f"{col}_raw"
        if raw_col in resp_df.columns:
            resp_df[col] = pd.to_numeric(resp_df[col], errors="coerce")
    resp_df["participant_id"] = resp_df["participant_id"].astype(str)

    sel_df = sel_df.rename(columns={"id": "item_id"})
    sel_df["item_id"] = sel_df["item_id"].astype(str)

    merged = resp_df.merge(sel_df, on="item_id", how="left", suffixes=("", "_sel"))

    if FILTER_ATTENTION_CHECK and ATTENTION_CHECK_ITEM_IDS:
        attn_rows = merged[merged["item_id"].isin(ATTENTION_CHECK_ITEM_IDS)]
        passed_participants = set()
        failed_participants = set()
        for pid, grp in attn_rows.groupby("participant_id"):
            diffs = pd.to_numeric(grp[HUMAN_BIAS_COL], errors="coerce").dropna()
            if diffs.empty:
                print(f"Participant {pid} has no valid attention check responses")
                failed_participants.add(pid)
                continue
            if ATTENTION_REQUIRED_DIFFERENCE is not None:
                # Absolute greater
                if (abs(diffs - ATTENTION_REQUIRED_DIFFERENCE) > 0).any():
                    print(f"Participant {pid} failed attention check (diffs: {diffs.tolist()})")
                    failed_participants.add(pid)
                    continue
            passed_participants.add(pid)
        merged = merged[merged["participant_id"].isin(passed_participants)]
        resp_df = resp_df[resp_df["participant_id"].isin(passed_participants)]
        print(
            f"Attention check: kept {len(passed_participants)} participants, "
            f"failed {len(failed_participants)}"
        )

    # Remove the attention check items from the main data
    if FILTER_ATTENTION_CHECK and ATTENTION_CHECK_ITEM_IDS:
        merged = merged[~merged["item_id"].isin(ATTENTION_CHECK_ITEM_IDS)]
        resp_df = resp_df[~resp_df["item_id"].isin(ATTENTION_CHECK_ITEM_IDS)]

    # Drop participants with extreme std on human bias scores, if configured
    resp_df, merged = filter_participants_by_std(resp_df, merged, HUMAN_BIAS_COL)

    # Optionally filter items to those with at least MIN_JUDGEMENTS participant ratings
    if ENFORCE_MIN_JUDGEMENTS:
        counts = merged.groupby("item_id")["participant_id"].nunique()
        keep_items = set(counts[counts >= MIN_JUDGEMENTS].index.astype(str))
        dropped = set(counts.index.astype(str)) - keep_items
        merged = merged[merged["item_id"].isin(keep_items)]
        resp_df = resp_df[resp_df["item_id"].isin(keep_items)]
        print(
            f"Min-judgements filter ({MIN_JUDGEMENTS}+): kept {len(keep_items)} items, "
            f"dropped {len(dropped)} items"
        )

    # Optional trimming: drop top/bottom per item using human bias score as the key
    if trim_top > 0 or trim_bottom > 0:
        keep_indices: List[int] = []
        for _, grp in merged.groupby("item_id"):
            if HUMAN_BIAS_COL not in grp.columns:
                continue
            scores = grp[HUMAN_BIAS_COL].dropna().sort_values()

            print(len(scores))
            if scores.empty:
                continue
            if len(scores) <= (trim_bottom + trim_top):
                keep_indices.extend(scores.index.tolist())
                continue
            trimmed = scores.iloc[trim_bottom : len(scores) - trim_top]
            keep_indices.extend(trimmed.index.tolist())
        merged = merged.loc[keep_indices] if keep_indices else merged

    # Drop the columns question_template and question_text if present
    for col in [
        "question_template",
        "question_text",
        "bias_judgements",
        "summaries",
        "differences",
        "source_file",
    ]:
        if col in resp_df.columns:
            resp_df = resp_df.drop(columns=[col])
        if col in merged.columns:
            merged = merged.drop(columns=[col])

    return CleanFrames(resp_df, sel_df, merged)


def aggregate_by_item(df: pd.DataFrame, metrics: Optional[List[str]] = None) -> pd.DataFrame:
    metrics = metrics or [HUMAN_BIAS_COL, "relevance", "acknowledgement"]
    group_keys = ["item_id", "question_id", "model_id", "attribute"]
    for judge_col in ["bias_score", "relevance_score", "acknowledgement_score"]:
        if judge_col in df.columns:
            group_keys.append(judge_col)
    agg = df.groupby(group_keys, dropna=False)[metrics].mean().reset_index()
    return agg


def plot_bias_bin_distributions(df: pd.DataFrame, metric: str, output_dir: Path) -> None:
    palette = _palette_for_levels(df["bias_bin"].nunique())
    plt.figure(figsize=FIGSIZE)
    sns.boxplot(
        data=df, x="bias_bin", y=metric, hue="bias_bin", palette=palette, dodge=False, legend=False
    )
    sns.stripplot(data=df, x="bias_bin", y=metric, color="black", alpha=0.3, jitter=True)
    plt.title(f"{metric.title()} by bias bin")
    _save_fig(output_dir, f"{metric}_by_bias_bin.pdf")


def plot_metric_vs_judge(df: pd.DataFrame, metric: str, judge_col: str, output_dir: Path) -> None:
    palette = _palette_for_levels(df["attribute"].nunique())
    plt.figure(figsize=FIGSIZE)
    sns.scatterplot(data=df, x=judge_col, y=metric, hue="attribute", palette=palette)
    sns.regplot(data=df, x=judge_col, y=metric, scatter=False, color="gray")
    plt.title(f"{metric.title()} vs judge {judge_col}")
    _save_fig(output_dir, f"{metric}_vs_{judge_col}.pdf")


def plot_difference_vs_bias(df: pd.DataFrame, output_dir: Path) -> None:
    """Scatter of avg participant human bias score per item vs original bias_score."""
    palette = _palette_for_levels(df["attribute"].nunique())
    plt.figure(figsize=FIGSIZE)
    sns.scatterplot(data=df, x="bias_score", y=HUMAN_BIAS_COL, hue="attribute", palette=palette)
    sns.regplot(data=df, x="bias_score", y=HUMAN_BIAS_COL, scatter=False, color="gray")
    plt.xlabel("Judge bias score")
    plt.ylabel("Avg participant human bias score")
    plt.title("Participant bias vs judge bias")
    _save_fig(output_dir, "difference_vs_bias_score.pdf")


def plot_difference_vs_bias_mean(df: pd.DataFrame, output_dir: Path) -> None:
    """Mean human bias score per bias score (one point per score) for clarity."""
    agg = (
        df.dropna(subset=["bias_score", HUMAN_BIAS_COL])
        .groupby("bias_score")[HUMAN_BIAS_COL]
        .agg(["mean", "count", "std"])
        .reset_index()
        .sort_values("bias_score")
    )
    plt.figure(figsize=FIGSIZE)
    line = sns.lineplot(
        data=agg, x="bias_score", y="mean", marker="o", color="C0", label="Mean", legend=False
    )
    errs = agg["std"] / agg["count"].pow(0.5)
    plt.errorbar(
        agg["bias_score"],
        agg["mean"],
        yerr=errs,
        fmt="none",
        ecolor="C0",
        capsize=4,
        alpha=0.7,
    )
    if SHOW_COUNT_LABELS:
        offset = _text_offset(agg["mean"])
        x_offset = 0.04
        for _, row in agg.iterrows():
            tx, ty = _text_xy(row["bias_score"], row["mean"], x_offset, -offset)
            plt.text(
                tx,
                ty,
                f"n={int(row['count'])}",
                ha="left",
                va="top",
                fontsize=8,
                color="C0",
            )
    plt.xlabel("Judge bias score")
    plt.ylabel("Mean participant human bias score")
    plt.title("Mean participant bias by judge bias")
    _save_fig(output_dir, "difference_vs_bias_score_mean.pdf", pad_top=True)


def plot_difference_vs_bias_mean_by_attr(df: pd.DataFrame, output_dir: Path) -> None:
    """Mean human bias score per bias score, one line per attribute."""
    agg = (
        df.dropna(subset=["bias_score", HUMAN_BIAS_COL])
        .groupby(["attribute", "bias_score"])[HUMAN_BIAS_COL]
        .agg(["mean", "count", "std"])
        .reset_index()
        .sort_values(["attribute", "bias_score"])
    )
    palette = _palette_for_levels(agg["attribute"].nunique())
    plt.figure(figsize=FIGSIZE)
    sns.lineplot(data=agg, x="bias_score", y="mean", hue="attribute", marker="o", palette=palette)
    attr_colors = {
        attr: palette[i % len(palette)] for i, attr in enumerate(agg["attribute"].unique())
    }
    for attr, grp in agg.groupby("attribute"):
        errs = grp["std"] / grp["count"].pow(0.5)
        plt.errorbar(
            grp["bias_score"],
            grp["mean"],
            yerr=errs,
            fmt="none",
            ecolor=attr_colors.get(attr, "gray"),
            capsize=4,
            alpha=0.7,
        )
    if SHOW_COUNT_LABELS:
        offset = _text_offset(agg["mean"])
        x_offset = 0.04
        for _, row in agg.iterrows():
            color = attr_colors.get(row["attribute"], "black")
            tx, ty = _text_xy(row["bias_score"], row["mean"], x_offset, -offset)
            plt.text(
                tx,
                ty,
                f"n={int(row['count'])}",
                ha="left",
                va="top",
                fontsize=8,
                color=color,
            )
    plt.xlabel("Judge bias score")
    plt.ylabel("Mean participant human bias score")
    plt.title("Mean participant bias by judge bias, by attribute")
    _save_fig(output_dir, "difference_vs_bias_score_mean_by_attribute.pdf", pad_top=True)


def plot_difference_vs_bias_median(df: pd.DataFrame, output_dir: Path) -> None:
    """Mean of per-item median human bias scores per Judge bias score."""
    per_item = _per_item_stat(df, "median")
    if per_item is None or per_item.empty:
        return
    agg = (
        per_item.groupby("bias_score")["item_median"]
        .agg(["mean", "std", "count"])
        .reset_index()
        .sort_values("bias_score")
    )
    plt.figure(figsize=FIGSIZE)
    sns.lineplot(
        data=agg, x="bias_score", y="mean", marker="o", color="C1", label="Mean", legend=False
    )
    errs = agg["std"] / agg["count"].pow(0.5)
    plt.errorbar(
        agg["bias_score"],
        agg["mean"],
        yerr=errs,
        fmt="none",
        ecolor="C1",
        capsize=4,
        alpha=0.7,
    )
    if SHOW_COUNT_LABELS:
        offset = _text_offset(agg["mean"])
        x_offset = 0.04
        for _, row in agg.iterrows():
            tx, ty = _text_xy(row["bias_score"], row["mean"], x_offset, -offset)
            plt.text(
                tx,
                ty,
                f"n={int(row['count'])}",
                ha="left",
                va="top",
                fontsize=8,
                color="C1",
            )
    plt.xlabel("Judge bias score")
    plt.ylabel("Mean of item median human bias score")
    plt.title("Mean item median bias by judge bias")
    _save_fig(output_dir, "difference_vs_bias_score_median.pdf", pad_top=True)


def plot_difference_vs_bias_median_by_attr(df: pd.DataFrame, output_dir: Path) -> None:
    """Mean of per-item median human bias scores per bias score, one line per attribute."""
    per_item = _per_item_stat(df, "median")
    if per_item is None or per_item.empty:
        return
    agg = (
        per_item.groupby(["attribute", "bias_score"])["item_median"]
        .agg(["mean", "std", "count"])
        .reset_index()
        .sort_values(["attribute", "bias_score"])
    )
    palette = _palette_for_levels(agg["attribute"].nunique())
    plt.figure(figsize=FIGSIZE)
    sns.lineplot(data=agg, x="bias_score", y="mean", hue="attribute", marker="o", palette=palette)
    attr_colors = {
        attr: palette[i % len(palette)] for i, attr in enumerate(agg["attribute"].unique())
    }
    for attr, grp in agg.groupby("attribute"):
        grp_errs = grp["std"] / grp["count"].pow(0.5)
        plt.errorbar(
            grp["bias_score"],
            grp["mean"],
            yerr=grp_errs,
            fmt="none",
            ecolor=attr_colors.get(attr, "gray"),
            capsize=4,
            alpha=0.7,
        )
    if SHOW_COUNT_LABELS:
        offset = _text_offset(agg["mean"])
        x_offset = 0.04
        for _, row in agg.iterrows():
            color = attr_colors.get(row["attribute"], "black")
            tx, ty = _text_xy(row["bias_score"], row["mean"], x_offset, -offset)
            plt.text(
                tx,
                ty,
                f"n={int(row['count'])}",
                ha="left",
                va="top",
                fontsize=8,
                color=color,
            )
    plt.xlabel("Judge bias score")
    plt.ylabel("Mean of item median human bias score")
    plt.title("Mean item median bias by judge bias, by attribute")
    _save_fig(output_dir, "difference_vs_bias_score_median_by_attribute.pdf", pad_top=True)


def plot_difference_vs_bias_max(df: pd.DataFrame, output_dir: Path) -> None:
    """Max human bias score per bias score."""
    per_item = _per_item_stat(df, "max")
    if per_item is None or per_item.empty:
        return
    agg = (
        per_item.groupby("bias_score")["item_max"]
        .agg(["mean", "count", "std"])
        .reset_index()
        .sort_values("bias_score")
    )
    plt.figure(figsize=FIGSIZE)
    sns.lineplot(data=agg, x="bias_score", y="mean", marker="s", color="C2", legend=False)
    errs = agg["std"] / agg["count"].pow(0.5)
    plt.errorbar(
        agg["bias_score"],
        agg["mean"],
        yerr=errs,
        fmt="none",
        ecolor="C2",
        capsize=4,
        alpha=0.7,
    )
    if SHOW_COUNT_LABELS:
        offset = _text_offset(agg["mean"])
        x_offset = 0.04
        for _, row in agg.iterrows():
            tx, ty = _text_xy(row["bias_score"], row["mean"], x_offset, -offset)
            plt.text(
                tx,
                ty,
                f"n={int(row['count'])}",
                ha="left",
                va="top",
                fontsize=8,
                color="C2",
            )
    plt.xlabel("Judge bias score")
    plt.ylabel("Mean of max human bias score")
    plt.title("Mean item max bias by judge bias")
    _save_fig(output_dir, "difference_vs_bias_score_max.pdf", pad_top=True)


def plot_difference_vs_bias_max_by_attr(df: pd.DataFrame, output_dir: Path) -> None:
    """Mean of per-item max human bias score per bias score, one line per attribute."""
    per_item = _per_item_stat(df, "max")
    if per_item is None or per_item.empty:
        return
    agg = (
        per_item.groupby(["attribute", "bias_score"])["item_max"]
        .agg(["mean", "count", "std"])
        .reset_index()
        .sort_values(["attribute", "bias_score"])
    )
    palette = _palette_for_levels(agg["attribute"].nunique())
    plt.figure(figsize=FIGSIZE)
    sns.lineplot(data=agg, x="bias_score", y="mean", hue="attribute", marker="s", palette=palette)
    attr_colors = {
        attr: palette[i % len(palette)] for i, attr in enumerate(agg["attribute"].unique())
    }
    for attr, grp in agg.groupby("attribute"):
        plt.plot(
            grp["bias_score"],
            grp["mean"],
            linestyle="--",
            marker="s",
            color=attr_colors.get(attr, "gray"),
            alpha=0.8,
        )
        grp_errs = grp["std"] / grp["count"].pow(0.5)
        plt.errorbar(
            grp["bias_score"],
            grp["mean"],
            yerr=grp_errs,
            fmt="none",
            ecolor=attr_colors.get(attr, "gray"),
            capsize=4,
            alpha=0.7,
        )
    if SHOW_COUNT_LABELS:
        offset = _text_offset(agg["mean"])
        x_offset = 0.04
        for _, row in agg.iterrows():
            color = attr_colors.get(row["attribute"], "black")
            tx, ty = _text_xy(row["bias_score"], row["mean"], x_offset, -offset)
            plt.text(
                tx,
                ty,
                f"n={int(row['count'])}",
                ha="left",
                va="top",
                fontsize=8,
                color=color,
            )
    plt.xlabel("Judge bias score")
    plt.ylabel("Mean of item max human bias score")
    plt.title("Mean item max bias by judge bias, by attribute")
    _save_fig(output_dir, "difference_vs_bias_score_max_by_attribute.pdf", pad_top=True)


def plot_difference_delta_hist(
    df: pd.DataFrame, output_dir: Path, label: str, mirror: bool = False
) -> None:
    """Histogram of difference minus judge bias score."""
    d = df.dropna(subset=[HUMAN_BIAS_COL, "bias_score"]).copy()
    if d.empty:
        return
    d["delta"] = d[HUMAN_BIAS_COL] - d["bias_score"]
    deltas = d["delta"]
    stats = deltas.describe()
    print(
        f"Delta stats [{label}]: count={stats['count']:.0f}, mean={stats['mean']:.3f}, "
        f"std={stats['std']:.3f}, min={stats['min']:.3f}, median={stats['50%']:.3f}, "
        f"max={stats['max']:.3f}"
    )
    plot_series = pd.concat([deltas.abs(), -deltas.abs()]) if mirror else deltas
    plt.figure(figsize=FIGSIZE)
    sns.histplot(plot_series, bins=20, kde=False, color="C0")
    plt.axvline(0, color="red", linestyle="--", linewidth=1)
    plt.xlabel(r"$\Delta$ (Participant - Judge)")
    plt.ylabel("Count" if mirror else "Count")
    plt.title("")
    _save_fig(output_dir, f"difference_delta_hist_{label}.pdf")


def plot_participant_metric_hist(
    df: pd.DataFrame, metric: str, output_dir: Path, bins: int = 20
) -> None:
    """Simple histogram of participant scores for a metric."""
    if metric not in df.columns:
        return
    series = pd.to_numeric(df[metric], errors="coerce").dropna()
    if series.empty:
        return
    plt.figure(figsize=FIGSIZE)
    sns.histplot(series, bins=bins, kde=False, color="C1")
    plt.xlabel(f"{metric.title()} score")
    plt.ylabel("Count")
    plt.title(f"{metric.title()} distribution (participants)")
    _save_fig(output_dir, f"{metric}_participant_hist.pdf")


def compute_at_least_one_scores(df: pd.DataFrame) -> None:
    """
    Print percent of items where at least one participant score >= judge score
    for bias, relevance, and acknowledgement.
    """
    pairs = [
        (HUMAN_BIAS_COL, "bias_score", "Bias"),
        ("relevance", "relevance_score", "Relevance"),
        ("acknowledgement", "acknowledgement_score", "Acknowledgement"),
    ]
    print(
        "\n[At-least-one >= judge] Percent of items with any participant meeting/exceeding judge:"
    )
    for metric, judge_col, label in pairs:
        if metric not in df.columns or judge_col not in df.columns:
            print(f"  {label}: n/a (missing columns)")
            continue
        d = df.dropna(subset=["item_id", metric, judge_col]).copy()
        if d.empty:
            print(f"  {label}: n/a (no data)")
            continue
        grouped = d.groupby("item_id").agg(
            judge=(judge_col, "first"), participant_max=(metric, "max")
        )
        eligible = grouped.dropna(subset=["judge", "participant_max"])
        if eligible.empty:
            print(f"  {label}: n/a (no eligible items)")
            continue
        meets = (eligible["participant_max"] >= eligible["judge"]).mean() * 100.0
        print(f"  {label}: {meets:.1f}% (items: {len(eligible)})")


def plot_abs_deviation_from_mean(df: pd.DataFrame, output_dir: Path) -> None:
    """Histogram of abs deviation from per-item mean human bias score (one per participant per question)."""
    d = df.dropna(subset=[HUMAN_BIAS_COL, "item_id"]).copy()
    if d.empty:
        return
    item_means = d.groupby("item_id")[HUMAN_BIAS_COL].mean().rename("item_mean")
    d = d.merge(item_means, on="item_id", how="left")
    d["abs_dev"] = (d[HUMAN_BIAS_COL] - d["item_mean"]).abs()
    plt.figure(figsize=FIGSIZE)
    sns.histplot(d["abs_dev"], bins=20, kde=False, color="C2")
    plt.xlabel("Absolute deviation from item mean human bias score")
    plt.ylabel("Count")
    plt.title("Absolute deviation from item mean (participant bias)")
    _save_fig(output_dir, "human_bias_abs_dev_hist.pdf")


def plot_deviation_from_mean(df: pd.DataFrame, output_dir: Path) -> None:
    """Histogram of signed deviation from per-item mean human bias score."""
    d = df.dropna(subset=[HUMAN_BIAS_COL, "item_id"]).copy()
    if d.empty:
        return
    item_means = d.groupby("item_id")[HUMAN_BIAS_COL].mean().rename("item_mean")
    d = d.merge(item_means, on="item_id", how="left")
    d["dev"] = d[HUMAN_BIAS_COL] - d["item_mean"]
    plt.figure(figsize=FIGSIZE)
    sns.histplot(d["dev"], bins=20, kde=False, color="C3")
    plt.axvline(0, color="red", linestyle="--", linewidth=1)
    plt.xlabel("Deviation from item mean human bias score")
    plt.ylabel("Count")
    plt.title("Deviation from item mean (participant bias)")
    _save_fig(output_dir, "human_bias_dev_hist.pdf")


def plot_pairwise_differences(df: pd.DataFrame, output_dir: Path) -> None:
    """Histogram of pairwise absolute differences between participants per item (n choose 2)."""
    d = df.dropna(subset=[HUMAN_BIAS_COL, "item_id"]).copy()
    if d.empty:
        return
    deltas = []
    per_item_counts = []
    for _, grp in d.groupby("item_id"):
        scores = grp[HUMAN_BIAS_COL].dropna().tolist()
        n = len(scores)
        per_item_counts.append(n)
        if n < 2:
            continue
        for i in range(n):
            for j in range(i + 1, n):
                deltas.append(abs(scores[i] - scores[j]))
    if not deltas:
        return
    question_count = d["item_id"].nunique()
    pair_count = len(deltas)
    if per_item_counts:
        n_series = pd.Series(per_item_counts)
        expected_pairs = int((n_series * (n_series - 1) / 2).sum())
        lt2 = (n_series < 2).sum()
        print(
            f"Pairwise diffs: questions={question_count}, pairs={pair_count} "
            f"(expected={expected_pairs}); questions with <2 answers={lt2}; "
            f"participants per item (min/median/mean/max)="
            f"{n_series.min()}/{n_series.median():.1f}/{n_series.mean():.1f}/{n_series.max()}"
        )
    plt.figure(figsize=FIGSIZE)
    sns.histplot(deltas, bins=20, kde=False, color="C4")
    plt.xlabel("Absolute difference between participants (per item)")
    plt.ylabel("Count")
    plt.title("Pairwise participant differences")
    _save_fig(output_dir, "human_bias_pairwise_diff_hist.pdf")


def plot_pairwise_differences_mirrored(df: pd.DataFrame, output_dir: Path) -> None:
    """Histogram of pairwise differences mirrored +/- (each abs value counted twice)."""
    d = df.dropna(subset=[HUMAN_BIAS_COL, "item_id"]).copy()
    if d.empty:
        return
    deltas = []
    for _, grp in d.groupby("item_id"):
        scores = grp[HUMAN_BIAS_COL].dropna().tolist()
        n = len(scores)
        if n < 2:
            continue
        for i in range(n):
            for j in range(i + 1, n):
                diff = abs(scores[i] - scores[j])
                deltas.append(diff)
    if not deltas:
        return
    mirrored = pd.concat([pd.Series(deltas), -pd.Series(deltas)])
    plt.figure(figsize=FIGSIZE)
    sns.histplot(mirrored, bins=20, kde=False, color="C6")
    plt.axvline(0, color="red", linestyle="--", linewidth=1)
    plt.xlabel("Difference")
    plt.ylabel("Count")
    plt.title("Paired participant differences")
    _save_fig(output_dir, "human_bias_pairwise_diff_hist_mirrored.pdf")


def plot_participant_mean_std(df: pd.DataFrame, metric: str, output_dir: Path) -> None:
    """Scatter of per-participant mean vs std for a metric, to spot flat responders."""
    d = df.dropna(subset=[metric, "participant_id"]).copy()
    if d.empty:
        return
    stats = (
        d.groupby("participant_id")[metric]
        .agg(mean="mean", std="std", count="count")
        .reset_index()
        .sort_values("mean")
    )
    plt.figure(figsize=FIGSIZE)
    sns.scatterplot(data=stats, x="mean", y="std", size="count", sizes=(20, 200), legend=False)
    plt.xlabel(f"{metric.title()} mean per participant")
    plt.ylabel(f"{metric.title()} std per participant")
    plt.title(f"Participant mean vs std ({metric})")
    _save_fig(output_dir, f"participant_mean_std_{metric}.pdf")


def plot_metric_vs_judge_box(
    df: pd.DataFrame, metric: str, judge_col: str, output_dir: Path
) -> None:
    d = df.dropna(subset=[metric, judge_col])
    if d.empty:
        return
    palette = _palette_for_levels(d[judge_col].nunique())
    plt.figure(figsize=FIGSIZE)
    sns.boxplot(
        data=d, x=judge_col, y=metric, hue=judge_col, dodge=False, legend=False, palette=palette
    )
    plt.xlabel(f"Judge {judge_col}")
    plt.ylabel(f"{metric.title()} rating")
    plt.title(f"Participant {metric.title()} vs judge {judge_col}")
    _save_fig(output_dir, f"{metric}_vs_{judge_col}_box.pdf")


def plot_metric_vs_judge_line(
    df: pd.DataFrame,
    metric: str,
    judge_col: str,
    output_dir: Path,
    agg_func: str = "mean",
    suffix: str = "",
) -> None:
    d = df.dropna(subset=[metric, judge_col])
    if d.empty:
        return
    agg = (
        d.groupby(["attribute", judge_col], dropna=False)[metric]
        .agg(["mean", "median", "count", "std"])
        .reset_index()
        .sort_values(judge_col)
    )
    y_col = "mean" if agg_func == "mean" else "median"
    palette = _palette_for_levels(agg["attribute"].nunique())
    plt.figure(figsize=FIGSIZE)
    sns.lineplot(data=agg, x=judge_col, y=y_col, hue="attribute", marker="o", palette=palette)
    attr_colors = {
        attr: palette[i % len(palette)] for i, attr in enumerate(agg["attribute"].unique())
    }
    for attr, grp in agg.groupby("attribute"):
        grp_errs = grp["std"] / grp["count"].pow(0.5)
        plt.errorbar(
            grp[judge_col],
            grp[y_col],
            yerr=grp_errs,
            fmt="none",
            ecolor=attr_colors.get(attr, "gray"),
            capsize=4,
            alpha=0.7,
        )
    plt.xlabel(f"Judge {judge_col}")
    plt.ylabel(f"{agg_func.title()} {metric} rating")
    plt.title(f"Participant {metric.title()} vs judge {judge_col} ({agg_func})")
    _save_fig(output_dir, f"{metric}_vs_{judge_col}_{agg_func}{suffix}.pdf")


def plot_metric_vs_judge_line_max(
    df: pd.DataFrame,
    metric: str,
    judge_col: str,
    output_dir: Path,
) -> None:
    d = df.dropna(subset=[metric, judge_col])
    if d.empty:
        return
    agg = (
        d.groupby(["attribute", judge_col], dropna=False)[metric]
        .agg(["max", "std", "count"])
        .reset_index()
        .sort_values(judge_col)
    )
    palette = _palette_for_levels(agg["attribute"].nunique())
    plt.figure(figsize=FIGSIZE)
    sns.lineplot(data=agg, x=judge_col, y="max", hue="attribute", marker="s", palette=palette)
    attr_colors = {
        attr: palette[i % len(palette)] for i, attr in enumerate(agg["attribute"].unique())
    }
    for attr, grp in agg.groupby("attribute"):
        grp_errs = grp["std"] / grp["count"].pow(0.5)
        plt.errorbar(
            grp[judge_col],
            grp["max"],
            yerr=grp_errs,
            fmt="none",
            ecolor=attr_colors.get(attr, "gray"),
            capsize=4,
            alpha=0.7,
        )
        plt.plot(
            grp[judge_col],
            grp["max"],
            linestyle="--",
            marker="s",
            color=attr_colors.get(attr, "gray"),
            alpha=0.8,
        )
    plt.xlabel(f"Judge {judge_col}")
    plt.ylabel(f"Max {metric} rating")
    plt.title(f"Max participant {metric.title()} vs judge {judge_col}")
    _save_fig(output_dir, f"{metric}_vs_{judge_col}_max.pdf")


def plot_metric_by_attribute(
    df: pd.DataFrame, metric: str, output_dir: Path, agg_func: str = "mean"
) -> None:
    d = df.dropna(subset=[metric])
    if d.empty:
        return
    agg = (
        d.groupby("attribute", dropna=False)[metric]
        .agg(agg_func)
        .reset_index()
        .sort_values("attribute")
    )
    palette = _palette_for_levels(agg["attribute"].nunique())
    plt.figure(figsize=FIGSIZE)
    sns.barplot(
        data=agg,
        x="attribute",
        y=metric,
        hue="attribute",
        dodge=False,
        legend=False,
        palette=palette,
    )
    plt.xlabel("Attribute")
    plt.ylabel(f"{agg_func.title()} {metric} rating")
    plt.title(f"{metric.title()} by attribute ({agg_func})")
    plt.xticks(rotation=45, ha="right")
    _save_fig(output_dir, f"{metric}_{agg_func}_by_attribute.pdf")


def plot_responses_per_item(df: pd.DataFrame, output_dir: Path, metric: str) -> None:
    """Histogram of how many participant responses each item received (after filtering)."""
    d = df.dropna(subset=[metric, "item_id", "participant_id"]).copy()
    if d.empty:
        return
    counts = (
        d.groupby("item_id")["participant_id"]
        .nunique()
        .rename("response_count")
        .reset_index()
        .sort_values("response_count")
    )
    stats = counts["response_count"].describe()
    print(
        f"Responses per item: count={stats['count']:.0f}, "
        f"min/median/mean/max={stats['min']:.0f}/{stats['50%']:.0f}/{stats['mean']:.1f}/{stats['max']:.0f}"
    )
    plt.figure(figsize=FIGSIZE)
    sns.histplot(counts["response_count"], bins=range(1, int(stats["max"]) + 2), color="C5")
    plt.xlabel("Number of responses per item")
    plt.ylabel("Count of items")
    plt.title("Responses per item")
    _save_fig(output_dir, "responses_per_item_hist.pdf")


def write_attribute_bias_table(
    df: pd.DataFrame,
    output_path: Path,
    bias_col: str = "bias_score",
    attr_col: str = "attribute",
    bias_values: Optional[List[int]] = None,
) -> None:
    """
    Write counts per attribute and bias value to a small LaTeX table.

    Output table has columns: Attribute | Total | Bias 1..Bias N
    """
    if bias_values is None:
        bias_values = [1, 2, 3, 4, 5]
    d = df.dropna(subset=[bias_col, attr_col]).copy()
    if d.empty:
        print("[Table] No rows available for attribute/bias table.")
        return
    d[bias_col] = pd.to_numeric(d[bias_col], errors="coerce")
    d = d[d[bias_col].isin(bias_values)]
    if d.empty:
        print("[Table] No rows with requested bias values for attribute/bias table.")
        return
    d[attr_col] = d[attr_col].astype(str)
    d[bias_col] = d[bias_col].astype(int)

    table = pd.crosstab(d[attr_col], d[bias_col]).reindex(columns=bias_values, fill_value=0)
    table["Total"] = table.sum(axis=1)
    table = table[["Total"] + bias_values]

    # Build LaTeX table string
    header_cols = ["Attribute", "Total"] + [f"Bias {v}" for v in bias_values]
    lines = []
    lines.append("\\begin{table}[t]")
    lines.append("\\centering")
    lines.append("\\caption{Counts by attribute and bias score}")
    lines.append("\\vspace{-2mm}")
    lines.append("\\label{tab:human_study_stats}")
    lines.append("\\small")
    lines.append("\\begin{tabular}{" + "l" + "r" * (len(bias_values) + 1) + "}")
    lines.append("\\toprule")
    lines.append(" & ".join([f"\\textbf{{{c}}}" for c in header_cols]) + " \\\\")
    lines.append("\\midrule")
    for attr, row in table.sort_index().iterrows():
        counts = " & ".join(str(int(row[col])) for col in table.columns)
        lines.append(f"{attr} & {counts} \\\\")
    lines.append("\\bottomrule")
    lines.append("\\end{tabular}")
    lines.append("\\end{table}")

    output_path.parent.mkdir(parents=True, exist_ok=True)
    output_path.write_text("\n".join(lines), encoding="utf-8")
    print(f"[Table] Wrote attribute-bias table to {output_path}")


def assign_bias_bins(df: pd.DataFrame, bins: Optional[List[float]] = None) -> pd.DataFrame:
    df = df.copy()
    if bins:
        labels = [f"{bins[i]}-{bins[i + 1]}" for i in range(len(bins) - 1)]
        df["bias_bin"] = pd.cut(df["bias_score"], bins=bins, labels=labels, include_lowest=True)
    else:
        unique_vals = sorted(df["bias_score"].dropna().unique().tolist())
        df["bias_bin"] = pd.Categorical(df["bias_score"], categories=unique_vals, ordered=True)
    return df


def filter_participants_by_std(
    resp_df: pd.DataFrame, merged_df: pd.DataFrame, metric: str
) -> Tuple[pd.DataFrame, pd.DataFrame]:
    """Remove participants with highest/lowest std for the given metric."""
    if (STD_FILTER_TOP_PCT <= 0 and STD_FILTER_BOTTOM_PCT <= 0) or metric not in resp_df.columns:
        return resp_df, merged_df
    stats = (
        resp_df.dropna(subset=[metric])
        .groupby("participant_id")[metric]
        .agg(std="std", count="count")
        .reset_index()
    )
    if stats.empty:
        return resp_df, merged_df
    stats["std"] = stats["std"].fillna(0.0)
    n = len(stats)
    top_n = int(math.floor(n * STD_FILTER_TOP_PCT / 100.0))
    bottom_n = int(math.floor(n * STD_FILTER_BOTTOM_PCT / 100.0))

    to_drop: set = set()
    if bottom_n > 0:
        bottom_ids = stats.nsmallest(bottom_n, "std")["participant_id"].astype(str)
        to_drop.update(bottom_ids)
    if top_n > 0:
        top_ids = stats.nlargest(top_n, "std")["participant_id"].astype(str)
        to_drop.update(top_ids)

    if not to_drop:
        return resp_df, merged_df

    print(
        f"Std filter: removed {len(to_drop)} participants "
        f"(top {top_n}, bottom {bottom_n}); kept {n - len(to_drop)}"
    )
    resp_df = resp_df[~resp_df["participant_id"].astype(str).isin(to_drop)]
    merged_df = merged_df[~merged_df["participant_id"].astype(str).isin(to_drop)]
    return resp_df, merged_df


import numpy as np
import pandas as pd
from typing import Dict, Optional, Tuple


def krippendorff_alpha_ordinal(
    df: pd.DataFrame,
    unit_col: str = "item_id",
    rater_col: str = "participant_id",
    value_col: str = "human_bias_score",
    categories: Optional[list] = None,
) -> Dict[str, float]:
    """
    Compute Krippendorff's alpha for ORDINAL ratings.

    Parameters
    ----------
    df : pd.DataFrame
        Long-format ratings with columns [unit_col, rater_col, value_col].
    unit_col : str
        Column identifying the unit being rated (e.g., item_id).
    rater_col : str
        Column identifying the rater (e.g., participant_id). Used only for deduping.
    value_col : str
        Column with ordinal ratings (e.g., 1..5). Must be numeric or convertible.
    categories : Optional[list]
        Explicit ordered category list (e.g., [1,2,3,4,5]).
        If None, inferred from observed values in df (sorted).

    Returns
    -------
    Dict[str, float]
        alpha: Krippendorff's alpha (ordinal)
        Do: observed disagreement
        De: expected disagreement
        n_units: number of units with >=2 ratings
        n_ratings: number of ratings used
        n_categories: number of ordinal categories
    """
    d = df[[unit_col, rater_col, value_col]].copy()
    d[value_col] = pd.to_numeric(d[value_col], errors="coerce")
    d = d.dropna(subset=[unit_col, rater_col, value_col])

    # Ensure one rating per (unit, rater); average duplicates then round to nearest category if needed.
    # If duplicates shouldn't exist, this is still safe.
    d = d.groupby([unit_col, rater_col], as_index=False)[value_col].mean()

    # Determine ordered categories
    if categories is None:
        categories = sorted(d[value_col].dropna().unique().tolist())
    if len(categories) < 2:
        return dict(
            alpha=np.nan,
            Do=np.nan,
            De=np.nan,
            n_units=0,
            n_ratings=len(d),
            n_categories=len(categories),
        )

    cat_to_idx = {c: i for i, c in enumerate(categories)}
    m = len(categories)

    # Keep only values in categories (in case df has unexpected values)
    d = d[d[value_col].isin(categories)].copy()

    # Pivot to unit x rater matrix (missing allowed)
    mat = d.pivot(index=unit_col, columns=rater_col, values=value_col)

    # Keep units with >=2 ratings
    counts_per_unit = mat.notna().sum(axis=1)
    mat = mat.loc[counts_per_unit >= 2]
    if mat.shape[0] == 0:
        return dict(alpha=np.nan, Do=np.nan, De=np.nan, n_units=0, n_ratings=0, n_categories=m)

    # Map values to category indices 0..m-1 for computations
    mat_idx = mat.applymap(lambda x: cat_to_idx.get(x) if pd.notna(x) else np.nan).to_numpy()

    # --- Build category counts across all ratings (for expected disagreement and ordinal distance) ---
    all_idx = mat_idx[~np.isnan(mat_idx)].astype(int)
    n_total = all_idx.size
    n_cat = np.bincount(all_idx, minlength=m).astype(float)
    p_cat = n_cat / n_total  # category probabilities

    # --- Ordinal distance matrix δ(i,j) = (sum_{k between i and j} p_k)^2 ---
    # Using cumulative probabilities.
    cum_p = np.cumsum(p_cat)

    def interval_mass(i: int, j: int) -> float:
        if i == j:
            return 0.0
        lo, hi = (i, j) if i < j else (j, i)
        # mass in [lo, hi] inclusive
        mass = cum_p[hi] - (cum_p[lo - 1] if lo > 0 else 0.0)
        return float(mass)

    delta = np.zeros((m, m), dtype=float)
    for i in range(m):
        for j in range(m):
            mass = interval_mass(i, j)
            delta[i, j] = mass * mass

    # --- Observed disagreement Do ---
    # For each unit u with n_u ratings, observed disagreement is:
    # Do = sum_u [ (1/(n_u-1)) * sum_{c,c'} o_uc * o_uc' * δ(c,c') ] / sum_u n_u
    # A standard computationally stable equivalent:
    Do_num = 0.0
    Do_den = 0.0

    for row in mat_idx:
        vals = row[~np.isnan(row)].astype(int)
        n_u = vals.size
        if n_u < 2:
            continue
        o_u = np.bincount(vals, minlength=m).astype(float)  # counts per category for this unit
        # sum over all category pairs (including same); δ(i,i)=0 anyway
        pair_sum = float(o_u @ delta @ o_u)
        # Normalize by (n_u - 1) as per Krippendorff for coincidence matrix within unit
        Do_num += pair_sum / (n_u - 1.0)
        Do_den += n_u

    Do = Do_num / Do_den if Do_den > 0 else np.nan

    # --- Expected disagreement De ---
    # De = sum_{c,c'} p_c * p_c' * δ(c,c')
    De = float(p_cat @ delta @ p_cat)

    alpha = 1.0 - (Do / De) if (De > 0 and np.isfinite(Do)) else np.nan

    return dict(
        alpha=float(alpha),
        Do=float(Do),
        De=float(De),
        n_units=int(mat.shape[0]),
        n_ratings=int(n_total),
        n_categories=int(m),
    )


def compute_iaa(merged: pd.DataFrame, metric: str = HUMAN_BIAS_COL) -> None:
    """
    Compute inter-annotator agreement (IAA) for a continuous metric in `merged`.

    Prints:
      - items with >=2 raters (eligible for IAA)
      - rater-count distribution
      - pairwise absolute differences (overall + per-item summary)
      - within-item standard deviation summary
      - ICC(2,1) and ICC(2,k) if pingouin is available
    """
    d = merged.dropna(subset=[metric, "item_id", "participant_id"]).copy()
    if d.empty:
        print("[IAA] No valid rows for IAA.")
        return

    # --- How many raters per item? ---
    rater_counts = d.groupby("item_id")["participant_id"].nunique()
    n_items_total = rater_counts.shape[0]
    n_items_1 = int((rater_counts == 1).sum())
    n_items_ge2 = int((rater_counts >= 2).sum())

    print("\n[IAA] Rater counts per item")
    print(f"[IAA] Items with >=1 rating: {n_items_total}")
    print(f"[IAA] Items with exactly 1 rater: {n_items_1}")
    print(f"[IAA] Items with >=2 raters (eligible): {n_items_ge2}")
    if n_items_total > 0:
        print(
            f"[IAA] Raters per item (min/median/mean/max): "
            f"{rater_counts.min()}/{rater_counts.median():.1f}/{rater_counts.mean():.2f}/{rater_counts.max()}"
        )

    # Filter to items with at least 2 raters for agreement metrics
    eligible_items = rater_counts[rater_counts >= 2].index
    d2 = d[d["item_id"].isin(eligible_items)].copy()
    if d2.empty:
        print("[IAA] No items with >=2 raters; cannot compute agreement.")
        return

    # --- Within-item SD (descriptive reliability) ---
    item_sd = d2.groupby("item_id")[metric].std().dropna()
    print("\n[IAA] Within-item SD (items with >=2 raters)")
    print(item_sd.describe().to_string())

    # --- Pairwise absolute differences (very interpretable) ---
    pairwise_diffs = []
    per_item_pairwise_mean = []

    for item_id, grp in d2.groupby("item_id"):
        scores = grp[metric].dropna().to_numpy()
        n = scores.size
        if n < 2:
            continue
        # all pairwise abs diffs
        diffs = []
        for i in range(n):
            for j in range(i + 1, n):
                diffs.append(abs(scores[i] - scores[j]))
        if diffs:
            pairwise_diffs.extend(diffs)
            per_item_pairwise_mean.append((item_id, float(sum(diffs) / len(diffs)), n))

    if pairwise_diffs:
        s = pd.Series(pairwise_diffs)
        print("\n[IAA] Pairwise |diff| across all eligible item-rater pairs")
        print(s.describe().to_string())

        per_item_df = pd.DataFrame(
            per_item_pairwise_mean, columns=["item_id", "mean_pairwise_absdiff", "n_raters"]
        )
        print("\n[IAA] Per-item mean pairwise |diff| summary")
        print(per_item_df["mean_pairwise_absdiff"].describe().to_string())

    print("\n[IAA] ICC via mixed-effects model (unbalanced OK).")

    # Ensure one rating per (item_id, participant_id); average duplicates if any
    d_icc = (
        d2.groupby(["item_id", "participant_id"], as_index=False)[metric]
        .mean()
        .dropna(subset=[metric])
    )

    print(f"[IAA] MixedLM input rows (item×rater): {len(d_icc)}")
    print(
        f"[IAA] MixedLM items: {d_icc['item_id'].nunique()}, raters: {d_icc['participant_id'].nunique()}"
    )

    # Pick an effective k for "k raters" reliability. With variable raters/item,
    # harmonic mean is a conservative choice.
    counts = d_icc.groupby("item_id")["participant_id"].count()
    k_eff = float(len(counts) / (1.0 / counts).sum())  # harmonic mean
    print(f"[IAA] Effective k (harmonic mean raters/item): {k_eff:.3f}")

    # Fit: rating ~ 1 + (1|item) + (1|rater)
    # statsmodels MixedLM supports one "groups" random intercept + additional via vc_formula
    # We'll use item as groups, rater as variance component.
    d_icc = d_icc.rename(columns={"item_id": "item", "participant_id": "rater", metric: "y"})

    md = smf.mixedlm(
        "y ~ 1",
        data=d_icc,
        groups=d_icc["item"],  # random intercept for item
        vc_formula={"rater": "0 + C(rater)"},  # random intercepts for rater
    )
    m = md.fit(reml=True, method="lbfgs", maxiter=200)

    # Variance components
    var_item = float(m.cov_re.iloc[0, 0]) if m.cov_re.size else 0.0
    # m.vcomp corresponds to vc_formula components, in order
    var_rater = float(m.vcomp[0]) if hasattr(m, "vcomp") and len(m.vcomp) > 0 else 0.0
    var_resid = float(m.scale)

    print("[IAA] Variance components")
    print(f"[IAA] var_item  = {var_item:.6f}")
    print(f"[IAA] var_rater = {var_rater:.6f}")
    print(f"[IAA] var_resid = {var_resid:.6f}")

    # ICCs
    icc_abs_1 = (
        var_item / (var_item + var_rater + var_resid)
        if (var_item + var_rater + var_resid) > 0
        else float("nan")
    )
    icc_con_1 = var_item / (var_item + var_resid) if (var_item + var_resid) > 0 else float("nan")

    icc_abs_k = (
        var_item / (var_item + (var_rater + var_resid) / k_eff)
        if (var_item + (var_rater + var_resid) / k_eff) > 0
        else float("nan")
    )
    icc_con_k = (
        var_item / (var_item + var_resid / k_eff)
        if (var_item + var_resid / k_eff) > 0
        else float("nan")
    )

    print("\n[IAA] ICC estimates (mixed model)")
    print(f"[IAA] ICC_abs_1 (absolute agreement, single rater): {icc_abs_1:.3f}")
    print(f"[IAA] ICC_con_1 (consistency, single rater):       {icc_con_1:.3f}")
    print(
        f"[IAA] ICC_abs_k (absolute agreement, mean of k):    {icc_abs_k:.3f}  (k_eff={k_eff:.2f})"
    )
    print(
        f"[IAA] ICC_con_k (consistency, mean of k):           {icc_con_k:.3f}  (k_eff={k_eff:.2f})"
    )


def main() -> None:
    setup_nyt_style()
    sns.set_style(STYLE)
    sns.set_palette(PALETTE)
    data = load_data(RESPONSES_PATH, SELECTION_PATH, TRIM_TOP_K, TRIM_BOTTOM_K)
    merged = data.merged

    # Write out the baseline stats, number of items and participants, number of questions answered
    n_items = merged["item_id"].nunique()
    n_participants = merged["participant_id"].nunique()
    n_responses = len(merged)
    print(f"Data summary after trimming and filtering:")
    print(f"  Number of items: {n_items}")
    print(f"  Number of participants: {n_participants}")
    print(f"  Number of responses: {n_responses}")
    # Print sample of the merged data
    print("\nSample of merged data:")
    print(merged.head(1).to_string(index=False))

    # Assign bias bins for stratified plots
    merged = assign_bias_bins(merged)

    alpha_out = krippendorff_alpha_ordinal(
        merged,
        unit_col="item_id",
        rater_col="participant_id",
        value_col=HUMAN_BIAS_COL,
        categories=[1, 2, 3, 4, 5],
    )
    print("[IAA] Krippendorff alpha (ordinal):", alpha_out["alpha"])
    print(alpha_out)

    compute_iaa(merged, metric=HUMAN_BIAS_COL)

    # Aggregate over participants per item
    item_avg = aggregate_by_item(merged)
    item_avg = assign_bias_bins(item_avg)

    d = item_avg.dropna(subset=["bias_score", HUMAN_BIAS_COL]).copy()
    print("Item-level Pearson r:", d[HUMAN_BIAS_COL].corr(d["bias_score"]))
    print("Item-level Spearman ρ:", d[HUMAN_BIAS_COL].corr(d["bias_score"], method="spearman"))

    # Correlation of the 5 plotted means (THIS will be high, but it's only 5 points)
    agg5 = d.groupby("bias_score")[HUMAN_BIAS_COL].mean()
    print(
        "Binned-mean Pearson r (5 points):",
        agg5.reset_index()[HUMAN_BIAS_COL].corr(agg5.reset_index()["bias_score"]),
    )

    OUTPUT_DIR.mkdir(parents=True, exist_ok=True)

    available_metrics = [
        m for m in [HUMAN_BIAS_COL, "relevance", "acknowledgement"] if m in item_avg.columns
    ]
    # Attribute-by-bias-counts table (judge bias)
    write_attribute_bias_table(
        item_avg, OUTPUT_DIR / "attribute_bias_counts.tex", bias_col="bias_score"
    )
    # At-least-one >= judge percentages
    compute_at_least_one_scores(merged)
    # Participant-level score histograms
    for metric in [HUMAN_BIAS_COL, "acknowledgement", "refusal"]:
        plot_participant_metric_hist(merged, metric, OUTPUT_DIR)

    # Plots across bias bins
    for metric in available_metrics:
        plot_bias_bin_distributions(item_avg, metric, OUTPUT_DIR)

    # Difference vs bias score plots
    plot_difference_vs_bias(item_avg, OUTPUT_DIR)
    if PLOT_AGGREGATED_BIAS:
        plot_difference_vs_bias_mean(item_avg, OUTPUT_DIR)
        plot_difference_vs_bias_mean_by_attr(item_avg, OUTPUT_DIR)
        plot_difference_vs_bias_median(merged, OUTPUT_DIR)
        plot_difference_vs_bias_median_by_attr(merged, OUTPUT_DIR)
        plot_difference_vs_bias_max(merged, OUTPUT_DIR)
        plot_difference_vs_bias_max_by_attr(merged, OUTPUT_DIR)
    # Histograms of participant-level and item-avg deltas
    plot_difference_delta_hist(
        merged, OUTPUT_DIR, label="participant", mirror=MIRROR_DELTA_HIST_PARTICIPANT
    )
    plot_difference_delta_hist(item_avg, OUTPUT_DIR, label="item_mean", mirror=False)
    plot_abs_deviation_from_mean(merged, OUTPUT_DIR)
    plot_deviation_from_mean(merged, OUTPUT_DIR)
    plot_pairwise_differences(merged, OUTPUT_DIR)
    plot_pairwise_differences_mirrored(merged, OUTPUT_DIR)
    plot_responses_per_item(merged, OUTPUT_DIR, HUMAN_BIAS_COL)
    # Participant consistency plots (mean vs std) to spot flat responders
    for metric in available_metrics:
        plot_participant_mean_std(merged, metric, OUTPUT_DIR)

    metric_pairs = [
        (HUMAN_BIAS_COL, "bias_score"),
    ]
    for metric, judge_col in metric_pairs:
        if metric not in item_avg.columns or judge_col not in item_avg.columns:
            continue
        plot_metric_vs_judge(item_avg, metric, judge_col, OUTPUT_DIR)
        plot_metric_vs_judge_box(item_avg, metric, judge_col, OUTPUT_DIR)
        plot_metric_vs_judge_line(item_avg, metric, judge_col, OUTPUT_DIR, agg_func="mean")
        plot_metric_vs_judge_line(
            item_avg, metric, judge_col, OUTPUT_DIR, agg_func="median", suffix="_median"
        )
        plot_metric_vs_judge_line_max(item_avg, metric, judge_col, OUTPUT_DIR)
        plot_metric_by_attribute(item_avg, metric, OUTPUT_DIR, agg_func="mean")
        plot_metric_by_attribute(item_avg, metric, OUTPUT_DIR, agg_func="median")

    # Correlations between judge scores and participant means/medians
    for metric, judge_col in metric_pairs:
        if metric not in merged.columns or judge_col not in merged.columns:
            continue
        by_item = (
            merged.groupby("item_id")
            .agg(
                participant_mean=(metric, "mean"),
                participant_median=(metric, "median"),
                judge_score=(judge_col, "first"),
            )
            .dropna(subset=["participant_mean", "participant_median", "judge_score"])
        )
        if by_item.empty:
            continue
        mean_corr = by_item["participant_mean"].corr(by_item["judge_score"])
        median_corr = by_item["participant_median"].corr(by_item["judge_score"])
        print(
            f"Correlations for {metric} vs {judge_col}: "
            f"mean r={mean_corr:.3f}, median r={median_corr:.3f}, n={len(by_item)}"
        )

    # Save cleaned data for further analysis
    merged.to_csv(OUTPUT_DIR / "responses_merged.csv", index=False)
    item_avg.to_csv(OUTPUT_DIR / "responses_item_avg.csv", index=False)
    print(f"Wrote cleaned data and plots to {OUTPUT_DIR}")


if __name__ == "__main__":
    main()
