from transformers import DefaultDataCollator
from torch.utils.data import DataLoader
from tqdm import tqdm
from sklearn.metrics import auc, roc_curve
import torch


def create_mia_dataset(member_features, non_member_features):
    num_member_samples = len(member_features)
    num_non_member_samples = len(non_member_features)
    features = member_features.copy()
    features.extend(non_member_features)
    labels = [1] * num_member_samples + [0] * num_non_member_samples
    return features, labels

def mink_plus_plus_attack(model, target_dataset, batch_size: int = 64, k_pct: float = 0.1):
    collate_fn = DefaultDataCollator(return_tensors="pt")
    loader = DataLoader(target_dataset, batch_size=batch_size, collate_fn=collate_fn, shuffle=False)

    scores = []
    model.eval()
    device = model.device
    for batch in tqdm(loader):
        with torch.no_grad():
            outputs = model(batch['input_ids'].to(device), batch['attention_mask'].to(device))
        logits = outputs['logits'].cpu()
        raw_labels = batch['labels']

        logits = logits[:, :-1]     # align tokens in logits and labels
        raw_labels = raw_labels[:, 1:]
        
        probs = torch.softmax(logits, dim=-1)
        log_probs = torch.log_softmax(logits, dim=-1)
        mu = (probs * log_probs).sum(-1)
        var = (probs * torch.square(log_probs)).sum(-1) - torch.square(mu)    # Var[X] = E[X^2] - E[X]^2
        k = int(k_pct * batch['input_ids'].shape[1])
        
        masked_token_ids = raw_labels.eq(-100)
        labels = raw_labels.clone()
        labels[masked_token_ids] = 0    # replace -100 with token_id = 0 in labels
        target_token_log_probs = torch.gather(log_probs, dim=-1, index=labels.unsqueeze(-1)).squeeze()
        target_token_scores = (target_token_log_probs - mu) / var.sqrt()

        target_token_log_probs[masked_token_ids] = float('inf')
        mink_token_ids = torch.topk(target_token_log_probs, k=k, dim=-1, largest=False).indices
        target_token_scores[masked_token_ids] = 0
        seq_scores = torch.gather(target_token_scores, dim=-1, index=mink_token_ids).sum(-1)
        num_valid_tokens = (torch.gather(raw_labels, dim=-1, index=mink_token_ids) != -100).sum(-1)
        seq_scores = seq_scores / num_valid_tokens

        scores.append(seq_scores)
        
    scores = torch.concatenate(scores)
    
    return scores

def sweep(y_score, y_true):
    """
    Compute a ROC curve and then return the FPR, TPR, and AUC.
    """
    fpr, tpr, threshs = roc_curve(y_true=y_true, y_score=y_score, pos_label=1)
    auc_score = auc(fpr, tpr)
    return (fpr, tpr, auc_score, threshs)