import os
import numpy as np
import pandas as pd

from audit_one_run.src.meta import get_target_noise, gaussianDP_blow_up_inverse
from audit_one_run.src.meta import get_gaussian_emp_eps_bs as audit_one_run_fdp

def create_canary_sets(df_scores):
    assert 'include' in df_scores.columns
    assert 'score' in df_scores.columns

    df_scores = df_scores

    mask = df_scores['include'].astype(bool)
    df_include = df_scores.loc[mask]
    df_not_include = df_scores.loc[~mask]
    m = np.minimum(len(df_include), len(df_not_include))
    
    # randomly sort and only keep m canaries
    df_include = df_include.sample(
        frac=1, replace=False, random_state=0
    ).iloc[:m].reset_index()
    df_not_include = df_not_include.sample(
        frac=1, replace=False, random_state=0
    ).iloc[:m].reset_index()

    for df in [df_include, df_not_include]:
        for col in ['include']:
            del df[col]
    
    df_canary_sets = pd.merge(
        df_include, df_not_include,
        left_index=True, right_index=True,
        suffixes=('_in', '_not_in'),
    )

    df_canary_sets['abs_score_diff'] = (df_canary_sets['score_in'] - df_canary_sets['score_not_in']).abs()
    df_canary_sets['max_score'] = np.maximum(df_canary_sets['score_in'], df_canary_sets['score_not_in'])
    df_canary_sets['min_score'] = np.maximum(df_canary_sets['score_in'], df_canary_sets['score_not_in'])
    df_canary_sets['in_score_is_higher'] = df_canary_sets['score_in'] > df_canary_sets['score_not_in']

    return df_canary_sets

def get_possible_guesses(df_canary_sets, guess_interval=1, max_guesses=100):
    df_canaries_compare_fdp = pd.DataFrame(
        (np.arange(len(df_canary_sets)) + 1),
        columns=['num_guesses'],
    )

    # baseline
    x = df_canary_sets.sort_values('abs_score_diff', ascending=False)
    # x = df_canary_sets.sort_values('max_score', ascending=False)
    df_canaries_compare_fdp['num_correct'] = x['in_score_is_higher'].cumsum().values
    df_canaries_compare_fdp['frac_correct'] = df_canaries_compare_fdp['num_correct'] / df_canaries_compare_fdp['num_guesses']

    df = df_canaries_compare_fdp.iloc[guess_interval - 1::guess_interval, :]

    # we don't bother auditing guesses that are worse than random guessing
    df = df[df['frac_correct'] > 0.5]

    # if multiple sets of guesses get the same proportion correct, we keep the one with the highest number of guesses
    df = df.sort_values('num_guesses', ascending=False).groupby('frac_correct').head(1)
    
    # only keep the top `max_guesses`, sorted by the proportion of correct guesses
    df = df.sort_values('frac_correct', ascending=False).reset_index(drop=True)
    df = df.iloc[:max_guesses]

    return df

def load_noises(
    theoretical_eps=8.0,
    num_candidate_noises=1000,
    noise_cache_dir=None,
):
    if noise_cache_dir is None:
        noise_cache_dir = '.'
    
    noise_path = os.path.join(
        noise_cache_dir,
        f'candidate_noises_eps{theoretical_eps}_{num_candidate_noises}.npy',
    )
    noise_path = os.path.join(noise_cache_dir, f'candidate_noises_eps{theoretical_eps}_{num_candidate_noises}.npy',)
    if os.path.exists(noise_path):
        candidate_noises = np.load(noise_path)
    else:
        epsilons = np.linspace(0, theoretical_eps, num_candidate_noises)
        candidate_noises = np.array([get_target_noise(eps) for eps in epsilons])
        candidate_noises = candidate_noises[::-1]
        np.save(noise_path, candidate_noises)
    return candidate_noises

def audit(
    df,
    m,
    p=0.05,
    delta=1e-5,
    theoretical_eps=8.0,
    num_candidate_noises=1000,
    noise_cache_dir=None,
):
    candidate_noises = load_noises(theoretical_eps, num_candidate_noises, noise_cache_dir)    
    inverse_blow_up_functions = [gaussianDP_blow_up_inverse(noise) for noise in candidate_noises]
    audit_fn_fdp = lambda x, y: audit_one_run_fdp(
        candidate_noises, inverse_blow_up_functions, m, x, y, threshold=p, delta=delta
    )

    r_list = df['num_guesses'].values
    v_list = df['num_correct'].values
    epsilons = list(map(audit_fn_fdp, v_list, r_list))

    df = pd.DataFrame(
        np.array([r_list, v_list, epsilons]).T,
        columns=['num_guesses', 'num_correct', 'eps'],
    )
    df[['num_guesses', 'num_correct']] = df[['num_guesses', 'num_correct']].astype(int)

    # -1 corresponds to 0
    df.loc[df['eps'] == -1, 'eps'] = 0
    
    df = df.sort_values('eps', ascending=False).reset_index(drop=True)
    return df
    
def audit_scores(
    df_scores,
    guess_interval=1,
    max_guesses=100,
    p=0.05,
    delta=1e-5,
    theoretical_eps=8.0,
    num_candidate_noises=1000,
    noise_cache_dir=None,
):
    df_canary_sets = create_canary_sets(df_scores)
    num_canary_sets = len(df_canary_sets)
    df = get_possible_guesses(
        df_canary_sets,
        guess_interval=guess_interval,
        max_guesses=max_guesses
    )
    df_results = audit(
        df,
        num_canary_sets,
        p=p,
        delta=delta,
        theoretical_eps=theoretical_eps,
        num_candidate_noises=num_candidate_noises,
        noise_cache_dir=noise_cache_dir,
    )

    # eps = 0 if there are no valid guesses
    if len(df_results) == 0:
        df_results.loc[len(df_results)] = 0

    return df_results

# example usage
if __name__ == "__main__":
    NOISE_CACHE_DIR = './src'

    df_scores = pd.read_csv('examples/example_scores.csv')
    df_results = audit_scores(
        df_scores,
        noise_cache_dir=NOISE_CACHE_DIR,
    )
    print(df_results.head())
    print('max epsilon:', df_results['eps'].max())

