import os
import pandas as pd
import numpy as np
import logging
import glob

# Setup logging
logging.basicConfig(level=logging.INFO, format="%(levelname)s: %(message)s")

PROJECT_ROOT = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", ".."))
PREPROCESSED_PATH = os.path.join(PROJECT_ROOT, "result", "group_level_eval", "preprocessed_depth_valid.csv")
OUTPUT_ROOT = os.path.join(PROJECT_ROOT, "result", "group_level_eval", "base_analysis")

EVENT_ORDER = [
    ("Initial Opinion", 0),
    ("tweet", 1),
    ("tweet", 2),
    ("tweet", 3),
    ("Post Opinion", 4),
]
EVENT_LABELS = ["Initial", "Tweet1", "Tweet2", "Tweet3", "Post"]
TWEET1_KEYS = [
    "Tweet 1 Bias", "Tweet 1 Diversity", "Tweet 1 Bias SEM", "Tweet 1 Diversity SEM"
]

def normalize_column(series):
    series = pd.to_numeric(series, errors="coerce").fillna(0)
    return series.apply(lambda x: x - 3.5 if x != 0 else 0)

def extract_stats(df, value_col, normalize=True):
    results = {}
    for label, (etype, order) in zip(EVENT_LABELS, EVENT_ORDER):
        filtered = df[(df["event_type"] == etype) & (df["chat_order"] == order)]
        values = pd.to_numeric(filtered[value_col], errors="coerce").fillna(0)
        if normalize:
            values = values.apply(lambda x: x - 3.5 if x != 0 else 0)
        nonzero = values[values != 0]
        mean = nonzero.mean()
        std = nonzero.std(ddof=0)
        n = len(nonzero)
        sem = std / np.sqrt(n) if n > 1 else 0.0
        se_std = std / np.sqrt(2 * (n - 1)) if n > 1 else 0.0

        results[f"{label}_Bias"] = mean
        results[f"{label}_Std"] = std
        results[f"{label}_N"] = n
        results[f"{label}_SEM"] = sem
        results[f"{label}_SE_Std"] = se_std
    return results

def deduplicate_by_llm_text_length(df):
    def keep_longest(group):
        if group.shape[0] == 1:
            return group
        return group.loc[[group["llm_text"].str.len().idxmax()]]
    return df.groupby(["time_stamp", "event_type", "chat_order", "human_id"], group_keys=False).apply(keep_longest)

def create_stat_row(topic, who, timestamp, stats):
    return {
        "Topic": topic,
        "Human/LLM": who,
        "time_stamp": timestamp,
        "Initial Opinion - Slider Bias": stats["Initial_Bias"],
        "Initial Opinion - Slider Diversity": stats["Initial_Std"],
        "Initial Opinion - Slider Bias SEM": stats["Initial_SEM"],
        "Initial Opinion - Slider Diversity SEM": stats["Initial_SE_Std"],

        "Post Opinion - Slider Bias": stats["Post_Bias"],
        "Post Opinion - Slider Diversity": stats["Post_Std"],
        "Post Opinion - Slider Bias SEM": stats["Post_SEM"],
        "Post Opinion - Slider Diversity SEM": stats["Post_SE_Std"],

        "Post Opinion (Slider) Bias - Initial Opinion (Slider) Bias": stats["Post_Bias"] - stats["Initial_Bias"],
        "Post Opinion (Slider) Diversity - Initial Opinion (Slider) Diversity": stats["Post_Std"] - stats["Initial_Std"],

        "Tweet 1 Bias": stats["Tweet1_Bias"],
        "Tweet 1 Diversity": stats["Tweet1_Std"],
        "Tweet 1 Bias SEM": stats["Tweet1_SEM"],
        "Tweet 1 Diversity SEM": stats["Tweet1_SE_Std"],

        "Tweet 2 Bias": stats["Tweet2_Bias"],
        "Tweet 2 Diversity": stats["Tweet2_Std"],
        "Tweet 2 Bias SEM": stats["Tweet2_SEM"],
        "Tweet 2 Diversity SEM": stats["Tweet2_SE_Std"],

        "Tweet 3 Bias": stats["Tweet3_Bias"],
        "Tweet 3 Diversity": stats["Tweet3_Std"],
        "Tweet 3 Bias SEM": stats["Tweet3_SEM"],
        "Tweet 3 Diversity SEM": stats["Tweet3_SE_Std"],

        "Tweet 3 Bias - Tweet 1 Bias": stats["Tweet3_Bias"] - stats["Tweet1_Bias"],
        "Tweet 3 Diversity - Tweet 1 Diversity": stats["Tweet3_Std"] - stats["Tweet1_Std"]
    }

