import pandas as pd
import matplotlib.pyplot as plt
import numpy as np
import matplotlib.ticker as ticker

# ==========================================
# 1. 数据准备区域
# ==========================================

# --- 左侧图的数据 (已有的数据) ---
labels_left = []
all_rewards_left = []

labels_left.append("w/o. update")
all_rewards_left.append([
    0.04150, 0.04345, 0.04541, 0.05517, 0.05322,    0.05517, 0.05566, 0.07519, 0.05712, 0.07421,
    0.06396, 0.05126, 0.05078, 0.07714, 0.05712,    0.06054, 0.07568, 0.04785, 0.06982, 0.05908,
    0.05957, 0.05566, 0.07031, 0.07373, 0.08105,    0.06835, 0.07812, 0.07470, 0.07470, 0.06396,
    0.07226, 0.06152, 0.06298, 0.05957, 0.06347,    0.06005, 0.07128, 0.07275, 0.07910, 0.06250,
    0.08007, 0.08447, 0.07275, 0.06494, 0.05810,    0.08496, 0.07519, 0.07128, 0.07861, 0.07031,
    0.08935, 0.07128, 0.05566, 0.08300, 0.06640,    0.06933, 0.07421, 0.07519, 0.07910, 0.07421,
    0.07373, 0.06250, 0.06345, 0.06689, 0.06689,    0.06494, 0.07080, 0.07226, 0.08496, 0.06787,
    0.07763, 0.08740, 0.06982, 0.07177, 0.07763,    0.05859, 0.07226, 0.07568, 0.05787, 0.06884,
    0.06396, 0.06787, 0.07519, 0.07714, 0.07031,    0.06640, 0.08593, 0.08593, 0.06835, 0.08154,
    0.06298, 0.06787, 0.07275, 0.06982, 0.07617,    
])

labels_left.append("w/o. adb w/o. dlw")
all_rewards_left.append([
    0.04394, 0.05175, 0.04101, 0.05761, 0.05761,    0.05712, 0.05712, 0.07275, 0.05908, 0.07275,
    0.05957, 0.04931, 0.05908, 0.08984, 0.06298,    0.05615, 0.07861, 0.05566, 0.06347, 0.06445,
    0.05810, 0.05517, 0.05957, 0.07177, 0.07226,    0.07226, 0.07958, 0.07031, 0.07519, 0.06902,
    0.06201, 0.05810, 0.04736, 0.05566, 0.05810,    0.06933, 0.05957, 0.05712, 0.06933, 0.05761,
    0.06640, 0.07226, 0.05908, 0.05712, 0.05419,    0.06152, 0.06591, 0.06054, 0.06440, 0.05957,
    0.07031, 0.04687, 0.05078, 0.06640, 0.05029,    0.06005, 0.05126, 0.06201, 0.07128, 0.06250,
    0.04785, 0.04394, 0.04785, 0.04931, 0.03515,    0.04736, 0.03759, 0.03710, 0.05810, 0.05517,
    0.03906, 0.05517, 0.03125, 0.04882, 0.06298,    0.02587, 0.03369, 0.04736, 0.05664, 0.03369,
    0.03857, 0.04296, 0.04150, 0.02783, 0.02587,    0.03076, 0.03222, 0.03271, 0.03466, 0.03808,
    0.03222, 0.04003, 0.02294, 0.02685, 0.04589,

])

labels_left.append("w/o. adb")
all_rewards_left.append([
    0.04150, 0.04882, 0.04980, 0.05712, 0.05322,    0.05810, 0.05517, 0.06445, 0.06152, 0.07666,
    0.06591, 0.04248, 0.06250, 0.06738, 0.05859,    0.05664, 0.07128, 0.05810, 0.06640, 0.06542,
    0.06396, 0.05712, 0.06787, 0.06738, 0.08203,    0.07128, 0.08691, 0.06542, 0.07080, 0.06982,
    0.06787, 0.07421, 0.06201, 0.06787, 0.06054,    0.06152, 0.06396, 0.07470, 0.08341, 0.06738,
    0.07812, 0.07910, 0.07275, 0.06835, 0.06347,    0.08837, 0.07031, 0.07275, 0.07812, 0.06494,
    0.08154, 0.07617, 0.05664, 0.08154, 0.07177,    0.07421, 0.07421, 0.08544, 0.08105, 0.07324,
    0.06892, 0.06689, 0.05957, 0.06347, 0.07519,    0.07910, 0.06542, 0.07568, 0.06891, 0.07714,
    0.07861, 0.09423, 0.06884, 0.08007, 0.08593,    0.06591, 0.08007, 0.07226, 0.07666, 0.07324,
    0.06982, 0.07080, 0.06494, 0.07421, 0.07275,    0.06787, 0.08740, 0.07958, 0.07568, 0.09277,
    0.07910, 0.07714, 0.07421, 0.06982, 0.08203,
])

