import os, argparse
from collections import defaultdict
import numpy as np
import pandas as pd
from sklearn.metrics import roc_curve, auc
from tqdm import tqdm
import zlib
from accelerate import Accelerator
import torch
import torch.nn.functional as F

from transformers import AutoTokenizer, AutoModelForCausalLM
from datasets import load_dataset


def convert_huggingface_data_to_list_dic(dataset):
    all_data = []
    for i in range(len(dataset)):
        ex = dataset[i]
        all_data.append(ex)
    return all_data

# generate weights
def generate_weights(length, shape="linear", alpha=1):
    t = np.arange(1, length + 1, dtype=np.float32)
    if shape == "linear":
        weights = 1-alpha * ((t - 1) / (length - 1))
    elif shape == "exponential":
        weights = np.exp(-alpha * (t - 1))
    elif shape == "poly":
        weights = (1-((t - 1) / (length - 1))) ** alpha
    else:
        raise ValueError(f"Unknown shape {shape}")
    return weights


# arguments
parser = argparse.ArgumentParser()
parser.add_argument('--model', type=str, default="EleutherAI/pythia-6.9b")
parser.add_argument(
    '--dataset', type=str, default='WikiMIA_length128', 
    choices=[
        'WikiMIA_length32', 'WikiMIA_length64', 'WikiMIA_length128',
        'WikiMIA_length32_paraphrased', 'WikiMIA_length64_paraphrased', 'WikiMIA_length128_paraphrased', 
        'Mimir-github', "Mimir-arxiv", "Mimir-pile_cc", "Mimir-hackernews", "Mimir-wikipedia_(en)", "Mimir-dm_mathematics", "Mimir-pubmed_central"
    ]
)
parser.add_argument('--weights', type=str, default='linear', choices=['linear', 'exponential', 'poly'])
parser.add_argument('--alpha', type=float, default=1.0, help='for weights generation')
parser.add_argument('--half', action='store_true')
parser.add_argument('--int8', action='store_true')
args = parser.parse_args()

# load model
def load_model(name, ref=False):
    accelerator = Accelerator()
    int8_kwargs = {}
    half_kwargs = {}
    # ref model is small and will be loaded in full precision
    if args.int8 and not ref:
        int8_kwargs = dict(load_in_8bit=True, torch_dtype=torch.bfloat16)
    elif args.half and not ref:
        half_kwargs = dict(torch_dtype=torch.bfloat16)
    
    if 'mamba' in name:
        try:
            from transformers import MambaForCausalLM
        except ImportError:
            raise ImportError
        model = MambaForCausalLM.from_pretrained(
            name, return_dict=True, device_map='auto', **int8_kwargs, **half_kwargs
        )        
    else:
        model = AutoModelForCausalLM.from_pretrained(
            name, return_dict=True, device_map='auto', **int8_kwargs, **half_kwargs
        )
    model.eval()
    tokenizer = AutoTokenizer.from_pretrained(name)
    model = accelerator.prepare(model)
    return model, tokenizer, accelerator

# hard-coded ref model
if 'pythia' in args.model:
    args.ref_model = 'EleutherAI/pythia-70m'
elif 'llama' in args.model:
    args.ref_model = 'huggyllama/llama-7b'
elif 'gpt-neox-20b' in args.model:
    args.ref_model = 'EleutherAI/gpt-neo-125m'
elif 'mamba' in args.model:
    args.ref_model = 'state-spaces/mamba-130m-hf'
elif 'opt' in args.model:
    args.ref_model = 'facebook/opt-350m'
else:
    raise NotImplementedError

model, tokenizer, accelerator = load_model(args.model)
ref_model, ref_tokenizer, ref_accelerate = load_model(args.ref_model, ref=True)

# load dataset
if 'Mimir' in args.dataset:
    member_dataset = load_dataset('iamgroot42/mimir', args.dataset.split('-')[-1], split="ngram_13_0.8")["member"]
    nonmember_dataset = load_dataset('iamgroot42/mimir', args.dataset.split('-')[-1], split="ngram_13_0.8")["nonmember"]
    data = []
    for nm_data, m_data in zip(nonmember_dataset, member_dataset):
        data.append({"input": nm_data, "label": 0})
        data.append({"input": m_data, "label": 1})
else:
    if not 'paraphrased' in args.dataset:
        dataset = load_dataset('swj0419/WikiMIA', split=args.dataset)
    else:
        dataset = load_dataset('zjysteven/WikiMIA_paraphrased_perturbed', split=args.dataset)
    data = convert_huggingface_data_to_list_dic(dataset)

