import pandas as pd
import matplotlib.pyplot as plt
import numpy as np

# Path to the CSV file
CSV_FILE_PATH = "../../data/guess/guess.csv"

# Load the data
df = pd.read_csv(CSV_FILE_PATH)

# Convert necessary columns to appropriate types
df["idRound"] = df["idRound"].astype(int)
df["outcomeRound"] = df["outcomeRound"].astype(float)

# List of opponent strategies to consider
opponent_strategies = ["always_rock", "always_paper", "always_scissor"]

# Filter and copy the relevant subset
df_filtered = df[df["opponentStrategy"].isin(opponent_strategies)].copy()

# Color palette
color_palette = {
    'gpt-4.5-preview-2025-02-27': '#7abaff',
    'gpt-4.5-preview-2025-02-27 strategy': '#7abaff',
    'llama3': '#32a68c',
    'llama3 strategy': '#32a68c',
    'llama3.3:latest': '#4b9f7d',
    'llama3.3:latest strategy': '#4b9f7d',
    'mistral-small': '#ff6941',
    'mistral-small strategy': '#ff6941',
    'mixtral:8x7b': '#f1a61a',
    'mixtral:8x7b strategy': '#f1a61a',
    'deepseek-r1': '#5862ed',
    'deepseek-r1 strategy': '#5862ed',
    'deepseek-r1:7b': '#9a7bff',
    'deepseek-r1:7b strategy': '#9a7bff',
    'random': '#000000',
    'qwen3': '#c02942'
}

# Linestyle palette (cycling through styles)
linestyle_dict = {
    'gpt-4.5-preview-2025-02-27': 'solid',
    'gpt-4.5-preview-2025-02-27 strategy': 'solid',
    'llama3': 'dotted',
    'llama3 strategy': 'dotted',
    'llama3.3:latest': 'dotted',
    'llama3.3:latest strategy': 'dotted',
    'mistral-small': 'dashed',
    'mistral-small strategy': 'dashed',
    'mixtral:8x7b': 'dashed',
    'mixtral:8x7b strategy': 'dashed',
    'deepseek-r1': 'dashdot',
    'deepseek-r1 strategy': 'dashdot',
    'deepseek-r1:7b': 'dashdot',
    'deepseek-r1:7b strategy': 'dashdot',
    'random': 'solid',
    'qwen3': 'dotted'
}

# Aggregate data
agg_data = df_filtered.groupby(["model", "idRound"]).agg(
    mean_outcome=("outcomeRound", "mean"),
    sem_outcome=("outcomeRound", lambda x: np.std(x, ddof=1) / np.sqrt(len(x)))
).reset_index()

agg_data["ci95"] = 1.96 * agg_data["sem_outcome"]

### --- First Figure: Models (no 'strategy' in name) ---

plt.figure(figsize=(10, 6))
model_only = agg_data[~agg_data["model"].str.contains("strategy")]

for model in model_only["model"].unique():
    df_model = model_only[model_only["model"] == model]
    color = color_palette.get(model, '#63656a')
    linestyle = linestyle_dict.get(model, 'solid')

    plt.plot(df_model["idRound"], df_model["mean_outcome"], label=model, color=color, linestyle=linestyle)
    plt.fill_between(df_model["idRound"],
                     df_model["mean_outcome"] - df_model["ci95"],
                     df_model["mean_outcome"] + df_model["ci95"],
                     color=color, alpha=0.2)

plt.xlim(1, 10)
plt.xlabel("Round Number")
plt.ylabel("Average Points Earned")
plt.title("Model Performance Against Constant Strategies")
plt.legend()
plt.grid(True)
plt.ylim(0, 1)
plt.savefig('../../figures/guess/guess_constant_models.svg', format='svg')


### --- Second Figure: Strategies (models with 'strategy' in name) ---

plt.figure(figsize=(10, 6))
strategy_only = agg_data[agg_data["model"].str.contains("strategy")]

for model in strategy_only["model"].unique():
    df_model = strategy_only[strategy_only["model"] == model]
    color = color_palette.get(model, '#63656a')
    linestyle = linestyle_dict.get(model, 'dashed')

    plt.plot(df_model["idRound"], df_model["mean_outcome"], label=model, color=color, linestyle=linestyle)
    plt.fill_between(df_model["idRound"],
                     df_model["mean_outcome"] - df_model["ci95"],
                     df_model["mean_outcome"] + df_model["ci95"],
                     color=color, alpha=0.2)

plt.xlim(1, 10)
plt.xlabel("Round Number")
plt.ylabel("Average Points Earned")
plt.title("Model Strategies vs Constant Behaviour")
plt.legend()
plt.grid(True)
plt.ylim(0, 1)
plt.savefig('../../figures/guess/guess_constant_strategies.svg', format='svg')