labels_left.append("w/o. dlw")
all_rewards_left.append([
    0.04150, 0.05419, 0.05322, 0.05078, 0.05419,    0.06933, 0.05908, 0.05859, 0.06152, 0.07910,
    0.06152, 0.05468, 0.05566, 0.07470, 0.05761,    0.06835, 0.07568, 0.06933, 0.07128, 0.05468,
    0.05371, 0.05224, 0.06440, 0.07031, 0.08154,    0.07177, 0.07421, 0.07763, 0.07373, 0.07031,
    0.07080, 0.07128, 0.06396, 0.05517, 0.07031,    0.07275, 0.07080, 0.06933, 0.07763, 0.06298,
    0.06494, 0.07226, 0.07958, 0.07421, 0.06396,    0.08496, 0.07177, 0.06054, 0.06982, 0.07324,
    0.06591, 0.07373, 0.06445, 0.07275, 0.06347,    0.06835, 0.06640, 0.06738, 0.07275, 0.07714,
    0.06054, 0.05566, 0.06396, 0.06640, 0.05126,    0.06738, 0.05029, 0.06494, 0.07080, 0.06445,
    0.05810, 0.05273, 0.05664, 0.06787, 0.06054,    0.07617, 0.06494, 0.05078, 0.06689, 0.08056,
    0.06103, 0.06738, 0.07031, 0.06445, 0.05419,    0.05566, 0.06787, 0.06005, 0.07958, 0.06347,
    0.05273, 0.06250, 0.06054, 0.07714, 0.06201,
])

labels_left.append("w/.adb dlw")
all_rewards_left.append([
    0.04150, 0.04785, 0.04882, 0.05615, 0.05224,    0.06738, 0.05517, 0.06835, 0.06054, 0.07080,
    0.06494, 0.04931, 0.06347, 0.08644, 0.06933,    0.06250, 0.07128, 0.06494, 0.07519, 0.06445,
    0.06298, 0.05566, 0.06640, 0.08349, 0.08837,    0.06933, 0.07714, 0.07324, 0.08251, 0.07958,
    0.06933, 0.07519, 0.07177, 0.06591, 0.06835,    0.08641, 0.09130, 0.08251, 0.09130, 0.07861,
    0.08789, 0.08105, 0.08056, 0.07812, 0.07128,    0.07812, 0.08837, 0.07958, 0.09179, 0.07324,
    0.08984, 0.08789, 0.06738, 0.09179, 0.07714,    0.08349, 0.08007, 0.09130, 0.07812, 0.08740,
    0.09228, 0.08544, 0.07275, 0.07424, 0.08349,    0.08789, 0.07421, 0.08740, 0.10640, 0.09765,
    0.08544, 0.10830, 0.08251, 0.08691, 0.09619,    0.07861, 0.09179, 0.07861, 0.08593, 0.09667,
    0.08154, 0.07666, 0.08886, 0.08251, 0.07519,    0.07031, 0.09716, 0.10050, 0.08691, 0.09716,
    0.09033, 0.08447, 0.08544, 0.08398, 0.07666,
])

# --- 右侧图的数据 (占位符) ---
# 提示：在这里填入你的新数据，格式和上面完全一样
labels_right = []
all_rewards_right = [] 

labels_right.append("w/o. update")
all_rewards_right.append([
    0.6078, 0.6956, 0.5078, 0.5052, 0.5000,     0.5338, 0.5026, 0.5416, 0.4895, 0.5416, 
    0.5052, 0.5000, 0.5494, 0.5885, 0.5052,     0.5260, 0.5312, 0.4739, 0.4635, 0.5208,
    0.5052, 0.5208, 0.4427, 0.4868, 0.4791,     0.5260, 0.4895, 0.4973, 0.5000, 0.5390,
    0.5000, 0.5104, 0.4947, 0.5364, 0.4843,     0.5000, 0.5052, 0.5208, 0.5208, 0.5208,
    0.4843, 0.4791, 0.4843, 0.5104, 0.4895,     0.5416, 0.5364, 0.5260, 0.5052, 0.5026,
    0.4947, 0.5156, 0.5390, 0.4947, 0.4947,     0.5000, 0.5442, 0.5026, 0.5052, 0.5156,
    0.5208, 0.5260, 0.5312, 0.5104, 0.4921,     0.4895, 0.5364, 0.4869, 0.5260, 0.5104,
    0.4739, 0.5104, 0.4557, 0.5078, 0.5156,     0.4817, 0.4921, 0.4843, 0.4375, 0.5312,
    0.4531, 0.4739, 0.5000, 0.4895, 0.4661,     0.4843, 0.5546, 0.5234, 0.4661, 0.4843,
    0.4661, 0.5000, 0.4609, 0.5260, 0.4869,
])

