import os
import pandas as pd
import matplotlib.pyplot as plt
import numpy as np
import textwrap

# --- Config ---
PROJECT_ROOT = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", ".."))
OUTPUT_ROOT = os.path.join(PROJECT_ROOT, "result", "group_level_eval", "base_analysis")

EVENT_LABELS = ["Initial", "Tweet 1", "Tweet 2", "Tweet 3", "Post"]
BIAS_COLS = [
    "Initial Opinion - Slider Bias",
    "Tweet 1 Bias",
    "Tweet 2 Bias",
    "Tweet 3 Bias",
    "Post Opinion - Slider Bias"
]
BIAS_SEM_COLS = [col + " SEM" for col in BIAS_COLS]

DIVERSITY_COLS = [
    "Initial Opinion - Slider Diversity",
    "Tweet 1 Diversity",
    "Tweet 2 Diversity",
    "Tweet 3 Diversity",
    "Post Opinion - Slider Diversity"
]
DIVERSITY_SEM_COLS = [col + " SEM" for col in DIVERSITY_COLS]

# ---- New: read group-level paired t-tests (means & stds) ----
TTEST_PATH = os.path.join(OUTPUT_ROOT, "paired_ttest_results.csv")

def _load_group_ttests(ttest_path=TTEST_PATH):
    """
    Returns a dict keyed by (domain, pair, who, stat_on) -> {'n','t','p'}.
      domain: 'stance' | 'slider'
      pair: 'tweet1_vs_tweet3' | 'initial_vs_post'
      who: 'human' | 'llm'
      stat_on: 'group_mean' | 'group_std'
    """
    if not os.path.exists(ttest_path):
        return {}
    df = pd.read_csv(ttest_path)
    out = {}
    for _, r in df.iterrows():
        key = (
            str(r.get("domain")).strip(),
            str(r.get("pair")).strip(),
            str(r.get("who")).strip(),
            str(r.get("stat_on")).strip(),
        )
        out[key] = {
            "n": r.get("n_groups", r.get("n", "")),
            "t": r.get("t_stat", ""),
            "p": r.get("p_value", "")
        }
    return out

def _fmt_line(res, domain, pair, stat_on, label):
    """
    Build a one-liner combining Human & LLM results:
      '<label>: Human t=..., p=... • LLM t=..., p=... (n=...)'
    If either side missing, returns '' (no annotation).
    """
    h = res.get((domain, pair, "human", stat_on))
    l = res.get((domain, pair, "llm",   stat_on))
    if not h or not l:
        return ""
    def f(x):
        try:
            return f"{float(x):.4f}"
        except:
            return str(x)
    n = h.get("n", "")
    return f"{label}: Human t={f(h['t'])}, p={f(h['p'])} • LLM t={f(l['t'])}, p={f(l['p'])} (n={n})"

def _tp_for(ttests, domain, pair, stat_on, who_tag):
    """
    who_tag: 'human' | 'llm'
    Returns '(t=..., p=...)' or '' if not found.
    """
    r = ttests.get((domain, pair, who_tag, stat_on))
    if not r:
        return ""
    def f(v):
        try: return f"{float(v):.4f}"
        except: return str(v)
    return f"(t={f(r.get('t'))}, p={f(r.get('p'))})"

def _who_tag_from_label(label):
    """Map legend label used in plots to t-test 'who' key."""
    return "human" if str(label).lower() == "human" else "llm"

def _total_group_n_in_df(df_all_topics):
    """
    Compute total number of groups across topics from the plotting DataFrame,
    after reverse-coded topics have been appended.

    Assumes each per-topic block contains a 'num_groups' column that counts
    groups for that topic; we take the max per topic and sum across topics.
    Excludes the 'All' pseudo-topic.
    """
    sub = df_all_topics[df_all_topics["Topic"] != "All"]
    if "num_groups" not in sub.columns:
        return "?"
    topic_counts = (
        sub.dropna(subset=["num_groups"])
           .groupby("Topic")["num_groups"]
           .max()
    )
    if topic_counts.empty:
        return "?"
    return int(topic_counts.sum())

