import os
import pandas as pd
import matplotlib.pyplot as plt
import numpy as np

# Paths
CSV_FILE_PATH = "../../data/nbos/nbos.csv"
FIGURE_DIR = "../../figures/nbos"
os.makedirs(FIGURE_DIR, exist_ok=True)

# Load data
df = pd.read_csv(CSV_FILE_PATH)

# We assume man_model == woman_model
df["model"] = df["man_model"]

# Define payoff according the role
df["initiatorPayoff"] = np.where(df["is_man_initiator"], df["manPayoff"], df["womanPayoff"])
df["responderPayoff"] = np.where(df["is_man_initiator"], df["womanPayoff"], df["manPayoff"])

# Mean and 95% CI for initiator
initiator_stats = df.groupby("model")["initiatorPayoff"].agg(
    mean="mean",
    sem=lambda x: np.std(x, ddof=1) / np.sqrt(len(x))
).reset_index()
initiator_stats["ci95"] = 1.96 * initiator_stats["sem"]

# Mean and 95% CI for initiator
responder_stats = df.groupby("model")["responderPayoff"].agg(
    mean="mean",
    sem=lambda x: np.std(x, ddof=1) / np.sqrt(len(x))
).reset_index()
responder_stats["ci95"] = 1.96 * responder_stats["sem"]

# Merge
merged = pd.merge(
    initiator_stats[["model", "mean", "ci95"]].rename(columns={"mean": "avgInitiator", "ci95": "ci95_initiator"}),
    responder_stats[["model", "mean", "ci95"]].rename(columns={"mean": "avgResponder", "ci95": "ci95_responder"}),
    on="model", how="outer"
).fillna(0)

# Mapping names
label_mapping = {
    'qwen3': 'Qwen3',
    'llama3': 'Llama3',
    'mistral-small': 'Mistral-Small',
    'deepseek-r1': 'Deepseek-R1',
    'gpt-4.5-preview-2025-02-27': 'GPT-4.5'
}
merged["label"] = merged["model"].map(label_mapping)

# Color
color_palette = {
    'Qwen3': '#c02942',
    'Llama3': '#32a68c',
    'Mistral-Small': '#ff6941',
    'Deepseek-R1': '#5862ed',
    'GPT-4.5': '#7abaff',
}

# Plot
plt.figure(figsize=(10, 6))
bar_width = 0.35
x = np.arange(len(merged))

# Color per model
colors = [color_palette.get(lbl, "#1f77b4") for lbl in merged["label"]]

# Bar with error (CI 95%) and transparencyy
bars1 = plt.bar(x - bar_width/2, merged["avgInitiator"],
        width=bar_width, label="Initiator",
        color=colors, yerr=merged["ci95_initiator"], capsize=5,
        edgecolor="black", alpha=1.0)

bars2 = plt.bar(x + bar_width/2, merged["avgResponder"],
        width=bar_width, label="Responder",
        color=colors, yerr=merged["ci95_responder"], capsize=5,
        alpha=0.5)

plt.xticks(x, merged["label"], rotation=30, ha="right")
plt.ylabel("Average Payoff per Round")
plt.title("Average Payoff per Round by Model (Initiator vs Responder, 95% CI)")
plt.legend()
plt.tight_layout()

# Save & Show
plt.savefig(os.path.join(FIGURE_DIR, "nbos_barchart.svg"), format="svg")
plt.show()