import os
import pandas as pd
import matplotlib.pyplot as plt
import numpy as np

# Paths
CSV_FILE_PATH = "../../data/nbos/nbos.csv"  # Update this path if needed
FIGURE_DIR = "../../figures/nbos"
os.makedirs(FIGURE_DIR, exist_ok=True)

# Load data
df = pd.read_csv(CSV_FILE_PATH)

# Derived metrics
df["avgPayoff"] = (df["manPayoff"] + df["womanPayoff"]) / 2
df["avgAccuracy"] = (df["manAccuracy"] + df["womanAccuracy"]) / 2
df["idRound"] = df["idRound"].astype(int)
df["total_messages"] = df["total_messages"].astype(int)

# Plot settings
color_palette = {
    'qwen3': '#c02942',
    'llama3': '#32a68c',
    'mistral-small': '#ff6941',
    'deepseek-r1': '#5862ed',
    'gpt-4.5-preview-2025-02-27': '#7abaff',
}
linestyles = {
    0: 'solid',
    1: 'dashed',
    2: 'dotted',
    3: 'dashdot'
}

# Plotting function: One figure per model
def plot_metric_by_model(metric, ylabel, title_suffix, filename_suffix, ylim):
    grouped = df.groupby(["man_model", "total_messages", "idRound"]).agg(
        mean_val=(metric, "mean"),
        sem_val=(metric, lambda x: np.std(x, ddof=1) / np.sqrt(len(x)))
    ).reset_index()
    grouped["ci95"] = 1.96 * grouped["sem_val"]

    for model in df["man_model"].unique():
        plt.figure(figsize=(10, 6))

        model_group = grouped[grouped["man_model"] == model]
        for total_messages, group in model_group.groupby("total_messages"):
            label = f"Messages = {total_messages}"
            linestyle = linestyles.get(total_messages, 'solid')
            color = color_palette.get(model, '#555')

            plt.plot(group["idRound"], group["mean_val"], label=label,
                     linestyle=linestyle, color=color)
            plt.fill_between(group["idRound"],
                             group["mean_val"] - group["ci95"],
                             group["mean_val"] + group["ci95"],
                             color=color, alpha=0.2)

        plt.xlim(1, df["idRound"].max())
        plt.ylim(*ylim)
        plt.xlabel("Round")
        plt.ylabel(ylabel)
        plt.title(f"{model} – {title_suffix}")
        plt.grid(True)
        plt.legend(title="Total Messages", loc="best")
        plt.tight_layout()
        plt.savefig(os.path.join(FIGURE_DIR, f"{model}_{filename_suffix}.svg"), format="svg")
        plt.show()

# Plot payoff per model
plot_metric_by_model(
    metric="avgPayoff",
    ylabel="Average Payoff per Round",
    title_suffix="Mean Payoff by Message Count (95% CI)",
    filename_suffix="payoff",
    ylim=(0, 3)
)

# Plot accuracy per model
plot_metric_by_model(
    metric="avgAccuracy",
    ylabel="Prediction Accuracy per Round",
    title_suffix="Prediction Accuracy by Message Count (95% CI)",
    filename_suffix="accuracy",
    ylim=(0, 1.05)
)
