from argparse import ArgumentParser, Namespace
import sys
import os
import json
# Add the submodule path to the system path
sys.path.append(os.path.join(os.getcwd(), 'tofu'))

from typing import List, Dict
import torch
from tqdm import tqdm
import zlib
import numpy as np
from transformers import (
    AutoTokenizer,
    AutoModelForCausalLM
)
from peft import LoraConfig, get_peft_model

from sklearn.metrics import auc as get_auc, roc_curve as get_roc_curve
import datasets
from train import *


def compute_ppl(text, model, tokenizer, device='cuda'):
    model.eval()
    input_ids = text[0].to(device)
    labels = text[1].to(device)
    with torch.no_grad():
        outputs = model(input_ids, labels=labels)
    loss, logits = outputs[:2]

    probabilities = torch.nn.functional.log_softmax(logits, dim=-1)
    all_prob = []
    input_ids_processed = input_ids[0][1:]
    for i, token_id in enumerate(input_ids_processed):
        probability = probabilities[0, i, token_id].item()
        all_prob.append(probability)

    ppl = torch.exp(loss).item()
    return ppl, all_prob, loss.item()


def inference(text, model, tokenizer) -> Dict:
    pred = {}

    _, all_prob, p1_likelihood = compute_ppl(text, model, tokenizer, device=model.device)
    # _, _, p_lower_likelihood = compute_ppl(text, model, tokenizer, device=model.device)
    # decoded_text = tokenizer.decode(text[0][0], skip_special_tokens=True)
    # zlib_entropy = len(zlib.compress(bytes(decoded_text, 'utf-8')))

    pred["PPL"] = float(p1_likelihood)
    # pred["PPL/lower"] = float(p1_likelihood / p_lower_likelihood)
    # pred["PPL/zlib"] = float(p1_likelihood / zlib_entropy)

    # print("PPL", pred["PPL"])
    # print("PPL/lower", pred["PPL/lower"])
    # print("PPL/zlib", pred["PPL/zlib"])

    # min-k prob
    # for ratio in [0.05, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6]:
    for ratio in [0.4]:
        k_length = int(len(all_prob)*ratio)
        topk_prob = np.sort(all_prob)[:k_length]
        pred[f"Min-{int(ratio*100)}%"] = float(-np.mean(topk_prob).item())

    return pred

def eval_data(data, model, tokenizer, data_collator=None):
    out = {}
    for sample in tqdm(data):
        if data_collator is not None:
            qa = data_collator([sample])
        else:
            qa = sample
        result = {'text': qa} | inference(qa, model, tokenizer)
        if not out:  # Initialize output dictionary based on the first result
            out = {key: [] for key in result.keys()}
        for key, value in result.items():
            out[key].append(value)
    return out


def sweep(ppl, y):
    fpr, tpr, _ = get_roc_curve(y, -ppl)
    acc = np.max(1-(fpr+(1-tpr))/2)

    return fpr, tpr, get_auc(fpr, tpr), acc


def eval_mia(
    forget_data,
    retain_data,
    holdout_data,
    model, tokenizer,
    data_collator=None
):
    log = {}
    print("Evaluating on the forget set...")
    log['forget'] = eval_data(forget_data, model, tokenizer, data_collator)
    print("Evaluating on the retain set...")
    log['retain'] = eval_data(retain_data, model, tokenizer, data_collator)
    print("Evaluating on the holdout set...")
    log['holdout'] = eval_data(holdout_data, model, tokenizer, data_collator)

    auc = {}
    ppl_types = list(log['forget'].keys())
    ppl_types.remove('text')
    
    for split0 in ['holdout']:
        for split1 in ['forget']:
            log0, log1 = log[split0], log[split1]
            for ppl_type in ppl_types:
                ppl_nonmember = log0[ppl_type]
                ppl_member = log1[ppl_type]
                
                min_size = min(len(ppl_nonmember), len(ppl_member))
                # Shuffle and select subsets of equal size
                if len(ppl_nonmember) > min_size:
                    ppl_nonmember = np.random.choice(ppl_nonmember, min_size, replace=False).tolist()
                else:
                    ppl_member = np.random.choice(ppl_member, min_size, replace=False).tolist()
                
                ppl = np.array(ppl_nonmember + ppl_member)
                y = np.array([0] * len(ppl_nonmember) + [1] * len(ppl_member))
                
                _, _, auc_score, _ = sweep(ppl, y)
                auc[f"{split0}_{split1}_{ppl_type}"] = np.mean(auc_score)
    return auc, log
