import os
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import argparse

# Safe fallback for environments without __file__
PROJECT_ROOT = os.path.abspath(os.path.join(os.getcwd(), "..", ".."))

SPLITS = {
    "group_split": "group_split_depth",
    "topic_split": "topic_split_depth",
    "round_split": "round_split_depth"
}

DIVERSITY_COLUMN = "Post Opinion (Slider) Diversity - Initial Opinion (Slider) Diversity"

# GPT model of interest per split
SPLIT_MODEL_IDS = {
    "group_split": "ft:gpt-4o-mini-2024-07-18:camer:group-split-all:BOvZjzvU",
    "topic_split": "ft:gpt-4o-mini-2024-07-18:camer:topic-split-all:BOqtcdMB",
    "round_split": "ft:gpt-4o-mini-2024-07-18:camer:round-split-all:BOvS862Y"
}

GPT_COLOR_BASE = "#404144"
GPT_COLOR_FADE1 = "#5e5f60"
GPT_COLOR_FADE2 = "#7a7a7a"

def generate_topic_index_map(df):
    topics = sorted(t for t in df["Topic"].unique() if t != "All")
    if "All" in df["Topic"].values:
        topics.append("All")  # Ensure 'All' is last
    return {t: i + 1 for i, t in enumerate(topics)}

def plot_split_restricted(df, split_key, save_path):
    gpt_model = SPLIT_MODEL_IDS[split_key]
    topic_index_map = generate_topic_index_map(df)
    df["Topic Index"] = df["Topic"].map(topic_index_map)

    ordered_types = ["train", "test", "train+test"]
    df = df[df["Human/LLM"] == gpt_model]
    df["Model+Type"] = df["Human/LLM"] + "-" + df["Type"]
    df["Model+Type"] = pd.Categorical(df["Model+Type"],
        categories=[f"{gpt_model}-{t}" for t in ordered_types],
        ordered=True)

    color_map = {
        f"{gpt_model}-train": GPT_COLOR_BASE,
        f"{gpt_model}-test": GPT_COLOR_FADE1,
        f"{gpt_model}-train+test": GPT_COLOR_FADE2,
    }

    grouped = df.groupby(["Topic", "Human/LLM", "Type", "Model+Type"])[DIVERSITY_COLUMN].mean().reset_index()
    grouped["Topic Index"] = grouped["Topic"].map(topic_index_map)

    plt.figure(figsize=(14, 6))
    ax = sns.barplot(
        data=grouped,
        x="Topic Index",
        y=DIVERSITY_COLUMN,
        hue="Model+Type",
        palette=color_map
    )

    bar_width = 0.2
    for bar in ax.patches:
        bar.set_width(bar_width)
        bar.set_x(bar.get_x() + (bar.get_width() - bar_width) / 2)

    ax.set_ylim(-2, 2)
    ax.set_title(f"Diversity Change ({split_key})")
    ax.set_xlabel("Topic Index")
    ax.set_ylabel(DIVERSITY_COLUMN)
    ax.legend(title="Model+Type", bbox_to_anchor=(1.05, 1), loc="upper left")
    plt.tight_layout()
    plt.savefig(save_path, bbox_inches="tight")
    plt.close()
    print(f"[INFO] Saved plot: {save_path}")

def plot_compare_only_gpt():
    target_models = [
        "human",
        "gpt-4o-mini-2024-07-18",
        SPLIT_MODEL_IDS["topic_split"],
        SPLIT_MODEL_IDS["group_split"],
        SPLIT_MODEL_IDS["round_split"],
        "ft:gpt-4o-mini-2024-07-18:camer:round-split-valid:BRTJQLtG"
    ]

    all_rows = []
    for split_dir in SPLITS.values():
        path = os.path.join(PROJECT_ROOT, "result", "group_level_eval", split_dir, "stats_summary.csv")
        if os.path.exists(path):
            df = pd.read_csv(path)
            df = df[(df["Type"] == "train+test") & (df["Human/LLM"].isin(target_models))]
            all_rows.append(df)
        else:
            print(f"[WARNING] Missing {path}")

    df_all = pd.concat(all_rows)

    topic_index_map = generate_topic_index_map(df_all)
    df_all["Topic Index"] = df_all["Topic"].map(topic_index_map)

    grouped = df_all.groupby(["Topic Index", "Human/LLM"])[DIVERSITY_COLUMN].mean().reset_index()

    ordered = target_models
    color_map = {
        "human": "#e63946",  # red
        "gpt-4o-mini-2024-07-18": "#7a7a7a",
        SPLIT_MODEL_IDS["topic_split"]: "#5e5f60",
        SPLIT_MODEL_IDS["group_split"]: "#404144",
        SPLIT_MODEL_IDS["round_split"]: "#202123",
        "ft:gpt-4o-mini-2024-07-18:camer:round-split-valid:BRTJQLtG": "#0d0d0d"
    }

    plt.figure(figsize=(14, 6))
    ax = sns.barplot(
        data=grouped,
        x="Topic Index",
        y=DIVERSITY_COLUMN,
        hue="Human/LLM",
        hue_order=ordered,
        palette=color_map
    )

    bar_width = 0.2
    for bar in ax.patches:
        bar.set_width(bar_width)
        bar.set_x(bar.get_x() + (bar.get_width() - bar_width) / 2)

    ax.set_ylim(-2, 2)
    ax.set_ylabel(DIVERSITY_COLUMN)
    ax.set_xlabel("Topic Index")
    ax.set_title("GPT Model Diversity Comparison (Train+Test)")
    ax.legend(bbox_to_anchor=(1.05, 1), loc="upper left")
    plt.tight_layout()

    save_path = os.path.join(PROJECT_ROOT, "result", "group_level_eval", "comparison_gpt_opinion_diversity.svg")
    plt.savefig(save_path, bbox_inches="tight")
    plt.close()
    print(f"[INFO] Saved comparison plot: {save_path}")

def main(mode):
    if mode == "compare":
        plot_compare_only_gpt()
    else:
        split_dir = SPLITS[mode]
        df_path = os.path.join(PROJECT_ROOT, "result", "group_level_eval", split_dir, "stats_summary.csv")
        if not os.path.exists(df_path):
            print(f"[WARNING] Missing {df_path}")
            return
        df = pd.read_csv(df_path)
        save_path = os.path.join(PROJECT_ROOT, "result", "group_level_eval", split_dir, f"{mode}_opinion_diversity.svg")
        plot_split_restricted(df, mode, save_path)

if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--mode", required=True, choices=["all", "group_split", "topic_split", "round_split", "compare"])
    args = parser.parse_args()
    if args.mode == "all":
        for m in SPLITS:
            main(m)
    else:
        main(args.mode)