# inference - get scores for each input
def inference(text, model, tokenizer ,accelerator):
    input_ids = torch.tensor(tokenizer.encode(text)).unsqueeze(0)
    input_ids = input_ids.to(accelerator.device)
    with torch.no_grad():
        outputs = model(input_ids, labels=input_ids)
    loss, logits = outputs[:2]
    return get_all_prob(input_ids, loss, logits)

def get_all_prob(input_ids, loss, logits):
    probabilities = torch.nn.functional.log_softmax(logits, dim=-1)
    all_prob = []
    input_ids_processed = input_ids[0][1:]
    for i, token_id in enumerate(input_ids_processed):
        probability = probabilities[0, i, token_id].item()
        all_prob.append(probability)
    ll = -loss.item()  # log-likelihood
    ppl = torch.exp(loss).item()
    prob = torch.exp(-loss).item()
    return prob, ll , ppl, all_prob, logits, input_ids

scores = defaultdict(list)

for i, d in enumerate(tqdm(data, total=len(data), desc='Samples')): 
    text = d['input']
    

    prob, ll, ppl, all_prob, logits, input_ids = inference(text, model,tokenizer, accelerator)
    prob_ref, ll_ref, ppl_ref, all_prob_ref, logits_ref, input_ids_ref = inference(text, ref_model, ref_tokenizer, ref_accelerate)
    ll_lowercase = inference(text.lower(), model, tokenizer, accelerator)[1]

    # generate weights
    weights = generate_weights(len(all_prob), shape=args.weights, alpha=args.alpha)


    scores['zlib'].append(ll / len(zlib.compress(bytes(text, 'utf-8'))))
    scores['lowercase'].append(ll_lowercase / ll)
    scores['loss'].append(ll)
    scores['loss-PDR'].append(np.mean(np.array(all_prob) * weights))

    scores['ref'].append(ll - ll_ref)
    target_score = np.mean(np.array(all_prob) * weights)
    ref_score = np.mean(np.array(all_prob_ref) * weights)
    scores[f'ref-PDR'].append(target_score - ref_score)

    # mink and mink++
    input_ids = input_ids[0][1:].unsqueeze(-1)
    probs = F.softmax(logits[0, :-1], dim=-1)
    log_probs = F.log_softmax(logits[0, :-1], dim=-1)
    token_log_probs = log_probs.gather(dim=-1, index=input_ids).squeeze(-1)
    mu = (probs * log_probs).sum(-1)
    sigma = (probs * torch.square(log_probs)).sum(-1) - torch.square(mu)

    k_ratio = 0.2
    # mink
    k_length = int(len(token_log_probs) * k_ratio)
    topk = np.sort(token_log_probs.cpu())[:k_length]
    topk_indices = np.argsort(token_log_probs.cpu())[:k_length]
    scores['mink'].append(np.mean(token_log_probs.cpu().numpy()[topk_indices]).item())
    scores[f'mink-PDR'].append(np.mean(token_log_probs.cpu().numpy()[topk_indices] * weights[topk_indices]).item()) 
    

    # mink++
    mink_plus = (token_log_probs - mu) / sigma.sqrt()
    k_length = int(len(mink_plus) * k_ratio)
    topk = np.sort(mink_plus.cpu())[:k_length]
    topk_indices = np.argsort(mink_plus.cpu())[:k_length]
    scores['mink++'].append(np.mean(mink_plus.cpu().numpy()[topk_indices]).item())
    scores[f'mink++-PDR'].append(np.mean(mink_plus.cpu().numpy()[topk_indices] * weights[topk_indices]).item())
    




# compute metrics
# tpr and fpr thresholds are hard-coded
def get_metrics(scores, labels):
    fpr_list, tpr_list, thresholds = roc_curve(labels, scores)
    auroc = auc(fpr_list, tpr_list)
    fpr95 = fpr_list[np.where(tpr_list >= 0.95)[0][0]]
    tpr05 = tpr_list[np.where(fpr_list <= 0.05)[0][-1]]
    return auroc, fpr95, tpr05

labels = [d['label'] for d in data] # 1: training, 0: non-training
results = defaultdict(list)
for method, scores in scores.items():
    try:
        auroc, fpr95, tpr05 = get_metrics(scores, labels)
    except:
        print(method)
        continue
    
    results['method'].append(method)
    results['auroc'].append(f"{auroc:.5}")
    results['fpr95'].append(f"{fpr95:.5}")
    results['tpr05'].append(f"{tpr05:.5}")

df = pd.DataFrame(results)
print(df)

save_root = f"results/{args.dataset}"
if not os.path.exists(save_root):
    os.makedirs(save_root)

model_id = args.model.split('/')[-1]
df.to_csv(os.path.join(save_root, f"{model_id}.csv"), index=False)