import h5py
import numpy as np
import math
from scipy.integrate import quad
from scipy.stats import beta, binom 
from scipy.optimize import minimize
import time

def get_beta_params_ratio(file_name):

    with h5py.File(file_name, 'r') as f:
        dataset_train = f['train'][:]

    n_samples = len(dataset_train)
    np.random.seed(42)

    num_pairs = 20000
    pairs_i = np.random.randint(0, n_samples, size=num_pairs)
    pairs_j = np.random.randint(0, n_samples, size=num_pairs)

    vecs1 = dataset_train[pairs_i]
    vecs2 = dataset_train[pairs_j]

    vecs1_norm = np.linalg.norm(vecs1, axis=1, keepdims=True)
    vecs2_norm = np.linalg.norm(vecs2, axis=1, keepdims=True)
    vecs1_normalized = np.divide(vecs1, vecs1_norm, out=np.zeros_like(vecs1), where=vecs1_norm!=0)
    vecs2_normalized = np.divide(vecs2, vecs2_norm, out=np.zeros_like(vecs2), where=vecs2_norm!=0)
    cosine = np.sum(vecs1_normalized * vecs2_normalized, axis=1)
    
    cosine = np.clip(cosine, -1.0, 1.0)
    scaled_data = (cosine + 1) / 2
    scaled_data = np.clip(scaled_data, 1e-6, 1-1e-6)

    def neg_log_likelihood(params, data):
        a, b = params
        return -np.sum(beta.logpdf(data, a, b, loc=0, scale=1))
    
    mean_data = np.mean(scaled_data)
    var_data = np.var(scaled_data)
    a_init = max(0.1, mean_data * (mean_data * (1 - mean_data) / var_data - 1))
    b_init = max(0.1, a_init * (1 - mean_data) / mean_data)
    result = minimize(neg_log_likelihood, [a_init, b_init], args=(scaled_data,),
                        method='Nelder-Mead', bounds=[(0.1, None), (0.1, None)])
    a, b = result.x
    return a, b

def integrand_ratio_optimized(theta, M, d, beta_params):

    a, b = beta_params
    similarity = theta / math.pi
    prob = binom.cdf(d, M, similarity)
    pdf_value = beta.pdf((math.cos(theta) + 1) / 2, a, b) / 2 * math.sin(theta)
    return prob * pdf_value

def get_selective_ratio(alpha, beta, d, M):

    params = (alpha, beta)
   
    integral, _ = quad(integrand_ratio_optimized, 0, math.pi, args=(M, d, params), epsabs=1e-12, epsrel=1e-12, limit=5000)
    return integral

if __name__ == '__main__':
    file_name = ""
    # M = [192,208,224,240,256,272,288]
    M = [256]
    D = [62,63,64,65,66,67,68]
    # D = [86]

    start_time = time.time()
    a, b = get_beta_params_ratio(file_name=file_name)
    for m in M:
        for d in D:
            ratio = get_selective_ratio(alpha=a, beta=b, d=d, M=m)
            print(f"threshold: {d}")
            print(f"M: {m}")
            print(f"selective ratio: {ratio:.6f}")

    end_time = time.time()
    print(f"\ntime cost: {end_time - start_time:.2f} second")