import matplotlib.pyplot as plt
import json
from utils import moving_average, increase_line

colors = ['deepskyblue', 'forestgreen', 'tomato']

leg = ['DAPO', 'DAPO-MAX', 'DAPO-HAMMER']

# window_sizes = [15, 10, 20, 20]  # Adjust window size for each dataset
window_size = 20
offset = (window_size - 1) // 2

linewidth = 0.5
alpha = 0.3
fontsize = 16


# fig, axes = plt.subplots(nrows=1, ncols=4, figsize=(14, 2.5))
fig = plt.figure(figsize=(4, 2.2))

with open(f'AIME24.json', 'r') as f:
    results = json.load(f)
    dapo_data = results['DAPO']
    hammer_data = results['DAPO-HAMMER']
    dapo_mh_data = results['DAPO-MAX']


smoothed_dapo_y = moving_average(dapo_data['y'], window_size)
smoothed_dapo_x = dapo_data['x'][offset : -offset][:len(smoothed_dapo_y)]
smoothed_hammer_y = moving_average(hammer_data['y'], window_size)
smoothed_hammer_x = hammer_data['x'][offset : -offset][:len(smoothed_hammer_y)]
smoothed_dapo_mh_y = moving_average(dapo_mh_data['y'], window_size)
smoothed_dapo_mh_x = dapo_mh_data['x'][offset : -offset][:len(smoothed_dapo_mh_y)]

inc_dapo_x, inc_dapo_y = increase_line(dapo_data['x'][:offset], dapo_data['y'][:offset], smoothed_dapo_y[0])
inc_hammer_x, inc_hammer_y = increase_line(hammer_data['x'][:offset], hammer_data['y'][:offset], smoothed_hammer_y[0])
inc_dapo_mh_x, inc_dapo_mh_y = increase_line(dapo_mh_data['x'][:offset], dapo_mh_data['y'][:offset], smoothed_dapo_mh_y[0])

smoothed_dapo_x = inc_dapo_x + smoothed_dapo_x
smoothed_dapo_y = inc_dapo_y + smoothed_dapo_y
smoothed_hammer_x = inc_hammer_x + smoothed_hammer_x
smoothed_hammer_y = inc_hammer_y + smoothed_hammer_y
smoothed_dapo_mh_x = inc_dapo_mh_x + smoothed_dapo_mh_x
smoothed_dapo_mh_y = inc_dapo_mh_y + smoothed_dapo_mh_y

plt.plot(dapo_data['x'], dapo_data['y'], color=colors[0], alpha=alpha, linewidth=linewidth)
line1, = plt.plot(smoothed_dapo_x, smoothed_dapo_y, color=colors[0], label=leg[0])
plt.plot(dapo_mh_data['x'], dapo_mh_data['y'], color=colors[1], alpha=alpha, linewidth=linewidth)
line2, = plt.plot(smoothed_dapo_mh_x, smoothed_dapo_mh_y, color=colors[1], label=leg[1])
plt.plot(hammer_data['x'], hammer_data['y'], color=colors[2], alpha=alpha, linewidth=linewidth)
line3, = plt.plot(smoothed_hammer_x, smoothed_hammer_y, color=colors[2], label=leg[2])

# plt.set_title("", fontsize=fontsize)
plt.xlabel('Steps', fontsize=fontsize)
plt.ylabel("pass@8", fontsize=fontsize)
plt.grid(True)

handles = [line1, line2, line3]

fig.subplots_adjust(wspace=0.5, hspace=0.2)
fig.legend(handles, leg, loc='center', bbox_to_anchor=(0.51, 1.00), ncol=3, frameon=False, fontsize=10, columnspacing=0.5)
fig.tight_layout()
fig.savefig('qwen3-1.7b-dapo-with-mh.png', dpi=300, bbox_inches='tight')