def _series_offsets(n, jitter):
    """Balanced offsets around 0: n=1->[0], n=2->[-j,+j], n=3->[-j,0,+j], ..."""
    if n <= 1:
        return [0.0]
    # start centered around 0
    idx = np.arange(n) - (n - 1) / 2.0
    # if even, this creates halves; multiply by jitter to space them
    return (idx * jitter).tolist()

# ---- Plotters ----

def plot_split_trajectories(df_topic, metric_cols, ylabel, base_path, title_base, publication=False):
    for traj_type in ["opinion", "tweet"]:
        plt.figure(figsize=(6, 6))

        x = [0, 4] if traj_type == "opinion" else [1, 2, 3]
        x_labels = [EVENT_LABELS[i] for i in x]

        for who in df_topic["Human/LLM"].unique():
            row = df_topic[df_topic["Human/LLM"] == who]
            values = row[metric_cols].values[0]
            sliced_y = [values[i] for i in x]

            is_human = who.lower() == "human"
            tweet_color = "#0000FF" if is_human else "#FF0000"
            opinion_color = "#0000FF" if is_human else "#FF0000"
            color = tweet_color if traj_type == "tweet" else opinion_color

            plt.plot(x, sliced_y, marker="o", linewidth=2, label=who, color=color)

        # Plot decoration
        if "Bias" in ylabel:
            plt.axhline(0, linestyle="--", color="gray", linewidth=1)
        plt.xticks(x, x_labels)
        plt.xlabel("Event")
        plt.ylabel(ylabel)

        num_groups = df_topic["num_groups"].dropna().unique()
        num_groups = int(num_groups[0]) if len(num_groups) == 1 else "?"
        title = f"{title_base} ({traj_type.capitalize()} Trajectory, n = {num_groups})"
        plt.title("\n".join(textwrap.wrap(title, width=70)))
        if not publication:
            plt.grid(True, linestyle="--", alpha=0.5)
            plt.legend()
        else:
            plt.grid(False)         # remove grid
            plt.legend().remove()   # remove legend
            plt.title("")           # remove title
            plt.xlabel("")          # remove x-label
            plt.ylabel("")          # remove y-label
            # plt.xticks([])
            # plt.yticks([])

        if "Bias" in ylabel:
            plt.ylim(-2.5, 2.5)
            plt.yticks([-2.5, -1.5, -0.5, 0.5, 1.5, 2.5])
        elif "Diversity" in ylabel:
            plt.ylim(0, 3)

        # if publication:
        #     plt.xticks([])
        #     plt.yticks([])
        if publication:
            plt.xticks(plt.xticks()[0], ['' for _ in plt.xticks()[0]])
            plt.yticks(plt.yticks()[0], ['' for _ in plt.yticks()[0]])

        plt.tight_layout()

        suffix = "_publication" if publication else ""
        path = f"{base_path}_{traj_type}{suffix}.svg"
        plt.savefig(path)

        plt.close()

def plot_trajectory(df_topic, metric_cols, ylabel, save_path, title, publication=False, jitter=0.06):
    plt.figure(figsize=(6, 6))
    base_x = np.arange(5)
    who_list = list(df_topic["Human/LLM"].unique())
    offsets = _series_offsets(len(who_list), jitter)

    for who, dx in zip(who_list, offsets):
        row = df_topic[df_topic["Human/LLM"] == who]
        values = row[metric_cols].values[0]

        is_human = who.lower() == "human"
        tweet_color = "#0000FF" if is_human else "#FF0000"
        opinion_color = "#0000FF" if is_human else "#FF0000"

        # jittered x’s
        x_tweets = (base_x[1:4] + dx)
        x_init   = base_x[0] + dx
        x_post   = base_x[4] + dx

        plt.plot(x_tweets, values[1:4], marker="o", linewidth=2, label=who, color=tweet_color)
        plt.plot(x_init, values[0], 'o', color=opinion_color)
        plt.plot(x_post, values[4], 'o', color=opinion_color)
        plt.plot([x_init, x_post], [values[0], values[4]], color=opinion_color, linewidth=2)

    if "Bias" in ylabel:
        plt.axhline(0, linestyle="--", color="gray", linewidth=1)

    # Keep ticks at the original positions
    plt.xticks(base_x, EVENT_LABELS)
    plt.xlabel("Event")
    plt.ylabel(ylabel)

    num_groups = df_topic["num_groups"].dropna().unique()
    num_groups = int(num_groups[0]) if len(num_groups) == 1 else "?"
    title_with_count = f"{title} (n = {num_groups})"
    plt.title("\n".join(textwrap.wrap(title_with_count, width=70)))

    if not publication:
        plt.grid(True, linestyle="--", alpha=0.5)
        plt.legend()
    else:
        plt.grid(False); plt.legend().remove(); plt.title(""); plt.xlabel(""); plt.ylabel("")
        plt.xticks(plt.xticks()[0], ['' for _ in plt.xticks()[0]])
        plt.yticks(plt.yticks()[0], ['' for _ in plt.yticks()[0]])

    if "Bias" in ylabel:
        plt.ylim(-2.5, 2.5); plt.yticks([-2.5, -1.5, -0.5, 0.5, 1.5, 2.5])
    elif "Diversity" in ylabel:
        plt.ylim(0, 3)

    plt.tight_layout()
    plt.savefig(save_path)
    plt.close()

