import os
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import matplotlib.colors as mcolors

# -------------------
# Paths (match yours)
# -------------------
PROJECT_ROOT = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", ".."))
PREPROCESSED_PATH = os.path.join(PROJECT_ROOT, "result", "group_level_eval", "preprocessed_depth_valid.csv")
OUTPUT_ROOT = os.path.join(PROJECT_ROOT, "result", "group_level_eval", "histograms")
os.makedirs(OUTPUT_ROOT, exist_ok=True)

# -------------------
# Load & Filter (same as your code)
# -------------------
df = pd.read_csv(PREPROCESSED_PATH)
df = df[df["model_name"] == "gpt-4o-mini-2024-07-18"]

EVENT_ORDER = [
    ("Initial Opinion", 0),
    ("tweet", 1),
    ("tweet", 2),
    ("tweet", 3),
    ("Post Opinion", 4),
]
df = df[df[["event_type", "chat_order"]].apply(tuple, axis=1).isin(EVENT_ORDER)].copy()

TOPIC_COL = "Topic" if "Topic" in df.columns else ("topic" if "topic" in df.columns else None)
if TOPIC_COL is None:
    raise KeyError("Neither 'Topic' nor 'topic' column found in preprocessed_depth_valid.csv")

REVERSE_TOPIC = "Everything_that_happens_can_eventually_be_explained_by_science"

def deduplicate_by_llm_text_length(d):
    d = d.copy()
    d["_len"] = d["llm_text"].fillna("").astype(str).str.len()
    idx = d.groupby(["time_stamp", "event_type", "chat_order", "human_id"])["_len"].idxmax()
    out = d.loc[idx.values].drop(columns=["_len"])
    out = out.sort_values(["time_stamp", "event_type", "chat_order", "human_id"])
    return out

df = deduplicate_by_llm_text_length(df)

# -------------------
# Helpers
# -------------------
def _norm_series(s):
    s = pd.to_numeric(s, errors="coerce")
    return s.dropna().astype(float) - 3.5

def get_values(event_type, chat_order, value_col):
    sub = df[(df["event_type"] == event_type) & (df["chat_order"] == chat_order)]
    vals = _norm_series(sub[value_col])
    if REVERSE_TOPIC in sub[TOPIC_COL].unique():
        mask = sub[TOPIC_COL] == REVERSE_TOPIC
        flipped = _norm_series(sub.loc[mask, value_col]) * -1.0
        keep = _norm_series(sub.loc[~mask, value_col])
        vals = pd.concat([keep, flipped], ignore_index=True)
    return vals

def mean_median_sd(x):
    if len(x) == 0:
        return np.nan, np.nan, np.nan
    m = float(np.mean(x)); med = float(np.median(x))
    sd = float(np.std(x, ddof=1)) if len(x) > 1 else np.nan
    return m, med, sd

def _discrete_edges(vals_a, vals_b):
    """Make pretty half-step edges for discrete Likert-like values; else auto."""
    allv = np.concatenate([vals_a, vals_b]) if (len(vals_a) or len(vals_b)) else np.array([])
    if allv.size == 0:
        return np.linspace(-3, 3, 7)  # fallback
    uniq = np.unique(np.round(allv, 6))
    if uniq.size <= 12:
        diffs = np.diff(uniq)
        if diffs.size and np.allclose(diffs, diffs[0], atol=1e-6):
            step = diffs[0]
            lo = uniq.min() - step/2; hi = uniq.max() + step/2
            return np.arange(lo, hi + step/2, step)
    return np.histogram_bin_edges(allv, bins=20)

def _darker_color(color, factor=0.6):
    """Return a darker shade of the given color (same family)."""
    r, g, b = mcolors.to_rgb(color)
    return (max(r * factor, 0), max(g * factor, 0), max(b * factor, 0))

