import os
import numpy as np
import torch
from tqdm import tqdm
from transformers import AutoModelForCausalLM, AutoTokenizer
from sklearn.metrics import roc_auc_score
import argparse
import random
from datasets import load_dataset
from MIA_algorithm.eval import fig_fpr_tpr

parser = argparse.ArgumentParser(description="PCA similarity analysis after unlearning/relearning")
parser.add_argument('--lr', type=float, default=3e-5, help="Learning rate, e.g., 3e-5")
parser.add_argument('--model_name', type=str, default="Qwen/Qwen2.5-7B", help="Model name or path")
parser.add_argument('--unlearning_algorithm', type=str, default="GA", help="Unlearning algorithm name")
parser.add_argument('--type', choices=['Text', 'Math'], default="Text", help="Data type: Text or Math")
parser.add_argument('--phase', type=str, default="N1_Request100", help="Unlearning phase identifier")
parser.add_argument('--device', type=str, default="cuda:0", help='Device to run the model on')

def calculatePerplexity(sentence, model, tokenizer, device):
    """
    Compute sentence perplexity and token log‐probs.
    """
    max_length = 1048
    input_ids = tokenizer.encode(sentence, return_tensors="pt", truncation=True, max_length=max_length, padding='max_length').to(device)
    with torch.no_grad():
        outputs = model(input_ids, labels=input_ids)
    loss = outputs.loss
    logits = outputs.logits

    log_probs = torch.nn.functional.log_softmax(logits, dim=-1)
    all_token_probs = [
        log_probs[0, i, token_id].item()
        for i, token_id in enumerate(input_ids[0][1:])
    ]
    return torch.exp(loss).item(), all_token_probs, loss.item()

def compute_baseline_score(text, model, tokenizer, device, ratio=0.2):
    """
    Compute baseline membership score: negative mean of lowest ratio% token log‐probs.
    """
    _, all_prob, _ = calculatePerplexity(text, model, tokenizer, device)
    k_length = int(len(all_prob) * ratio)
    topk_prob = np.sort(all_prob)[:k_length]
    min = -np.mean(topk_prob).item()
    return min

# Example
if __name__ == "__main__":
    args = parser.parse_args()
    lr = args.lr
    model_name = args.model_name
    unlearning_algorithm = args.unlearning_algorithm
    data_type = args.type
    phase = args.phase
    device = args.device

    # === Load data ===
    # Retain set
    retain_dataset_general = load_dataset("llmunlearn/unlearn_dataset", name="general", split="retain")
    retain_list_general = [data['text'] for data in retain_dataset_general]
    retain_text = retain_list_general[:100]

    # Forget set
    random.seed(42)
    if data_type == "Text":
        # —— Text scenario: arXiv + GitHub —— #
        forget_arxiv = load_dataset("llmunlearn/unlearn_dataset", name="arxiv", split="forget")
        forget_github = load_dataset("llmunlearn/unlearn_dataset", name="github", split="forget")
        # Extract text field and merge
        forget_list = [d["text"] for d in forget_arxiv] + [d["text"] for d in forget_github]
        random.shuffle(forget_list)
        # Use first 100 samples
        texts = forget_list[:100]

    elif data_type == "Math":
        # —— Math scenario: NuminaMath —— #
        dataset = load_dataset("AI-MO/NuminaMath-1.5")
        problems = dataset["train"]["problem"]
        solutions = dataset["train"]["solution"]
        answers = dataset["train"]["answer"]

        # Method A: Combine each problem-solution-answer into a single string
        forget_list = [
            f"{p} {s} {a}"
            for p, s, a in zip(problems, solutions, answers)
        ]

        # Optional: group every N items into a long text
        N = 10
        grouped = [
            "  ".join(forget_list[i: i + N])
            for i in range(0, len(forget_list), N)
        ]
        forget_list = grouped
        random.shuffle(forget_list)
        # Use first 100 grouped samples
        texts = forget_list[:100]

    texts_all = texts + retain_text
    labels = [1] * len(texts) + [0] * len(retain_text)  # 1 for training data, 0 for retain set

    base = model_name.split('/')[-1]
    before = model_name
    checkpoint_after_un = f"Model/{data_type}/all_layer/{base}/lr{lr}_{unlearning_algorithm}_{phase}"
    checkpoint_after_re = f"Model/recovery/{data_type}/{base}/{unlearning_algorithm}/lr{lr}_all_layers_forget_{phase}"

    after_un_model = AutoModelForCausalLM.from_pretrained(
        checkpoint_after_un,
        torch_dtype=torch.bfloat16,
        trust_remote_code=True,
        use_flash_attention_2=True,
    ).to(device)

    after_re_model = AutoModelForCausalLM.from_pretrained(
        checkpoint_after_re,
        torch_dtype=torch.bfloat16,
        trust_remote_code=True,
        use_flash_attention_2=True,
    ).to(device)

    tokenizer = AutoTokenizer.from_pretrained(
        model_name,
        trust_remote_code=True
    )

    after_un_result = []
    after_re_result = []
    label = [1] * len(texts) + [0] * len(retain_text)

    for i in range(len(texts_all)):
        after_un_score = compute_baseline_score(texts_all[i], after_un_model, tokenizer, device)
        after_un_seq = {
            'label': label[i],
            'pred': {
                'min': after_un_score
            }
        }
        after_un_result.append(after_un_seq)

        after_re_score = compute_baseline_score(texts_all[i], after_re_model, tokenizer, device)
        after_re_seq = {
            'label': label[i],
            'pred': {
                'min': after_re_score
            }
        }
        after_re_result.append(after_re_seq)

    print(f"after_un_model {checkpoint_after_un}:")
    fig_fpr_tpr(after_un_result, "./")

    print(f"after_re_model: {checkpoint_after_re}")
    fig_fpr_tpr(after_re_result, "./")
