import os
import re
import glob
import argparse
import textwrap
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from collections import defaultdict
from util import extract_topic_versions

AGREEMENT_PATTERN = re.compile(
    r"^(Certainly disagree|Probably disagree|Lean disagree|Lean agree|Probably agree|Certainly agree)",
    re.IGNORECASE
)

def extract_agreement_label(text):
    if not isinstance(text, str):
        return 0
    match = AGREEMENT_PATTERN.match(text.strip())
    if match:
        normalized = match.group(1).strip()
        return map_agreement_to_likert(normalized)
    return 0

# --- Project Paths ---
PROJECT_ROOT = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", ".."))
HUMAN_INPUT_DIR = os.path.join(PROJECT_ROOT, "data", "processed_data")
SIMULATION_INPUT_DIR = os.path.join(PROJECT_ROOT, "result", "simulation")
OUTPUT_ROOT = os.path.join(PROJECT_ROOT, "result", "eval", "group_level")

# --- Utility Functions ---
def stats_with_sd_and_se(values):
    non_zero = values[values != 0]
    if len(non_zero) == 0:
        return 0.0, 0.0, 0.0
    avg = non_zero.mean()
    std = non_zero.std(ddof=0)
    se = std / np.sqrt(len(non_zero)) if len(non_zero) > 1 else 0.0
    return avg, std, se

def map_agreement_to_likert(val):
    mapping = {
        "Certainly disagree": 1,
        "Probably disagree": 2,
        "Lean disagree": 3,
        "Lean agree": 4,
        "Probably agree": 5,
        "Certainly agree": 6,
    }
    return mapping.get(val, 0)

def likert_to_signed(val):
    signed_map = {
        1: -2.5,
        2: -1.5,
        3: -0.5,
        4:  0.5,
        5:  1.5,
        6:  2.5,
        0:  0.0
    }
    return signed_map.get(val, 0.0)

def get_label_from_exp_dir(exp_dir):
    match = re.match(r"(\d{8}_\d{6})", exp_dir)
    return match.group(1) if match else exp_dir

# --- Data Processing ---
def scan_and_group_by_topic(input_dir, mode):
    grouped_files = defaultdict(list)

    if mode == "human":
        for filepath in glob.glob(os.path.join(input_dir, "*.csv")):
            try:
                filename = os.path.basename(filepath)
                underlined_topic, _ = extract_topic_versions(filename)
                grouped_files[underlined_topic].append(filepath)
            except ValueError as e:
                print(f"{e}")
                print(f"[WARNING] Skipping file: {e}")

    elif mode == "simulation":
        for exp_dir in os.listdir(SIMULATION_INPUT_DIR):
            exp_path = os.path.join(SIMULATION_INPUT_DIR, exp_dir)
            if not os.path.isdir(exp_path):
                continue
            try:
                underlined_topic, _ = extract_topic_versions(exp_dir)
                grouped_files[underlined_topic].append(exp_path)
            except ValueError as e:
                print(f"[WARNING] Skipping exp dir: {e}")

    return grouped_files

def compute_human_opinion_stats(filepath):
    df = pd.read_csv(filepath)

    def get_stats(event_type):
        values = df[df["event_type"] == event_type]["sliderValue"].fillna(0).replace("", 0)
        values = pd.to_numeric(values, errors='coerce').fillna(0).astype(int).apply(likert_to_signed)
        return stats_with_sd_and_se(values)

    avg_init, std_init, _ = get_stats("Initial Opinion")
    avg_post, std_post, _ = get_stats("Post Opinion")

    base = os.path.basename(filepath)
    label = re.match(r"^(\d{8}_\d{6})", base).group(1) if re.match(r"^(\d{8}_\d{6})", base) else base

    return {
        "filename": base,
        "label": label,
        "avg_initial": avg_init,
        "std_initial": std_init,
        "avg_post": avg_post,
        "std_post": std_post,
        "diff_avg": avg_post - avg_init,
        "diff_std": std_post - std_init,
    }

