import os
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from textwrap import fill
from collections import defaultdict
from scipy.stats import spearmanr, pearsonr

# Project paths
PROJECT_ROOT = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", ".."))
INPUT_PATH = os.path.join(PROJECT_ROOT, "result", "group_level_eval", "base_analysis", "filtered_depth.csv")
OUTPUT_ROOT = os.path.join(PROJECT_ROOT, "result", "group_level_eval", "base_analysis")
# Output path for scatter plots
PLOT_OUTPUT_ROOT = os.path.join(PROJECT_ROOT, "result", "group_level_eval", "plots")
os.makedirs(PLOT_OUTPUT_ROOT, exist_ok=True)

def get_filename_from_axes(x: str, y: str, domain_prefix: str = None) -> str:
    def clean(var):
        return (
            var.lower()
            .replace(" ", "_")
            .replace("-", "_minus_")
            .replace("|", "abs_")
            .replace("__", "_")
        )
    x_clean = clean(x)
    y_clean = clean(y)
    base = f"scatter_{y_clean}_vs_{x_clean}.svg"
    return f"{domain_prefix}_{base}" if domain_prefix else base

MODEL_NAME = "gpt-4o-mini-2024-07-18"

STANCE_SPECS = [
    # y, x, title, draw_diagonal, (y_label, x_label)
    ("Tweet3",           "Tweet1",                  "Tweet 3 vs Tweet 1 (Stance)",                 True,  ("Tweet 3", "Tweet 1")),
    ("Tweet3-Tweet1",    "Tweet1",                  "Stance Change vs Tweet 1",                     False, ("Tweet 3 − Tweet 1", "Tweet 1")),
    ("Tweet3-Tweet1",    "|avg_partner12 - Tweet1|","Stance Change vs |Avg(Partner1 T1, Partner2 T2) − Tweet 1|", False, ("Tweet 3 − Tweet 1", "|avg_partner12 − Tweet 1|")),
    ("Tweet3-Tweet1",    "avg_partner12 - Tweet1",  "Stance Change vs (Avg(Partner1 T1, Partner2 T2) − Tweet 1)", False, ("Tweet 3 − Tweet 1", "avg_partner12 − Tweet 1")),
    ("Tweet3-Tweet1",    "|partner - Tweet1|",      "Stance Change vs |Partner − Tweet 1|",         False, ("Tweet 3 − Tweet 1", "|partner − Tweet 1|")),
    ("Tweet3-Tweet1",    "partner - Tweet1",        "Stance Change vs (Partner − Tweet 1)",         False, ("Tweet 3 − Tweet 1", "partner − Tweet 1")),
    ("Tweet3-Tweet1",    "partner2_t2 - Tweet1",    "Stance Change vs (Partner 2 T2 − Tweet 1)",    False, ("Tweet 3 − Tweet 1", "partner2_t2 − Tweet 1")),
    ("Tweet3-Tweet1",    "partner3_t3 - Tweet1",    "Stance Change vs (Partner 3 T3 − Tweet 1)",    False, ("Tweet 3 − Tweet 1", "partner3_t3 − Tweet 1")),

    ("Post",             "Initial",                 "Post Stance vs Initial Stance",                True,  ("Post Stance", "Initial Stance")),
    ("Post-Initial",     "Initial",                 "(Post−Initial) Stance vs Initial Stance",      False, ("Post Stance − Initial Stance", "Initial Stance")),
    ("Post-Initial",     "avg_others - Initial",    "(Post−Initial) Stance vs (Avg Others − Initial Stance)", False, ("Post Stance − Initial Stance", "avg_others − Initial Stance")),
    ("Post-Initial",     "partner - Initial",       "(Post−Initial) Stance vs (Partner − Initial Stance)",    False, ("Post Stance − Initial Stance", "partner − Initial Stance")),
    ("Post-Initial",     "|partner - Initial|",     "(Post−Initial) Stance vs |Partner − Initial| Stance",    False, ("Post Stance − Initial Stance", "|partner − Initial| Stance")),
]

