import os
import argparse
import numpy as np
import pandas as pd

from src.utils import read_pickle
from audit_one_run.google import audit_scores

import pdb

def get_args():
    parser = argparse.ArgumentParser()
    parser.add_argument('--name', type=str)
    parser.add_argument('--out_dir', type=str, default='out/black_box')
    parser.add_argument('--score', type=str, default='neg_loss')
    parser.add_argument('--guess_interval', type=int, default=10)
    parser.add_argument('--config_prefix', type=str, default='')

    args = parser.parse_args()
    return args

if __name__ == '__main__':
    args = get_args()
    print(args)
    
    save_dir = f'{args.out_dir}/{args.config_prefix}{args.name}'

    out_path = os.path.join(save_dir, 'out.pkl')
    out = read_pickle(out_path)

    columns = ['id', 'include', 'holdout', 'canary']
    scores = ['logit_target', 'sum_logit_nontarget', 'neg_loss', 'logit_diff', 'logit_next_diff']
    df = pd.DataFrame({k: out[k] for k in columns + scores})
    df.set_index('id', inplace=True)
    
    df = df[['include', 'holdout', 'canary', args.score]]
    df = df.rename(columns={args.score: 'score'})

    # check no duplicates
    ids = df.index.values
    assert ids.shape[0] == np.unique(ids).shape[0], \
        'contain duplicate ids'
    
    # just canaries
    df_canaries = df[df['canary'].astype(bool)]

    # audit
    df_scores = df_canaries
    df_results = audit_scores(
        df_scores,
        guess_interval=args.guess_interval,
    )

    print(args.name)
    print(f"Google ({args.score}):")
    print(df_results.head())