import numpy as np
from scipy.stats import entropy
import matplotlib.pyplot as plt
from scipy.stats import gaussian_kde
import numpy as np
from scipy.integrate import simpson

def kde_kl_divergence(x1, x2, bandwidth='scott', xmin=None, xmax=None, num_points=1000):
    kde1 = gaussian_kde(x1, bw_method=bandwidth)
    kde2 = gaussian_kde(x2, bw_method=bandwidth)

    if xmin is None or xmax is None:
        xmin = min(np.min(x1), np.min(x2))
        xmax = max(np.max(x1), np.max(x2))

    x = np.linspace(xmin, xmax, num_points)
    p = kde1(x)
    q = kde2(x)

    epsilon = 1e-8
    p = p + epsilon
    q = q + epsilon

    kl_div = simpson(p * np.log(p / q), x)
    return kl_div
# 두 모델의 VT 분포 불러오기
vt2 = np.load("/home/minchan.kwon/ADPA/token_visualizations/V-vis/dpo/vt_values_chosen.npy")
vt1 = np.load("/home/minchan.kwon/ADPA/token_visualizations/V-vis/dpo_student/vt_values_chosen.npy")
# histogram 기반 KL 계산 (정규화 필수)
p, _ = np.histogram(vt1, bins=50, range=(-10, 10), density=True)
q, _ = np.histogram(vt2, bins=50, range=(-10, 10), density=True)

kl = kde_kl_divergence(vt1, vt2)
print(f"KDE-based KL divergence: {kl:.4f}")

def plot_distributions(vt1, vt2, label1="Model1", label2="Model2", bins=50):
    plt.hist(vt1, bins=bins, alpha=0.5, label=label1, density=True)
    plt.hist(vt2, bins=bins, alpha=0.5, label=label2, density=True)
    plt.xlabel("V_T (last token soft value)")
    plt.ylabel("Density")
    plt.legend()
    plt.title("Distribution of V_T (chosen)")
    plt.savefig("./V_T_distribution.png")
    
plot_distributions(vt1, vt2)