def compute_simulation_opinion_stats(exp_path, model_dir, version):
    filepath = os.path.join(exp_path, model_dir, f"simulation-{version}.csv")
    if not os.path.exists(filepath):
        print(f"[WARNING] File does not exist for model {model_dir} at {filepath}. Skipping.")
        return None
    df = pd.read_csv(filepath)

    if df.empty:
        print(f"[WARNING] Empty file: {filepath}")
        return None

    # Mapping for Post Opinion
    agreement_to_likert = {
        "Certainly disagree": 1,
        "Probably disagree": 2,
        "Lean disagree": 3,
        "Lean agree": 4,
        "Probably agree": 5,
        "Certainly agree": 6,
    }

    # For Initial Opinion: get from 'sliderValue'
    initial_values = (
        df[df["event_type"] == "Initial Opinion"]["sliderValue"]
        .fillna(0)
        .replace("", 0)
        .astype(int)
        .apply(likert_to_signed)
    )
    initial_values = initial_values[initial_values != 0]  # skip zeros

    # For Post Opinion: map 'agreement_level'
    post_df = df[df["event_type"] == "Post Opinion"]
    post_likerts = post_df["agreement_level"].apply(extract_agreement_label)
    post_values = post_likerts.apply(likert_to_signed)
    post_values = post_values.astype(float)
    post_values = post_values[post_values != 0]  # skip zeros

    # Then compute stats as usual
    avg_init, std_init, se_init = stats_with_sd_and_se(pd.Series(initial_values))
    avg_post, std_post, se_post = stats_with_sd_and_se(pd.Series(post_values))

    exp_folder = os.path.basename(exp_path)
    label = exp_folder.split("_")[0]  # extract leading date part like 20241015

    return {
        "filename": exp_folder,
        "label": label,
        "avg_initial": avg_init,
        "std_initial": std_init,
        "avg_post": avg_post,
        "std_post": std_post,
        "diff_avg": avg_post - avg_init,
        "diff_std": std_post - std_init,
    }

# --- Plotting ---
# def plot_trajectory(df, save_path, topic, y_col1, y_col2, ylabel, title_suffix):
#     plt.figure(figsize=(12, 6))
#     for _, row in df.iterrows():
#         plt.plot([0, 1], [row[y_col1], row[y_col2]], marker='o', label=row["label"])
#     plt.xticks([0, 1], ["Initial", "Post"])
#     if "Bias" in title_suffix:
#         plt.ylim(-3, 3)
#         plt.axhline(0, color='gray', linestyle='--', linewidth=1)
#     elif "Standard Deviation" in title_suffix:
#         plt.ylim(0, 3)
#     plt.ylabel(ylabel)
#     plt.grid(True, linestyle='--', alpha=0.5)
#     wrapped_title = "\n".join(textwrap.wrap(f"{title_suffix} – {topic.replace('_', ' ')}", width=70))
#     plt.title(wrapped_title)
#     plt.legend(bbox_to_anchor=(1.05, 1), loc='upper left')
#     plt.tight_layout()
#     plt.savefig(save_path, bbox_inches="tight")
#     plt.close()

def signed_likert_to_ordinal_float(val):
    try:
        val = float(val)
        if val < -3 or val > 3 or val == 0:
            return np.nan
        if val < 0:
            return val + 3  # -3 → 0, -2 → 1, -1 → 2
        else:
            return val + 2  # 1 → 3, 2 → 4, 3 → 5
    except Exception:
        return np.nan

def plot_trajectory(df, save_path, topic, y_col1, y_col2, ylabel, title_suffix):
    plt.figure(figsize=(12, 6))

    level_labels = {0: -3, 1: -2, 2: -1, 3: 1, 4: 2, 5: 3}

    if "Bias" in title_suffix:
        df = df.copy()
        df[y_col1] = df[y_col1].apply(signed_likert_to_ordinal_float)
        df[y_col2] = df[y_col2].apply(signed_likert_to_ordinal_float)

    for _, row in df.iterrows():
        plt.plot([0, 1], [row[y_col1], row[y_col2]], marker='o', label=row["label"])

    plt.xticks([0, 1], ["Initial", "Post"])

    if "Bias" in title_suffix:
        yticks = sorted(level_labels.keys())
        yticklabels = [level_labels[t] for t in yticks]
        plt.yticks(yticks, yticklabels)
        plt.axhline(2.5, color='gray', linestyle='--', linewidth=1)
    elif "Standard Deviation" in title_suffix:
        plt.ylim(0, 3)

    plt.ylabel(ylabel)
    plt.grid(True, linestyle='--', alpha=0.5)
    wrapped_title = "\n".join(textwrap.wrap(f"{title_suffix} – {topic.replace('_', ' ')}", width=70))
    plt.title(wrapped_title)
    plt.legend(bbox_to_anchor=(1.05, 1), loc='upper left')
    plt.tight_layout()
    plt.savefig(save_path, bbox_inches="tight")
    plt.close()

