import pandas as pd
import numpy as np
import glob
import re
from sklearn.metrics import roc_curve, auc
from tqdm import tqdm
import traceback
from joblib import Parallel, delayed
import argparse
import torch

from src.dataset_inference.preprocessing import create_df, create_grouped_split, create_split
from src.dataset_inference.attacks import get_mia
from src.dataset_inference.metrics import get_p_values, rank_candidates

parser = argparse.ArgumentParser()
parser.add_argument('--model_name', type=str, default='EleutherAI/pythia-12b-deduped')
parser.add_argument('--path', type=str)
args = parser.parse_args()

model_name = args.model_name 
filter_features = lambda x: True

def valid_filename(s):
    if model_name.split('/')[-1] not in s:
        return False
    return True
    
files = glob.glob(f'{args.path}/*.pkl')
files = list(filter(valid_filename, files))

def get_dataset(filter_files, files):
    files = list(filter(filter_files, files))
    print('files:', len(files))
    members_scores_string, nonmembers_scores_string = create_df(files)
    dataset = create_grouped_split([model_name], members_scores_string, nonmembers_scores_string, filter_features, number_of_nonmembers=None)
    x, y, set_id = create_split(dataset)
    features_names = [f'{model}_{attack}' for model in [model_name] for attack in get_mia(next(iter(members_scores_string.values()))).keys() if filter_features(attack)]
    return x, y, set_id, features_names

if 'pythia' in model_name:
    groups = [
        (lambda x: 'pile-val' in x or 'pile-test' in x, 'pile-test'),
        (lambda x: 'pile-train' in x, 'pile-train'),

        (lambda x: 'pile-train' in x and 'Github' in x, 'pile-train-Github'),
        (lambda x: 'pile-train' in x and 'StackExchange' in x, 'pile-train-StackExchange'),
        (lambda x: ('pile-val' in x or 'pile-test' in x) and 'Github' in x, 'pile-test-Github'),
        (lambda x: ('pile-val' in x or 'pile-test' in x) and 'StackExchange' in x, 'pile-test-StackExchange'),
 
        (lambda x: 'pile-train' in x and 'Wikipediaen' in x, 'pile-train-Wikipediaen'),
        (lambda x: 'pile-train' in x and 'UbuntuIRC' in x, 'pile-train-UbuntuIRC'),
        (lambda x: 'pile-train' in x and 'PubMedCentral' in x, 'pile-train-PubMedCentral'),
        (lambda x: 'pile-train' in x and 'HackerNews' in x, 'pile-train-HackerNews'),
        (lambda x: 'pile-train' in x and 'Pile-CC' in x, 'pile-train-PileCC'),
        (lambda x: 'pile-train' in x and 'ArXiv' in x, 'pile-train-ArXiv'),
     ]
else:
    groups = [
        (lambda x: 'dolma_c4' in x, 'dolma_c4'),
        (lambda x: 'dolma_pes2o' in x, 'dolma_pes2o'),
        (lambda x: 'dolma_megawika' in x, 'dolma_megawika'),
        (lambda x: 'dolma_arxiv' in x, 'dolma_arxiv'),
        (lambda x: 'dolma_falcon' in x, 'dolma_falcon'),
        (lambda x: 'dolma_algebraic-stack-train' in x, 'dolma_algebraic-stack-train'),
        (lambda x: 'dolma_open-web-math-train' in x, 'dolma_open-web-math-train'),
        (lambda x: 'proof-pile-2_test_algebraic-stack' in x or 'proof-pile-2_validation_algebraic-stack' in x or 'proof-pile-2_test_algebraic-stack' in x or 'proof-pile-2_test_algebraic-stack' in x, 'proof-test'),
    ]

def run_eval(dataset_f, dataset_name):
    df = {
        'attack': [],
        'AUC': [],
        'TPR@ 1% FPR': [],
        'dataset': [],
        'ranks': [],
        'average rank': [],
        'median rank': [],
    }
    print(dataset_name)
    try:
        x_test, y_test, set_id_test, features_names = get_dataset(dataset_f, files)
    except Exception as e:
        print(f"Error with {dataset_name}. {e}")
        traceback.print_exc()
        return None
    unique_groups = np.unique(set_id_test)
    print(x_test.shape, y_test.shape, len(unique_groups), len(features_names))

    y_0 = []
    for group_id in np.unique(set_id_test):
        cnt_1 = int(y_test[(set_id_test==group_id)&(y_test==1)].sum())
        y_0.append(x_test[(y_test==0)&(set_id_test==group_id)][:cnt_1])
    y_0 = torch.cat(y_0, dim=0).numpy()
    y_1 = x_test[y_test==1].numpy()

    for i, feat in enumerate(tqdm(features_names, disable=True)):
        tpr, fpr, _ = roc_curve(y_test, x_test[:, i])
        score = auc(fpr, tpr)
        rankings_list = []
        for g in unique_groups:
            test_mask = np.where(set_id_test == g)[0]
            y_group = y_test[test_mask].numpy()
            y_group_prob = -x_test[test_mask][:,i].numpy()
            ranks, max_rank = rank_candidates(y_group, y_group_prob)
            rankings_list += ranks    

        p_values = get_p_values([1000], heldout_train=y_1[:,i], heldout_val=y_0[:,i], ranks=rankings_list, max_rank=max_rank, seed=42, repeatitions=1)
        df['attack'].append(feat)
        df['AUC'].append(max(score, 1-score))
        df['TPR@ 1% FPR'].append(tpr[np.argmin(np.abs(fpr - 0.01))])
        df['dataset'].append(dataset_name)
        df['ranks'].append(rankings_list)
        df['average rank'].append(np.mean(rankings_list))
        df['median rank'].append(np.median(rankings_list))
        
        for k, v in p_values.items():
            df[f'pvalue-{k}'] = df.get(f'pvalue-{k}', []) + v
        
    df = pd.DataFrame(df)
    return df
