import os
import re
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import textwrap
from collections import defaultdict

PROJECT_ROOT = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", ".."))
INPUT_ROOT = os.path.join(PROJECT_ROOT, "result", "eval", "human_llm")
OUTPUT_ROOT = os.path.join(PROJECT_ROOT, "result", "eval", "group_level")

def stats_with_sd_and_se(values):
    non_zero = values[values != 0]
    if len(non_zero) == 0:
        return 0.0, 0.0, 0.0
    avg = non_zero.mean()
    std = non_zero.std(ddof=0)
    se = std / np.sqrt(len(non_zero)) if len(non_zero) > 1 else 0.0
    return avg, std, se

def extract_label(filename):
    match = re.match(r"(\d{8}_\d{6})", filename)
    return match.group(1) if match else filename

def extract_topic(exp_dir):
    try:
        parts = exp_dir.split("_")
        topic = "_".join(parts[2:-1])
        return topic
    except Exception:
        return None

def likert_to_signed(val):
    signed_map = {
        1: -2.5,
        2: -1.5,
        3: -0.5,
        4:  0.5,
        5:  1.5,
        6:  2.5,
        0:  0.0
    }
    return signed_map.get(val, 0.0)

def signed_likert_to_ordinal_float(val):
    try:
        val = float(val)
        if val < -3 or val > 3 or val == 0:
            return np.nan
        if val < 0:
            return val + 3  # -3 → 0, -2 → 1, -1 → 2
        else:
            return val + 2  # 1 → 3, 2 → 4, 3 → 5
    except Exception:
        return np.nan

# def plot_stance_trajectory(df, save_path, topic, ylabel, title_suffix, signed=False):
#     plt.figure(figsize=(10, 6))
#     for _, row in df.iterrows():
#         plt.plot([1, 2, 3], [row['chat1'], row['chat2'], row['chat3']],
#                  marker='o', label=row['label'])
#     plt.xticks([1, 2, 3], ["Tweet 1", "Tweet 2", "Tweet 3"])
#     if signed:
#         plt.ylim(-3, 3)
#         plt.axhline(0, color='gray', linestyle='--', linewidth=1)
#     elif "Std" in title_suffix:
#         plt.ylim(0, 3)
#     plt.ylabel(ylabel)
#     plt.grid(True, linestyle='--', alpha=0.5)
#     title = f"{title_suffix} – {topic.replace('_', ' ')}"
#     plt.title("\n".join(textwrap.wrap(title, width=60)))
#     plt.legend(bbox_to_anchor=(1.05, 1), loc='upper left')
#     plt.tight_layout()
#     plt.savefig(save_path, bbox_inches="tight")
#     plt.close()

# ==================CURRENT=================

def plot_stance_trajectory(df, save_path, topic, ylabel, title_suffix, signed=False):
    plt.figure(figsize=(10, 6))
    level_labels = {0: -3, 1: -2, 2: -1, 3: 1, 4: 2, 5: 3}

    if signed:
        df = df.copy()
        df["chat1"] = df["chat1"].apply(signed_likert_to_ordinal_float)
        df["chat2"] = df["chat2"].apply(signed_likert_to_ordinal_float)
        df["chat3"] = df["chat3"].apply(signed_likert_to_ordinal_float)

    for _, row in df.iterrows():
        plt.plot([1, 2, 3], [row["chat1"], row["chat2"], row["chat3"]],
                 marker='o', label=row["label"])

    plt.xticks([1, 2, 3], ["Tweet 1", "Tweet 2", "Tweet 3"])

    if signed:
        yticks = sorted(level_labels.keys())
        yticklabels = [level_labels[t] for t in yticks]
        plt.yticks(yticks, yticklabels)
        plt.axhline(2.5, color='gray', linestyle='--', linewidth=1)

    plt.ylabel(ylabel)
    plt.grid(True, linestyle='--', alpha=0.5)
    title = f"{title_suffix} – {topic.replace('_', ' ')}"
    plt.title("\n".join(textwrap.wrap(title, width=60)))
    plt.legend(bbox_to_anchor=(1.05, 1), loc='upper left')
    plt.tight_layout()
    plt.savefig(save_path, bbox_inches="tight")
    plt.close()

