import numpy as np
import pandas as pd


worst_case_model_at_fpr_01 = lambda s: 10 ** ((np.log10(s) * -0.2695) + 1.0784)
average_model_at_fpr_01 = lambda s, c: 10 ** ((np.log10(s) * -0.506) + (np.log10(c) * 0.09) + 0.314)
average_model_at_fpr_001 = lambda s, c: 10 ** ((np.log10(s) * -0.555) + (np.log10(c) * 0.182) + 0.083)
average_model_at_fpr_0001 = lambda s, c: 10 ** ((np.log10(s) * -0.627) + (np.log10(c) * 0.300) - 0.173)
dp_bound = (
    lambda epsilon, delta, fpr: min(np.exp(epsilon) * fpr + delta, 1 - np.exp(-epsilon) * (1 - delta - fpr)) - fpr
)


def round_up(n):
    s = str(n)
    l = len(s)
    leading_digit = int(s[0])
    second_digit = int(s[1])
    if int(s[2:]) > 0:
        second_digit += 1
        if second_digit == 10:
            leading_digit += 1
            second_digit = 0
    return 10 ** (l - 1) * leading_digit + 10 ** (l - 2) * second_digit


if __name__ == "__main__":

    c = 2
    data = []
    for epsilon in [0.25, 0.5, 0.75, 1.0]:
        for delta in [1e-5]:

            # fpr = 0.1
            fpr = 0.1
            tpr_bound = dp_bound(epsilon, delta, fpr)
            for s in np.logspace(1, 64, int(1e7), base=2, dtype=int):
                if worst_case_model_at_fpr_01(s) <= tpr_bound:
                    s_worst_case = s
                    break
            for s in np.logspace(1, 64, int(1e7), base=2, dtype=int):
                if average_model_at_fpr_01(s, c) <= tpr_bound:
                    s_average01 = s
                    break

            # fpr = 0.01
            fpr = 0.01
            tpr_bound = dp_bound(epsilon, delta, fpr)
            for s in np.logspace(1, 64, int(1e7), base=2, dtype=int):
                if average_model_at_fpr_001(s, c) <= tpr_bound:
                    s_average001 = s
                    break

            # fpr = 0.001
            fpr = 0.001
            tpr_bound = dp_bound(epsilon, delta, fpr)
            for s in np.logspace(1, 64, int(1e7), base=2, dtype=int):
                if average_model_at_fpr_0001(s, c) <= tpr_bound:
                    s_average0001 = s
                    break

            data.append((epsilon, delta, s_average01, s_average001, s_average0001, s_worst_case))

    data_df = pd.DataFrame(
        data,
        columns=[
            "epsilon",
            "delta",
            "S (average) at FPR = 0.1",
            "S (average) at FPR = 0.01",
            "S (average) at FPR = 0.001",
            "S (worst-case) at FPR=0.1",
        ],
    )

    data_table = data_df.copy()

    for c in [
        "S (average) at FPR = 0.1",
        "S (average) at FPR = 0.01",
        "S (average) at FPR = 0.001",
        "S (worst-case) at FPR=0.1",
    ]:
        data_table[c] = data_table[c].apply(lambda x: round_up(x))

    data_table["S (worst-case) at FPR=0.1"] = data_table["S (worst-case) at FPR=0.1"].apply(
        lambda x: "{:.2e}".format(x).replace("e+", r"\times 10^")
    )
    data_table["S (worst-case) at FPR=0.1"] = data_table["S (worst-case) at FPR=0.1"].apply(
        lambda x: str(np.ceil(float(x[:4]) * 10) / 10) + x[4:]
    )

    data_table = data_table.drop(columns=["delta"])

    table_string = data_table.to_latex(index=False, float_format="%.2f")
    for r in table_string.split("\n"):
        print(r.replace("0.0", "$10^{-5}$"))