df = pd.concat(
    [a for a in Parallel(n_jobs=-1)(delayed(run_eval)(dataset_f, dataset_name) for dataset_f, dataset_name in groups) if a is not None]
)

print(df['dataset'].unique())

def get_attack_type(x):
    x = x.split('-')[-1].replace('deduped_', '')
    return re.sub(r'[\d.]', '', x).replace('__', '_')
df['attack_type'] = df['attack'].apply(get_attack_type)
df['model'] = df['attack'].apply(lambda x: x.split('_')[0])
df['attack_name'] = df['attack'].apply(lambda x: '_'.join(x.split('_')[1:]))


attacks = [
    'prob', 
    'mink_0.1', 
    'mink++_1.0', 
    'real_recall_prob', 
    'real_recall_neg_hinge', 
    'neg_hinge',
]
metrics = ['AUC', 'TPR@ 1% FPR']

if 'pythia' in model_name:
    groups = {
        'Pile': {'Train': 'pile-train', 'Test': 'pile-test'},
        'Github': {'Train': 'pile-train-Github', 'Test': 'pile-test-Github'},
        'StackExchange': {'Train': 'pile-train-StackExchange', 'Test': 'pile-test-StackExchange'},
        'Train': {'UbuntuIRC': 'pile-train-UbuntuIRC', 'Wikipediaen': 'pile-train-Wikipediaen', 'PubMedCentral': 'pile-train-PubMedCentral', 'HackerNews': 'pile-train-HackerNews', 'Pile-CC': 'pile-train-PileCC', 'ArXiv': 'pile-train-ArXiv'},
    }
    has_test = True
else:
    groups = {
        'Dolma': {
            'C4': 'dolma_c4', 
            'pes2o': 'dolma_pes2o', 
            'megawika': 'dolma_megawika', 
            'arxiv': 'dolma_arxiv', 
            'falcon': 'dolma_falcon', 
            'algebraic stack': 'dolma_algebraic-stack-train', 
            'open web math': 'dolma_open-web-math-train', 
            'Proof Test': 'proof-test'
        }
    }
    has_test = False

num_cols = sum(len(v) for v in groups.values())
map_names = {
    'real_recall_prob': 'ReCALL',
    'mink_0.1': 'Min-K\\%',
    'prob': 'Loss',
    'mink++_1.0': 'Min-K\\%++',
    'real_recall_neg_hinge': 'ReCALL (Hinge)',
    'neg_hinge': 'Hinge',
}
for metric in metrics:
    normalize_name = lambda x: map_names.get(x, x.replace('_', ' ').replace('-', ' ').replace('Pile', '').title().strip())
    for model, model_df in df.groupby('model'):
        print(metric, model)
        s = '''\\begin{tabular}{l''' + 'r' * (num_cols) + '''|''' + ('rr' if has_test else 'r') + '''}\n\\toprule\n'''
        
        for k, v in groups.items():
            s += f' & \\multicolumn{{{len(v)}}}{{c}}{{\\textbf{{{k}}}}}'
        s += ' & \\multicolumn{'+('2' if has_test else '1')+'}{c}{\\textbf{Average}}  \\\\\n'
        s += '\\textbf{MIA}'
        for k, v in groups.items():
            for kk, vv in v.items():
                s += f' & {kk}'
        if has_test:
            s += ' & Train & Test \\\\\n\\midrule\n'
        else:
            s += ' & Train \\\\\n\\midrule\n'
        best_attack = model_df[(model_df['attack_name'].apply(lambda x: x in attacks))].sort_values(metric, ascending=False).groupby('dataset').first().to_dict()['attack_name']
        for attack in attacks:
            s += f"{normalize_name(attack)}"
            avg_train = []
            avg_test = []
            for k, v in groups.items():
                for dataset, dataset_key in v.items():
                    curr_df = model_df[(model_df['dataset'] == dataset_key) & (model_df['attack_name'] == attack)]
                    if len(curr_df) == 0:
                        s += ' & -'
                    else:
                        if best_attack[dataset_key] == attack:
                            s += f" & \\textbf{{{curr_df[metric].iloc[0]*100:.1f}}}"
                        else:
                            s += f" & {curr_df[metric].iloc[0]*100:.1f}"
                        if 'test' in dataset_key:
                            avg_test.append(curr_df[metric].iloc[0])
                        else:
                            avg_train.append(curr_df[metric].iloc[0])

            if has_test:
                s += f" & {np.mean(avg_train)*100:.1f} & {np.mean(avg_test)*100:.1f} \\\\\n"
            else:
                s += f" & {np.mean(avg_train)*100:.1f} \\\\\n"
            
        s += '''\\bottomrule\n\\end{tabular}\n'''
        print(s)
        print()
        print()