def plot_summary_errorbar(y, yerr, save_path, topic, ylabel, title_prefix, signed=False):
    x = np.array([1, 2, 3])
    fig, ax = plt.subplots(figsize=(6, 5))

    if signed:
        # Map real values to ordinal float space
        # y_mapped = [signed_likert_to_ordinal_float(v) for v in y]
        # ax.errorbar(x, y_mapped, yerr=yerr, fmt='o-', capsize=8, linewidth=2)

        # ordinal_to_likert = {0: -3, 1: -2, 2: -1, 3: 1, 4: 2, 5: 3}
        # yticks = sorted(ordinal_to_likert.keys())
        # yticklabels = [ordinal_to_likert[t] for t in yticks]
        # ax.set_yticks(yticks)
        # ax.set_yticklabels(yticklabels)
        # ax.axhline(2.5, color='gray', linestyle='--', linewidth=1)  # between -1 and 1
        ax.errorbar(x, y, yerr=yerr, fmt='o-', capsize=8, linewidth=2)
        ax.set_yticks([-2.5, -1.5, -0.5, 0.5, 1.5, 2.5])
        ax.set_ylim(-3, 3)
        ax.axhline(0, color='gray', linestyle='--', linewidth=1)
    else:
        ax.errorbar(x, y, yerr=yerr, fmt='o-', capsize=8, linewidth=2)
        ax.set_ylim(0, 2)

    ax.set_xticks(x)
    ax.set_xticklabels(["Tweet 1", "Tweet 2", "Tweet 3"])
    ax.set_ylabel(ylabel)
    ax.set_title("\n".join(textwrap.wrap(f"{title_prefix}: {topic.replace('_', ' ')}", width=60)))
    ax.grid(True, linestyle='--', alpha=0.5)
    plt.tight_layout()
    plt.savefig(save_path, bbox_inches="tight")
    plt.close()

# ==================CURRENT=================

# def plot_summary_errorbar(y, yerr, save_path, topic, ylabel, title_prefix, signed=False):
#     x = np.array([1, 2, 3])
#     fig, ax = plt.subplots(figsize=(6, 5))

#     if signed:
#         level_map = {-3: 0, -2: 1, -1: 2, 1: 3, 2: 4, 3: 5}
#         ordinal_to_likert = {v: k for k, v in level_map.items()}
#         y_mapped = [level_map.get(round(v), np.nan) for v in y]
#         ax.errorbar(x, y_mapped, yerr=yerr, fmt='o-', capsize=8, linewidth=2)
#         yticks = sorted(ordinal_to_likert.keys())
#         yticklabels = [ordinal_to_likert[t] for t in yticks]
#         ax.set_yticks(yticks)
#         ax.set_yticklabels(yticklabels)
#         ax.axhline(2.5, color='gray', linestyle='--', linewidth=1)
#     else:
#         ax.errorbar(x, y, yerr=yerr, fmt='o-', capsize=8, linewidth=2)
#         ax.set_ylim(0, 3)

#     ax.set_xticks(x)
#     ax.set_xticklabels(["Tweet 1", "Tweet 2", "Tweet 3"])
#     ax.set_ylabel(ylabel)
#     ax.set_title("\n".join(textwrap.wrap(f"{title_prefix}: {topic.replace('_', ' ')}", width=60)))
#     ax.grid(True, linestyle='--', alpha=0.5)
#     plt.tight_layout()
#     plt.savefig(save_path, bbox_inches="tight")
#     plt.close()

