import os
import argparse
import random
import math
import torch
import numpy as np
from transformers import AutoModelForCausalLM, AutoTokenizer
from datasets import load_dataset

class Unlearning_Evaluator:
    def __init__(self, model, tokenizer, chunk_size=4096):
        self.model = model.eval()
        self.tokenizer = tokenizer
        self.device = model.device
        self.chunk_size = chunk_size

    @torch.no_grad()
    def compute_ppl_and_acc_for_text(self, text: str):
        inputs = self.tokenizer(text, return_tensors="pt", truncation=False).to(self.device)
        ids = inputs.input_ids.squeeze(0)
        mask = inputs.attention_mask.squeeze(0)
        total_ppl = total_acc = chunks = 0
        for i in range(0, ids.size(0), self.chunk_size):
            chunk_ids = ids[i:i+self.chunk_size].unsqueeze(0)
            chunk_mask = mask[i:i+self.chunk_size].unsqueeze(0)
            if chunk_ids.size(1) < 2: continue
            out = self.model(input_ids=chunk_ids, attention_mask=chunk_mask, labels=chunk_ids)
            loss = out.loss.item()
            ppl = math.exp(loss)
            preds = out.logits.argmax(-1)
            valid = chunk_mask[:,1:].bool()
            correct = (preds[:,:-1][valid] == chunk_ids[:,1:][valid]).sum().item()
            acc = correct/valid.sum().item()
            total_ppl += ppl; total_acc += acc; chunks += 1
        if chunks == 0: return float("nan"), float("nan")
        return total_ppl/chunks, total_acc/chunks

    def evaluate_forget_and_retain(self, forget_texts, retain_texts, retain_samples=100):
        fps, fas = zip(*(self.compute_ppl_and_acc_for_text(t) for t in forget_texts))
        rps, ras = zip(*(self.compute_ppl_and_acc_for_text(t) for t in retain_texts[:retain_samples]))
        return {
            "forget_ppl": round(np.nanmean(fps), 4),
            "forget_acc": round(np.nanmean(fas), 4),
            "retain_ppl": round(np.nanmean(rps), 4),
            "retain_acc": round(np.nanmean(ras), 4)
        }

def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--lr', type=float, default=1e-6, help="Learning rate, e.g. 1e-6")
    parser.add_argument('--model_name', type=str, default="Qwen/Qwen2.5-7B", help="Model name or path")
    parser.add_argument('--process', type=str, choices=['unlearn','relearn'], default="unlearn",
                        help="Choose whether to load checkpoint after unlearning or relearning")
    parser.add_argument('--unlearning_algorithm', type=str, default="GA", help="Unlearning algorithm identifier")
    parser.add_argument('--type', choices=['Text','Math'], default="Text", help="Data type")
    parser.add_argument('--phase', type=str, default="N1_Request100", help="Unlearning phase")
    args = parser.parse_args()

    base = args.model_name.split('/')[-1]

    if args.process == "unlearn":
        # e.g. Model/Text/all_layer/Qwen2.5-7B/lr1e-06_GA_N1_Request100
        ckpt = f"Model/{args.type}/all_layer/{base}/lr{args.lr}_{args.unlearning_algorithm}_{args.phase}"
    elif args.process == "relearn":
        # e.g. Model/recovery/Text/Qwen2.5-7B/GA/lr1e-06_all_layers_forget_N1_Request100
        ckpt = (
            f"Model/recovery/{args.type}/{base}/"
            f"{args.unlearning_algorithm}/"
            f"lr{args.lr}_all_layers_forget_{args.phase}"
        )
    else:
        raise ValueError("`--process` must be either 'unlearn' or 'relearn'")

    model = AutoModelForCausalLM.from_pretrained(
        ckpt, torch_dtype=torch.bfloat16, trust_remote_code=True,
        use_flash_attention_2=True, device_map="auto"
    )
    tokenizer = AutoTokenizer.from_pretrained(ckpt)
    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token
        tokenizer.pad_token_id = tokenizer.eos_token_id

    ds_forget  = load_dataset("llmunlearn/unlearn_dataset", name="arxiv", split="forget")
    ds_github  = load_dataset("llmunlearn/unlearn_dataset", name="github", split="forget")
    ds_ret_arx = load_dataset("llmunlearn/unlearn_dataset", name="arxiv", split="retain")
    ds_ret_gb  = load_dataset("llmunlearn/unlearn_dataset", name="github", split="retain")
    ds_ret_gen = load_dataset("llmunlearn/unlearn_dataset", name="general", split="retain")

    forget_list = [d["text"] for d in ds_forget] + [d["text"] for d in ds_github]
    retain_list = [d["text"] for d in ds_ret_arx]  + [d["text"] for d in ds_ret_gen] + [d["text"] for d in ds_ret_gb]

    random.seed(42)
    random.shuffle(forget_list)
    random.shuffle(retain_list)

    evaluator = Unlearning_Evaluator(model, tokenizer)
    results = evaluator.evaluate_forget_and_retain(
        forget_list[:100], retain_list, retain_samples=100
    )
    print(results)

if __name__ == "__main__":
    main()