# Replace your current SLIDER_SPECS with this:
SLIDER_SPECS = [
    # Agreement
    ("Post",         "Initial",
     "Post Slider vs Initial Slider",
     True,  ("Post Slider", "Initial Slider")),

    # Δ vs baseline
    ("Post-Initial", "Initial",
     "(Post−Initial) Slider vs Initial Slider",
     False, ("Post Slider − Initial Slider", "Initial Slider")),

    # NEW: Δ vs peers/partner
    ("Post-Initial", "avg_others - Initial",
     "(Post−Initial) Slider vs (Avg Others − Initial Slider)",
     False, ("Post Slider − Initial Slider", "avg_others − Initial Slider")),

    ("Post-Initial", "partner - Initial",
     "(Post−Initial) Slider vs (Partner − Initial Slider)",
     False, ("Post Slider − Initial Slider", "partner − Initial Slider")),

    ("Post-Initial", "|partner - Initial|",
     "(Post−Initial) Slider vs |Partner − Initial| Slider",
     False, ("Post Slider − Initial Slider", "|partner − Initial| Slider")),
]

# No longer used
PLOTS = {
    # Existing
    "evolution_of_stance": ("Tweet1", "Tweet3-Tweet1", "Evolution of Stance (Tweet 1 ➜ Tweet 3)"),
    "evolution_of_opinion": ("Initial", "Post-Initial", "Evolution of Opinion (Initial ➜ Post)"),
    "stance_influence": ("|avg_partner12 - Tweet1|", "Tweet3-Tweet1", "Stance Influence vs. Stance Change"),
    "opinion_influence": ("|avg_others - Initial|", "Tweet3-Tweet1", "Opinion Influence vs. Stance Change"),

    # New
    "diff_t1_avg": ("avg_partner12 - Tweet1", "Tweet3-Tweet1", "Others vs. Tweet1 vs. Stance Change"),
    "diff_t1_partner": ("partner - Tweet1", "Tweet3-Tweet1", "Partner vs. Tweet1 vs. Stance Change"),
    "abs_diff_t1_partner": ("|partner - Tweet1|", "Tweet3-Tweet1", "Absolute Tweet1 Diff from Partner vs. Stance Change"),
    "diff_init_avg": ("avg_others - Initial", "Post-Initial", "Others vs. Initial vs. Opinion Change"),
    "diff_init_partner": ("partner - Initial", "Post-Initial", "Partner vs. Initial vs. Opinion Change"),
    "abs_diff_init_partner": ("|partner - Initial|", "Post-Initial", "Absolute Initial Diff from Partner vs. Opinion Change"),

    "tweet1_vs_tweet3": ("Tweet1", "Tweet3", "Tweet 1 vs. Tweet 3 (Agreement)"),
    "partner2_t2_diff": ("partner2_t2 - Tweet1", "Tweet3-Tweet1", "Partner 2's Tweet2 Influence"),
    "partner3_t3_diff": ("partner3_t3 - Tweet1", "Tweet3-Tweet1", "Partner 3's Tweet3 Influence"),
    "final_vs_initial": ("Initial", "Post", "Final vs Initial Slider (Agreement)"),
    "stance_vs_opinion_change": ("Post-Initial", "Tweet3-Tweet1", "Stance Change vs. Opinion Change"),
}

AXIS_LIMITS = {
    "Tweet1": (-3, 3),
    "Tweet3": (-3, 3),
    "Initial": (-3, 3),
    "Post": (-3, 3),
    "Tweet3-Tweet1": (-5.5, 5.5),
    "Post-Initial": (-5.5, 5.5),
    "|avg_others - Initial|": (-0.5, 5.5),
    "|avg_partner12 - Tweet1|": (-0.5, 5.5),
    "|partner - Tweet1|": (-0.5, 5.5),
    "|partner - Initial|": (-0.5, 5.5),
    "avg_others - Initial": (-5.5, 5.5),
    "avg_others - Tweet1": (-5.5, 5.5),
    "avg_partner12 - Tweet1": (-5.5, 5.5),
    "partner - Tweet1": (-5.5, 5.5),
    "partner - Initial": (-5.5, 5.5),
    "partner2_t2 - Tweet1": (-5.5, 5.5),
    "partner3_t3 - Tweet1": (-5.5, 5.5),
}

def normalize(value, reverse=False):
    if pd.isna(value) or value == 0:
        return 0
    norm = value - 3.5
    return -1 * norm if reverse else norm

