import pandas as pd
import matplotlib.pyplot as plt
import numpy as np

# Path to the CSV file
CSV_FILE_PATH = "../../data/rps/rps.csv"

# Load the data
df = pd.read_csv(CSV_FILE_PATH)

# Convert necessary columns to appropriate types
df["idRound"] = df["idRound"].astype(int)
df["outcomeRound"] = df["outcomeRound"].astype(float)

# List of opponent strategies to consider
opponent_strategies = ["always_rock", "always_paper", "always_scissor"]

# **Fix Warning**: Ensure we work with a full copy
df_filtered = df[df["opponentStrategy"].isin(opponent_strategies)].copy()

# Color palette
color_palette = {
    'gpt-4.5-preview-2025-02-27': '#7abaff',
    'gpt-4.5-preview-2025-02-27 strategy': '#7abaff',
    'llama3': '#32a68c',
    'llama3 strategy': '#32a68c',
    'llama3.3:latest': '#4b9f7d',
    'llama3.3:latest strategy': '#4b9f7d',
    'mistral-small': '#ff6941',
    'mistral-small strategy': '#ff6941',
    'mixtral:8x7b': '#f1a61a',
    'mixtral:8x7b strategy': '#f1a61a',
    'deepseek-r1': '#5862ed',
    'deepseek-r1 strategy': '#5862ed',
    'deepseek-r1:7b': '#9a7bff',
    'deepseek-r1:7b strategy': '#9a7bff',
    'random': '#000000',
    'qwen3': '#c02942'
}

# Linestyle palette (cycling through styles)
linestyle_dict = {
    'gpt-4.5-preview-2025-02-27': 'solid',
    'gpt-4.5-preview-2025-02-27 strategy': 'solid',
    'llama3': 'dotted',
    'llama3 strategy': 'dotted',
    'llama3.3:latest': 'dotted',
    'llama3.3:latest strategy': 'dotted',
    'mistral-small': 'dashed',
    'mistral-small strategy': 'dashed',
    'mixtral:8x7b': 'dashed',
    'mixtral:8x7b strategy': 'dashed',
    'deepseek-r1': 'dashdot',
    'deepseek-r1 strategy': 'dashdot',
    'deepseek-r1:7b': 'dashdot',
    'deepseek-r1:7b strategy': 'dashdot',
    'random': 'solid',
    'qwen3': 'dotted'
}



# Compute mean, standard error (SEM), and 95% confidence interval by model and round
agg_data = df_filtered.groupby(["model", "idRound"]).agg(
    mean_outcome=("outcomeRound", "mean"),
    sem_outcome=("outcomeRound", lambda x: np.std(x, ddof=1) / np.sqrt(len(x)))  # Standard error
).reset_index()

# Compute 95% Confidence Interval (CI)
agg_data["ci95"] = 1.96 * agg_data["sem_outcome"]  # 95% confidence interval

# Set the figure size
plt.figure(figsize=(10, 6))

# Loop through each model and plot its aggregated performance across rounds
for model in agg_data["model"].unique():
    df_model = agg_data[agg_data["model"] == model]
    color = color_palette.get(model, '#63656a')  # Default to light gray if model not in palette
    linestyle = linestyle_dict.get(model, 'solid')

    # Plot mean values
    plt.plot(df_model["idRound"], df_model["mean_outcome"], label=model, color=color, linestyle=linestyle)

    # Add 95% confidence interval (shaded region)
    plt.fill_between(df_model["idRound"],
                     df_model["mean_outcome"] - df_model["ci95"],  # Lower bound (95% CI)
                     df_model["mean_outcome"] + df_model["ci95"],  # Upper bound (95% CI)
                     color=color, alpha=0.2)  # Transparency for shading

# Add legends and labels
plt.xlim(1, 10)
plt.xlabel("Round Number")
plt.ylabel("Average Points Earned")
plt.title("Average Points Earned per Round Against Constant Behaviour (with 95% Confidence Interval)")
plt.legend()
plt.grid(True)
plt.ylim(0, 2)  # Points are between 0 and 2

# Save the figure as an SVG file
plt.savefig('../../figures/rps/rps_constant.svg', format='svg')
