import os
import numpy as np
import pandas as pd
from scipy.stats import ttest_rel

# --- Paths ---
PROJECT_ROOT = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", ".."))
PREPROCESSED_PATH = os.path.join(PROJECT_ROOT, "result", "group_level_eval", "preprocessed_depth_valid.csv")
OUTPUT_ROOT = os.path.join(PROJECT_ROOT, "result", "group_level_eval", "base_analysis")
os.makedirs(OUTPUT_ROOT, exist_ok=True)

# --- Load + Filter ---
df = pd.read_csv(PREPROCESSED_PATH)
df = df[df["model_name"] == "gpt-4o-mini-2024-07-18"]

EVENT_ORDER = [
    ("Initial Opinion", 0),
    ("tweet", 1),
    ("tweet", 2),
    ("tweet", 3),
    ("Post Opinion", 4),
]
df = df[df[["event_type", "chat_order"]].apply(tuple, axis=1).isin(EVENT_ORDER)].copy()

# Detect topic column name
TOPIC_COL = "Topic" if "Topic" in df.columns else ("topic" if "topic" in df.columns else None)
if TOPIC_COL is None:
    raise KeyError("Neither 'Topic' nor 'topic' column found in preprocessed_depth_valid.csv")

# The topic we reverse-code (keep only the reversed version)
REVERSE_TOPIC = "Everything_that_happens_can_eventually_be_explained_by_science"
REVERSED_SUFFIX = "_reversed"

# Deduplicate (choose row with longest llm_text per (time_stamp, event_type, chat_order, human_id))
def deduplicate_by_llm_text_length(d):
    d = d.copy()
    d["_len"] = d["llm_text"].fillna("").astype(str).str.len()
    idx = d.groupby(["time_stamp", "event_type", "chat_order", "human_id"])["_len"].idxmax()
    out = d.loc[idx.values].drop(columns=["_len"])
    out = out.sort_values(["time_stamp", "event_type", "chat_order", "human_id"])
    return out

df = deduplicate_by_llm_text_length(df)

# ------------------------------
# Helpers for group-level pairing
# ------------------------------
KEYS = [TOPIC_COL, "time_stamp"]  # group identity

def make_group_agg(sub_df, value_col):
    """
    Normalize by subtracting 3.5, then compute mean & std per (topic, time_stamp, event).
    Returns columns: KEYS + event_type + chat_order + mean + std + n
    """
    work = sub_df.copy()
    work[value_col] = work[value_col].astype(float)
    work["_value_norm"] = work[value_col] - 3.5  # normalization
    agg = (
        work.groupby(KEYS + ["event_type", "chat_order"], as_index=False)["_value_norm"]
            .agg(mean="mean", std="std", n="count")
    )
    return agg

def augment_with_reversed_topic(agg_df):
    """
    Replace REVERSE_TOPIC with a reversed version only.
    Flips the mean (× -1) and renames topic to REVERSE_TOPIC + REVERSED_SUFFIX.
    Std (dispersion) remains unchanged by design.
    """
    if REVERSE_TOPIC not in set(agg_df[TOPIC_COL].unique()):
        return agg_df

    rev = agg_df[agg_df[TOPIC_COL] == REVERSE_TOPIC].copy()
    if rev.empty:
        return agg_df

    # Flip mean values, keep std/n
    rev["mean"] = -1.0 * rev["mean"]
    rev[TOPIC_COL] = REVERSE_TOPIC + REVERSED_SUFFIX

    # Drop the original topic, keep only reversed
    others = agg_df[agg_df[TOPIC_COL] != REVERSE_TOPIC]
    return pd.concat([others, rev], ignore_index=True)

def build_group_pairs(agg_df, eventA, eventB):
    """
    From aggregated table, build paired rows for events A and B by KEYS.
    eventA/B: tuple -> ('tweet', 1) or ('Initial Opinion', 0), etc.
    Output includes mean_A, std_A, mean_B, std_B along with KEYS.
    """
    ea_type, ea_ord = eventA
    eb_type, eb_ord = eventB

    A = (agg_df[(agg_df["event_type"] == ea_type) & (agg_df["chat_order"] == ea_ord)]
         [KEYS + ["mean", "std"]].rename(columns={"mean": "mean_A", "std": "std_A"}))
    B = (agg_df[(agg_df["event_type"] == eb_type) & (agg_df["chat_order"] == eb_ord)]
         [KEYS + ["mean", "std"]].rename(columns={"mean": "mean_B", "std": "std_B"}))

    merged = A.merge(B, on=KEYS, how="inner").dropna(subset=["mean_A", "mean_B"])
    return merged

