import numpy as np
import pandas as pd

from audit_one_run.src.google import get_eps_audit as audit_one_run

# evaluate k+ and k- guesses
def get_possible_guesses(df_scores, guess_interval=1):
    assert 'include' in df_scores.columns
    assert 'score' in df_scores.columns


    df = pd.DataFrame(
        (np.arange(len(df_scores)) + 1),
        columns=['num_guesses'],
    )

    x = df_scores.sort_values('score', ascending=False)
    df['num_correct_+'] = x['include'].cumsum().values
    x = df_scores.sort_values('score', ascending=True)
    df['num_correct_-'] = (1 - x['include']).cumsum().values

    df = df.iloc[guess_interval - 1::guess_interval, :]

    return df

# combine k+ and k- guesses
def combine_guesses(df, max_guesses=100):
    df_plus = df[['num_guesses', f'num_correct_+']].copy()
    df_plus.loc[0] = [0, 0]
    df_minus = df[['num_guesses', f'num_correct_-']].copy()
    df_minus.loc[0] = [0, 0]

    df_cross = pd.merge(df_plus, df_minus, how='cross', suffixes=['_+', '_-'])
    df_cross['num_guesses'] = df_cross['num_guesses_+'] + df_cross['num_guesses_-']
    df_cross['num_correct'] = df_cross[f'num_correct_+'] + df_cross[f'num_correct_-']
    df_cross = df_cross[['num_guesses', 'num_correct']]
    df_cross = df_cross.drop_duplicates()

    df_cross = df_cross[df_cross['num_guesses'] > 0]
    df_cross['frac_correct'] = df_cross['num_correct'] / df_cross['num_guesses']
    df = df_cross

    # we don't bother auditing guesses that are worse than random guessing
    df = df[df_cross['frac_correct'] > 0.5]

    # if multiple sets of guesses get the same proportion correct, we keep the one with the highest number of guesses
    df = df.sort_values('num_guesses', ascending=False).groupby('frac_correct').head(1)

    # only keep the top `max_guesses`, sorted by the proportion of correct guesses
    df = df.sort_values('frac_correct', ascending=False).reset_index(drop=True)
    df = df.iloc[:max_guesses]

    return df

def audit(df, m, p=0.05, delta=1e-5):
    audit_fn = lambda x, y: audit_one_run(m, x, y, p=p, delta=delta)

    r_list_baseline = df['num_guesses'].values
    v_list_baseline = df['num_correct'].values

    # keep
    mask = r_list_baseline <= m # throws out invalid combinations where sum(k+, k-) > m
    r_list_baseline, v_list_baseline = r_list_baseline[mask], v_list_baseline[mask]

    epsilons_baseline = list(map(audit_fn, r_list_baseline, v_list_baseline))

    df = pd.DataFrame(
        np.array([r_list_baseline, v_list_baseline, epsilons_baseline]).T,
        columns=['num_guesses', 'num_correct', 'eps'],
    )
    df[['num_guesses', 'num_correct']] = df[['num_guesses', 'num_correct']].astype(int)

    df = df.sort_values('eps', ascending=False).reset_index(drop=True)
    return df

def audit_scores(
    df_scores,
    guess_interval=1,
    max_guesses=100,
    p=0.05,
    delta=1e-5,
):  
    num_canaries = len(df_scores)
    df = get_possible_guesses(df_scores, guess_interval=guess_interval)
    df = combine_guesses(df, max_guesses=max_guesses)
    df_results = audit(df, num_canaries, p=p, delta=delta)

    # eps = 0 if there are no valid guesses
    if len(df_results) == 0:
        df_results.loc[len(df_results)] = 0

    return df_results

# example usage
if __name__ == "__main__":
    df_scores = pd.read_csv('examples/example_scores.csv')
    df_results = audit_scores(df_scores)
    print(df_results.head())
    print('max epsilon:', df_results['eps'].max())