def grouped_hist_bars(
    ax, a, b, label_a, label_b, color_a, color_b, title,
    density=True, show_labels=True, show_legend=True, mean_line_width=2
):
    """Distribution as grouped bars per bin (two bars per bin), ticks at bin centers."""
    a = np.asarray(a, dtype=float); b = np.asarray(b, dtype=float)

    edges = _discrete_edges(a, b)
    counts_a, _ = np.histogram(a, bins=edges, density=density)
    counts_b, _ = np.histogram(b, bins=edges, density=density)

    centers = (edges[:-1] + edges[1:]) / 2.0
    bin_w = (edges[1] - edges[0]); bar_w = bin_w * 0.45

    # Bars (side-by-side per bin)
    ax.bar(centers - bar_w/2, counts_a, width=bar_w, color=color_a, edgecolor="black")
    ax.bar(centers + bar_w/2, counts_b, width=bar_w, color=color_b, edgecolor="black")

    # Stats
    mA, medA, sdA = mean_median_sd(a)
    mB, medB, sdB = mean_median_sd(b)

    # Mean lines = darker shade, dashed for both
    dark_a = _darker_color(color_a, factor=0.6)
    dark_b = _darker_color(color_b, factor=0.6)
    if not np.isnan(mA):
        ax.axvline(mA, color=dark_a, linestyle="--", linewidth=mean_line_width)
    if not np.isnan(mB):
        ax.axvline(mB, color=dark_b, linestyle="--", linewidth=mean_line_width)

    # Legend (bars with stats; mean lines implied by color/dash)
    if show_legend:
        ax.legend([
            f"{label_a} (mean={mA:.2f}, median={medA:.2f}, sd={sdA:.2f})",
            f"{label_b} (mean={mB:.2f}, median={medB:.2f}, sd={sdB:.2f})"
        ], loc="upper right", frameon=True, fontsize=9)

    # X ticks centered between the two bars
    ax.set_xticks(centers)
    ax.set_xticklabels([f"{pos:.1f}" for pos in centers])

    ax.set_title(title, fontsize=12)

    if show_labels:
        ax.set_xlabel("Normalized value")
        ax.set_ylabel("Density" if density else "Count")
    else:
        ax.set_xlabel("")
        ax.set_ylabel("")

def save_fig(fig, filename):
    fig.tight_layout()
    out = os.path.join(OUTPUT_ROOT, filename)
    fig.savefig(out, format="svg")
    plt.close(fig)
    print(f"Saved: {out}")

# Colors
HUMAN_LIGHT, HUMAN_DEEP = "#9ecae1", "#2171b5"   # blue family
LLM_LIGHT,   LLM_DEEP   = "#fcbba1", "#cb181d"   # red  family

# -------------------
# Build & save (normal + publication versions)
# -------------------
def make_pair(a, b, la, lb, ca, cb, title, base_name):
    # Normal
    fig, ax = plt.subplots(figsize=(7, 4))
    grouped_hist_bars(ax, a, b, la, lb, ca, cb, title, density=True,
                      show_labels=True, show_legend=True)
    save_fig(fig, f"{base_name}.svg")

    # Publication (no axis labels, no legend)
    fig, ax = plt.subplots(figsize=(7, 4))
    grouped_hist_bars(ax, a, b, la, lb, ca, cb, title, density=True,
                      show_labels=False, show_legend=False)
    save_fig(fig, f"{base_name}_pub.svg")

# -------------------
# Existing four stance-based comparisons
# -------------------
# Human – Tweet1 vs Tweet3 (stance)
human_t1 = get_values("tweet", 1, "human_likert_pred")
human_t3 = get_values("tweet", 3, "human_likert_pred")
make_pair(human_t1, human_t3, "Tweet 1", "Tweet 3",
          HUMAN_LIGHT, HUMAN_DEEP, "Human: Tweet 1 vs Tweet 3 (stance)",
          "barhist_human_tweet1_vs_tweet3")

# Human – Initial vs Post (STANCE)
human_init = get_values("Initial Opinion", 0, "human_likert_pred")
human_post = get_values("Post Opinion", 4, "human_likert_pred")
make_pair(human_init, human_post, "Initial", "Post",
          HUMAN_LIGHT, HUMAN_DEEP, "Human: Initial vs Post (stance)",
          "barhist_human_initial_vs_post_stance")

# LLM – Tweet1 vs Tweet3 (stance)
llm_t1 = get_values("tweet", 1, "llm_likert_pred")
llm_t3 = get_values("tweet", 3, "llm_likert_pred")
make_pair(llm_t1, llm_t3, "Tweet 1", "Tweet 3",
          LLM_LIGHT, LLM_DEEP, "LLM: Tweet 1 vs Tweet 3 (stance)",
          "barhist_llm_tweet1_vs_tweet3")

# LLM – Initial vs Post (STANCE)
llm_init = get_values("Initial Opinion", 0, "llm_likert_pred")
llm_post = get_values("Post Opinion", 4, "llm_likert_pred")
make_pair(llm_init, llm_post, "Initial", "Post",
          LLM_LIGHT, LLM_DEEP, "LLM: Initial vs Post (stance)",
          "barhist_llm_initial_vs_post_stance")

