from top_k_recall import get_beta_params_recall,get_recall_parallel
from selective_ratio import get_beta_params_ratio,get_selective_ratio

def find_parameter(file_name,recall,M_low,M_high,K,step=16):
    ans = []
    recall_params = get_beta_params_recall(file_name=file_name,k=K)
    alpha,beta = get_beta_params_ratio(file_name=file_name)
    for M in range(M_low,M_high+1,step):
        d_left = 1
        d_right = M
        d_optimal = 0
        while d_left <= d_right:
            d_mid = d_left + (d_right - d_left) // 2
            recall_predict = get_recall_parallel(d=d_mid,m=M,K=K,beta_params=recall_params)
            print(f"recall_predict:{recall_predict}")
            print(f"M:{M}")
            print(f"d:{d_mid}")
            if recall_predict >= recall:
                d_right = d_mid - 1
                d_optimal = d_mid
            else:
                d_left = d_mid + 1
        recall_predict = get_recall_parallel(d=d_optimal,m=M,K=K,beta_params=recall_params)
        ratio = get_selective_ratio(alpha=alpha,beta=beta,d=d_optimal,M=M)
        print(f"ratio:{ratio}")
        ans.append((M,d_optimal,recall_predict,ratio))
    ans.sort(key=lambda x : x[3])
    return ans
    
if __name__ == '__main__':
    file_name = ''
    recall = 0.9
    M_low = 128
    M_high = 320
    K=10
    ans = find_parameter(file_name=file_name,recall=recall,M_low=M_low,M_high=M_high,K=K)
    print(ans[0])
    
            