import matplotlib.pyplot as plt

# 데이터
methods = [ 'DPO', 'DPKD', 'DCKD', 'Ours']
kls = [ 3.01255, 2.39395, 2.3597, 2.9537]  # x축
margins = [22.56, 9.0452, 9.05, 14.7967]  # y축

# 플롯
plt.figure(figsize=(8, 6))
plt.scatter(kls, margins, color='blue')

# 각 점에 텍스트 라벨 추가
for method, x, y in zip(methods, kls, margins):
    plt.text(x + 0.05, y, method, fontsize=10, ha='left', va='center')

# 축 및 제목 설정
plt.xlabel('Mean KL Divergence (lower is better)')
plt.ylabel('Margin (higher is better)')
plt.title('Trade-off Between Margin and KL Divergence')
#plt.grid(True)
plt.tight_layout()

# 저장
plt.savefig("margin_vs_kl_scatter.png", dpi=300)
plt.show()