def ttest_pair(a, b):
    """
    Paired t-test a vs b. Returns (n, df, t, p) with 'NA' for underpowered cases.
    """
    n = len(a)
    if n > 1:
        t, p = ttest_rel(a, b)
        df_val = n - 1
        return n, df_val, float(t), float(p)
    return n, "NA", "NA", "NA"

def summarize_array(x):
    """
    Returns (mean, sd) with ddof=1 for sd (or nan if len<2). 'NA' when empty.
    """
    if len(x) == 0:
        return "NA", "NA"
    m = float(np.mean(x))
    s = float(np.std(x, ddof=1)) if len(x) > 1 else float("nan")
    return m, s

def add_rows_for_mean_and_std(results, domain, pair_label, who, merged):
    """
    Add two rows to results: one for paired t-test on group means, one for group stds.
    """
    meanA, meanB = merged["mean_A"].to_numpy(), merged["mean_B"].to_numpy()
    stdA,  stdB  = merged["std_A"].to_numpy(),  merged["std_B"].to_numpy()

    # Means
    n_m, df_m, t_m, p_m = ttest_pair(meanA, meanB)
    delta_m = meanA - meanB
    diff_mean_m = float(np.mean(delta_m) if n_m > 0 else "NA")
    sd_diff_m = float(np.std(delta_m, ddof=1)) if n_m > 1 else float("nan")
    se_diff_m = (sd_diff_m / np.sqrt(n_m)) if n_m > 1 else float("nan")

    mA, sA = summarize_array(meanA)
    mB, sB = summarize_array(meanB)
    results.append({
        "domain": domain,
        "pair": pair_label,
        "who": who,
        "stat_on": "group_mean",
        "n_groups": n_m,
        "df": df_m,
        "t_stat": round(t_m, 4) if t_m != "NA" else "NA",
        "p_value": round(p_m, 4) if p_m != "NA" else "NA",
        "A_mean_of_group_values": round(mA, 4) if mA != "NA" else "NA",
        "A_sd_of_group_values": round(sA, 4) if sA != "NA" else "NA",
        "B_mean_of_group_values": round(mB, 4) if mB != "NA" else "NA",
        "B_sd_of_group_values": round(sB, 4) if sB != "NA" else "NA",
        "diff_mean": round(diff_mean_m, 6) if isinstance(diff_mean_m, float) else diff_mean_m,
        "sd_diff": round(sd_diff_m, 6) if isinstance(sd_diff_m, float) else sd_diff_m,
        "se_diff": round(se_diff_m, 6) if isinstance(se_diff_m, float) else se_diff_m,
    })

    # STDs (diversity)
    n_s, df_s, t_s, p_s = ttest_pair(stdA, stdB)
    d_s = stdA - stdB
    diff_mean_s = float(np.mean(d_s)) if n_s > 0 else "NA"
    sd_diff_s = float(np.std(d_s, ddof=1)) if n_s > 1 else float("nan")
    se_diff_s = (sd_diff_s / np.sqrt(n_s)) if n_s > 1 else float("nan")

    mA_s, sA_s = summarize_array(stdA)
    mB_s, sB_s = summarize_array(stdB)
    results.append({
        "domain": domain,
        "pair": pair_label,
        "who": who,
        "stat_on": "group_std",
        "n_groups": n_s,
        "df": df_s,
        "t_stat": round(t_s, 4) if t_s != "NA" else "NA",
        "p_value": round(p_s, 4) if p_s != "NA" else "NA",
        "A_mean_of_group_values": round(mA_s, 4) if mA_s != "NA" else "NA",
        "A_sd_of_group_values": round(sA_s, 4) if sA_s != "NA" else "NA",
        "B_mean_of_group_values": round(mB_s, 4) if mB_s != "NA" else "NA",
        "B_sd_of_group_values": round(sB_s, 4) if sB_s != "NA" else "NA",

        "diff_mean": round(diff_mean_s, 6) if isinstance(diff_mean_s, float) else diff_mean_s,
        "sd_diff": round(sd_diff_s, 6) if isinstance(sd_diff_s, float) else sd_diff_s,
        "se_diff": round(se_diff_s, 6) if isinstance(se_diff_s, float) else se_diff_s,
    })

