import matplotlib.pyplot as plt
import matplotlib.patches as patches
import numpy as np

# 设置风格
plt.style.use('seaborn-v0_8-whitegrid')
plt.rcParams['font.family'] = 'sans-serif'
plt.rcParams['font.sans-serif'] = ['Arial', 'DejaVu Sans']

# 创建画布，上下排列
fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(10, 12), gridspec_kw={'height_ratios': [1.2, 1]})
fig.subplots_adjust(hspace=0.3)

# ==========================================
# 图 1: 散点图 (Jason Wei's Verifier's Rule)
# ==========================================

# 模拟数据点
# 容易验证的任务 (Green/Blue) - 右下角
x_easy = np.random.normal(7, 1, 15)
y_easy = np.random.normal(2, 1, 15)

# 难以验证的任务 (Orange) - 左上角 (你的目标区域)
x_hard = np.random.normal(2.5, 0.8, 8)
y_hard = np.random.normal(7.5, 0.8, 8)

# 中间过渡任务 (Blue/Grey)
x_mid = np.random.normal(5, 1, 10)
y_mid = np.random.normal(4, 1, 10)

# 绘制散点
ax1.scatter(x_easy, y_easy, color='#4CAF50', s=100, alpha=0.8, label='Verifiable') # Green
ax1.scatter(x_mid, y_mid, color='#5D99C6', s=80, alpha=0.7) # Blue
ax1.scatter(x_hard, y_hard, color='#FF9800', s=120, alpha=0.9, label='Hard to Verify') # Orange

# 添加具体标签 (手动微调位置以防重叠)
labels = [
    (2.2, 8.2, "Factual Essay"),
    (2.3, 7.6, "Creative Writing"),
    (2.1, 7.0, "Open-ended Agent"),
    (7.5, 2.5, "Math"),
    (7.2, 1.8, "Code"),
    (7.8, 1.2, "Short Q&A")
]

for x, y, text in labels:
    ax1.text(x + 0.2, y, text, fontsize=18, fontweight='bold', color='#333333')

# 绘制红色的虚线圈 (重点区域)
ellipse = patches.Ellipse((2.5, 7.5), width=4.5, height=3.5, 
                          edgecolor='#C0392B', facecolor='none', 
                          linestyle='--', linewidth=3)
ax1.add_patch(ellipse)

# 添加红色箭头和注释 "My Target / Our Intersection"
ax1.annotate('My Target /\nOur Intersection', 
             xy=(4.8, 7.5), xycoords='data',
             xytext=(6.5, 8.5), textcoords='data',
             arrowprops=dict(facecolor='#C0392B', shrink=0.05, edgecolor='#C0392B'),
             fontsize=24, fontweight='bold', color='#C0392B', ha='left')

# 设置轴标签和标题
ax1.set_title('Jason Wei’s “Verifier’s Rule”\nEasy to verify vs. Hard to verify', fontsize=27, fontweight='bold', pad=20)
ax1.set_xlabel('Task complexity', fontsize=21)
ax1.set_ylabel('Difficulty of verification', fontsize=21)

# 隐藏刻度数字，保持抽象概念
ax1.set_xticks([])
ax1.set_yticks([])
ax1.set_xlim(0, 10)
ax1.set_ylim(0, 10)

# ==========================================
# 图 2: 柱状图 (Verification Success Rate)
# ==========================================

# 数据
categories = ['Math', 'Code', 'Factual Essay', 'Agent Plan']
values = [0.92, 0.89, 0.12, 0.15]
colors = ['#008000', '#008000', '#FF8C00', '#FF8C00'] # 前两个绿色，后两个橙色

# 绘制柱状图
bars = ax2.bar(categories, values, color=colors, width=0.6)

# 设置Y轴范围
ax2.set_ylim(0, 1.0)

# 设置标题和标签
ax2.set_title('Verification Success Rate', fontsize=27, fontweight='bold', pad=15)
ax2.tick_params(axis='x', labelsize=14)
ax2.tick_params(axis='y', labelsize=12)

# 在柱子上添加数值
for bar in bars:
    height = bar.get_height()
    ax2.text(bar.get_x() + bar.get_width()/2., height + 0.02,
             f'{height:.2f}',
             ha='center', va='bottom', fontsize=18, fontweight='bold')

# 添加底部的总结文字
fig.text(0.5, 0.02, 
         "Current RLHF success is concentrated\nin verifiable domains.", 
         ha='center', fontsize=24, fontweight='normal', color='#333333')

# 去掉上边框和右边框，让图表更干净
for ax in [ax1, ax2]:
    ax.spines['top'].set_visible(False)
    ax.spines['right'].set_visible(False)
    ax.spines['left'].set_linewidth(1.5)
    ax.spines['bottom'].set_linewidth(1.5)

plt.tight_layout(rect=[0, 0.05, 1, 1]) # 留出底部文字空间
plt.show()
plt.savefig('research_vision_slide.png') # 保存图片