labels_right.append("w/o. adb w/o. dlw")
all_rewards_right.append([
    0.6080, 0.6822, 0.5494, 0.6236, 0.6770,     0.6223, 0.5312, 0.6015, 0.5820, 0.5833,
    0.5234, 0.4544, 0.5911, 0.5481, 0.5364,     0.5559, 0.5000, 0.5677, 0.5208, 0.5481,
    0.5520, 0.4466, 0.5729, 0.4544, 0.5130,     0.5169, 0.5520, 0.5000, 0.5651, 0.5716,
    0.4986, 0.4882, 0.5195, 0.4986, 0.4791,     0.5130, 0.4361, 0.5026, 0.3971, 0.4257,
    0.5625, 0.4687, 0.4934, 0.4335, 0.5065,     0.4895, 0.5299, 0.4713, 0.5221, 0.5195,
    0.4557, 0.4583, 0.5117, 0.3997, 0.4635,     0.5924, 0.4596, 0.5130, 0.5182, 0.5000,
    0.4348, 0.4943, 0.5208, 0.4687, 0.4444,     0.4791, 0.4492, 0.4934, 0.4453, 0.5338,
    0.4843, 0.4882, 0.4804, 0.5143, 0.5065,     0.5390, 0.5143, 0.5026, 0.5000, 0.4921,
    0.5611, 0.5013, 0.4830, 0.5156, 0.5091,     0.4674, 0.4986, 0.5312, 0.4622, 0.5130,
    0.5117, 0.5377, 0.5143, 0.4648, 0.5143,
])

labels_right.append("w/o. adb")
all_rewards_right.append([
    0.6078, 0.6778, 0.5000, 0.5364, 0.5000,     0.5130, 0.5156, 0.5520, 0.5260, 0.5260,
    0.5000, 0.4895, 0.5416, 0.5130, 0.4765,     0.4869, 0.5520, 0.5052, 0.5234, 0.5234,
    0.4244, 0.5234, 0.4843, 0.5104, 0.5364,     0.5052, 0.5078, 0.5703, 0.5312, 0.5026,
    0.5078, 0.4973, 0.5000, 0.4973, 0.5625,     0.4713, 0.4739, 0.5520, 0.5416, 0.4869,
    0.4713, 0.5156, 0.5546, 0.5000, 0.5234,     0.5260, 0.5000, 0.5364, 0.5520, 0.5208,
    0.5312, 0.5078, 0.4557, 0.4739, 0.5520,     0.5000, 0.5208, 0.5546, 0.5260, 0.5703,
    0.5156, 0.5000, 0.5104, 0.4791, 0.5312,     0.5260, 0.5078, 0.5364, 0.5546, 0.5208,
    0.5208, 0.5104, 0.5208, 0.4713, 0.5208,     0.4479, 0.5078, 0.5026, 0.5546, 0.5312,
    0.5182, 0.5625, 0.5572, 0.5260, 0.4921,     0.5156, 0.5260, 0.5520, 0.5338, 0.5390,
    0.5625, 0.4739, 0.4947, 0.5364, 0.5312,
])

labels_right.append("w/o. dlw")
all_rewards_right.append([
    0.6171, 0.6901, 0.5468, 0.6614, 0.5104,     0.6171, 0.5677, 0.5286, 0.6250, 0.5104,
    0.5572, 0.6901, 0.5625, 0.5390, 0.5364,     0.5286, 0.4583, 0.4427, 0.5885, 0.5338,
    0.5546, 0.5286, 0.5520, 0.5182, 0.6510,     0.4036, 0.6119, 0.5260, 0.4661, 0.4036,
    0.6250, 0.5468, 0.6197, 0.3854, 0.4427,     0.5000, 0.6250, 0.5755, 0.5755, 0.6328,
    0.6015, 0.4140, 0.4739, 0.6250, 0.4869,     0.6041, 0.5468, 0.5182, 0.5364, 0.5468,
    0.5677, 0.6250, 0.5833, 0.6250, 0.6145,     0.5937, 0.4843, 0.5000, 0.6458, 0.5703,
    0.5026, 0.5468, 0.6015, 0.4817, 0.5546,     0.5390, 0.4973, 0.6380, 0.5625, 0.5781,
    0.5182, 0.6250, 0.5234, 0.5026, 0.5546,     0.5416, 0.6015, 0.5546, 0.4739, 0.5104,
    0.6302, 0.6223, 0.5651, 0.4687, 0.3958,     0.5026, 0.5807, 0.4895, 0.6171, 0.7005,
    0.6718, 0.5260, 0.5520, 0.6640, 0.5833 
])