def add_diff_of_diff(results, topic_col, df, domain, col_human, col_llm, eventA, eventB, pair_label):
    """
    Diff-of-diff on MEANS: build per-group delta for human & llm, paired t-test on deltas.
    Adds diff_mean, sd_diff, se_diff where diff_mean = mean(human_delta - llm_delta).
    """
    # Human
    sub_h = df[[topic_col, "time_stamp", "event_type", "chat_order", col_human]].dropna()
    agg_h = make_group_agg(sub_h, col_human)
    agg_h = augment_with_reversed_topic(agg_h)
    mh = build_group_pairs(agg_h, eventA, eventB)
    mh = mh[[topic_col, "time_stamp", "mean_A", "mean_B"]].rename(
        columns={"mean_A": "human_mean_A", "mean_B": "human_mean_B"}
    )

    # LLM
    sub_l = df[[topic_col, "time_stamp", "event_type", "chat_order", col_llm]].dropna()
    agg_l = make_group_agg(sub_l, col_llm)
    agg_l = augment_with_reversed_topic(agg_l)
    ml = build_group_pairs(agg_l, eventA, eventB)
    ml = ml[[topic_col, "time_stamp", "mean_A", "mean_B"]].rename(
        columns={"mean_A": "llm_mean_A", "mean_B": "llm_mean_B"}
    )

    aligned = mh.merge(ml, on=[topic_col, "time_stamp"], how="inner")

    aligned["human_delta"] = aligned["human_mean_B"] - aligned["human_mean_A"]
    aligned["llm_delta"]   = aligned["llm_mean_B"]   - aligned["llm_mean_A"]

    valid = aligned.dropna(subset=["human_delta", "llm_delta"])
    a = valid["human_delta"].to_numpy()
    b = valid["llm_delta"].to_numpy()

    n, df_val, t_out, p_out = ttest_pair(a, b)

    # summarize each series and the paired differences (a - b)
    def _summ(x):
        if len(x) == 0:
            return "NA", "NA"
        m = float(np.mean(x))
        s = float(np.std(x, ddof=1)) if len(x) > 1 else float("nan")
        return m, s

    mA, sA = _summ(a)  # human Δ
    mB, sB = _summ(b)  # llm Δ
    d = a - b
    diff_mean = float(np.mean(d)) if n > 0 else "NA"
    sd_diff   = float(np.std(d, ddof=1)) if n > 1 else float("nan")
    se_diff   = (sd_diff / np.sqrt(n)) if n > 1 else float("nan")

    results.append({
        "domain": domain,
        "pair": pair_label,
        "who": "human_vs_llm",
        "stat_on": "diff_of_diff_mean",
        "n_groups": n,
        "df": df_val,
        "t_stat": round(t_out, 4) if t_out != "NA" else "NA",
        "p_value": round(p_out, 4) if p_out != "NA" else "NA",
        "A_mean_of_group_values": round(mA, 4) if mA != "NA" else "NA",
        "A_sd_of_group_values": round(sA, 4) if sA != "NA" else "NA",
        "B_mean_of_group_values": round(mB, 4) if mB != "NA" else "NA",
        "B_sd_of_group_values": round(sB, 4) if sB != "NA" else "NA",

        "diff_mean": round(diff_mean, 6) if isinstance(diff_mean, float) else diff_mean,
        "sd_diff":   round(sd_diff,   6) if isinstance(sd_diff,   float) else sd_diff,
        "se_diff":   round(se_diff,   6) if isinstance(se_diff,   float) else se_diff,
    })