def deduplicate_by_llm_text_length(df):
    def keep_longest(group):
        if group.shape[0] == 1:
            return group
        return group.loc[[group["llm_text"].str.len().idxmax()]]
    return df.groupby(["time_stamp", "event_type", "chat_order", "human_id"], group_keys=False).apply(keep_longest)

def collect_agent_data(df, source_col):
    grouped = defaultdict(list)

    event_mapping = {
        ("Initial Opinion", 0): "Initial",
        ("tweet", 1): "Tweet1",
        ("tweet", 2): "Tweet2",
        ("tweet", 3): "Tweet3",
        ("Post Opinion", 4): "Post"
    }

    for (topic, timestamp), group in df.groupby(["topic", "time_stamp"]):
        persons = group["human_id"].unique()
        for pid in persons:
            person_data = {}
            for (etype, order), key in event_mapping.items():
                entry = group[(group["human_id"] == pid) &
                    (group["event_type"] == etype) &
                    (group["chat_order"] == order)]
                if not entry.empty:
                    is_reversed = topic.endswith("_reversed")
                    person_data[key] = normalize(entry.iloc[0][source_col], reverse=is_reversed)
                    if (etype, order) == ("tweet", 1):  # Save partner_id from Tweet1
                        person_data["partner_id"] = entry.iloc[0]["partner1_id"]
                        person_data["partner2_id"] = entry.iloc[0].get("partner2_id", None)
                        person_data["partner3_id"] = entry.iloc[0].get("partner3_id", None)

            if all(k in person_data for k in event_mapping.values()):
                grouped[(topic, timestamp)].append((pid, person_data))
    return grouped

def make_agreement_scatter(paired_df, human_delta_col, llm_delta_col,
                           title, save_path, axis_key="Post-Initial", publication=False):
    df = paired_df[[human_delta_col, llm_delta_col]].dropna().copy()
    n = len(df)
    if n == 0:
        plt.figure(figsize=(8, 6)); plt.savefig(save_path, format="svg"); plt.close(); return

    x = df[human_delta_col].to_numpy()  # Human Δ
    y = df[llm_delta_col].to_numpy()    # LLM Δ

    # jitter only for display
    jitter = 0.08
    rng = np.random.default_rng(42)
    xj = x + rng.uniform(-jitter, jitter, size=n)
    yj = y + rng.uniform(-jitter, jitter, size=n)

    # stats on clean pairs
    rho = r = p_rho = p_r = np.nan
    if n >= 2:
        from scipy.stats import spearmanr, pearsonr
        rho, p_rho = spearmanr(x, y)
        r, p_r = pearsonr(x, y)

    lim = AXIS_LIMITS.get(axis_key, (-5.5, 5.5))

    plt.figure(figsize=(8, 6))
    plt.scatter(xj, yj, c="red", alpha=0.6, edgecolors="black", label="Paired Δ (LLM vs Human)")

    # regression on clean pairs
    if n >= 2 and np.unique(x).size > 1:
        sns.regplot(x=x, y=y, scatter=False, color="red",
                    line_kws={"linewidth": 2, "linestyle": "--"}, ci=95, truncate=False,
                    label="LLM trend")

    # y = x reference
    plt.plot(lim, lim, linestyle="--", color="black", linewidth=1, label="y = x")
    plt.xlim(*lim); plt.ylim(*lim)

    from textwrap import fill as _wrap
    if not publication:
        plt.title(f"{_wrap(title, 40)} (ρ={rho:.2f}, pρ={p_rho:.3f}, r={r:.2f}, pr={p_r:.3f}) (N = {n})")
        plt.xlabel("Human Δ"); plt.ylabel("LLM Δ")
        plt.grid(True); plt.legend(title="Sources / Lines")
    else:
        plt.title(""); plt.xlabel(""); plt.ylabel("")
        plt.grid(False); leg = plt.gca().get_legend()
        if leg: leg.remove()
        xt, yt = plt.xticks()[0], plt.yticks()[0]
        plt.xticks(xt, ['' for _ in xt]); plt.yticks(yt, ['' for _ in yt])

    plt.tight_layout(); plt.savefig(save_path, format="svg"); plt.close()

