from statsmodels.stats.proportion import proportion_confint
from scipy.stats import norm
import numpy as np


def certify_random(weight_top_class: int, rs_num: int, rs_len: int, atk_len: int, alpha=0.999) -> float:
    NA = weight_top_class
    N = rs_num
    Ns = rs_len
    L_atk = atk_len
    pa = proportion_confint(NA, N, alpha=2 * alpha, method="beta")[0]
    delta= (2- comb_N_n_ns(Ns, N-L_atk , N))/2
    if pa > delta:
        return True
    else:
        return False


def certify_block(weight_top_class: int, weight_second_class: int, rs_len: int, atk_len: int) -> float:
    NA = weight_top_class
    NB = weight_second_class
    Ns = rs_len
    L_atk = atk_len
    delta=Ns+L_atk-1
    if NA -NB > delta:
        return True
    else:
        return False


def comb_N_n_ns(Ns, n, N):
    # C^ns_n / C^ns_N
    s = 1
    if N >= n:
        for i in range(n - Ns + 1, n + 1):
            s *= i
            s = s / (i + N - n)
        return s
    else:
        for i in range(N - Ns + 1, N + 1):
            s *= i
            s = s / (i + n - N)
        return 1 / s


def test_certification_RandomAblation():
    N=1000
    L_atk=10
    Ns=50
    for NA in range(0, 1000,100):
        print(certify_random(NA, N, Ns, L_atk))


def certify(x: np.array, label: np.array, sigma: float, eps: float, alpha=0.001):
    """ Monte Carlo algorithm for certifying that g's prediction around x is constant within some L2 radius.
    With probability at least 1 - alpha, the class returned by this method will equal g(x), and g's prediction will
    robust within a L2 ball of radius R around x.

    :param x: [batch_size, num_classes]
    :param label: [batch_size]
    :param alpha: failure probability of robustness certification
    :param eps: norm of perturbations
    :return: (batch_size, is_correct_and_robust)
                in the case of abstention, the class will be ABSTAIN and the radius 2.
    """
    n = x.sum(axis=1)  # the number of Gaussian samples
    nA = np.max(x, axis=1)  # the number of samples of the top-1 class
    pred = np.argmax(x, axis=1)  # predictions
    is_correct = (label == pred)
    pABar = _lower_confidence_bound(nA, n, alpha)
    radius = sigma * norm.ppf(pABar)
    is_robust = (radius > eps)
    is_ca = is_correct*is_robust
    return is_ca


def _lower_confidence_bound(NA: int, N: int, alpha: float) -> float:
    """ Returns a (1 - alpha) lower confidence bound on a bernoulli proportion.

    This function uses the Clopper-Pearson method.

    :param NA: the number of "successes"
    :param N: the number of total draws
    :param alpha: the confidence level
    :return: a lower bound on the binomial proportion which holds true w.p at least (1 - alpha) over the samples
    """
    return proportion_confint(NA, N, alpha=2 * alpha, method="beta")[0]




if __name__ == '__main__':
    test_certification_RandomAblation()