def add_diff_of_diff_std(results, topic_col, df, domain, col_human, col_llm, eventA, eventB, pair_label):
    """
    Diff-of-diff on STDs (diversity): per-group std delta for human & llm, paired t-test on those deltas.
    Adds diff_mean, sd_diff, se_diff where diff_mean = mean(human_std_delta - llm_std_delta).
    """
    # Human
    sub_h = df[[topic_col, "time_stamp", "event_type", "chat_order", col_human]].dropna()
    agg_h = make_group_agg(sub_h, col_human)
    agg_h = augment_with_reversed_topic(agg_h)  # flip means only; stds preserved
    mh = build_group_pairs(agg_h, eventA, eventB)
    mh = mh[[topic_col, "time_stamp", "std_A", "std_B"]].rename(
        columns={"std_A": "human_std_A", "std_B": "human_std_B"}
    )

    # LLM
    sub_l = df[[topic_col, "time_stamp", "event_type", "chat_order", col_llm]].dropna()
    agg_l = make_group_agg(sub_l, col_llm)
    agg_l = augment_with_reversed_topic(agg_l)
    ml = build_group_pairs(agg_l, eventA, eventB)
    ml = ml[[topic_col, "time_stamp", "std_A", "std_B"]].rename(
        columns={"std_A": "llm_std_A", "std_B": "llm_std_B"}
    )

    aligned = mh.merge(ml, on=[topic_col, "time_stamp"], how="inner")

    aligned["human_std_delta"] = aligned["human_std_B"] - aligned["human_std_A"]
    aligned["llm_std_delta"]   = aligned["llm_std_B"]   - aligned["llm_std_A"]

    valid = aligned.dropna(subset=["human_std_delta", "llm_std_delta"])
    a = valid["human_std_delta"].to_numpy()
    b = valid["llm_std_delta"].to_numpy()

    n, df_val, t_out, p_out = ttest_pair(a, b)

    def _summ(x):
        if len(x) == 0:
            return "NA", "NA"
        m = float(np.mean(x))
        s = float(np.std(x, ddof=1)) if len(x) > 1 else float("nan")
        return m, s

    mA, sA = _summ(a)  # human std Δ
    mB, sB = _summ(b)  # llm std Δ
    d = a - b
    diff_mean = float(np.mean(d)) if n > 0 else "NA"
    sd_diff   = float(np.std(d, ddof=1)) if n > 1 else float("nan")
    se_diff   = (sd_diff / np.sqrt(n)) if n > 1 else float("nan")

    results.append({
        "domain": domain,
        "pair": pair_label,
        "who": "human_vs_llm",
        "stat_on": "diff_of_diff_std",
        "n_groups": n,
        "df": df_val,
        "t_stat": round(t_out, 4) if t_out != "NA" else "NA",
        "p_value": round(p_out, 4) if p_out != "NA" else "NA",
        "A_mean_of_group_values": round(mA, 4) if mA != "NA" else "NA",
        "A_sd_of_group_values": round(sA, 4) if sA != "NA" else "NA",
        "B_mean_of_group_values": round(mB, 4) if mB != "NA" else "NA",
        "B_sd_of_group_values": round(sB, 4) if sB != "NA" else "NA",

        "diff_mean": round(diff_mean, 6) if isinstance(diff_mean, float) else diff_mean,
        "sd_diff":   round(sd_diff,   6) if isinstance(sd_diff,   float) else sd_diff,
        "se_diff":   round(se_diff,   6) if isinstance(se_diff,   float) else se_diff,
    })

# ------------------------------
# Build aggregates & run tests
# ------------------------------
results = []
group_tables_to_save = []

# 1) Stance-based Tweet1 vs Tweet3 (likert preds)
for who, col in [("human", "human_likert_pred"), ("llm", "llm_likert_pred")]:
    sub = df[[TOPIC_COL, "time_stamp", "event_type", "chat_order", col]].dropna()
    agg = make_group_agg(sub, col)
    agg = augment_with_reversed_topic(agg)  # keep only reversed version of REVERSE_TOPIC
    merged = build_group_pairs(agg, ("tweet", 1), ("tweet", 3))
    merged["_who"] = who; merged["_domain"] = "stance"; merged["_pair"] = "tweet1_vs_tweet3"
    group_tables_to_save.append(merged)
    add_rows_for_mean_and_std(results, "stance", "tweet1_vs_tweet3", who, merged)

