import matplotlib.pyplot as plt
import h5py
import scipy.stats as stats
import numpy as np
from tqdm import tqdm
import math


def plot_and_validation(file_name, png_name, k):
    angular = []

    with h5py.File(file_name, 'r') as f:
        dataset = f['distances']
        for i in range(len(dataset)):
            angular.append(dataset[i][k - 1])
   
    alpha_ = 5.0  
    beta_ = 5.0   
    num_samples = 10000
    angular_copy = angular.copy()
    samples_tmp = []
    # 为每个近邻组采样并添加数据
    for j in range(k):
        samples = np.random.beta(alpha_, beta_, size=num_samples).tolist()
        angular.extend(samples)
        samples_tmp = samples
    transformed_scores = np.array(angular) / math.pi
    data_to_fit = transformed_scores
    angular_to_fit = np.array(angular_copy) / math.pi
    samples_to_fit = np.array(samples_tmp) / math.pi

    a, b, loc, scale = stats.beta.fit(data_to_fit, floc=0, fscale=1)

    
    a_origin,b_origin,loc_origin,scale_origin = stats.beta.fit(angular_to_fit,floc=0,fscale=1)

    
    a_sample,b_sample,loc_sample,scale_sample = stats.beta.fit(samples_to_fit,floc=0,fscale=1)

    plt.figure(figsize=(12, 7))

  
    plt.hist(angular_to_fit, bins=50, density=True, color='lightblue', edgecolor='gray', alpha=0.7, label='origin')

    plt.hist(data_to_fit,bins=50,density=True, color='lightyellow', edgecolor='gray', alpha=0.7, label='sample')

    xmin, xmax = plt.xlim()
    x = np.linspace(xmin, xmax, 200)

    pdf_beta = stats.beta.pdf(x, a, b, loc=loc, scale=scale)
    plt.plot(x, pdf_beta, 'r-', linewidth=2.5, label=f'origin (a={a:.2f}, b={b:.2f})')
    pdf_beta_origin = stats.beta.pdf(x,a_origin,b_origin,loc=loc_origin,scale=scale_origin)
    plt.plot(x,pdf_beta_origin,'b-',linewidth=2.5, label=f'sample (a={a_origin:.2f}, b={b_origin:.2f})')
    
    pdf_beta_sample = stats.beta.pdf(x,a_sample,b_sample,loc=loc_sample,scale=scale_sample)
    plt.plot(x,pdf_beta_sample,'g-',linewidth=2.5, label=f'sample (a={a_sample:.2f}, b={b_sample:.2f})')
    
    plt.title(f'{png_name} - origin an Fit', fontsize=14)
    plt.xlabel('Transformed Cosine Similarity', fontsize=12)
    plt.ylabel('Probability Density', fontsize=12)
    plt.legend(fontsize=10)
    plt.grid(True, linestyle='--', alpha=0.8)

    plt.savefig(png_name)

    print(f"{png_name}")

if __name__ == '__main__':
    file_name = ''
    png_name = 'glove-50-angular-top1-sample.png'
    k = 1
    plot_and_validation(file_name, png_name, k=k)