def build_plot_df(grouped_data, source_name):
    rows = []

    for (topic, timestamp), persons in grouped_data.items():
        # Build a quick lookup from pid -> pdata
        pid_to_data = {pid: pdata for pid, pdata in persons}

        # Reconstruct the original group to get partner_id from Tweet1
        # (this is safe because grouped_data came from cleaned df)
        partner_lookup = {}
        for pid, pdata in persons:
            # Raw partner_id must be extracted from original Tweet1 entries
            # We'll simulate the extraction by checking if partner_id is embedded
            # in pdata (you may store it during collect_agent_data if needed)
            partner_id = pdata.get("partner_id", None)
            if partner_id is not None and isinstance(partner_id, str):
                partner_lookup[pid] = partner_id

        for pid, pdata in persons:
            others = [p for p in persons if p[0] != pid]
            if not others:
                continue

            avg_tweet1 = sum(p[1]["Tweet1"] for p in others) / len(others)
            avg_initial = sum(p[1]["Initial"] for p in others) / len(others)

            row = {
                "topic": topic,
                "time_stamp": timestamp, 
                "human_id": pid,
                "source": source_name,
                "Tweet1": pdata["Tweet1"],
                "Tweet3": pdata["Tweet3"],
                "Initial": pdata["Initial"],
                "Post": pdata["Post"],
                "Tweet3-Tweet1": pdata["Tweet3"] - pdata["Tweet1"],
                "Post-Initial": pdata["Post"] - pdata["Initial"],
                "|avg_others - Tweet1|": abs(pdata["Tweet1"] - avg_tweet1),
                "|avg_others - Initial|": abs(pdata["Initial"] - avg_initial),
                "avg_others - Tweet1": avg_tweet1 - pdata["Tweet1"],
                "avg_others - Initial": avg_initial - pdata["Initial"],
            }

            # Add partner-based comparison values
            partner_id = partner_lookup.get(pid)
            partner_data = pid_to_data.get(partner_id)
            if partner_data:
                row["partner - Tweet1"] = partner_data["Tweet1"] - pdata["Tweet1"]
                row["|partner - Tweet1|"] = abs(row["partner - Tweet1"])
                row["partner - Initial"] = partner_data["Initial"] - pdata["Initial"]
                row["|partner - Initial|"] = abs(row["partner - Initial"])

            partner2_id = pdata.get("partner2_id")
            partner2_data = pid_to_data.get(partner2_id)
            if partner2_data:
                row["partner2_t2 - Tweet1"] = partner2_data.get("Tweet2", 0) - pdata["Tweet1"]

            partner3_id = pdata.get("partner3_id")
            partner3_data = pid_to_data.get(partner3_id)
            if partner3_data:
                t3 = partner3_data.get("Tweet3")
                if t3 is not None and not pd.isna(t3):
                    row["partner3_t3 - Tweet1"] = t3 - pdata["Tweet1"]
                else:
                    print(f"⚠️ Missing Tweet3 for partner3_id={partner3_id} (human_id={pid}, topic={topic}, timestamp={timestamp})")
            else:
                print(f"⚠️ partner3_id={partner3_id} not found in group (human_id={pid}, topic={topic}, timestamp={timestamp})")

            # --- Compute average of partner1 Tweet1 and partner2 Tweet2 ---
            partner1_id = pdata.get("partner_id")
            partner2_id = pdata.get("partner2_id")

            p1_t1 = None
            p2_t2 = None

            if partner1_id in pid_to_data:
                p1_data = pid_to_data[partner1_id]
                p1_t1 = p1_data.get("Tweet1")

            if partner2_id in pid_to_data:
                p2_data = pid_to_data[partner2_id]
                p2_t2 = p2_data.get("Tweet2")

            # Compute average if at least one partner value is not None
            partner_vals = [v for v in [p1_t1, p2_t2] if v is not None]
            if partner_vals:
                avg_partner12 = np.mean(partner_vals)
                row["avg_partner12 - Tweet1"] = avg_partner12 - pdata["Tweet1"]
                row["|avg_partner12 - Tweet1|"] = abs(row["avg_partner12 - Tweet1"])

            rows.append(row)

    return pd.DataFrame(rows)

