import numpy as np
from sklearn.mixture import GaussianMixture as GMM
def search_threshold(weight: np.ndarray,bins=10):
    hist_y, hist_x = np.histogram(np.log10(weight), bins=bins, range=(np.log10(min(weight)), np.log10(max(weight))))

    log_candidates = hist_y==0
    hist_y_diff = np.diff(hist_y)
    
    if not np.any(log_candidates):
        log_candidates = (hist_y_diff[:-1]<=0) * (hist_y_diff[1:]>=0)
        log_candidates = np.where(log_candidates)[0]+1
        chosen_idx = np.sort(log_candidates)[len(log_candidates)//2]
        return 10**hist_x[chosen_idx]
    # filter by monotonicity
    log_candidates = log_candidates[1:-1] #remove last: never be zero
    log_candidates = log_candidates * (hist_y_diff[:-1]<=0) * (hist_y_diff[1:]>=0)
    log_candidates = np.where(log_candidates)[0]+1

    if len(log_candidates) == 1:
        chosen_idx = log_candidates[0]
        # return 10**

    else:
        chosen_idx =  np.random.choice(log_candidates[1:])
        
    return 10**hist_x[chosen_idx]

def search_threshold_ratio(weight:np.ndarray,preserve_ratio:float=0.7):
    boundary_idx=int(len(weight)*preserve_ratio)
    boundaries =  np.sort(weight)[boundary_idx:boundary_idx+2]

    return np.mean(boundaries)

def gaussian_intersection(m1,m2,cov1,cov2,w1,w2,cnt=0):
    a = 1/(2*cov1) - 1/(2*cov2)
    b = m2/(cov2) - m1/(cov1)
    c = m1**2 /(2*cov1) - m2**2 / (2*cov2) - np.log(cov2/cov1)/2 - np.log(w1/w2)
    cands = np.roots([a,b,c])
    ret = [x for x in cands if  m2<x<m1 or m1<x<m2]
    if len(ret) == 0:
        if cnt > 10:
            print('GMM failure')
            return [(m1*w1+m2*w2)/(w1+w2)]
        try:
            reduced_ret = gaussian_intersection(m1,m2,cov1,cov2,w1/2+w2/2,w2,cnt=cnt+1)
            return reduced_ret
            # return reduced_ret
            # if len(reduced_ret) == 0:
            #     return gaussian_intersection(m1,m2,cov1,cov2,1,1)
            # else:
            #     return reduced_ret
        except RecursionError:
            print('GMM failure')
            return [(m1*w1+m2*w2)/(w1+w2)]
    else:
        return ret

def search_threshold_GMM(ref_vector:np.array,n_components=3,random_state:int=0,covariance_type='full',**kwargs):
    gmm =GMM(n_components=n_components,random_state =random_state,covariance_type=covariance_type).fit(np.log10(ref_vector.reshape(-1,1)))

    best_two_ind = sorted(range(n_components),key=lambda x: gmm.weights_[x],reverse=True)[:2]
    means = gmm.means_[best_two_ind].flatten()
    covs = gmm.covariances_[best_two_ind].flatten()
    weights = gmm.weights_[best_two_ind].flatten()
    try:
        eps = gaussian_intersection(*[*means,*covs,*weights])[0]
    except RecursionError:
        print(f'GMM failure with {[*means,*covs,*weights]}. return mean of mean:')
        eps = (means[0]*weights[0]+means[1]*weights[1])/(weights[0]+weights[1])

    return 10**eps