import matplotlib.pyplot as plt
import matplotlib.cm as cm
import numpy as np
methods = ['DCKD', 'DPO', 'DPKD', 'Ours']
kl_div = [2.39, 3.01, 2.35, 2.95]
margin = [9.05, 22.56, 9.05, 14.80]
mt_bench = [3.43, 3.73, 3.55, 3.97]

# 색상 매핑
norm = plt.Normalize(min(mt_bench), max(mt_bench))
#cmap = cm.get_cmap('OrRd')
cmap=cm.Reds

fig, ax = plt.subplots(figsize=(8, 4))

# Base scatter (invisible, just for colorbar)
scatter = ax.scatter(kl_div, margin, c=mt_bench, cmap=cmap, s=0)

# 개별 점 찍기 (더 크게, 별표 포함)
for i, method in enumerate(methods):
    marker = '*' if method == 'Ours' else 'o'
    s=500 if method == 'Ours' else 300
    ax.scatter(kl_div[i], margin[i], c=[cmap(norm(mt_bench[i]))], s=s, edgecolor='k', marker=marker, zorder=3)

    # 텍스트 위치 조절
    if method == 'DCKD':
        ax.text(kl_div[i]+0.12, margin[i]-0.4, method, fontsize=18, ha='right')
    elif method == 'DPKD':
        ax.text(kl_div[i], margin[i]+0.7, method, fontsize=18, ha='center')
    elif method == 'DPO':
        ax.text(kl_div[i]-0.1, margin[i], method, fontsize=18, ha='left')
    elif method == 'Ours':
        ax.text(kl_div[i]-0.1, margin[i] + 0.2, method, fontsize=18, ha='left')

# 점을 잇는 선
sorted_idx = np.argsort(kl_div)
sorted_kl = np.array(kl_div)[sorted_idx]
sorted_margin = np.array(margin)[sorted_idx]
ax.plot(sorted_kl, sorted_margin, linestyle='--', color='gray', linewidth=1)

# 축, 라벨, 타이틀
ax.set_xlabel('KL-divergence (Teacher)', fontsize=18)
ax.set_ylabel('Reward Margin (Test Set)', fontsize=18)
#ax.set_title('Comparison of Preference Optimization Methods', fontsize=14)
ax.set_xlim(2.3, 3.1)
ax.set_ylim(7.5, 24)
# 색상바 (scatter 객체로부터 직접 추출)
cbar = plt.colorbar(scatter, ax=ax)
cbar.set_label('MT-bench Score',fontsize=18)

plt.grid(False)
plt.tight_layout()
#plt.show()
plt.savefig('kl_plt.png', dpi=300, bbox_inches='tight')