import matplotlib.pyplot as plt
import numpy as np

# 设置学术风格绘图参数
plt.rcParams['font.family'] = 'serif'
plt.rcParams['axes.linewidth'] = 1.2

# 定义 Sigmoid 和 Loss 函数
def log_sigmoid_loss(s, v, gamma=1.0):
    # 模拟图片中的公式: -log(sigma(s + V(x) - gamma))
    return -np.log(1 / (1 + np.exp(-(s + v - gamma))))

# 创建画布
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(14, 6))

# --- 左图: ADB (Adaptive Bias) 机制 ---
s_range = np.linspace(-10, 10, 200)
gamma = 1.0

# 绘制三条不同 V(x) 的曲线
ax1.plot(s_range, log_sigmoid_loss(s_range, -3, gamma), label='Hard Task ($V(x)=-3$)', color='#ff7f0e', lw=2.5)
ax1.plot(s_range, log_sigmoid_loss(s_range, 0, gamma), label='Neutral ($V(x)=0$)', color='#1f77b4', lw=2.5, linestyle='--')
ax1.plot(s_range, log_sigmoid_loss(s_range, 3, gamma), label='Easy Task ($V(x)=3$)', color='#2ca02c', lw=2.5)

# 添加辅助标注
ax1.annotate('Boundary Shift', xy=(2.5, 4), xytext=(-5, 2),
             arrowprops=dict(arrowstyle='<->', color='blue', lw=1.5))
ax1.set_title('(A) ADB: Adaptive Decision Boundary\nShifting Loss Curves based on $V(x)$', fontweight='bold')
ax1.set_xlabel('Log-ratio term ($s$)')
ax1.set_ylabel('Loss Value')
ax1.grid(True, linestyle=':', alpha=0.6)
ax1.legend()

# --- 右图: DLW (Dynamic Loss Weighting) 机制 ---
# 模拟场景：1个正确 (v)，3个错误 (x)
labels = ['Resp 1 ($\checkmark$)', 'Resp 2 ($\\times$)', 'Resp 3 ($\\times$)', 'Resp 4 ($\\times$)']
standard_weights = [1.0, 1.0, 1.0, 1.0]
dlw_weights = [2.00, 0.67, 0.67, 0.67] # 根据图 中的数据

x = np.arange(len(labels))
width = 0.35

ax2.bar(x - width/2, standard_weights, width, label='Standard (Equal Weight)', color='lightgray', edgecolor='gray')
ax2.bar(x + width/2, dlw_weights, width, label='DLW (Ours)', color=['#2ca02c', '#d62728', '#d62728', '#d62728'])

# 添加权重放大/缩小的标注
ax2.text(0 + width/2, 2.1, 'Amplify Scarce Signal\n($w_i=2.0$)', ha='center', color='#2ca02c', fontsize=10, fontweight='bold')
ax2.text(2 + width/2, 0.8, 'Suppress Frequent Class\n($w_i=0.67$)', ha='center', color='#d62728', fontsize=10)

ax2.set_title('(B) DLW: Dynamic Loss Weighting\nRescuing Weak Models from Sparse Rewards', fontweight='bold')
ax2.set_ylabel('Loss Weight ($w_i$)')
ax2.set_xticks(x)
ax2.set_xticklabels(labels)
ax2.set_ylim(0, 2.8)
ax2.legend()
ax2.grid(axis='y', linestyle=':', alpha=0.6)

plt.tight_layout()

plt.savefig("~/verl_cs/fig/dynamic_beta1.pdf", bbox_inches='tight')