# -------------------
# NEW: Differences histograms (stance)
# -------------------
# Utility to compute per-unit diffs aligned by keys, then apply reverse-topic flip
def get_pair_diff(eventA, eventB, value_col):
    """
    Return array of (B - A) for rows aligned by (topic, time_stamp, human_id).
    Values are normalized (−3.5) and flipped for REVERSE_TOPIC.
    """
    keys = [k for k in [TOPIC_COL, "time_stamp", "human_id"] if k in df.columns]

    ea_type, ea_ord = eventA
    eb_type, eb_ord = eventB

    A = df[(df["event_type"] == ea_type) & (df["chat_order"] == ea_ord)][keys + [value_col]].copy()
    B = df[(df["event_type"] == eb_type) & (df["chat_order"] == eb_ord)][keys + [value_col]].copy()

    A = A.rename(columns={value_col: "val_A"})
    B = B.rename(columns={value_col: "val_B"})

    merged = pd.merge(A, B, on=keys, how="inner")
    merged["val_A"] = pd.to_numeric(merged["val_A"], errors="coerce").astype(float) - 3.5
    merged["val_B"] = pd.to_numeric(merged["val_B"], errors="coerce").astype(float) - 3.5

    merged["diff"] = merged["val_B"] - merged["val_A"]

    if REVERSE_TOPIC in merged[TOPIC_COL].unique():
        rmask = merged[TOPIC_COL] == REVERSE_TOPIC
        merged.loc[rmask, "diff"] = -1.0 * merged.loc[rmask, "diff"]

    return merged["diff"].dropna().to_numpy()

# 1) Δ stance: Tweet 3 − Tweet 1 (Human vs LLM)
human_d_t31 = get_pair_diff(("tweet", 1), ("tweet", 3), "human_likert_pred")
llm_d_t31   = get_pair_diff(("tweet", 1), ("tweet", 3), "llm_likert_pred")
make_pair(human_d_t31, llm_d_t31, "Human Δ (T3−T1)", "LLM Δ (T3−T1)",
          HUMAN_LIGHT, LLM_LIGHT, "Tweet 3 − Tweet 1 (stance, Δ)",  # use light blue vs light red
          "barhist_diff_tweet3_minus_tweet1_stance")

# 2) Δ stance: Post − Initial (Human vs LLM)
human_d_pi = get_pair_diff(("Initial Opinion", 0), ("Post Opinion", 4), "human_likert_pred")
llm_d_pi   = get_pair_diff(("Initial Opinion", 0), ("Post Opinion", 4), "llm_likert_pred")
make_pair(human_d_pi, llm_d_pi, "Human Δ (Post−Initial)", "LLM Δ (Post−Initial)",
          HUMAN_LIGHT, LLM_LIGHT, "Post − Initial (stance, Δ)",
          "barhist_diff_post_minus_initial_stance")

# -------------------
# NEW: Slider-based comparisons
# -------------------

# Human – Initial vs Post (SLIDER)
human_init_s = get_values("Initial Opinion", 0, "human_slider")
human_post_s = get_values("Post Opinion", 4, "human_slider")
make_pair(human_init_s, human_post_s, "Initial", "Post",
          HUMAN_LIGHT, HUMAN_DEEP, "Human: Initial vs Post (slider)",
          "barhist_human_initial_vs_post_slider")

# LLM – Initial vs Post (SLIDER)
llm_init_s = get_values("Initial Opinion", 0, "llm_slider")
llm_post_s = get_values("Post Opinion", 4, "llm_slider")
make_pair(llm_init_s, llm_post_s, "Initial", "Post",
          LLM_LIGHT, LLM_DEEP, "LLM: Initial vs Post (slider)",
          "barhist_llm_initial_vs_post_slider")

# Δ slider: Post − Initial (Human vs LLM)
human_d_pi_s = get_pair_diff(("Initial Opinion", 0), ("Post Opinion", 4), "human_slider")
llm_d_pi_s   = get_pair_diff(("Initial Opinion", 0), ("Post Opinion", 4), "llm_slider")
make_pair(human_d_pi_s, llm_d_pi_s, "Human Δ (Post−Initial)", "LLM Δ (Post−Initial)",
          HUMAN_LIGHT, LLM_LIGHT, "Post − Initial (slider, Δ)",
          "barhist_diff_post_minus_initial_slider")


print(f"✅ All SVGs saved under: {OUTPUT_ROOT}")