import os
import sys
sys.path.append("NL-Augmenter")
from mia_tests_library import aggregate_metrics, generate_perturbations
import numpy as np
from sklearn.metrics import roc_auc_score
from argparse import ArgumentParser
from p3 import ngram, embedding, exact_match, indel_similarity
from misc import utils 


if __name__ == "__main__":
    parser = ArgumentParser()
    parser.add_argument("--model_name", type=str, default="EleutherAI/pythia-6.9b")
    parser.add_argument("--dataset_name", type=str, required=True, choices=["ai4privacy", "ag_news", "mimir"])
    parser.add_argument("--definition", type=str, required=True, choices=["ngram", "embedding", "exact_match", "indel_similarity"])
    parser.add_argument("--definition_hyperparameter", type=float)
    parser.add_argument("--output_result_report_path", type=str)
    args = parser.parse_args()

    result_report = {}
    for arg in vars(args):
        result_report[arg] = vars(args)[arg]
    result_report["tests"] = {}

    model, tokenizer = utils.load_model(args.model_name)
    if args.definition == "ngram":
        definition = ngram.NgramP3(k=args.definition_hyperparameter, dataset_name=args.dataset_name)
    elif args.definition == "embedding":
        definition = embedding.EmbeddingP3(cs=args.definition_hyperparameter, dataset_name=args.dataset_name)
    elif args.definition == "exact_match":
        definition = exact_match.ExactMatchP3(dataset_name=args.dataset_name)
    elif args.definition == "indel_similarity":
        definition = indel_similarity.IndelSimilarityP3(sim=args.definition_hyperparameter, dataset_name=args.dataset_name)
    else:
        raise ValueError("Definition not found. Please check the available definitions.")
    members, non_members = definition.generate_members_and_nonmembers_labels()
    perturbed_members = generate_perturbations(members['text'])
    for key in perturbed_members.keys():
        members = members.add_column(key, perturbed_members[key])
    perturbed_non_members = generate_perturbations(non_members['text'])
    for key in perturbed_non_members.keys():
        non_members = non_members.add_column(key, perturbed_non_members[key])

    tests = ["k_min_probs", "ppl", "zlib_ratio", "perturbation", "reference_model"]
    member_scores = aggregate_metrics(model, tokenizer, members, tests, None, batch_size=32)
    non_member_scores = aggregate_metrics(model, tokenizer, non_members, tests, None, batch_size=32)

    for test in member_scores.keys():
        member_scores_test = -1 * np.array(member_scores[test])
        non_member_scores_test = -1 * np.array(non_member_scores[test])
        auc_score = roc_auc_score(np.concatenate((np.ones(len(member_scores_test)), np.zeros(len(non_member_scores_test)))), np.concatenate((member_scores_test, non_member_scores_test)))
        print(f"AUC score for {test} is {auc_score:.3f}")
        result_report["tests"][test] = auc_score

    with open(args.output_result_report_path, "w") as f:
        f.write(str(result_report))

    
        