import matplotlib.pyplot as plt
import json
from utils import moving_average, increase_line

datasets = ['AIME24', 'AIME25', 'AMC', "Olympiad"]

colors = ['deepskyblue', 'tomato']

ylabels = ['pass@8', 'pass@8', 'pass@8', 'pass@1']

leg = ['GRPO', 'GRPO-HAMMER']

window_sizes = [15, 10, 20, 20]  # Adjust window size for each dataset

linewidth = 0.5
alpha = 0.3
fontsize = 16


fig, axes = plt.subplots(nrows=1, ncols=4, figsize=(14, 2.5))

for i, dataset in enumerate(datasets):
    with open(f'{dataset}.json', 'r') as f:
        results = json.load(f)
        grpo_data = results['GRPO']
        hammer_data = results['GRPO-HAMMER']

    window_size = window_sizes[i]
    offset = (window_size - 1) // 2

    smoothed_dapo_y = moving_average(grpo_data['y'], window_size)
    smoothed_dapo_x = grpo_data['x'][offset : -offset][:len(smoothed_dapo_y)]
    smoothed_hammer_y = moving_average(hammer_data['y'], window_size)
    smoothed_hammer_x = hammer_data['x'][offset : -offset][:len(smoothed_hammer_y)]

    inc_dapo_x, inc_dapo_y = increase_line(grpo_data['x'][:offset], grpo_data['y'][:offset], smoothed_dapo_y[0])
    inc_hammer_x, inc_hammer_y = increase_line(hammer_data['x'][:offset], hammer_data['y'][:offset], smoothed_hammer_y[0])

    smoothed_dapo_x = inc_dapo_x + smoothed_dapo_x
    smoothed_dapo_y = inc_dapo_y + smoothed_dapo_y
    smoothed_hammer_x = inc_hammer_x + smoothed_hammer_x
    smoothed_hammer_y = inc_hammer_y + smoothed_hammer_y

    axes[i].plot(grpo_data['x'], grpo_data['y'], color=colors[0], alpha=alpha, linewidth=linewidth)
    line1, = axes[i].plot(smoothed_dapo_x, smoothed_dapo_y, color=colors[0], label=leg[0])
    axes[i].plot(hammer_data['x'], hammer_data['y'], color=colors[1], alpha=alpha, linewidth=linewidth)
    line2, = axes[i].plot(smoothed_hammer_x, smoothed_hammer_y, color=colors[1], label=leg[1])

    dataset_name = {
        'AIME24': 'AIME 2024',
        'AIME25': 'AIME 2025',
        'AMC': 'AMC 2023',
        'Olympiad': 'Olympiad'
    }

    axes[i].set_title(dataset_name[dataset], fontsize=fontsize)
    axes[i].set_xlabel('Steps', fontsize=fontsize)
    axes[i].set_ylabel(ylabels[i], fontsize=fontsize)
    axes[i].grid(True)

    if i == 0:
        handles = [line1, line2]

fig.subplots_adjust(wspace=0.5, hspace=0.2)
fig.legend(handles, leg, loc='center', bbox_to_anchor=(0.51, 1.05), ncol=2, frameon=False, fontsize=fontsize)
fig.tight_layout()
fig.savefig('qwen3-1.7b-grpo-training-dynamic.png', dpi=300, bbox_inches='tight')