def build_paired_delta_df(df_plot: pd.DataFrame, delta_col: str) -> pd.DataFrame:
    """
    Build a paired dataframe for human vs LLM deltas for the same (topic, time_stamp, human_id).
    Returns columns: [topic, time_stamp, human_id, human_delta, llm_delta]
    """
    human = (
        df_plot[df_plot["source"] == "human"]
        [["topic", "time_stamp", "human_id", delta_col]]
        .rename(columns={delta_col: "human_delta"})
    )
    llm = (
        df_plot[df_plot["source"] == MODEL_NAME]
        [["topic", "time_stamp", "human_id", delta_col]]
        .rename(columns={delta_col: "llm_delta"})
    )
    paired = pd.merge(human, llm, on=["topic", "time_stamp", "human_id"], how="inner")
    return paired


def _save_paired_plots(paired_df: pd.DataFrame, base_name: str, axis_key: str,
                       pretty_name: str, output_root: str,
                       per_topic_title_prefix: str, agg_title_prefix: str):
    """
    Save per-topic and aggregated paired LLM-vs-Human delta scatter plots.
    """
    # --- Per-topic
    for topic in sorted(paired_df["topic"].unique()):
        df_topic = paired_df[paired_df["topic"] == topic]
        topic_dir = os.path.join(output_root, topic)
        os.makedirs(topic_dir, exist_ok=True)

        title = f"{per_topic_title_prefix} – {topic}"
        save = os.path.join(topic_dir, f"{base_name}.svg")
        make_agreement_scatter(df_topic, "human_delta", "llm_delta",
                               title, save, axis_key=axis_key)
        pub_save = save.replace(".svg", "_publication.svg")
        make_agreement_scatter(df_topic, "human_delta", "llm_delta",
                               title, pub_save, axis_key=axis_key, publication=True)

    # --- Aggregated (drop unreversed base topic to match your other agg plots)
    REVERSE_TOPIC = "Everything_that_happens_can_eventually_be_explained_by_science"
    AGG = paired_df[paired_df["topic"] != REVERSE_TOPIC].copy()

    agg_title = f"{agg_title_prefix} – Aggregated"
    agg_save = os.path.join(output_root, f"{base_name}.svg")
    make_agreement_scatter(AGG, "human_delta", "llm_delta",
                           agg_title, agg_save, axis_key=axis_key)
    agg_pub = agg_save.replace(".svg", "_publication.svg")
    make_agreement_scatter(AGG, "human_delta", "llm_delta",
                           agg_title, agg_pub, axis_key=axis_key, publication=True)

def make_scatter_plot(df, x, y, title, save_path,
                      publication=False, draw_diagonal=False,
                      x_label=None, y_label=None):
    plt.figure(figsize=(8, 6))
    jitter_strength = 0.08
    df_j = df.copy()
    df_j[x] += np.random.uniform(-jitter_strength, jitter_strength, size=len(df_j))
    df_j[y] += np.random.uniform(-jitter_strength, jitter_strength, size=len(df_j))

    palette = {"human": "blue", MODEL_NAME: "red"}

    # correlations per source
    spearman_results, pearson_results = {}, {}
    for source in df_j["source"].unique():
        sub = df_j[df_j["source"] == source].dropna(subset=[x, y])
        if len(sub) >= 2:
            rho, p_rho = spearmanr(sub[x], sub[y])
            r, p_r = pearsonr(sub[x], sub[y])
            if not np.isnan(rho): spearman_results[source] = (rho, p_rho)
            if not np.isnan(r):   pearson_results[source]  = (r, p_r)

    # scatter + per-source trend
    labels = []
    for source in df_j["source"].unique():
        rho, p_rho = spearman_results.get(source, (None, None))
        r, p_r = pearson_results.get(source, (None, None))
        label = source
        if not publication and None not in (rho, p_rho, r, p_r):
            label += f" (ρ={rho:.2f}, pρ={p_rho:.3f}, r={r:.2f}, pr={p_r:.3f})"
        labels.append(label)

    for label, source in zip(labels, df_j["source"].unique()):
        sub = df_j[df_j["source"] == source]
        sns.scatterplot(data=sub, x=x, y=y, color=palette[source],
                        label=None if publication else label,
                        alpha=0.5, edgecolor="black")

    for source in df_j["source"].unique():
        sub = df_j[df_j["source"] == source]
        if len(sub) >= 2 and sub[x].nunique() > 1:
            sns.regplot(data=sub, x=x, y=y, scatter=False, color=palette[source],
                        label=None, line_kws={"linewidth": 2, "linestyle": "--"})

    valid = df_j[[x, y]].dropna()
    n = len(valid)
    plt.title("" if publication else f"{fill(title, 40)}\n(N = {n/2:.0f})")
    plt.xlabel(x if x_label is None else x_label)
    plt.ylabel(y if y_label is None else y_label)

    if x in AXIS_LIMITS: plt.xlim(*AXIS_LIMITS[x])
    if y in AXIS_LIMITS: plt.ylim(*AXIS_LIMITS[y])

    if draw_diagonal:
        lim = [min(plt.xlim()[0], plt.ylim()[0]), max(plt.xlim()[1], plt.ylim()[1])]
        plt.plot(lim, lim, '--', color='black', linewidth=1)
    else:
        plt.axhline(0, linestyle='--', color='black', linewidth=1)

    if not publication:
        plt.grid(True); plt.legend(title="Source")
    else:
        plt.grid(False); plt.legend().remove()
        plt.xlabel(""); plt.ylabel(""); plt.title("")
        plt.tick_params(axis="both", which="both", labelsize=0)

    plt.tight_layout(); plt.savefig(save_path, format="svg"); plt.close()

