import pandas as pd
import numpy as np

# Create the data for the 5 specific models
data = {
    'Models': [
        'Qwen3-4B-Base',
        'Qwen3-4B-SFT',
        'Qwen3-4B-Mistral',
        'Qwen3-4B-Gemini',
        'Qwen3-4B (SPIRAL)'
    ],
    'Game Avg': [18.65, 46.52, 36.65, 39.31, 59.5],
    'Benchmark Avg': [34.0, 39.7, 29.6, 33.4, 44.5],
    'MATH500': [73.4, 74.2, 64.0, 69.2, 78.2],
    'AIME24': [9.6, 13.7, 4.3, 5.2, 19.7],
    'AIME25': [6.2, 11.7, 2.1, 4.7, 13.3],
    'OlympiadBench': [33.3, 37.6, 29.8, 33.8, 41.8],
    'AMC-23': [42.4, 51.1, 31.6, 29.8, 61.6],
    'Minerva Math': [29.4, 40.1, 26.1, 33.8, 42.6],
    'GPQA': [30.6, 37.8, 35.6, 35.3, 40.1],
    'MMLU-Pro': [47.2, 51.3, 43.6, 55.5, 58.5],
}

df = pd.DataFrame(data)

# Display basic info and first few rows to understand the structure
print("Data structure:")
print(df.head())
print(f"\nShape: {df.shape}")
print(f"Columns: {list(df.columns)}")

import matplotlib.pyplot as plt
import seaborn as sns
from matplotlib.patches import Patch

plt.rcParams.update({'font.size': 18})
plt.rcParams['axes.labelsize'] = 18
plt.rcParams['axes.titlesize'] = 18
plt.rcParams['xtick.labelsize'] = 20
plt.rcParams['ytick.labelsize'] = 20
plt.rcParams['legend.fontsize'] = 18

# Set style
sns.set_theme(style="whitegrid")

# Define hatches - only Qwen3-4B (SPIRAL) should have hatches
models = df["Models"].unique()
hatches = []
for model in models:
    if model == "Qwen3-4B (SPIRAL)":
        hatches.append("///")  # Diagonal hatching for SPIRAL
    else:
        hatches.append("")  # No hatching for other models

# Prepare data for plotting
plot_columns = [
    "Game Avg", "Benchmark Avg",
    "MATH500", "AIME24", "AIME25", "OlympiadBench", "AMC-23", "Minerva Math",
    "GPQA", "MMLU-Pro"
]

# Melt dataframe to long format
df_long = df.melt(id_vars="Models", value_vars=plot_columns,
                  var_name="Benchmark", value_name="Score")

# Create custom color palette with darker color for Qwen3-4B (SPIRAL)
models = df["Models"].unique()
base_color = (1.0, 0.6, 0.2)  # light orange RGB
palette_orange = sns.light_palette(base_color, n_colors=len(models), input="rgb")

# Make Qwen3-4B (SPIRAL) darker
custom_palette = {}
for i, model in enumerate(models):
    if model == "Qwen3-4B (SPIRAL)":
        # Use a darker orange for Qwen3-4B (SPIRAL)
        custom_palette[model] = (0.8, 0.4, 0.1)  # darker orange
    else:
        custom_palette[model] = palette_orange[i+1]

# Define boundaries for vertical dashed lines
# We'll separate at: after Benchmark Avg (2), after Minerva Math (8)
vertical_lines = [2 - 0.5, 8 - 0.5]  # x-axis index positions

# Plot
plt.figure(figsize=(12, 5))
ax = sns.barplot(x="Benchmark", y="Score", hue="Models", data=df_long, palette=custom_palette)

# Get the bar patches and apply hatch to specific ones
print(f"len(ax.patches): {len(ax.patches)}")
print(f"len(plot_columns): {len(plot_columns)}")
print(f"len(models): {len(models)}")
for i, bar in enumerate(ax.patches):
    index = i // len(plot_columns)
    if index < len(models):
        model_name = models[index]
        if model_name == "Qwen3-4B (SPIRAL)":
            bar.set_hatch('///')

handles, labels = ax.get_legend_handles_labels()
new_handles = []
for i, (handle, label) in enumerate(zip(handles, labels)):
    if label == "Qwen3-4B (SPIRAL)":
        new_handle = Patch(facecolor=handle.get_facecolor(), hatch='///',
                           label=label)
        new_handles.append(new_handle)
    else:
        new_handles.append(handle)
handles = new_handles

# Draw vertical dashed lines
for vline in vertical_lines:
    plt.axvline(x=vline, color='gray', linestyle='--', linewidth=1)

# Add values on bars
# for p in ax.patches:
#     height = p.get_height()
#     ax.annotate(f'{height:.1f}',
#                 (p.get_x() + p.get_width() / 2., height),
#                 ha='center', va='bottom',
#                 fontsize=9, color='black', rotation=0)

legend = ax.legend(handles, labels, frameon=True, bbox_to_anchor=(0.55, 0.53), loc='best', fontsize=14)
for text in legend.get_texts():
    print(text.get_text())
    if text.get_text() == "Qwen3-4B (SPIRAL)":  # Change to the specific label you want bold
        text.set_weight("bold")  # Make only this entry bold
        print('set bold')

# Titles and labels
plt.ylabel("Score (%)", fontsize=14)
plt.xlabel("")
plt.xticks(rotation=30, ha="right", fontsize=14)

ax.xaxis.label.set_size(14)
ax.yaxis.label.set_size(14)

# Set y-axis limits to 0-80
ax.set_ylim(0, 80)

# Improve layout
plt.tight_layout()

# Save the plot
plt.savefig("fig1.pdf", bbox_inches="tight")
plt.savefig("assets/fig1-1.png", dpi=300, bbox_inches="tight")

# Display the plot
# plt.show()

print("\nPlot saved as 'qwen3_4b_performance_comparison.png' and 'qwen3_4b_performance_comparison.pdf'")
print("\nKey findings:")
print("- SPIRAL (Multi-Game) achieves the best overall performance")
print("- SPIRAL methods consistently outperform traditional SFT approaches")
print("- Game-specific SPIRAL training shows strong improvements in math benchmarks")