from argparse import ArgumentParser, Namespace
import sys
import os
import json
# Add the submodule path to the system path
sys.path.append(os.path.join(os.getcwd(), 'tofu'))

from typing import List, Dict
import torch
from tqdm import tqdm
import zlib
import numpy as np
from transformers import (
    AutoTokenizer,
    AutoModelForCausalLM
)
from peft import LoraConfig, get_peft_model

from sklearn.metrics import auc as get_auc, roc_curve as get_roc_curve
import datasets
from scr.train import *


def compute_ppl(text, model, tokenizer, device='cuda'):
    model.eval()
    input_ids = text[0].to(device)
    labels = text[1].to(device)
    with torch.no_grad():
        outputs = model(input_ids, labels=labels)
    loss, logits = outputs[:2]

    probabilities = torch.nn.functional.log_softmax(logits, dim=-1)
    all_prob = []
    input_ids_processed = input_ids[0][1:]
    for i, token_id in enumerate(input_ids_processed):
        probability = probabilities[0, i, token_id].item()
        all_prob.append(probability)

    ppl = torch.exp(loss).item()
    return ppl, all_prob, loss.item()


def inference(text, model, tokenizer) -> Dict:
    pred = {}

    _, all_prob, p1_likelihood = compute_ppl(text, model, tokenizer, device=model.device)

    pred["PPL"] = float(p1_likelihood)
    for ratio in [0.4]:
        k_length = int(len(all_prob)*ratio)
        topk_prob = np.sort(all_prob)[:k_length]
        pred[f"Min-{int(ratio*100)}%"] = float(-np.mean(topk_prob).item())

    return pred


def eval_data(data, model, tokenizer, data_collator=None):
    out = {}
    for sample in tqdm(data):
        if data_collator is not None:
            qa = data_collator([sample])
        else:
            qa = sample
        result = {'text': qa} | inference(qa, model, tokenizer)
        if not out:  # Initialize output dictionary based on the first result
            out = {key: [] for key in result.keys()}
        for key, value in result.items():
            out[key].append(value)
    return out


def sweep(ppl, y):
    fpr, tpr, _ = get_roc_curve(y, -ppl)
    acc = np.max(1-(fpr+(1-tpr))/2)

    return fpr, tpr, get_auc(fpr, tpr), acc


def eval_mia(
    forget_data,
    retain_data,
    holdout_data,
    model, tokenizer,
    data_collator=None
):
    log = {}
    print("Evaluating on the forget set...")
    log['forget'] = eval_data(forget_data, model, tokenizer, data_collator)
    print("Evaluating on the retain set...")
    log['retain'] = eval_data(retain_data, model, tokenizer, data_collator)
    print("Evaluating on the holdout set...")
    log['holdout'] = eval_data(holdout_data, model, tokenizer, data_collator)

    auc = {}
    ppl_types = list(log['forget'].keys())
    ppl_types.remove('text')
    
    for split0 in ['forget']:
        for split1 in ['holdout']:
            log0, log1 = log[split0], log[split1]
            for ppl_type in ppl_types:
                ppl_nonmember = log0[ppl_type]
                ppl_member = log1[ppl_type]
                
                min_size = min(len(ppl_nonmember), len(ppl_member))
                # Shuffle and select subsets of equal size
                if len(ppl_nonmember) > min_size:
                    ppl_nonmember = np.random.choice(ppl_nonmember, min_size, replace=False).tolist()
                else:
                    ppl_member = np.random.choice(ppl_member, min_size, replace=False).tolist()
                
                ppl = np.array(ppl_nonmember + ppl_member)
                y = np.array([0] * len(ppl_nonmember) + [1] * len(ppl_member))
                
                _, _, auc_score, _ = sweep(ppl, y)
                auc[f"{split0}_{split1}_{ppl_type}"] = np.mean(auc_score)
    return auc, log

if __name__ == "__main__":
    UNLEARN_METHODS = [
        "original", 
        "retraining", 
        "finetune", 
        "ga", 
        "gdiff",
        "KL",
        "dpo",
        "scr_newton", 
        "scrub",
    ]
    parser = ArgumentParser() 
    parser.add_argument("--seed", type=int, default=42,
                        help="Random seed")
    parser.add_argument("--unlearned_model_path", type=str, default="main_results/model.pt",
                        help="Path of the unlearned model")
    parser.add_argument("--pretrained_model_name_or_path", type=str, default="meta-llama/Llama-2-7b-chat-hf",
                        help="Name or path to the model locally or on HuggingFace Hub")
    parser.add_argument("--unlearn_method", type=str, default="retraining",
                        choices=UNLEARN_METHODS, 
                        help="Unlearning method")            
    parser.add_argument("--remove_pct", type=str, default="1",
                        help="Removal percentage")
    args = parser.parse_args()

    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    set_seed(args.seed)
    
    privleak_auc_key = 'forget_holdout_Min-40%'
    
    out_dir = args.unlearned_model_path.rsplit('/', 1)[0] + '/eval'
    mia_path = f"{out_dir}/mia_{args.unlearn_method}.json"
    retrain_path = f"{out_dir}/mia_retraining.json"

    with open(mia_path, "r") as f:
        auc = json.load(f)
    with open(retrain_path, "r") as f:
        AUC_RETRAIN = json.load(f)
        
    privleak = (auc[privleak_auc_key] - AUC_RETRAIN[privleak_auc_key]) / AUC_RETRAIN[privleak_auc_key] * 100
    privleak = {'privleak': privleak}
    
    with open(f"{out_dir}/privleak_{args.unlearn_method}.json", "a") as f:
        json.dump(privleak, f)
