import os
import json
import subprocess
from peft import PeftModel
import torch
from transformers import GPT2LMHeadModel, AutoModelForCausalLM, AutoTokenizer
# 假设 evaluate_cyber_qa_accuracy 已经定义，接受 (model, loader, tokenizer, device) 并返回浮点型准确率
from models.transformer_model import TransformerModel
from utils.data_loader_wmdp import LanguageModelingDataset, preprocess_test
from utils.test_qa import evaluate_cyber_qa_accuracy
from datasets import load_from_disk
access_token = "your_huggingface_access_token"

def evaluate_cyber_accuracy_from_checkpoint(
    save_path: str,
    model_name: str,
    cyber_eval_loader: torch.utils.data.DataLoader,
    tokenizer,
    device: torch.device,
    save_dir: str
) -> float:
    base = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.float16, token=access_token)
    model2 = PeftModel.from_pretrained(base, save_path)
    model = model2.merge_and_unload()
    print(f"✅  已加载合并模型")
    os.makedirs(save_dir, exist_ok=True)
    model.save_pretrained(save_dir)
    tokenizer.save_pretrained(save_dir)
    print(f"✅  已保存合并模型到 {save_dir}")
    # 4. 将模型切换到 eval 模式并移动到指定设备
    model = torch.nn.DataParallel(model)  # 如果需要多GPU并行
    model.eval()
    txt_path = os.path.join(save_dir, "qa_res.txt")
    # 5. 调用已有的评估函数
    with torch.no_grad():
        accuracy = evaluate_cyber_qa_accuracy(model, cyber_eval_loader, tokenizer, device, qa_save_path=txt_path)
    print(f"QA Accuracy: {accuracy:.4f}")
    return accuracy

def evaluate_with_llmeval(
    save_path: str,
    llmeval_tasks: str,
    output_json: str,
    batch_size: int = 32,
    parallelize: bool = True,
) -> dict:
    """
    使用外部的 lm-eval 工具对模型进行多任务评估，并返回读取到的 JSON 结果。

    参数：
        ckpt_path (str): 已保存的模型权重文件路径，比如 "res/ga/model_epoch.pth"
        llmeval_tasks (str): lm-eval 支持的任务列表，用逗号分隔，例如 "wmdp-cyber,mmlu,wikitext"
        output_json (str): lm-eval 输出的 JSON 文件路径，比如 "res/ga/llmeval_epoch_1.json"
        batch_size (int): lm-eval 时的 batch_size
        parallelize (bool): 是否在 lm-eval 时使用并行模式（如果支持）

    返回：
        result (dict): 如果成功运行并且文件存在，返回解析后的字典；否则返回空字典 {}
    """
    # 1. 构建 lm-eval 命令
    #    注意：--model hf 表示使用 HuggingFace 接口；--model_args 包含 pretrained 模型路径和 parallelize 参数
    parallel_flag = "True" if parallelize else "False"
    # 如果模型需要额外的参数，可以在这里拼接到 --model_args 后面
    model_args_str = f"pretrained={save_path},parallelize={parallel_flag}"

    eval_command = (
        f"lm-eval --model hf "
        f"--model_args \"{model_args_str}\" "
        f"--tasks {llmeval_tasks} "
        f"--batch_size {batch_size} "
        f"--output_path {output_json}"
    )

    print(f"Running LLM eval command:\n  {eval_command}")

    # 2. 执行命令（这里用 subprocess 比 os.system 更灵活）
    #    shell=True 以便处理字符串中的引号
    try:
        completed = subprocess.run(
            eval_command,
            shell=True,
            check=True,
            stdout=subprocess.PIPE,
            stderr=subprocess.PIPE,
            universal_newlines=True
        )
        # 如果需要查看 lm-eval 的输出，可以打印 completed.stdout
        print(completed.stdout)
    except subprocess.CalledProcessError as e:
        print("Error running lm-eval:")
        print(e.stderr)
        return {}

    # 3. 读取输出的 JSON
    if os.path.exists(output_json):
        with open(output_json, 'r', encoding='utf-8') as f:
            result = json.load(f)
        return result
    else:
        print(f"Warning: {output_json} not found after lm-eval.")
        return {}

if __name__ == "__main__":
    import torch
    from torch.utils.data import DataLoader
    # 参数
    os.environ["CUDA_VISIBLE_DEVICES"] = "3"  # 设置可见的 GPU
    model_name = "mistralai/Mistral-7B-v0.1"  # 或者其他模型名称
    name = model_name.split("/")[-1]
    checkpoint_path = f"/data/wwh/llmUN2/relearn/npo_1e4/{name}/hf_ckpt_4"
    # checkpoint_path = f"/data/wwh/llmUN/relearn/{name}/maml_ga_gd_5"
    # checkpoint_path = f"/data/wwh/llmUN2/base/Mistral-7B-v0.1/0"
    name = model_name.split("/")[-1]
    save_dir = "./tmp/model/"
    # save_dir = f"/data/wwh/llmUN/temp_learn/npo_5e4/{name}/hf_ckpt_21"
    # save_dir = f"/data/wwh/llmUN/relearning/{name}/maml_npo_6"
    output_json_path = save_dir + "/llmeval_results.json"
    un_task = "wmdp-cyber"
    # llm_tasks = "wmdp,mmlu"  # lm-eval 支持的任务列表
    llm_tasks = "wmdp,mmlu"
    test_qa = False
    eval_or_not = False
    device = torch.device("cuda:3" if torch.cuda.is_available() else "cpu")
    tokenizer = TransformerModel(model_name=model_name).get_tokenizer()
    print(checkpoint_path)
    if (un_task == "wmdp-cyber" or un_task == "wmdp-bio") and test_qa:
        print(f"开始评估任务：{un_task}")
        data_wmdp = load_from_disk(f"data/{un_task}")["test"]
        ids, masks, answers = [], [], []
        for i, ex in enumerate(data_wmdp):
            proc = preprocess_test(ex, tokenizer, i)
            ids.append(proc["input_ids"]);  masks.append(proc["attention_mask"]);  answers.append(proc["answer"])
        eval_ds = LanguageModelingDataset(ids, masks);  
        eval_ds.answers = answers
        eval_loader = DataLoader(eval_ds, batch_size=32)
        accuracy = evaluate_cyber_accuracy_from_checkpoint(
            save_path=checkpoint_path,
            model_name=model_name,
            cyber_eval_loader=eval_loader,
            tokenizer=tokenizer,
            save_dir=save_dir,
            device=device
        )
        print(f"Cyber MCQ Accuracy: {accuracy:.4f}")
    elif test_qa:
        raise ValueError(f"Unsupported task: {un_task}")
    base = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.float16, token=access_token)
    model2 = PeftModel.from_pretrained(base, checkpoint_path)
    model2 = model2.to(device)
    model = model2.merge_and_unload()
    print(f"✅  已加载合并模型")
    os.makedirs(save_dir, exist_ok=True)
    model.save_pretrained(save_dir)
    tokenizer.save_pretrained(save_dir)
    print(f"✅  已保存合并模型到 {save_dir}")
    # 例 2：使用 lm-eval 多任务评估
    llm_results = evaluate_with_llmeval(
        save_path=save_dir,
        llmeval_tasks=llm_tasks,
        output_json=output_json_path,
        batch_size=4,
        parallelize=True
    )
    if llm_results:
        print("LLM-Eval results:")
        print(json.dumps(llm_results, indent=2, ensure_ascii=False))
    else:
        print("LLM-Eval did not produce any result.")