def plot_errorbar(df_topic, metric_cols, sem_cols, ylabel, save_path, title, publication=False, note=None, n_override=None):
    x = np.arange(5)
    fig, ax = plt.subplots(figsize=(6, 6))

    for who in df_topic["Human/LLM"].unique():
        row = df_topic[df_topic["Human/LLM"] == who]
        means = row[metric_cols].values[0]
        yerr = row[sem_cols].values[0]

        is_human = who.lower() == "human"
        tweet_color = "#0000FF" if is_human else "#FF0000"
        opinion_color = "#0000FF" if is_human else "#FF0000"

        ax.errorbar(x[1:4], means[1:4], yerr=yerr[1:4], label=who, marker='o', capsize=6,
                    linewidth=2, color=tweet_color)
        ax.errorbar(x[0], means[0], yerr=yerr[0], fmt='o', color=opinion_color, capsize=6)
        ax.errorbar(x[4], means[4], yerr=yerr[4], fmt='o', color=opinion_color, capsize=6)
        ax.plot([x[0], x[4]], [means[0], means[4]], color=opinion_color, linewidth=2)

    if "Bias" in ylabel:
        ax.axhline(0, linestyle="--", color="gray", linewidth=1)

    ax.set_xticks(x)
    ax.set_xticklabels(EVENT_LABELS)
    ax.set_ylabel(ylabel)

    if n_override is not None:
        num_groups = n_override
    else:
        num_groups = df_topic["num_groups"].dropna().unique()
        num_groups = int(num_groups[0]) if len(num_groups) == 1 else "?"
    title_with_count = f"{title} (n = {num_groups})"
    ax.set_title("\n".join(textwrap.wrap(title_with_count, width=70)))

    if not publication:
        ax.grid(True, linestyle="--", alpha=0.5)
        ax.legend()
    else:
        ax.grid(False)
        ax.legend().remove()
        ax.set_title("")
        ax.set_xlabel("")
        ax.set_ylabel("")
        # ax.set_xticks([])
        # ax.set_yticks([])

    if "Bias" in ylabel:
        # ax.set_ylim(-2.5, 2.5)
        # ax.set_yticks([-2.5, -1.5, -0.5, 0.5, 1.5, 2.5])
        ax.set_ylim(-1.5, 1.5)
        ax.set_yticks([-1.5, -0.5, 0.5, 1.5])
    elif "Diversity" in ylabel:
        # ax.set_ylim(0, 3)
        ax.set_ylim(0, 2.5)

    # if publication:
    #     ax.set_xticks([])
    #     ax.set_yticks([])

    # neat placement above axes
    if note:
        ax.text(0.5, 1.02, note, ha='center', va='bottom', transform=ax.transAxes,
                fontsize=9)
        
    if publication:
        ax.set_xticklabels(['' for _ in ax.get_xticks()])
        ax.set_yticklabels(['' for _ in ax.get_yticks()])

    plt.tight_layout()
    plt.savefig(save_path)
    plt.close()