labels_right.append("w/.adb dlw")
all_rewards_right.append([
    0.6171, 0.7031, 0.5859, 0.6484, 0.6614,     0.6380, 0.6484, 0.4531, 0.6562, 0.6458,
    0.4791, 0.6093, 0.5781, 0.6510, 0.5911,     0.4921, 0.6302, 0.4869, 0.5390, 0.5468,
    0.5859, 0.6119, 0.5755, 0.5937, 0.6093,     0.6015, 0.6770, 0.5286, 0.5937, 0.5000,
    0.5000, 0.6223, 0.6510, 0.5859, 0.4843,     0.5494, 0.5755, 0.4140, 0.4739, 0.5260,
    0.6276, 0.5598, 0.6744, 0.5260, 0.6223,     0.6640, 0.5963, 0.5338, 0.5468, 0.6041,
    0.5625, 0.4921, 0.6484, 0.6223, 0.6562,     0.5156, 0.5859, 0.6093, 0.5937, 0.5234,
    0.5807, 0.5156, 0.5859, 0.6302, 0.5963,     0.6380, 0.6302, 0.6067, 0.5625, 0.6093,
    0.5338, 0.5520, 0.4947, 0.5755, 0.5052,     0.5442, 0.5494, 0.6380, 0.6093, 0.5911,
    0.6536, 0.6119, 0.5546, 0.6093, 0.6885,     0.6328, 0.6197, 0.6223, 0.5546, 0.5546,
    0.4270, 0.6171, 0.5546, 0.6250, 0.5104,     # 0.6510, 0.6692, 0.4531,
])

# ==========================================
# 2. 绘图逻辑
# ==========================================

# 预定义颜色，保证左右两侧颜色一致
colors = ['#1f77b4', '#ff7f0e', '#2ca02c', '#d62728', '#9467bd']
window_size = 3

# 创建 1行2列 的画布
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(16, 6))

# --- 定义一个绘图函数，避免代码重复 ---
def plot_curve(ax, rewards_list, labels_list, title_text=""):
    for i, raw_data in enumerate(rewards_list):
        # 自动生成 x 轴，防止数据长度不一致
        current_x = np.arange(len(raw_data))
        
        s = pd.Series(raw_data)
        smooth_mean = s.rolling(window=window_size, min_periods=1).mean()
        smooth_std = s.rolling(window=window_size, min_periods=1).std()
        
        # 如果颜色不够用，循环使用
        color = colors[i % len(colors)]
        
        # 画误差带 (无label，不进图例)
        ax.fill_between(current_x, 
                        smooth_mean - smooth_std, 
                        smooth_mean + smooth_std, 
                        color=color, 
                        alpha=0.05)
        
        # 画实线 (有label)
        ax.plot(current_x, 
                smooth_mean, 
                color=color, 
                linewidth=2, 
                label=labels_list[i])
    
    # 子图美化
    if len(title_text) != 0:
        ax.set_title(title_text, fontsize=14)
    ax.set_xlabel('Steps', fontsize=20)
    ax.set_ylabel('Reward', fontsize=20) # 只有左图需要Y轴标签，或者两边都加
    ax.tick_params(axis='both', which='major', labelsize=20)
    ax.grid(True, linestyle='--', alpha=0.1)

# --- 3. 执行绘图 ---

# 绘制左图
plot_curve(ax1, all_rewards_left, labels_left)
ax1.set_ylabel('Outcome Training Rewards', fontsize=20)

# 绘制右图
plot_curve(ax2, all_rewards_right, labels_right)
ax2.set_ylabel('RM Accuracy', fontsize=20) # 如果需要右图也有Y轴标签请取消注释


# 1. 设置刻度的位置（0, 10, 20 ... 100）
ax1.xaxis.set_major_locator(ticker.FixedLocator(np.arange(0, 96, 8)))
ax2.xaxis.set_major_locator(ticker.FixedLocator(np.arange(0, 96, 8)))

# 2. 定义显示逻辑：将刻度值除以 10 
# lambda x, pos: 这里的 x 是原始坐标值，我们返回格式化后的字符串
formatter = ticker.FuncFormatter(lambda x, pos: f'{x/8:g}')
ax1.xaxis.set_major_formatter(formatter)
ax2.xaxis.set_major_formatter(formatter)

# 3. 修改标签名字
ax1.set_xlabel('Training Samples (x16384)', fontsize=20)
ax2.set_xlabel('Training Samples (x16384)', fontsize=20)


# --- 4. 设置共用图例 ---
# 获取左图的图例句柄和标签
handles, labels = ax1.get_legend_handles_labels()
# 在整张图的上方居中显示图例
fig.legend(handles, labels, loc='upper center', bbox_to_anchor=(0.5, 1.05), ncol=5, fontsize=20)

plt.tight_layout()
# 这里的 savefig 路径请根据需要修改
plt.savefig("~/verl_cs/fig/dynamic_rm_comparison.pdf", bbox_inches='tight')