# 2) Stance-based Initial vs Post (likert preds)
for who, col in [("human", "human_likert_pred"), ("llm", "llm_likert_pred")]:
    sub = df[[TOPIC_COL, "time_stamp", "event_type", "chat_order", col]].dropna()
    agg = make_group_agg(sub, col)
    agg = augment_with_reversed_topic(agg)
    merged = build_group_pairs(agg, ("Initial Opinion", 0), ("Post Opinion", 4))
    merged["_who"] = who; merged["_domain"] = "stance"; merged["_pair"] = "initial_vs_post"
    group_tables_to_save.append(merged)
    add_rows_for_mean_and_std(results, "stance", "initial_vs_post", who, merged)

# 3) Slider-based Initial vs Post (slider)
for who, col in [("human", "human_slider"), ("llm", "llm_slider")]:
    sub = df[[TOPIC_COL, "time_stamp", "event_type", "chat_order", col]].dropna()
    agg = make_group_agg(sub, col)
    agg = augment_with_reversed_topic(agg)
    merged = build_group_pairs(agg, ("Initial Opinion", 0), ("Post Opinion", 4))
    merged["_who"] = who; merged["_domain"] = "slider"; merged["_pair"] = "initial_vs_post"
    group_tables_to_save.append(merged)
    add_rows_for_mean_and_std(results, "slider", "initial_vs_post", who, merged)

# Diff-of-diff (MEANS)
add_diff_of_diff(
    results, TOPIC_COL, df,
    domain="stance",
    col_human="human_likert_pred",
    col_llm="llm_likert_pred",
    eventA=("tweet", 1),
    eventB=("tweet", 3),
    pair_label="tweet1_vs_tweet3"
)
add_diff_of_diff(
    results, TOPIC_COL, df,
    domain="stance",
    col_human="human_likert_pred",
    col_llm="llm_likert_pred",
    eventA=("Initial Opinion", 0),
    eventB=("Post Opinion", 4),
    pair_label="initial_vs_post"
)
add_diff_of_diff(
    results, TOPIC_COL, df,
    domain="slider",
    col_human="human_slider",
    col_llm="llm_slider",
    eventA=("Initial Opinion", 0),
    eventB=("Post Opinion", 4),
    pair_label="initial_vs_post"
)

# Diff-of-diff (STDs / Diversity)
add_diff_of_diff_std(
    results, TOPIC_COL, df,
    domain="stance",
    col_human="human_likert_pred",
    col_llm="llm_likert_pred",
    eventA=("tweet", 1),
    eventB=("tweet", 3),
    pair_label="tweet1_vs_tweet3"
)
add_diff_of_diff_std(
    results, TOPIC_COL, df,
    domain="stance",
    col_human="human_likert_pred",
    col_llm="llm_likert_pred",
    eventA=("Initial Opinion", 0),
    eventB=("Post Opinion", 4),
    pair_label="initial_vs_post"
)
add_diff_of_diff_std(
    results, TOPIC_COL, df,
    domain="slider",
    col_human="human_slider",
    col_llm="llm_slider",
    eventA=("Initial Opinion", 0),
    eventB=("Post Opinion", 4),
    pair_label="initial_vs_post"
)

# --- Save results ---
results_df = pd.DataFrame(results)
out_csv = os.path.join(OUTPUT_ROOT, "paired_ttest_results.csv")
results_df.to_csv(out_csv, index=False)
print(f"✅ Saved group-level (means, stds, and diff-of-diff for both) paired t-test results to:\n{out_csv}")

# Optional: inspect the per-group merged tables used for the tests
if group_tables_to_save:
    debug_df = pd.concat(group_tables_to_save, ignore_index=True)
    debug_out = os.path.join(OUTPUT_ROOT, "paired_ttest_group_stats.csv")
    debug_df.to_csv(debug_out, index=False)
    print(f"🧪 Saved per-group aggregates to:\n{debug_out}")
