import os
import pandas as pd
import matplotlib.pyplot as plt
import numpy as np

# Paths
CSV_FILE_PATH = "../../data/nbos/nbos.csv"  # Update this path if needed
FIGURE_DIR = "../../figures/nbos"
os.makedirs(FIGURE_DIR, exist_ok=True)

# Load data
df = pd.read_csv(CSV_FILE_PATH)

# Derived metrics
df["avgPayoff"] = (df["manPayoff"] + df["womanPayoff"]) / 2
df["avgAccuracy"] = (df["manAccuracy"] + df["womanAccuracy"]) / 2
df["idRound"] = df["idRound"].astype(int)
df["total_messages"] = df["total_messages"].astype(int)

# Plot settings
color_palette = {
    'qwen3': '#c02942',
    'llama3': '#32a68c',
    'mistral-small': '#ff6941',
    'deepseek-r1': '#5862ed',
    'gpt-4.5-preview-2025-02-27': '#7abaff',
}
linestyles = {
    0: 'solid',
    1: 'dashed',
    2: 'dotted',
    3: 'dashdot'
}

# Filter only selected models
selected_models = ["qwen3", "gpt-4.5-preview-2025-02-27"]
df_filtered = df[df["man_model"].isin(selected_models)]

# Shared plot function for both metrics, one figure per metric
def plot_metric_combined(df_input, metric, ylabel, title, filename, ylim):
    grouped = df_input.groupby(["man_model", "total_messages", "idRound"]).agg(
        mean_val=(metric, "mean"),
        sem_val=(metric, lambda x: np.std(x, ddof=1) / np.sqrt(len(x)))
    ).reset_index()
    grouped["ci95"] = 1.96 * grouped["sem_val"]

    plt.figure(figsize=(12, 7))

    for (model, total_messages), group in grouped.groupby(["man_model", "total_messages"]):
        label = f"{model} (Messages = {total_messages})"
        linestyle = linestyles.get(total_messages, 'solid')
        color = color_palette.get(model, '#555')

        plt.plot(group["idRound"], group["mean_val"], label=label,
                 linestyle=linestyle, color=color)
        plt.fill_between(group["idRound"],
                         group["mean_val"] - group["ci95"],
                         group["mean_val"] + group["ci95"],
                         color=color, alpha=0.2)

    plt.xlim(1, df_input["idRound"].max())
    plt.ylim(*ylim)
    plt.xlabel("Round")
    plt.ylabel(ylabel)
    plt.title(title)
    plt.grid(True)
    plt.legend(title="Model & Messages", loc="best")
    plt.tight_layout()
    plt.savefig(os.path.join(FIGURE_DIR, f"{filename}.svg"), format="svg")
    plt.show()

# Plot average payoff (combined)
plot_metric_combined(
    df_input=df_filtered,
    metric="avgPayoff",
    ylabel="Average Payoff per Round",
    title="Average Payoff by Round (95% CI)",
    filename="nbos_payoff",
    ylim=(0, 3)
)

# Plot prediction accuracy (combined)
plot_metric_combined(
    df_input=df_filtered,
    metric="avgAccuracy",
    ylabel="Prediction Accuracy per Round",
    title="Prediction Accuracy by Round (95% CI)",
    filename="nbos_prediction",
    ylim=(0, 1.05)
)