def process_stance_data(version, source):
    key = "human_likert_pred" if source == "human" else "llm_likert_pred"
    output_mode = "human" if source == "human" else "simulation"
    grouped_by_topic = defaultdict(lambda: defaultdict(list))

    for exp_dir in os.listdir(INPUT_ROOT):
        exp_path = os.path.join(INPUT_ROOT, exp_dir)
        if not os.path.isdir(exp_path): continue

        topic = extract_topic(exp_dir)
        if topic is None:
            print(f"[WARNING] Cannot extract topic from {exp_dir}, skipping...")
            continue

        for model_name in os.listdir(exp_path):
            model_path = os.path.join(exp_path, model_name)
            file_name = f"opinion_memory_gpt-4o-mini-2024-07-18_{version}.csv"
            file_path = os.path.join(model_path, file_name)
            if not os.path.exists(file_path): continue

            df = pd.read_csv(file_path)
            if df.empty: continue

            label = extract_label(exp_dir)
            stance_vals = []
            for order in [1, 2, 3]:
                vals = df[df["chat_order"] == order][key].fillna(0).replace("", 0)
                vals = pd.to_numeric(vals, errors="coerce").fillna(0).astype(int)
                vals = vals.apply(likert_to_signed)
                vals = vals[vals != 0]
                avg, std, _ = stats_with_sd_and_se(vals)
                stance_vals.append((avg, std))

            grouped_by_topic[topic][model_name].append({
                "filename": exp_dir,
                "label": label,
                "chat1": stance_vals[0][0],
                "chat2": stance_vals[1][0],
                "chat3": stance_vals[2][0],
                "std1": stance_vals[0][1],
                "std2": stance_vals[1][1],
                "std3": stance_vals[2][1],
            })

    for topic, model_dict in grouped_by_topic.items():
        topic_dir = os.path.join(OUTPUT_ROOT, topic, output_mode, "tweet")
        os.makedirs(topic_dir, exist_ok=True)
        all_dfs = []

        for model_name, data in model_dict.items():
            df = pd.DataFrame(data)
            all_dfs.append(df)
            model_dir = os.path.join(topic_dir, model_name)
            os.makedirs(model_dir, exist_ok=True)

            df.to_csv(os.path.join(model_dir, "stance_stats.csv"), index=False)
            plot_stance_trajectory(df, os.path.join(model_dir, "stance_bias_trajectory.svg"),
                                   topic, "Signed Likert Score", f"Tweet Stance Bias Trajectory – {model_name}", signed=True)
            plot_stance_trajectory(df[["label", "std1", "std2", "std3"]].rename(
                columns={"std1": "chat1", "std2": "chat2", "std3": "chat3"}),
                os.path.join(model_dir, "stance_std_trajectory.svg"),
                topic, "Standard Deviation", f"Tweet Stance Std Trajectory – {model_name}")
            plot_summary_errorbar(
                [df["chat1"].mean(), df["chat2"].mean(), df["chat3"].mean()],
                [df["chat1"].std()/np.sqrt(len(df)), df["chat2"].std()/np.sqrt(len(df)), df["chat3"].std()/np.sqrt(len(df))],
                os.path.join(model_dir, "summary_stance_bias_errorbar.svg"),
                topic, "Signed Likert Score", f"Bias Summary – {model_name}", signed=True
            )
            plot_summary_errorbar(
                [df["std1"].mean(), df["std2"].mean(), df["std3"].mean()],
                [df["std1"].std()/np.sqrt(len(df)), df["std2"].std()/np.sqrt(len(df)), df["std3"].std()/np.sqrt(len(df))],
                os.path.join(model_dir, "summary_stance_std_errorbar.svg"),
                topic, "Standard Deviation", f"Std Summary – {model_name}"
            )
            pd.DataFrame([{
                "overall_chat1": df["chat1"].mean(),
                "overall_chat2": df["chat2"].mean(),
                "overall_chat3": df["chat3"].mean(),
                "overall_std1": df["std1"].mean(),
                "overall_std2": df["std2"].mean(),
                "overall_std3": df["std3"].mean(),
            }]).to_csv(os.path.join(model_dir, "summary_stance.csv"), index=False)

        all_df = pd.concat(all_dfs, ignore_index=True)
        if not all_df.empty:
            all_df.to_csv(os.path.join(topic_dir, "stance_stats.csv"), index=False)
            plot_stance_trajectory(all_df, os.path.join(topic_dir, "stance_bias_trajectory.svg"),
                                   topic, "Signed Likert Score", "Tweet Stance Bias Trajectory (All Models)", signed=True)
            plot_stance_trajectory(all_df[["label", "std1", "std2", "std3"]].rename(
                columns={"std1": "chat1", "std2": "chat2", "std3": "chat3"}),
                os.path.join(topic_dir, "stance_std_trajectory.svg"),
                topic, "Standard Deviation", "Tweet Stance Std Trajectory (All Models)")
            plot_summary_errorbar(
                [all_df["chat1"].mean(), all_df["chat2"].mean(), all_df["chat3"].mean()],
                [all_df["chat1"].std()/np.sqrt(len(all_df)), all_df["chat2"].std()/np.sqrt(len(all_df)), all_df["chat3"].std()/np.sqrt(len(all_df))],
                os.path.join(topic_dir, "summary_stance_bias_errorbar.svg"),
                topic, "Signed Likert Score", "Bias Summary – All Models", signed=True
            )
            plot_summary_errorbar(
                [all_df["std1"].mean(), all_df["std2"].mean(), all_df["std3"].mean()],
                [all_df["std1"].std()/np.sqrt(len(all_df)), all_df["std2"].std()/np.sqrt(len(all_df)), all_df["std3"].std()/np.sqrt(len(all_df))],
                os.path.join(topic_dir, "summary_stance_std_errorbar.svg"),
                topic, "Standard Deviation", "Std Summary – All Models"
            )
            pd.DataFrame([{
                "overall_chat1": all_df["chat1"].mean(),
                "overall_chat2": all_df["chat2"].mean(),
                "overall_chat3": all_df["chat3"].mean(),
                "overall_std1": all_df["std1"].mean(),
                "overall_std2": all_df["std2"].mean(),
                "overall_std3": all_df["std3"].mean(),
            }]).to_csv(os.path.join(topic_dir, "summary_stance.csv"), index=False)

if __name__ == "__main__":
    import argparse
    parser = argparse.ArgumentParser()
    parser.add_argument("--version", required=True, help="Version like v0, v1...")
    parser.add_argument("--source", required=True, choices=["human", "simulation"])
    args = parser.parse_args()
    process_stance_data(args.version, args.source)
