import os
import pandas as pd
import matplotlib.pyplot as plt
import numpy as np

# Path to the CSV file
CSV_FILE_PATH = "../../data/mp/mp_claude35.csv"
FIGURE_DIR = "../../figures/mp"
os.makedirs(FIGURE_DIR, exist_ok=True)

# Load and clean data
df = pd.read_csv(CSV_FILE_PATH, sep=";")
df = df[df["outcomeRound"].notnull()]
df["idRound"] = df["idRound"].astype(int)
df["outcomeRound"] = df["outcomeRound"].astype(float)
df["predictionRound"] = df.get("predictionRound", 0).fillna(0).astype(float)

# Filter opponent strategies
opponent_strategies = ["always_head", "always_tail"] #"H-T", "always_head",
df_filtered = df[df["opponentStrategy"].isin(opponent_strategies)].copy()

# Plot settings
color_palette = {
    'qwen3': '#c02942', 'qwen3 strategy': '#c02942',
    'llama3': '#32a68c', 'llama3 strategy': '#32a68c',
    'mistral-small': '#ff6941', 'mistral-small strategy': '#ff6941',
    'deepseek-r1': '#5862ed', 'deepseek-r1 strategy': '#5862ed',
    'gpt-4.5-preview-2025-02-27': '#7abaff',
}
linestyle_dict = {
    'qwen3': 'dotted',
    'llama3': 'dashed',
    'mistral-small': 'solid',
    'deepseek-r1': 'dashdot',
}

# Function to plot
def plot_metric(metric: str, ylabel: str, title: str, filename: str, ylim: tuple):
    agg = df_filtered.groupby(["model", "idRound"]).agg(
        mean_val=(metric, "mean"),
        sem_val=(metric, lambda x: np.std(x, ddof=1) / np.sqrt(len(x)))
    ).reset_index()
    agg["ci95"] = 1.96 * agg["sem_val"]

    plt.figure(figsize=(12, 7))
    for model, group in agg.groupby("model"):
        label = model
        color = color_palette.get(model, '#63656a')
        linestyle = linestyle_dict.get(model, 'solid')
        plt.plot(group["idRound"], group["mean_val"], label=label,
                 color=color, linestyle=linestyle)
        plt.fill_between(group["idRound"],
                         group["mean_val"] - group["ci95"],
                         group["mean_val"] + group["ci95"],
                         color=color, alpha=0.2)

    plt.xlim(1, 10)
    plt.ylim(*ylim)
    plt.xlabel("Round Number")
    plt.ylabel(ylabel)
    plt.title(title)
    plt.legend(loc="upper right")
    plt.grid(True)
    plt.savefig(os.path.join(FIGURE_DIR, filename), format="svg")
    plt.show()


# Plot Payoff
plot_metric(
    metric="outcomeRound",
    ylabel="Average Points Earned",
    title="MP: Average Points Earned per Round against Constant Behaviour (95% CI)",
    filename="mp_payoff_claude35.svg",
    ylim=(-1, 1)
)

# Plot Prediction Score
plot_metric(
    metric="predictionRound",
    ylabel="Prediction Accuracy",
    title="MP: Prediction Accuracy per Round against Constant Behaviour (95% CI)",
    filename="mp_prediction_claude35.svg",
    ylim=(0, 1.05)
)
