import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd

# 设置 seaborn 风格以获得更美观的图形
sns.set_style("whitegrid", {'grid.linestyle': '--'})

# 创建一个 figure 和 axes 对象
fig, ax = plt.subplots(figsize=(12, 9))

# 数据
# 我们需要估算每个点的位置 (Training Epochs, FID50K) 和大小
# 注意：点的大小 (s) 需要手动调整以匹配图像效果
data = {
    'Model': ['DUPA-XL/2', 'DUPA-XL/2', 'MDTv2', 'SiT-XL/2', 'DiT-XL/2','DDT-XL/2'],
    'Training Epochs': [80, 400,4600//5 , 1400, 1400,400],
    'FID50K': [1.95, 1.48, 1.58, 2.06, 2.27,1.87],
    'Size': [675, 675, 675, 675, 675,675],
    'Color': ['#1a9898', '#1a9898', '#e0c8f7', '#ff99c8', '#fcf6c7','#A3D8F4']
}

# 额外的数据点（DDT-L/2 和 DDT-XL/2 旁边的小点）
extra_points = {
    'Training Epochs': [80+64, 80+128, 80+64+128,80+256],
    'FID50K': [1.86, 1.77,1.67, 1.58],
    'Size': [100, 100, 100,100],
    'Color': ['#1a9898', '#1a9898', '#1a9898', '#1a9898']
}


df = pd.DataFrame(data)
df_extra = pd.DataFrame(extra_points)

repa_outer_size = 8000 # 估算的大小，比内部圆稍大
repa_inner_size = 450# 内部绿色圆圈的大小，与原来的 REPA-XL/2 大小相近
repa_epochs = 800
repa_fid = 1.42
#, '#c9f6d7'
ax.scatter(
    repa_epochs,
    repa_fid,
    s=repa_outer_size,
    c='gray', # 外层环形改为灰色
    alpha=0.8,
    edgecolors='white',
    linewidth=2,
    zorder=2 # 确保它在其他元素之下，但内部圆之上
)
# 然后绘制内部的绿色圆圈
ax.scatter(
    repa_epochs,
    repa_fid,
    s=repa_inner_size*4, # 内部圆圈大小
    c='#c9f6d7', # 保持原来的浅绿色
    alpha=0.8,
    # edgecolors='white',
    linewidth=2,
    zorder=3 # 确保它在灰色圆圈之上
)

# 绘制主要的散点图（气泡图）
scatter = ax.scatter(
    df['Training Epochs'],
    df['FID50K'],
    s=df['Size']*4,
    c=df['Color'],
    alpha=0.8,
    edgecolors='white', # 给气泡添加白色描边
    linewidth=2
)

# 绘制额外的小点
ax.scatter(
    df_extra['Training Epochs'],
    df_extra['FID50K'],
    s=df_extra['Size'],
    c=df_extra['Color'],
    alpha=0.8,
    edgecolors='white',
    linewidth=1.5,
    marker='D' # 使用菱形标记
)


# 添加每个数据点的标签
for i, row in df.iterrows():
    ax.text(
        row['Training Epochs'] + 65, # x 轴方向的偏移量
        row['FID50K']+0.03,            # y 轴方向的偏移量
        row['Model'],
        fontsize=20,
        fontweight='regular',
        verticalalignment='center' # 垂直居中对齐
    )

ax.text(
    repa_epochs +120,
    repa_fid,
    'REPA-XL/2',
    fontsize=20,
    fontweight='regular',
    verticalalignment='center',
    zorder=4
)

# 设置坐标轴标签和标题
ax.set_xlabel('Training Epochs', fontsize=22, fontweight='bold')
ax.set_ylabel('FID-50K on ImageNet256x256', fontsize=22, fontweight='bold')

# 设置坐标轴的范围
ax.set_xlim(0, 1750)
ax.set_ylim(1.3, 2.35)

# 设置坐标轴刻度的字体大小
ax.tick_params(axis='both', which='major', labelsize=18)

# 添加箭头和旁边的文字
# ax.text(
#     '3x Training Acc',
#     # xy=(340, 1.58), # 箭头指向的位置
#     xytext=(400, 1.58+0.05), # 文字的起始位置
#     fontsize=15,
#     fontweight='bold',
#     verticalalignment='center',
#     color='#004f4f'
# )
ax.text(
    s='3x Training Acc',
    x=410,
    y=1.58+0.02,
    fontsize=20,
    color='#004f4f'
)
from matplotlib.patches import FancyArrowPatch
arrow = FancyArrowPatch((355, 1.58),(870, 1.58), arrowstyle="<->", mutation_scale=20, fc="k", ec="k",linewidth=2)
plt.gca().add_patch(arrow)
# ax.grid(alpha=0.6)
for spine in ax.spines.values():
    spine.set_alpha(1)
    spine.set_color("black")      # 设置透明度为1 (完全不透明)
    spine.set_linewidth(2) # 你可以调整这个数值来改变粗细
# 调整网格线的透明度
# ax.grid(alpha=0.6)

# 显示图形
plt.savefig('head.svg', format='svg', bbox_inches='tight')