import h5py
import numpy as np
from tqdm import tqdm
import math
from scipy.special import comb
from scipy.integrate import quad
from scipy.stats import beta
import matplotlib.pyplot as plt
import pandas as pd
from pathlib import Path

def get_beta_params(file_name, k):

    with h5py.File(file_name, 'r') as f:
        dataset = f['distances']
        dataset_train = f['train']
        dataset_test = f['test']

        angular = [[] for _ in range(k)]
        

        for i in range(len(dataset)):
            for j in range(k):
                angular[j].append(dataset[i][j])

        sample_size = 2000
        for i in range(k):
            random_indices_test = np.random.choice(len(dataset_test), size=sample_size, replace=False)
            random_indice_train = np.random.choice(len(dataset_train), size=sample_size, replace=False)
            random_indice_train.sort()
            random_indices_test.sort()
            random_train = dataset_train[random_indice_train]
            random_test = dataset_test[random_indices_test]
            
            product_dot = np.sum(random_train * random_test, axis=1)
            norm_test = np.linalg.norm(random_test, axis=1)
            norm_train = np.linalg.norm(random_train, axis=1)
            cos_sim = product_dot / (norm_test * norm_train)
            
            cos_sim = np.clip(cos_sim, -1.0, 1.0)
            angles = np.arccos(cos_sim)
            angular[i].extend(angles)
            
    angular = np.array(angular)
    beta_params = []
    
    for i, data in enumerate(angular):
       
        scaled_data = data / math.pi
        scaled_data = np.clip(scaled_data, 1e-6, 1-1e-6)
        
        
        a, b, loc, scale = beta.fit(scaled_data, floc=0, fscale=1)
        beta_params.append((a, b))
        print(f" a={a:.4f}, b={b:.4f}")
        
    return beta_params

def integrand(theta, M, d, beta_params):

    a, b = beta_params
    
 
    similarity = theta / math.pi
    

    summation = sum(comb(M, j) * (similarity ** j) * ((1 - similarity) ** (M - j)) 
                   for j in range(d + 1))
    prob = summation
    
 
    pdf_value = beta.pdf(similarity, a, b) / math.pi
    
    return prob * pdf_value

def plot_rho_values(rho_values, rho_mean, K, output_dir=None):
    
    plt.figure(figsize=(10, 6))
    
    
 
    plt.bar(range(1, K+1), rho_values, color='skyblue', edgecolor='black')
    
    
    plt.axhline(y=rho_mean, color='r', linestyle='--', label=f'平均值: {rho_mean:.6f}')
    
   
    for i, v in enumerate(rho_values):
        plt.text(i+1, v + 0.001, f'{v:.6f}', ha='center', fontsize=9)
    
    
    plt.xticks(range(1, K+1))
    plt.legend()
    plt.grid(axis='y', linestyle='--', alpha=0.7)
    
    
    plt.tight_layout()
    
   
    if output_dir:
        output_path = Path(output_dir) / "rho_values.png"
        plt.savefig(output_path, dpi=300, bbox_inches='tight')
      
    plt.show()

def plot_beta_distributions(beta_params, K, output_dir=None):
   
    plt.figure(figsize=(12, 8))
    

    x = np.linspace(0, 1, 1000)
    

    for i, (a, b) in enumerate(beta_params[:K]):

        pdf = beta.pdf(x, a, b)
        
       
        plt.plot(x, pdf, label=f'a={a:.4f}, b={b:.4f}')
        
        # 添加峰值标注
        peak_x = (a - 1) / (a + b - 2) if a > 1 and b > 1 else 0.5
        peak_y = beta.pdf(peak_x, a, b)
        plt.annotate(f'({peak_x:.4f}, {peak_y:.4f})', 
                     xy=(peak_x, peak_y), 
                     xytext=(peak_x+0.1, peak_y+0.5),
                     arrowprops=dict(facecolor='black', shrink=0.05, width=1.5, headwidth=8))
    

    plt.legend()
    plt.grid(True, linestyle='--', alpha=0.7)
    plt.xlim(0, 1)
    plt.ylim(0, None) 
    

    plt.tight_layout()

    if output_dir:
        output_path = Path(output_dir) / "beta_distributions.png"
        plt.savefig(output_path, dpi=300, bbox_inches='tight')


    plt.show()

def save_results(beta_params, rho_values, rho_mean, output_dir):

    output_dir = Path(output_dir)
    output_dir.mkdir(parents=True, exist_ok=True)
    
   
    beta_df = pd.DataFrame(beta_params, columns=['a', 'b'])
    beta_df.index += 1  

    beta_path = output_dir / "beta_parameters.csv"
    beta_df.to_csv(beta_path)



if __name__ == '__main__':
    file_name = ""
  
    M = 256  
    K = 10    
    d = 100   
    output_dir = "results"  
    
    
    beta_params = get_beta_params(file_name=file_name, k=K)
    
    rho_values = []
    for k in range(K):
       
        params = beta_params[k]
        
        
        integral, _ = quad(integrand, 0, math.pi, args=(M, d, params))
        rho_values.append(integral)
    
    
    rho = np.mean(rho_values)

    
   
    save_results(beta_params, rho_values, rho, output_dir)
    
   
    plot_rho_values(rho_values, rho, K, output_dir)
    
   
    plot_beta_distributions(beta_params, K, output_dir)