def plot_summary_errorbar(y, yerr, save_path, topic, ylabel, title_prefix):
    x = np.array([0, 1])
    x_labels = ["Initial", "Post"]
    fig, ax = plt.subplots(figsize=(6, 5))

    def signed_likert_to_ordinal_float(val):
        if val < -3 or val > 3 or val == 0:
            return np.nan
        if val < 0:
            return val + 3  # -3 → 0, -2 → 1, -1 → 2
        else:
            return val + 2  # 1 → 3, 2 → 4, 3 → 5

    if "Bias" in title_prefix:
        # y_mapped = [signed_likert_to_ordinal_float(v) for v in y]
        # ax.errorbar(x, y_mapped, yerr=yerr, fmt='o-', capsize=8, linewidth=2)

        # ordinal_to_likert = {0: -3, 1: -2, 2: -1, 3: 1, 4: 2, 5: 3}
        # yticks = sorted(ordinal_to_likert.keys())
        # yticklabels = [ordinal_to_likert[t] for t in yticks]
        # ax.set_yticks(yticks)
        # ax.set_yticklabels(yticklabels)
        # ax.axhline(2.5, color='gray', linestyle='--', linewidth=1)  # midpoint between -1 and 1
        ax.errorbar(x, y, yerr=yerr, fmt='o-', capsize=8, linewidth=2)
        ax.set_yticks([-2.5, -1.5, -0.5, 0.5, 1.5, 2.5])
        ax.set_ylim(-3, 3)
        ax.axhline(0, color='gray', linestyle='--', linewidth=1)
    else:
        ax.errorbar(x, y, yerr=yerr, fmt='o-', capsize=8, linewidth=2)
        if "Std" in title_prefix:
            ax.set_ylim(0, 2)

    ax.set_xticks(x)
    ax.set_xticklabels(x_labels)
    ax.set_ylabel(ylabel)

    wrapped_title = "\n".join(textwrap.wrap(f"{title_prefix}: {topic.replace('_', ' ')}", width=70))
    ax.set_title(wrapped_title)
    ax.grid(True, linestyle='--', alpha=0.5)
    plt.tight_layout()
    plt.savefig(save_path, bbox_inches="tight")
    plt.close()

# def plot_summary_errorbar(y, yerr, save_path, topic, ylabel, title_prefix):
#     x = np.array([0, 1])
#     x_labels = ["Initial", "Post"]
#     fig, ax = plt.subplots(figsize=(6, 5))
#     ax.errorbar(x, y, yerr=yerr, fmt='o-', capsize=8, linewidth=2)
#     ax.set_xticks(x)
#     ax.set_xticklabels(x_labels)
#     ax.set_ylabel(ylabel)
#     wrapped_title = "\n".join(textwrap.wrap(f"{title_prefix}: {topic.replace('_', ' ')}", width=70))
#     ax.set_title(wrapped_title)
#     ax.grid(True, linestyle='--', alpha=0.5)
#     if "Std" in title_prefix:
#         ax.set_ylim(0, 3)
#     plt.tight_layout()
#     plt.savefig(save_path, bbox_inches="tight")
#     plt.close()

# def plot_summary_errorbar(y, yerr, save_path, topic, ylabel, title_prefix):
#     x = np.array([0, 1])
#     x_labels = ["Initial", "Post"]
#     fig, ax = plt.subplots(figsize=(6, 5))

#     if "Bias" in title_prefix:
#         level_map = {-3: 0, -2: 1, -1: 2, 1: 3, 2: 4, 3: 5}
#         ordinal_to_likert = {v: k for k, v in level_map.items()}

#         y_mapped = [level_map.get(round(v), np.nan) for v in y]
#         ax.errorbar(x, y_mapped, yerr=yerr, fmt='o-', capsize=8, linewidth=2)

#         yticks = sorted(ordinal_to_likert.keys())
#         yticklabels = [ordinal_to_likert[t] for t in yticks]
#         ax.set_yticks(yticks)
#         ax.set_yticklabels(yticklabels)
#         ax.axhline(2.5, color='gray', linestyle='--', linewidth=1)  # midpoint between -1 and 1
#     else:
#         ax.errorbar(x, y, yerr=yerr, fmt='o-', capsize=8, linewidth=2)
#         if "Std" in title_prefix:
#             ax.set_ylim(0, 3)

