import json
import os
from peft import LoraConfig, PeftModel
import torch
from models.transformer_model import TransformerModel
from utils.lora import apply_lora
from transformers import GPT2LMHeadModel, AutoModelForCausalLM, AutoTokenizer
def compute_ans(text0, text1):
    ans0 = [False, False, False, False]
    if "A" in text0:
        ans0[0] = True
    if "B" in text0:
        ans0[1] = True
    if "C" in text0:
        ans0[2] = True
    if "D" in text0:
        ans0[3] = True
    ans1 = [False, False, False, False]
    if "A" in text1:
        ans1[0] = True
    if "B" in text1:
        ans1[1] = True
    if "C" in text1:
        ans1[2] = True
    if "D" in text1:
        ans1[3] = True
    # 计算对于text1中为True的选项，text0中是否也为True
    correct = sum(a and b for a, b in zip(ans0, ans1))
    allans = sum(ans1)  # text1中有多少个选项为True
    return correct, allans

# -------- 评估 ----------
@torch.no_grad()
def evaluate_cyber_qa_accuracy(model, dataloader, tokenizer, device, qa_save_path="temp.csv"):
    model.eval()
    model = model.to(device)
    total = correct = 0
    for i, batch in enumerate(dataloader):
        ids   = batch["input_ids"].to(device)
        mask  = batch["attention_mask"].to(device)
        outs  = model.module.generate(ids, attention_mask=mask, max_new_tokens=400, pad_token_id=tokenizer.pad_token_id)
        # 正确答案，按batch索引切分
        start_idx = i * dataloader.batch_size
        end_idx = start_idx + ids.shape[0]
        gt_answers = dataloader.dataset.answers[start_idx:end_idx]
        new_token_ids = outs[:, ids.shape[1]:]  # 只取新增tokens
        texts = tokenizer.batch_decode(outs, skip_special_tokens=True)
        # 将单条outs去除换行保存到txt
        if i == 0:
            with open(qa_save_path, "w") as f:
                for text in texts:
                    f.write(text.replace("\n", " ") + "\n")
        new_texts = tokenizer.batch_decode(new_token_ids, skip_special_tokens=True)
        # 提取new_texts中包含A,B,C,D中的哪些
        # 判断预测是否包含正确答案
        for pred, gold in zip(new_texts, gt_answers):
            a, b = compute_ans(pred, gold)
            correct += a
            total += b
    print(f"Cyber MCQ Accuracy: {correct}/{total} = {correct / total:.4f}")
    return correct / total if total else 0

def test_qa(args, lora_save_dir, eval_loader, eval_outfile, save_dir = "temp/full_ckpt", access_token=None):
    os.makedirs(save_dir, exist_ok=True)
    base = AutoModelForCausalLM.from_pretrained(args.model_name, torch_dtype=torch.float16, token=access_token)
    model2 = PeftModel.from_pretrained(base, lora_save_dir)
    acc = evaluate_cyber_qa_accuracy(model2, eval_loader, tokenizer, args.device)
    print(f"Cyber MCQ acc={acc:.4f}")
    # ==== ② 合并 LoRA 到基座，并保存 ====
    merged = model2.merge_and_unload()             # 权重合并
    merged.save_pretrained(save_dir)
    tokenizer = TransformerModel(model_name=args.model_name).get_tokenizer()
    tokenizer.save_pretrained(save_dir)
    print(f"✅  已保存合并模型到 {save_dir}")
    eval_command = (
        f"lm-eval --model hf "
        f"--model_args pretrained=res/ga_gd/full_ckpt,parallelize=True,dtype=float16 "
        f"--tasks wmdp,mmlu,wikitext "
        f"--batch_size=2 --output_path {eval_outfile}"
    )
    print("Running LLM eval:", eval_command)
    os.system(eval_command)
    eval_results = []
    # 加载llm-eval输出
    if os.path.exists(eval_outfile):
        with open(eval_outfile, 'r') as f:
            result = json.load(f)
            eval_results.append(result)
    else:
        eval_results.append(None)
    with open("res/ga_gd/eval_results.json", "w") as f:
        json.dump(eval_results, f, indent=2)