import pandas as pd
import matplotlib.pyplot as plt
import numpy as np
import matplotlib.patheffects as path_effects
import matplotlib.patches as mpatches

# 1. 创建示例数据
data = {
    'method': ['SiT-L/2', 'DDT-L/2','DDT-L/2','DDT-L/2', 'DUPA-L/2', 'DUPA-L/2','DUPA-L/2'],
    'iter': ["400K", "400K", "400K", "800K", "400K","400K","400K"],
    'name': ["SiT", "DDT", "DDT", "DDT", "DUPA","DUPA","DUPA"],
    'py': [0.015, 0.02, 0.02, 0.02, 0.025,0.025,0.025],
    'batch_size': [256, 256, 512, 256, 256,256,256],
    'sample_nums': [1, 1, 1, 1, 2,3,4],
    'training_speed': [0.97, 1.01, 1.87, 1.01, 1.12,1.26,1.43],
    'FID': [18.8, 15.2, 13.9, 13.4, 11.1,10.8,10.7],
    'gpu_usage': [23017, 23842,36348,23842,28530,33265,39072] # e.g., in GiB
}
df = pd.DataFrame(data)

# 配色方案
colors = ['#69b3a2', '#c1e5c0', '#d3b7e3', '#f4a5c3', '#fef0c2', "#6b6677", "#448F91"]
colors = ['#9ED2C6', '#FFCBCB', '#A3D8F4', '#FFD3B6', '#D9B8FF', '#B5EAD7', '#FF9AA2']
colors = ['#FF6B6B', '#A4FFB9', '#6BFF6B', '#47D974', '#E6A4FF', '#D96BFF', '#B84DD9']
df['color'] = colors[:len(df)]


# 2. 开始绘图
fig, ax = plt.subplots(figsize=(13,6))

# 绘制气泡图
scatter = ax.scatter(
    df['training_speed'],
    df['FID'],
    s=(df['gpu_usage']/2000)**3 * 1.5,
    c=df['color'],
    alpha=0.85,
    edgecolors=df['color'],
    linewidth=0,
    zorder=5
)

df = pd.DataFrame(data)
for i, row in df.iterrows():
    ax.text(
        row['training_speed']-row['py'], # x 轴方向的偏移量
        row['FID'],            # y 轴方向的偏移量
        row['name'], 
        fontsize=15,
        fontweight='bold',
        verticalalignment='center', # 垂直居中对齐
        color="white",
        zorder=10
        # fontdict = {'fontfamily': 'Comic', 'fontstyle': 'italic', 'fontweight': 'bold', 'color': 'white'}
    )

# 3. 设置坐标轴和标题
ax.set_xlabel('Training Speed (seconds/step)', fontsize=20)
ax.set_ylabel('FID-50K', fontsize=20)
# ax.set_title('Performance vs. Speed Analysis', fontsize=16, pad=20)
ax.grid(False)
from matplotlib.patches import FancyArrowPatch
curve = FancyArrowPatch(
    (1.06, 15.8),
    (1.8, 15),
    connectionstyle="arc3,rad=-0.2", # <--- 关键：将直线变为弧线
    arrowstyle="->",                  # <--- 关键：去掉箭头                 # <--- 增加单向箭头
    mutation_scale=30,
    color="#6BFF6B",               # 使用更柔和的颜色
    linewidth=4,                     # 线条可以稍细一些
    # linestyle='-'                   # 使用虚线增加设计感
)
plt.gca().add_patch(curve)
ax.text(
    1.35,
    18,
    'Batch Size *= 2',
    fontsize=20,
    color="#6BFF6B",
    fontweight='bold',
    verticalalignment='center',
    zorder=4
)

curve = FancyArrowPatch(
    (1.05, 15.2),
    (1.05, 13.5),
    connectionstyle="arc3,rad=-0", # <--- 关键：将直线变为弧线
    arrowstyle="->",                  # <--- 关键：去掉箭头                 # <--- 增加单向箭头
    mutation_scale=30,
    color="#6BFF6B",               # 使用更柔和的颜色
    linewidth=4,                     # 线条可以稍细一些
    # linestyle='-'                   # 使用虚线增加设计感
)
plt.gca().add_patch(curve)
ax.text(
    1.07,
    14.5,
    'Iteration *= 2',
    fontsize=20,
    color="#6BFF6B",
    fontweight='bold',
    verticalalignment='center',
    zorder=4
)

curve = FancyArrowPatch(
    (1.12, 12.4),
    (1.25, 12.2),
    connectionstyle="arc3,rad=-0.2", # <--- 关键：将直线变为弧线
    arrowstyle="->",                  # <--- 关键：去掉箭头                 # <--- 增加单向箭头
    mutation_scale=30,
    color="#D96BFF",               # 使用更柔和的颜色
    linewidth=4,                     # 线条可以稍细一些
    # linestyle='-'                   # 使用虚线增加设计感
)
plt.gca().add_patch(curve)
ax.text(
    1.15,
    13.1,
    'Noising Time += 1',
    fontsize=20,
    color="#D96BFF",
    fontweight='bold',
    verticalalignment='center',
    zorder=4
)

curve = FancyArrowPatch(
    (1.27, 12.2),
    (1.41, 12.2),
    connectionstyle="arc3,rad=-0.2", # <--- 关键：将直线变为弧线
    arrowstyle="->",                  # <--- 关键：去掉箭头                 # <--- 增加单向箭头
    mutation_scale=30,
    color="#D96BFF",               # 使用更柔和的颜色
    linewidth=4,                     # 线条可以稍细一些
    # linestyle='-'                   # 使用虚线增加设计感
)
plt.gca().add_patch(curve)

ax.set_ylim(9, 19.7)
ax.set_xlim(0.93, 1.94)
for spine in ax.spines.values():
    spine.set_alpha(1)
    spine.set_color("black")      # 设置透明度为1 (完全不透明)
    spine.set_linewidth(2) # 你可以调整这个数值来改变粗细
plt.tight_layout(rect=[0, 0, 1, 0.96])
# plt.show()
plt.savefig('paopao.svg', format='svg', bbox_inches='tight')