#     ax.set_xticks(x)
#     ax.set_xticklabels(x_labels)
#     ax.set_ylabel(ylabel)

#     wrapped_title = "\n".join(textwrap.wrap(f"{title_prefix}: {topic.replace('_', ' ')}", width=70))
#     ax.set_title(wrapped_title)
#     ax.grid(True, linestyle='--', alpha=0.5)
#     plt.tight_layout()
#     plt.savefig(save_path, bbox_inches="tight")
#     plt.close()

# --- Main Logic ---
def ensure_and_save_output(grouped, mode, version, model_name=None):
    for topic, folders in grouped.items():
        if mode == "human":
            print(topic)
            output_dir = os.path.join(OUTPUT_ROOT, topic, "human", "opinion")
            os.makedirs(output_dir, exist_ok=True)
            results = [compute_human_opinion_stats(f) for f in folders]

            df = pd.DataFrame(results).sort_values(by=["label", "filename"])
            if df.empty:
                print(f"[WARNING] No valid data for topic {topic}, skipping...")
                continue

            # Save opinion_stats.csv
            df.to_csv(os.path.join(output_dir, "opinion_stats.csv"), index=False)

            # Save plots
            plot_trajectory(df, os.path.join(output_dir, "opinion_bias_trajectory.svg"), topic, "avg_initial", "avg_post", "Likert Score", "Bias Trajectory")
            plot_trajectory(df, os.path.join(output_dir, "opinion_std_trajectory.svg"), topic, "std_initial", "std_post", "Opinion Std", "Standard Deviation Trajectory")

            # Summary plot
            plot_summary_errorbar(
                [df["avg_initial"].mean(), df["avg_post"].mean()],
                [df["avg_initial"].std() / np.sqrt(len(df)), df["avg_post"].std() / np.sqrt(len(df))],
                os.path.join(output_dir, "summary_opinion_bias_errorbar.svg"),
                topic,
                ylabel="Likert Score",
                title_prefix="Bias Summary"
            )
            plot_summary_errorbar(
                [df["std_initial"].mean(), df["std_post"].mean()],
                [df["std_initial"].std() / np.sqrt(len(df)), df["std_post"].std() / np.sqrt(len(df))],
                os.path.join(output_dir, "summary_opinion_std_errorbar.svg"),
                topic,
                ylabel="Opinion Std",
                title_prefix="Std Summary"
            )
            summary_stats = {
                "overall_avg_initial": df["avg_initial"].mean(),
                "overall_avg_post": df["avg_post"].mean(),
                "overall_std_initial": df["std_initial"].mean(),
                "overall_std_post": df["std_post"].mean(),
            }
            pd.DataFrame([summary_stats]).to_csv(os.path.join(output_dir, "summary_opinion.csv"), index=False)

        else:  # mode == "simulation"
            model_stats_map = defaultdict(list)

            for exp_dir in folders:
                exp_path = os.path.join(SIMULATION_INPUT_DIR, exp_dir)
                model_dirs = [model_name] if model_name else os.listdir(exp_path)

                for mdl in model_dirs:
                    stat = compute_simulation_opinion_stats(exp_path, mdl, version)
                    if stat is not None:
                        stat["model_name"] = mdl
                        model_stats_map[mdl].append(stat)

            # Save results per model
            for mdl, stats in model_stats_map.items():
                df = pd.DataFrame(stats).sort_values(by=["label", "filename"])
                if df.empty:
                    continue

                model_dir = os.path.join(OUTPUT_ROOT, topic, "simulation", "opinion", mdl)
                os.makedirs(model_dir, exist_ok=True)

                df.drop(columns=["model_name"], errors="ignore").to_csv(os.path.join(model_dir, "opinion_stats.csv"), index=False)

                plot_trajectory(df, os.path.join(model_dir, "opinion_bias_trajectory.svg"), topic, "avg_initial", "avg_post", "Likert Score", f"Bias Trajectory - {mdl}")
                plot_trajectory(df, os.path.join(model_dir, "opinion_std_trajectory.svg"), topic, "std_initial", "std_post", "Opinion Std", f"Standard Deviation Trajectory - {mdl}")

                plot_summary_errorbar(
                    [df["avg_initial"].mean(), df["avg_post"].mean()],
                    [df["avg_initial"].std() / np.sqrt(len(df)), df["avg_post"].std() / np.sqrt(len(df))],
                    os.path.join(model_dir, "summary_opinion_bias_errorbar.svg"),
                    topic,
                    ylabel="Likert Score",
                    title_prefix=f"Bias – {mdl}"
                )
                plot_summary_errorbar(
                    [df["std_initial"].mean(), df["std_post"].mean()],
                    [df["std_initial"].std() / np.sqrt(len(df)), df["std_post"].std() / np.sqrt(len(df))],
                    os.path.join(model_dir, "summary_opinion_std_errorbar.svg"),
                    topic,
                    ylabel="Opinion Std",
                    title_prefix=f"Std – {mdl}"
                )

                summary_stats = {
                    "overall_avg_initial": df["avg_initial"].mean(),
                    "overall_avg_post": df["avg_post"].mean(),
                    "overall_std_initial": df["std_initial"].mean(),
                    "overall_std_post": df["std_post"].mean(),
                }
                pd.DataFrame([summary_stats]).to_csv(os.path.join(model_dir, "summary_opinion.csv"), index=False)

            # Final summary over all models
            combined_df = pd.concat([pd.DataFrame(v) for v in model_stats_map.values()], ignore_index=True)
            combined_df["avg_initial_signed"] = combined_df["avg_initial"].apply(likert_to_signed)
            combined_df["avg_post_signed"] = combined_df["avg_post"].apply(likert_to_signed)
            if not combined_df.empty:
                overall_dir = os.path.join(OUTPUT_ROOT, topic, "simulation", "opinion")
                # os.makedirs(overall_dir, exist_ok=True)
                combined_df.to_csv(os.path.join(overall_dir, "opinion_stats.csv"), index=False)

                plot_trajectory(
                    combined_df.rename(columns={
                        "avg_initial_signed": "avg_initial",
                        "avg_post_signed": "avg_post"
                    }),
                    save_path=os.path.join(overall_dir, "opinion_bias_trajectory.svg"),
                    topic=topic,
                    y_col1="avg_initial",
                    y_col2="avg_post",
                    ylabel="Signed Likert Score",
                    title_suffix="Bias Trajectory (All Models)"
                )

                # Std trajectory line plot
                plot_trajectory(
                    combined_df,
                    save_path=os.path.join(overall_dir, "opinion_std_trajectory.svg"),
                    topic=topic,
                    y_col1="std_initial",
                    y_col2="std_post",
                    ylabel="Opinion Std",
                    title_suffix="Standard Deviation Trajectory (All Models)"
    )
                
                plot_summary_errorbar(
                    [combined_df["avg_initial"].mean(), combined_df["avg_post"].mean()],
                    [combined_df["avg_initial"].std() / np.sqrt(len(combined_df)), combined_df["avg_post"].std() / np.sqrt(len(combined_df))],
                    os.path.join(overall_dir, "summary_opinion_bias_errorbar.svg"),
                    topic,
                    ylabel="Likert Score",
                    title_prefix="Bias Over All Models"
                )
                plot_summary_errorbar(
                    [combined_df["std_initial"].mean(), combined_df["std_post"].mean()],
                    [combined_df["std_initial"].std() / np.sqrt(len(combined_df)), combined_df["std_post"].std() / np.sqrt(len(combined_df))],
                    os.path.join(overall_dir, "summary_opinion_std_errorbar.svg"),
                    topic,
                    ylabel="Opinion Std",
                    title_prefix="Std Over All Models"
                )
                summary_stats = {
                    "overall_avg_initial": combined_df["avg_initial"].mean(),
                    "overall_avg_post": combined_df["avg_post"].mean(),
                    "overall_std_initial": combined_df["std_initial"].mean(),
                    "overall_std_post": combined_df["std_post"].mean(),
                }
                pd.DataFrame([summary_stats]).to_csv(os.path.join(overall_dir, "summary_opinion.csv"), index=False)


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--mode", required=True, choices=["human", "simulation"], help="Mode: human or simulation")
    parser.add_argument("--version", required=False, help="Version name like v0, v1, v2 (only needed for simulation)")
    parser.add_argument("--model_name", required=False, help="Model name (optional, if you only want to process one model)")

    args = parser.parse_args()

    if args.mode == "simulation" and not args.version:
        parser.error("--version is required when mode is 'simulation'.")

    # Set input_dir based on mode
    if args.mode == "human":
        input_dir = HUMAN_INPUT_DIR
    else:
        input_dir = SIMULATION_INPUT_DIR

    grouped = scan_and_group_by_topic(input_dir, args.mode)
    ensure_and_save_output(grouped, args.mode, args.version, args.model_name)