def generate_scatter_plots():
    df = pd.read_csv(INPUT_PATH)
    df = df[~(df["human_likert_pred"].isna() & df["llm_likert_pred"].isna())]
    df = deduplicate_by_llm_text_length(df)
    df = df[df[["event_type", "chat_order"]].apply(tuple, axis=1).isin([
        ("Initial Opinion", 0),
        ("tweet", 1),
        ("tweet", 2),
        ("tweet", 3),
        ("Post Opinion", 4),
    ])]

    REVERSE_TOPIC = "Everything_that_happens_can_eventually_be_explained_by_science"
    PSEUDO_TOPIC = f"{REVERSE_TOPIC}_reversed"

    # ---------------- Stance data (likert preds) ----------------
    df_human_st = df[df["human_likert_pred"].notna()]
    df_llm_st   = df[(df["llm_likert_pred"].notna()) & (df["model_name"] == MODEL_NAME)]

    # duplicate & mark the reversed topic
    human_extra_st = df_human_st[df_human_st["topic"] == REVERSE_TOPIC].copy()
    human_extra_st["topic"] = PSEUDO_TOPIC
    df_human_st = pd.concat([df_human_st, human_extra_st], ignore_index=True)

    llm_extra_st = df_llm_st[df_llm_st["topic"] == REVERSE_TOPIC].copy()
    llm_extra_st["topic"] = PSEUDO_TOPIC
    df_llm_st = pd.concat([df_llm_st, llm_extra_st], ignore_index=True)

    grouped_human_st = collect_agent_data(df_human_st, "human_likert_pred")
    grouped_llm_st   = collect_agent_data(df_llm_st,   "llm_likert_pred")
    df_human_plot_st = build_plot_df(grouped_human_st, "human")
    df_llm_plot_st   = build_plot_df(grouped_llm_st,   MODEL_NAME)
    plot_df_st = pd.concat([df_human_plot_st, df_llm_plot_st], ignore_index=True)

    # ---------------- Slider data ----------------
    df_human_sl = df[df["human_slider"].notna()]
    df_llm_sl   = df[(df["llm_slider"].notna()) & (df["model_name"] == MODEL_NAME)]

    human_extra_sl = df_human_sl[df_human_sl["topic"] == REVERSE_TOPIC].copy()
    human_extra_sl["topic"] = PSEUDO_TOPIC
    df_human_sl = pd.concat([df_human_sl, human_extra_sl], ignore_index=True)

    llm_extra_sl = df_llm_sl[df_llm_sl["topic"] == REVERSE_TOPIC].copy()
    llm_extra_sl["topic"] = PSEUDO_TOPIC
    df_llm_sl = pd.concat([df_llm_sl, llm_extra_sl], ignore_index=True)

    grouped_human_sl = collect_agent_data(df_human_sl, "human_slider")
    grouped_llm_sl   = collect_agent_data(df_llm_sl,   "llm_slider")
    df_human_plot_sl = build_plot_df(grouped_human_sl, "human")
    df_llm_plot_sl   = build_plot_df(grouped_llm_sl,   MODEL_NAME)
    plot_df_sl = pd.concat([df_human_plot_sl, df_llm_plot_sl], ignore_index=True)

    # ---------------- Per-topic plots ----------------
    def run_manifest(manifest, df_plot, domain_prefix):
        for topic in df_plot["topic"].unique():
            df_topic = df_plot[df_plot["topic"] == topic]
            topic_dir = os.path.join(PLOT_OUTPUT_ROOT, topic)
            os.makedirs(topic_dir, exist_ok=True)

            for (y, x, title, diag, (y_lab, x_lab)) in manifest:
                fname = get_filename_from_axes(x, y, domain_prefix=domain_prefix)
                save = os.path.join(topic_dir, fname)
                make_scatter_plot(df_topic, x, y, f"{title} – {topic}", save,
                                  draw_diagonal=diag, x_label=x_lab, y_label=y_lab)
                pub_save = save.replace(".svg", "_publication.svg")
                make_scatter_plot(df_topic, x, y, f"{title} – {topic}", pub_save,
                                  publication=True, draw_diagonal=diag,
                                  x_label=x_lab, y_label=y_lab)

    run_manifest(STANCE_SPECS, plot_df_st, "stance")
    run_manifest(SLIDER_SPECS, plot_df_sl, "slider")

    # ---------------- Aggregated plots (drop unreversed base topic) ----------------
    def run_agg(manifest, df_plot, domain_prefix):
        AGG = df_plot[df_plot["topic"] != REVERSE_TOPIC]
        for (y, x, title, diag, (y_lab, x_lab)) in manifest:
            fname = get_filename_from_axes(x, y, domain_prefix=domain_prefix)
            save = os.path.join(PLOT_OUTPUT_ROOT, fname)
            make_scatter_plot(AGG, x, y, f"{title} – Aggregated", save,
                              draw_diagonal=diag, x_label=x_lab, y_label=y_lab)
            pub_save = save.replace(".svg", "_publication.svg")
            make_scatter_plot(AGG, x, y, f"{title} – Aggregated", pub_save,
                              publication=True, draw_diagonal=diag,
                              x_label=x_lab, y_label=y_lab)

    run_agg(STANCE_SPECS, plot_df_st, "stance")
    run_agg(SLIDER_SPECS, plot_df_sl, "slider")

    # ---------------- New: Three paired Δ plots (LLM vs Human) ----------------
    # 1) Δ Slider (Post-Initial) using plot_df_sl
    paired_sl = build_paired_delta_df(plot_df_sl, "Post-Initial")
    _save_paired_plots(
        paired_sl,
        base_name="paired_delta_slider",
        axis_key="Post-Initial",
        pretty_name="Δ Slider",
        output_root=PLOT_OUTPUT_ROOT,
        per_topic_title_prefix="Δ Slider (LLM vs Human; y vs x)",
        agg_title_prefix="Δ Slider (LLM vs Human; y vs x)"
    )

    # 2) Δ Stance (Post-Initial) using plot_df_st
    paired_st_post = build_paired_delta_df(plot_df_st, "Post-Initial")
    _save_paired_plots(
        paired_st_post,
        base_name="paired_delta_stance",
        axis_key="Post-Initial",
        pretty_name="Δ Stance (Post−Initial)",
        output_root=PLOT_OUTPUT_ROOT,
        per_topic_title_prefix="Δ Stance (Post−Initial) (LLM vs Human; y vs x)",
        agg_title_prefix="Δ Stance (Post−Initial) (LLM vs Human; y vs x)"
    )

    # 3) Δ Tweet Stance (Tweet3−Tweet1) using plot_df_st
    paired_st_tweet = build_paired_delta_df(plot_df_st, "Tweet3-Tweet1")
    _save_paired_plots(
        paired_st_tweet,
        base_name="paired_delta_tweet",
        axis_key="Tweet3-Tweet1",
        pretty_name="Δ Tweet Stance (T3−T1)",
        output_root=PLOT_OUTPUT_ROOT,
        per_topic_title_prefix="Δ Tweet (T3−T1) (LLM vs Human; y vs x)",
        agg_title_prefix="Δ Tweet (T3−T1) (LLM vs Human; y vs x)"
    )


    print("✅ SVG scatter plots generated for stance & slider (with publication versions).")

if __name__ == "__main__":
    generate_scatter_plots()