def plot_split_errorbars(df_topic, metric_cols, sem_cols, ylabel, base_path, title_base,
                         publication=False, note_opinion=None, note_tweet=None,
                         legend_labels_opinion=None, legend_labels_tweet=None,
                         n_override=None, jitter=0.06):
    for traj_type in ["opinion", "tweet"]:
        fig, ax = plt.subplots(figsize=(6, 6))

        base_x = np.array([0, 4] if traj_type == "opinion" else [1, 2, 3])
        x_labels = [EVENT_LABELS[i] for i in base_x]

        who_list = list(df_topic["Human/LLM"].unique())
        offsets = _series_offsets(len(who_list), jitter)

        for who, dx in zip(who_list, offsets):
            row = df_topic[df_topic["Human/LLM"] == who]
            means = row[metric_cols].values[0]
            errors = row[sem_cols].values[0]

            y = [means[i] for i in base_x]
            yerr = [errors[i] for i in base_x]

            is_human = who.lower() == "human"
            tweet_color = "#0000FF" if is_human else "#FF0000"
            opinion_color = "#0000FF" if is_human else "#FF0000"
            color = tweet_color if traj_type == "tweet" else opinion_color

            # choose legend label (only used when not publication)
            if traj_type == "opinion" and legend_labels_opinion:
                legend_label = legend_labels_opinion.get(who, who)
            elif traj_type == "tweet" and legend_labels_tweet:
                legend_label = legend_labels_tweet.get(who, who)
            else:
                legend_label = who

            # jittered x’s
            x = base_x + dx

            ax.errorbar(
                x, y, yerr=yerr,
                label=(None if publication else legend_label),
                marker='o', capsize=6, linewidth=2, color=color
            )

        if "Bias" in ylabel:
            ax.axhline(0, linestyle="--", color="gray", linewidth=1)

        # ticks stay at base positions (no jitter)
        ax.set_xticks(base_x)
        ax.set_xticklabels(x_labels)
        ax.set_ylabel(ylabel)

        num_groups = n_override if n_override is not None else (
            int(df_topic["num_groups"].dropna().unique()[0])
            if len(df_topic["num_groups"].dropna().unique()) == 1 else "?"
        )
        title = f"{title_base} ({traj_type.capitalize()} Trajectory, n = {num_groups})"
        ax.set_title("\n".join(textwrap.wrap(title, width=70)))

        if not publication:
            ax.grid(True, linestyle="--", alpha=0.5)
            ax.legend()
        else:
            ax.grid(False); ax.legend().remove(); ax.set_title(""); ax.set_xlabel(""); ax.set_ylabel("")
            ax.set_xticklabels(['' for _ in ax.get_xticks()])
            ax.set_yticklabels(['' for _ in ax.get_yticks()])

        if "Bias" in ylabel:
            ax.set_ylim(-1.5, 1.5); ax.set_yticks([-1.5, -0.5, 0.5, 1.5])
        elif "Diversity" in ylabel:
            ax.set_ylim(0, 2.5)

        note = note_opinion if traj_type == "opinion" else note_tweet
        if note:
            ax.text(0.5, 1.02, note, ha='center', va='bottom', transform=ax.transAxes, fontsize=9)

        plt.tight_layout()

        suffix = "_publication" if publication else ""
        path = f"{base_path}_{traj_type}{suffix}.svg"
        plt.savefig(path)
        plt.close()