def create_grouped_stat_rows(topic, who, grouped_df, col, normalize):
    group_rows = []
    for timestamp, df_group in grouped_df:
        stats = extract_stats(df_group, col, normalize)
        row = create_stat_row(topic, who, timestamp, stats)
        row["num_groups"] = -1
        group_rows.append(row)
    df_grouped = pd.DataFrame(group_rows)
    avg_data = df_grouped.drop(columns=["time_stamp"]).groupby(["Topic", "Human/LLM"]).mean().reset_index()
    avg_data["time_stamp"] = "all"
    avg_data["num_groups"] = len(grouped_df)
    return group_rows + avg_data.to_dict(orient="records")

def process_stats(label_name, value_col_human, value_col_llm, normalize=True):
    df = pd.read_csv(PREPROCESSED_PATH)
    df = df[df["model_name"] == "gpt-4o-mini-2024-07-18"]
    df = df[df[["event_type", "chat_order"]].apply(tuple, axis=1).isin(EVENT_ORDER)].copy()
    df = deduplicate_by_llm_text_length(df)
    os.makedirs(OUTPUT_ROOT, exist_ok=True)
    df.to_csv(os.path.join(OUTPUT_ROOT, "filtered_depth.csv"), index=False)

    all_rows = []

    for topic in df["topic"].unique():
        df_topic = df[df["topic"] == topic]
        df_topic = df_topic[df_topic[value_col_llm].notna()]
        out_dir = os.path.join(OUTPUT_ROOT, topic)
        os.makedirs(out_dir, exist_ok=True)
        df_topic.to_csv(os.path.join(out_dir, "filtered_depth.csv"), index=False)

        grouped = df_topic.groupby("time_stamp")
        human_rows = create_grouped_stat_rows(topic, "human", grouped, value_col_human, normalize)
        model_rows = create_grouped_stat_rows(topic, "gpt-4o-mini-2024-07-18", grouped, value_col_llm, normalize)

        # Copy human Tweet 1 SEMs to model
        human_all = next((r for r in human_rows if r["time_stamp"] == "all"), None)
        model_all = next((r for r in model_rows if r["time_stamp"] == "all"), None)
        if human_all and model_all:
            for key in TWEET1_KEYS:
                model_all[key] = human_all[key]

        stats_df = pd.DataFrame(human_rows + model_rows)
        stats_df.to_csv(os.path.join(out_dir, f"base_{label_name}_stats.csv"), index=False)

        all_rows.extend([r for r in human_rows + model_rows if r["time_stamp"] == "all"])

    all_df = pd.DataFrame(all_rows)
    all_path = os.path.join(OUTPUT_ROOT, f"all_base_{label_name}_stats.csv")
    all_df.to_csv(all_path, index=False)
    logging.info(f"✅ all_base_{label_name}_stats.csv saved.")

    # Add Topic=All row by averaging
    if not all_df.empty:
        num_groups_sum = all_df.groupby("Human/LLM")["num_groups"].sum().reset_index()
        df_agg = all_df.drop(columns=["time_stamp", "Topic", "num_groups"]).groupby("Human/LLM").mean().reset_index()
        df_agg = pd.merge(df_agg, num_groups_sum, on="Human/LLM", how="left")
        df_agg.insert(0, "Topic", "All")
        df_agg.insert(2, "time_stamp", "all")
        final_df = pd.concat([all_df, df_agg], ignore_index=True)
        final_df.to_csv(all_path, index=False)
        logging.info(f"📊 Aggregated Topic='All' row added to all_base_{label_name}_stats.csv.")

def base_model_analysis():
    process_stats("stance", "human_likert_pred", "llm_likert_pred", normalize=True)
    process_stats("slider", "human_slider", "llm_slider", normalize=True)

if __name__ == "__main__":
    base_model_analysis()
