import numpy as np


def get_rate_above_threshold(ind_test_scores, threshold):
    """
    Get the rate of the samples that are above the threshold.

    :param ind_test_scores: in-distribution samples
    :param threshold: the threshold
    :return: rate of samples above threshold
    """
    above_threshold = ind_test_scores[ind_test_scores > threshold]
    return len(above_threshold) / len(ind_test_scores)


def find_threshold_ind_data(ind_test_scores, percentile=0.01):
    """
    Find a threshold for which the in-distribution data has less percentile of
    samples with scores above the threshold.

    :param ind_test_scores: scores for in-distribution test or validation data.
    :param percentile: we find the threshold for which the percentile of the
    data points have lower scores than the threshold

    :return: the threshold
    """
    # Do binary search for the threshold.
    max = np.max(ind_test_scores)
    min = np.min(ind_test_scores)
    mean = np.median(ind_test_scores)

    counter_trials = 0
    max_trials = 100
    rate = get_rate_above_threshold(
        ind_test_scores=ind_test_scores, threshold=mean)
    while counter_trials < max_trials and rate != percentile:
        counter_trials += 1
        if rate < percentile:
            # The mean threshold is too high.
            max = mean
            mean = (min + max) / 2
        else:
            # The mean threshold is too low.
            min = mean
            mean = (min + max) / 2
        rate = get_rate_above_threshold(
            ind_test_scores=ind_test_scores, threshold=mean)
    return mean