# ---- Driver ----
def generate_all_plots(input_csv, suffix):
    df = pd.read_csv(input_csv)

    # Domain-aware y-axis labels
    is_stance = (suffix == "stance")
    YLABEL_BIAS = "Stance Bias" if is_stance else "Slider Bias"
    YLABEL_DIV  = "Stance Diversity" if is_stance else "Slider Diversity"

    REVERSE_TOPIC = "Everything_that_happens_can_eventually_be_explained_by_science"
    PSEUDO_TOPIC = f"{REVERSE_TOPIC}_reversed"

    # These columns are shared in both slider and stance files
    bias_cols = [
        "Initial Opinion - Slider Bias",
        "Tweet 1 Bias",
        "Tweet 2 Bias",
        "Tweet 3 Bias",
        "Post Opinion - Slider Bias"
    ]

    # Create reversed pseudo-topic if available
    missing_cols = [col for col in bias_cols if col not in df.columns]
    if not missing_cols:
        df_reversed = df[df["Topic"] == REVERSE_TOPIC].copy()
        df_reversed["Topic"] = PSEUDO_TOPIC
        df_reversed[bias_cols] *= -1
        df = pd.concat([df, df_reversed], ignore_index=True)
        print(f"🔁 Added reversed-code version of '{REVERSE_TOPIC}' as '{PSEUDO_TOPIC}'")
    else:
        print(f"⚠️ Skipping reversal — missing columns: {missing_cols}")

    # Per-topic plots....................................................
    for topic in df["Topic"].unique():
        df_topic = df[df["Topic"] == topic]
        for who in df_topic["Human/LLM"].unique():
            df_source = df_topic[df_topic["Human/LLM"] == who]
            who_dir = os.path.join(OUTPUT_ROOT, topic, who)
            os.makedirs(who_dir, exist_ok=True)

            for fn, metric, sem, ylabel, plot_func in [
                ("bias_trajectory", BIAS_COLS, None, YLABEL_BIAS, plot_trajectory),
                ("std_trajectory", DIVERSITY_COLS, None, YLABEL_DIV, plot_trajectory),
                ("summary_bias_errorbar", BIAS_COLS, BIAS_SEM_COLS, YLABEL_BIAS, plot_errorbar),
                ("summary_std_errorbar", DIVERSITY_COLS, DIVERSITY_SEM_COLS, YLABEL_DIV, plot_errorbar)
            ]:
                title = (f"{ylabel.split()[1]} – {topic.replace('_', ' ')} – {who}"
                         if "trajectory" in fn else
                         f"{ylabel.split()[1]} Summary – {topic.replace('_', ' ')} – {who}")
                save_path = os.path.join(who_dir, f"{fn}_{suffix}.svg")
                save_path_pub = os.path.join(who_dir, f"{fn}_{suffix}_publication.svg")
                if "trajectory" in fn:
                    plot_trajectory(df_source, metric, ylabel, save_path, title, publication=False)
                    plot_trajectory(df_source, metric, ylabel, save_path_pub, title, publication=True)

                    base_path = os.path.join(who_dir, f"{fn}_{suffix}")
                    plot_split_trajectories(df_source, metric, ylabel, base_path, title, publication=False)
                    plot_split_trajectories(df_source, metric, ylabel, base_path, title, publication=True)
                else:
                    plot_errorbar(df_source, metric, sem, ylabel, save_path, title, publication=False)
                    plot_errorbar(df_source, metric, sem, ylabel, save_path_pub, title, publication=True)

                    base_path = os.path.join(who_dir, f"{fn}_{suffix}")
                    plot_split_errorbars(df_source, metric, sem, ylabel, base_path, title, publication=False)
                    plot_split_errorbars(df_source, metric, sem, ylabel, base_path, title, publication=True)

    # Aggregated plots...................................................
    df_filtered = df[df["Topic"] != REVERSE_TOPIC].copy()

    df_all = df_filtered[df_filtered["Topic"] == "All"]
    if not df_all.empty:
        print(f"📊 Generating aggregated 'All' plots for {suffix}...")

        # <-- NEW: compute n that includes reverse-coded topic(s)
        total_n_groups = _total_group_n_in_df(df_filtered)

        # Load group-level paired t-test results
        ttests = _load_group_ttests()

        # Mean (Bias) notes
        mean_stance_opinion = _fmt_line(
            ttests, "stance", "initial_vs_post", "group_mean",
            "Paired t-test (Stance Initial→Post)"
        )
        mean_stance_tweet = _fmt_line(
            ttests, "stance", "tweet1_vs_tweet3", "group_mean",
            "Paired t-test (Stance Tweet1→Tweet3)"
        )
        mean_slider_opinion = _fmt_line(
            ttests, "slider", "initial_vs_post", "group_mean",
            "Paired t-test (Slider Initial→Post)"
        )

        # STD (Diversity) notes
        std_stance_opinion = _fmt_line(
            ttests, "stance", "initial_vs_post", "group_std",
            "Diversity (STD) t-test (Stance Initial→Post)"
        )
        std_stance_tweet = _fmt_line(
            ttests, "stance", "tweet1_vs_tweet3", "group_std",
            "Diversity (STD) t-test (Stance Tweet1→Tweet3)"
        )
        std_slider_opinion = _fmt_line(
            ttests, "slider", "initial_vs_post", "group_std",
            "Diversity (STD) t-test (Slider Initial→Post)"
        )

        for fn, metric, sem, ylabel in [
            ("summary_bias_errorbar", BIAS_COLS, BIAS_SEM_COLS, YLABEL_BIAS),
            ("summary_std_errorbar", DIVERSITY_COLS, DIVERSITY_SEM_COLS, YLABEL_DIV)
        ]:
            title = f"{ylabel.split()[1]} Summary – Aggregated Across All Topics ({suffix})"
            base_path = os.path.join(OUTPUT_ROOT, f"{fn}_{suffix}")

            # We no longer show the note above the plot for aggregated
            note_full = None
            note_opinion = None
            note_tweet = None

            # Decide which stat dimension the t-tests refer to
            stat_on = "group_mean" if "Bias" in ylabel else "group_std"

            # Build legend labels (only used on split aggregated plots)
            legend_labels_opinion = {}
            legend_labels_tweet = {}

            # What domain/pairs are relevant?
            # - stance: opinion = initial_vs_post, tweet = tweet1_vs_tweet3
            # - slider: opinion = initial_vs_post only
            for who in df_all["Human/LLM"].unique():
                who_tag = _who_tag_from_label(who)

                if suffix == "stance":
                    # Opinion (Initial→Post)
                    tp_op = _tp_for(ttests, "stance", "initial_vs_post", stat_on, who_tag)
                    legend_labels_opinion[who] = (f"{who} {tp_op}").strip()

                    # Tweet (Tweet1→Tweet3)
                    tp_tw = _tp_for(ttests, "stance", "tweet1_vs_tweet3", stat_on, who_tag)
                    legend_labels_tweet[who] = (f"{who} {tp_tw}").strip()

                elif suffix == "slider":
                    tp_op = _tp_for(ttests, "slider", "initial_vs_post", stat_on, who_tag)
                    legend_labels_opinion[who] = (f"{who} {tp_op}").strip()
                    # No tweet trajectory for slider
                    legend_labels_tweet = None

            # Full aggregated plots (keep uncluttered)
            plot_errorbar(df_all, metric, sem, ylabel, base_path + ".svg", title,
                          publication=False, note=note_full, n_override=total_n_groups)
            plot_errorbar(df_all, metric, sem, ylabel, base_path + "_publication.svg", title,
                          publication=True, note=note_full, n_override=total_n_groups)

            # Split aggregated plots — show t/p inside legend
            plot_split_errorbars(
                df_all, metric, sem, ylabel, base_path, title,
                publication=False,
                note_opinion=None, note_tweet=None,  # no annotation text above
                legend_labels_opinion=legend_labels_opinion,
                legend_labels_tweet=legend_labels_tweet,
                n_override=total_n_groups
            )

            plot_split_errorbars(
                df_all, metric, sem, ylabel, base_path, title,
                publication=True,   # publication version keeps legend hidden per your rule
                note_opinion=None, note_tweet=None,
                legend_labels_opinion=legend_labels_opinion,
                legend_labels_tweet=legend_labels_tweet,
                n_override=total_n_groups
            )

def main():
    slider_path = os.path.join(OUTPUT_ROOT, "all_base_slider_stats.csv")
    stance_path = os.path.join(OUTPUT_ROOT, "all_base_stance_stats.csv")

    print("📈 Generating SLIDER-based plots...")
    generate_all_plots(slider_path, "slider")

    print("📈 Generating STANCE-based plots...")
    generate_all_plots(stance_path, "stance")

    print("✅ All slider and stance plots generated.")

if __name__ == "__main__":
    main()