import matplotlib.pyplot as plt
import json
from utils import moving_average, increase_line

kk = [int(16), int(32), int(64)]

colors = ['deepskyblue', 'tomato']


leg = [
    'DAPO-16', 'DAPO-16-HAMMER', 
    'DAPO-32', 'DAPO-32-HAMMER',
    'DAPO-64', 'DAPO-64-HAMMER',
]

window_size = 20
offset = (window_size - 1) // 2

linewidth = 0.5
alpha = 0.3
fontsize = 14


fig, axes = plt.subplots(nrows=1, ncols=3, figsize=(10, 2))

for i, k in enumerate(kk):
    with open(f'AMC-{k}.json', 'r') as f:
        results = json.load(f)
        dapo_data = results[f'DAPO-{k}']
        hammer_data = results[f'DAPO-{k}-HAMMER']

    smoothed_dapo_y = moving_average(dapo_data['y'], window_size)
    smoothed_dapo_x = dapo_data['x'][offset : -offset][:len(smoothed_dapo_y)]
    smoothed_hammer_y = moving_average(hammer_data['y'], window_size)
    smoothed_hammer_x = hammer_data['x'][offset : -offset][:len(smoothed_hammer_y)]

    inc_dapo_x, inc_dapo_y = increase_line(dapo_data['x'][:offset], dapo_data['y'][:offset], smoothed_dapo_y[0])
    inc_hammer_x, inc_hammer_y = increase_line(hammer_data['x'][:offset], hammer_data['y'][:offset], smoothed_hammer_y[0])

    smoothed_dapo_x = inc_dapo_x + smoothed_dapo_x
    smoothed_dapo_y = inc_dapo_y + smoothed_dapo_y
    smoothed_hammer_x = inc_hammer_x + smoothed_hammer_x
    smoothed_hammer_y = inc_hammer_y + smoothed_hammer_y

    axes[i].plot(dapo_data['x'], dapo_data['y'], color=colors[0], alpha=alpha, linewidth=linewidth)
    line1, = axes[i].plot(smoothed_dapo_x, smoothed_dapo_y, color=colors[0], label=f'DAPO-{k}')
    axes[i].plot(hammer_data['x'], hammer_data['y'], color=colors[1], alpha=alpha, linewidth=linewidth)
    line2, = axes[i].plot(smoothed_hammer_x, smoothed_hammer_y, color=colors[1], label=f'DAPO-{k}-HAMMER')

    axes[i].legend(frameon=False)
    axes[i].set_xlabel('Steps', fontsize=fontsize)
    axes[i].grid(True)
    axes[i].set_ylabel("pass@1", fontsize=fontsize)

    if i == 0:
        handles = [line1, line2]


fig.subplots_adjust(wspace=0.5, hspace=0.2)
# fig.legend(handles, leg, loc='center', bbox_to_anchor=(0.51, 1.05), ncol=2, frameon=False, fontsize=fontsize)
fig.tight_layout()
fig.savefig('qwen3-1.7b-dapo-batchsize-ablation.png', dpi=300, bbox_inches='tight')