# import os
# import json
# import math
# import argparse
# import time

# import torch
# import torch.distributed as dist
# from torch.utils.data import DataLoader, DistributedSampler
# from torch.nn import CrossEntropyLoss

# import deepspeed
# from transformers import AutoModelForCausalLM, AutoTokenizer

# from data_utils import get_sft_dataset, collate_sft  # 复用你训练用的接口

# torch.set_num_threads(4)


# # ============================= 小工具：分布式 & 时间格式 =============================

# def setup_distributed():
#     if not dist.is_initialized():
#         deepspeed.init_distributed()


# def set_seed(seed: int):
#     import random
#     import numpy as np
#     random.seed(seed)
#     np.random.seed(seed)
#     torch.manual_seed(seed)
#     torch.cuda.manual_seed_all(seed)


# def format_seconds(seconds: float) -> str:
#     """把秒数格式化成 Xd XXh XXm XXs"""
#     seconds = int(seconds)
#     days, rem = divmod(seconds, 86400)
#     hours, rem = divmod(rem, 3600)
#     minutes, seconds = divmod(rem, 60)
#     return f"{days}d {hours:02d}h {minutes:02d}m {seconds:02d}s"


# # ============================= 参数 =============================

# def parse_args():
#     parser = argparse.ArgumentParser(description="Evaluate pruned HF LLM with DeepSpeed (TALE-style)")

#     # 模型与路径
#     parser.add_argument("--model_name_or_path", type=str, required=True,
#                         help="要加载的 HF 格式模型路径（可以是 finetune 后的 checkpoint 或 final）")
#     parser.add_argument("--output_dir", type=str, required=True,
#                         help="评估日志与结果保存目录")
#     parser.add_argument("--deepspeed_config", type=str, required=True,
#                         help="DeepSpeed json 配置文件路径")

#     # 数据相关
#     parser.add_argument("--sft_dataset", type=str, default="mmlu",
#                         help="评估使用的数据集类型，沿用训练时的 data_utils.get_sft_dataset")
#     parser.add_argument("--max_length", type=int, default=512,
#                         help="评估时的最大序列长度（需要与构造 SFT 数据集时一致）")
#     parser.add_argument("--num_eval_samples", type=int, default=None,
#                         help="可选：限制使用多少条样本进行评估（不填则用完整数据集）")
#     parser.add_argument("--seed", type=int, default=42,
#                         help="构造数据集时的随机种子（与训练脚本保持一致即可）")
#     parser.add_argument("--num_workers", type=int, default=2,
#                         help="DataLoader 的 num_workers")
#     parser.add_argument("--eval_split", type=str, default="validation",
#                         choices=["train", "validation", "valid", "test"],
#                         help="Which split of parquet to use for evaluation")

#     # batch & 分布式
#     parser.add_argument("--per_device_eval_batch_size", type=int, default=4,
#                         help="每张卡上的 eval batch size")
#     parser.add_argument("--gradient_accumulation_steps", type=int, default=1,
#                         help="评估时一般用 1 即可，这里仍保留参数，方便后面修改")
#     # 日志
#     parser.add_argument("--logging_steps", type=int, default=50,
#                         help="每多少个 step 打一次日志（仅 rank 0）")

#     return parser.parse_args()


# # ============================= 模型 & 数据 =============================

# def load_model_and_tokenizer(model_path: str):
#     tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=True)
#     if tokenizer.pad_token is None:
#         tokenizer.pad_token = tokenizer.eos_token
#         tokenizer.pad_token_id = tokenizer.eos_token_id
#     tokenizer.padding_side = "right"

#     # 这里默认 bf16，你可以按需改成 float16 或 float32
#     model = AutoModelForCausalLM.from_pretrained(
#         model_path,
#         torch_dtype=torch.bfloat16,
#         device_map=None  # 交给 DeepSpeed 放到 GPU
#     )
#     return model, tokenizer


# def build_eval_dataloader(args, tokenizer):
#     """
#     现在用 split 参数控制读取 train / validation / test 的 parquet
#     """
#     dataset = get_sft_dataset(
#         name=args.sft_dataset,
#         tokenizer=tokenizer,
#         max_length=args.max_length,
#         seed=args.seed,
#         num_samples=args.num_eval_samples,
#         split=args.eval_split,
#     )

#     world_size = dist.get_world_size()
#     rank = dist.get_rank()

#     sampler = DistributedSampler(
#         dataset,
#         num_replicas=world_size,
#         rank=rank,
#         shuffle=False,
#         drop_last=False,
#     )

#     dataloader = DataLoader(
#         dataset,
#         batch_size=args.per_device_eval_batch_size,
#         sampler=sampler,
#         collate_fn=collate_sft,
#         num_workers=args.num_workers,
#         pin_memory=True,
#     )

#     return dataloader, len(dataset)



# # ============================= 评估核心：loss + token-accuracy =============================
# def run_evaluation(args, model_engine, eval_dataloader, tokenizer):
#     """
#     针对 MMLU 这类「prompt + 一个答案 token」的任务：

#     - avg_loss：只在 labels != -100 的 token（答案部分）上算 loss
#                 （直接用 HF 内置 loss，忽略 -100）
#     - classification_accuracy：按题目算，
#         用「prompt 最后一个 token 的 logits」预测「第一个答案 token」，
#         即：
#             ctx_pos = ans_pos - 1
#             gold_id = labels[ans_pos]
#             pred_id = argmax(logits[ctx_pos])
#     - 额外：
#       * rank0 打印前 5 个样本的题干（prompt）+ 预测答案 + 真实答案
#       * 每个 rank 把自己看到的所有样本的 (pred, gold, correct, prompt) 写到
#         args.output_dir/mmlu_predictions_rank{rank}.txt
#     """
#     device = torch.device("cuda", model_engine.local_rank)
#     model_engine.eval()

#     # ==== 准备 MMLU 4 个选项的 token id（假定你用的是 " A"/" B"/" C"/" D" 格式）====
#     choice_texts = [" A", " B", " C", " D"]
#     # choice_texts = [" A", " B"]
#     choice_token_ids = []
#     for txt in choice_texts:
#         ids = tokenizer.encode(txt, add_special_tokens=False)
#         if len(ids) != 1:
#             # 如果这里抛异常，说明该 tokenizer 把 " A" 拆成了多个 token，
#             # 那就需要改成用“多 token 序列匹配”的策略，这里先假定是单 token。
#             raise ValueError(f"Choice text {txt!r} is not a single token: {ids}")
#         choice_token_ids.append(ids[0])
#     choice_token_ids = torch.tensor(choice_token_ids, device=device)  # [4]

#     rank = dist.get_rank()
#     world_size = dist.get_world_size()
#     is_main = (rank == 0)

#     # 全局统计
#     total_loss = 0.0              # 累积「答案 token」的总 loss 之和（sum）
#     total_answer_tokens = 0       # 全部样本中 label != -100 的 token 数
#     total_class_correct = 0       # 整题预测正确的样本数
#     total_samples = 0             # 整体样本数

#     # debug：只在 rank0 打印前几个题目
#     debug_print_limit = 5
#     debug_print_count = 0

#     # 把每个 rank 的预测写到自己一个文件里
#     os.makedirs(args.output_dir, exist_ok=True)
#     pred_log_path = os.path.join(args.output_dir, f"mmlu_predictions_rank{rank}.txt")
#     pred_f = open(pred_log_path, "w", encoding="utf-8")

#     if is_main:
#         eval_start = time.time()

#     with torch.no_grad():
#         for step, batch in enumerate(eval_dataloader):
#             batch = {k: v.to(device) for k, v in batch.items()}

#             input_ids = batch["input_ids"]       # [B, T]
#             attention_mask = batch["attention_mask"]  # [B, T]
#             labels = batch["labels"]             # [B, T]，prompt=-100, answer/eos=token_id

#             # HF 自带的 loss：已经按照 ignore_index=-100 做好平均
#             outputs = model_engine(
#                 input_ids=input_ids,
#                 attention_mask=attention_mask,
#                 labels=labels,
#                 use_cache=False,
#             )
#             logits = outputs.logits              # [B, T, V]
#             batch_loss_mean = outputs.loss       # scalar，答案 token 的平均 CE

#             # 统计答案 token 数（用于还原 sum_loss）
#             valid_mask = labels.ne(-100) & attention_mask.eq(1)
#             num_answer_tokens = valid_mask.sum().item()

#             batch_loss_sum = batch_loss_mean.item() * max(num_answer_tokens, 1)

#             total_loss += batch_loss_sum
#             total_answer_tokens += num_answer_tokens

#             batch_size = input_ids.size(0)
#             total_samples += batch_size

#             # ===== 整题分类正确率：ctx_pos = ans_pos - 1 =====
#             for i in range(batch_size):
#                 # 有效答案位置：label != -100 且在有效长度内
#                 valid_pos = (labels[i] != -100) & (attention_mask[i] == 1)
#                 if not valid_pos.any():
#                     continue

#                 # 第一个答案 token 的 index
#                 ans_pos = valid_pos.nonzero(as_tuple=False)[0].item()
#                 if ans_pos == 0:
#                     # 理论上不会发生（前面有 prompt），但还是防御一下
#                     continue
#                 ctx_pos = ans_pos - 1

#                 gold_id = labels[i, ans_pos].item()
#                 # 用「prompt 最后一个位置」的 logits 预测答案 token
#                 # 只在 {A, B, C, D} 这几个候选 token 上做 argmax
#                 logits_ctx = logits[i, ctx_pos]                     # [V]
#                 choice_logits = logits_ctx[choice_token_ids]        # [4]
#                 choice_idx = choice_logits.argmax(dim=-1).item()    # 0/1/2/3
#                 pred_id = choice_token_ids[choice_idx].item()

#                 is_correct = int(pred_id == gold_id)
#                 total_class_correct += is_correct

#                 # 解码方便查看
#                 pred_str = tokenizer.decode([pred_id], skip_special_tokens=False)
#                 gold_str = tokenizer.decode([gold_id], skip_special_tokens=False)

#                 # 题干 = labels==-100 & attention_mask==1 的部分
#                 prompt_token_mask = (labels[i] == -100) & (attention_mask[i] == 1)
#                 prompt_ids = input_ids[i][prompt_token_mask]
#                 prompt_text = tokenizer.decode(
#                     prompt_ids.tolist(), skip_special_tokens=False
#                 )

#                 # rank0 打前几个样本
#                 if is_main and debug_print_count < debug_print_limit:
#                     print("=" * 80)
#                     print(f"[Sample #{debug_print_count}]")
#                     print("Prompt / Question (含选项和 'Answer:'):")
#                     print(prompt_text)
#                     print(f"Gold answer token : {gold_str!r} (id={gold_id})")
#                     print(f"Pred answer token : {pred_str!r} (id={pred_id})")
#                     print(f"Correct? {bool(is_correct)}")
#                     debug_print_count += 1

#                 # 写入 rank 的预测日志
#                 safe_prompt = prompt_text.replace("\n", "\\n")
#                 safe_pred = pred_str.replace("\n", "\\n")
#                 safe_gold = gold_str.replace("\n", "\\n")
#                 pred_f.write(
#                     f"{is_correct}\t{safe_gold}\t{safe_pred}\t{safe_prompt}\n"
#                 )

#             # rank0 打 log
#             if is_main and (step + 1) % args.logging_steps == 0:
#                 elapsed = time.time() - eval_start
#                 avg_step_time = elapsed / (step + 1)
#                 est_total_steps = len(eval_dataloader)
#                 remaining_steps = max(est_total_steps - (step + 1), 0)
#                 eta_seconds = remaining_steps * avg_step_time

#                 print(
#                     f"[Eval] step {step+1}/{est_total_steps} | "
#                     f"batch_loss(mean): {batch_loss_mean.item():.4f} | "
#                     f"answer_tokens_in_batch: {num_answer_tokens} | "
#                     f"avg_step_time {avg_step_time:.3f}s | "
#                     f"ETA {format_seconds(eta_seconds)}"
#                 )

#     pred_f.close()

#     # ===== 多卡汇总 =====
#     stats = torch.tensor(
#         [total_loss, total_answer_tokens, total_class_correct, total_samples],
#         dtype=torch.float64,
#         device=device,
#     )
#     dist.all_reduce(stats, op=dist.ReduceOp.SUM)
#     total_loss, total_answer_tokens, total_class_correct, total_samples = stats.tolist()

#     total_answer_tokens = max(total_answer_tokens, 1.0)
#     total_samples = max(total_samples, 1.0)

#     avg_loss = total_loss / total_answer_tokens
#     classification_accuracy = total_class_correct / total_samples

#     return avg_loss, classification_accuracy

# # ============================= 主函数 =============================

# def main():
#     torch.backends.cudnn.enabled = False

#     args = parse_args()
#     setup_distributed()
#     rank = dist.get_rank()
#     world_size = dist.get_world_size()
#     is_main = (rank == 0)

#     if is_main:
#         os.makedirs(args.output_dir, exist_ok=True)
#         print("====== Eval args ======")
#         for k, v in vars(args).items():
#             print(f"{k}: {v}")
#         print("=======================")

#     set_seed(args.seed + rank)

#     # 1. 加载模型 + tokenizer
#     if is_main:
#         print(f"[Rank 0] Loading model from {args.model_name_or_path}")
#     model, tokenizer = load_model_and_tokenizer(args.model_name_or_path)

#     # 2. 加载 DeepSpeed 配置
#     with open(args.deepspeed_config, "r") as f:
#         ds_config = json.load(f)

#     # 修正 batch 相关字段
#     per_device_bs = args.per_device_eval_batch_size
#     grad_accum = args.gradient_accumulation_steps

#     ds_config["train_micro_batch_size_per_gpu"] = int(per_device_bs)
#     ds_config["gradient_accumulation_steps"] = int(grad_accum)
#     ds_config["train_batch_size"] = int(per_device_bs * grad_accum * world_size)

#     # 减少 DeepSpeed 自己的打印
#     ds_config["steps_per_print"] = ds_config.get("steps_per_print", 10_000_000)

#     if "zero_optimization" in ds_config:
#         ds_config["zero_optimization"]["stage"] = 0


#     # ⚠ 与训练脚本一样：去掉 optimizer / scheduler，防止 DeepSpeed 构建 FusedAdam / CPUAdam
#     if "optimizer" in ds_config:
#         if is_main:
#             print("[Rank 0] Remove 'optimizer' from ds_config for evaluation")
#         ds_config.pop("optimizer")
#     if "scheduler" in ds_config:
#         if is_main:
#             print("[Rank 0] Remove 'scheduler' from ds_config for evaluation")
#         ds_config.pop("scheduler")

#     # 避免 DeepSpeed 同时从 args 里再读一遍 config
#     if hasattr(args, "deepspeed"):
#         args.deepspeed = None
#     if hasattr(args, "deepspeed_config"):
#         args.deepspeed_config = None

#     # 3. 构造 eval DataLoader
#     if is_main:
#         print(f"[Rank 0] Loading Eval Datasets.")
#     eval_dataloader, dataset_size = build_eval_dataloader(args, tokenizer)
#     if is_main:
#         print(f"World size: {world_size}")
#         print(f"Eval dataset size: {dataset_size}")
#         total_steps = len(eval_dataloader)
#         print(f"Eval steps per epoch: {total_steps}")

#     # 4. 初始化 DeepSpeed（无 optimizer / scheduler）
#     model_engine, _, _, _ = deepspeed.initialize(
#         model=model,
#         model_parameters=model.parameters(),
#         args=args,
#         config_params=ds_config,
#     )

#     # 5. 运行评估
#     if is_main:
#         print("[Rank 0] Start evaluation...")
#     eval_start = time.time()
#     avg_loss, cls_acc = run_evaluation(args, model_engine, eval_dataloader, tokenizer)
#     total_time = time.time() - eval_start

#     # 6. 打印 & 保存结果（只在 rank 0）
#     if is_main:
#         log_str = (
#             f"[Eval Finished] avg_loss: {avg_loss:.4f} | "
#             f"classification_accuracy: {cls_acc:.4%} | "
#             f"total_time: {format_seconds(total_time)}"
#         )
#         print(log_str)

#         os.makedirs(args.output_dir, exist_ok=True)
#         with open(os.path.join(args.output_dir, "eval_log.txt"), "a", encoding="utf-8") as f:
#             f.write(log_str + "\n")

#         result_json = {
#             "avg_loss": float(avg_loss),
#             "classification_accuracy": float(cls_acc),
#             "total_time_seconds": float(total_time),
#         }
#         with open(os.path.join(args.output_dir, "eval_metrics.json"), "w", encoding="utf-8") as f:
#             json.dump(result_json, f, indent=2, ensure_ascii=False)

#     dist.barrier()



# if __name__ == "__main__":
#     main()

import os
import json
import math
import argparse
import time
import re

import torch
import torch.distributed as dist
from torch.utils.data import DataLoader, DistributedSampler
import torch.nn.functional as F

import deepspeed
from transformers import AutoModelForCausalLM, AutoTokenizer

from data_utils import get_sft_dataset, collate_sft  # 复用你训练用的接口

torch.set_num_threads(4)


# ============================= 小工具：分布式 & 时间格式 =============================

def setup_distributed():
    if not dist.is_initialized():
        deepspeed.init_distributed()


def set_seed(seed: int):
    import random
    import numpy as np
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)


def format_seconds(seconds: float) -> str:
    """把秒数格式化成 Xd XXh XXm XXs"""
    seconds = int(seconds)
    days, rem = divmod(seconds, 86400)
    hours, rem = divmod(rem, 3600)
    minutes, seconds = divmod(rem, 60)
    return f"{days}d {hours:02d}h {minutes:02d}m {seconds:02d}s"


# ============================= 参数 =============================

def parse_args():
    parser = argparse.ArgumentParser(description="Evaluate pruned HF LLM with DeepSpeed (TALE-style)")

    # 模型与路径
    parser.add_argument("--model_name_or_path", type=str, required=True,
                        help="要加载的 HF 格式模型路径（可以是 finetune 后的 checkpoint 或 final）")
    parser.add_argument("--output_dir", type=str, required=True,
                        help="评估日志与结果保存目录")
    parser.add_argument("--deepspeed_config", type=str, required=True,
                        help="DeepSpeed json 配置文件路径")

    # 数据相关
    parser.add_argument("--sft_dataset", type=str, default="mmlu",
                        help="评估使用的数据集类型，沿用训练时的 data_utils.get_sft_dataset")
    parser.add_argument("--max_length", type=int, default=512,
                        help="评估时的最大序列长度（需要与构造 SFT 数据集时一致）")
    parser.add_argument("--num_eval_samples", type=int, default=None,
                        help="可选：限制使用多少条样本进行评估（不填则用完整数据集）")
    parser.add_argument("--seed", type=int, default=42,
                        help="构造数据集时的随机种子（与训练脚本保持一致即可）")
    parser.add_argument("--num_workers", type=int, default=2,
                        help="DataLoader 的 num_workers")
    parser.add_argument("--eval_split", type=str, default="validation",
                        choices=["train", "validation", "valid", "test"],
                        help="Which split of parquet to use for evaluation")

    # batch & 分布式
    parser.add_argument("--per_device_eval_batch_size", type=int, default=4,
                        help="每张卡上的 eval batch size")
    parser.add_argument("--gradient_accumulation_steps", type=int, default=1,
                        help="评估时一般用 1 即可，这里仍保留参数，方便后面修改")

    # 日志
    parser.add_argument("--logging_steps", type=int, default=50,
                        help="每多少个 step 打一次日志（仅 rank 0）")

    return parser.parse_args()


# ============================= 模型 & 数据 =============================

def load_model_and_tokenizer(model_path: str):
    tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=True)
    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token
        tokenizer.pad_token_id = tokenizer.eos_token_id
    tokenizer.padding_side = "right"

    model = AutoModelForCausalLM.from_pretrained(
        model_path,
        torch_dtype=torch.bfloat16,
        device_map=None
    )
    return model, tokenizer


def build_eval_dataloader(args, tokenizer):
    dataset = get_sft_dataset(
        name=args.sft_dataset,
        tokenizer=tokenizer,
        max_length=args.max_length,
        seed=args.seed,
        num_samples=args.num_eval_samples,
        split=args.eval_split,
    )

    world_size = dist.get_world_size()
    rank = dist.get_rank()

    sampler = DistributedSampler(
        dataset,
        num_replicas=world_size,
        rank=rank,
        shuffle=False,
        drop_last=False,
    )

    dataloader = DataLoader(
        dataset,
        batch_size=args.per_device_eval_batch_size,
        sampler=sampler,
        collate_fn=collate_sft,
        num_workers=args.num_workers,
        pin_memory=True,
    )

    return dataloader, len(dataset)


# ============================= 解析与打分工具 =============================

_OPTION_LINE_RE = re.compile(r"^\s*([A-D])\.\s", re.IGNORECASE)
_OPTION_LINE_NUM_RE = re.compile(r"^\s*([1-9])\.\s")
_ANSWER_KEY_RE = re.compile(r"^\s*([A-D]|[1-9])\s*$", re.IGNORECASE)


def extract_option_keys_from_prompt(prompt_text: str):
    """
    从 prompt 的 Options 区域提取选项 key。
    兼容：
      A. ...
      B. ...
    或：
      1. ...
      2. ...
    """
    keys = []

    # 仅在 Options 之后的段落里搜（更稳）
    idx = prompt_text.find("### Options:")
    if idx >= 0:
        opt_block = prompt_text[idx:].split("### Answer", 1)[0]
    else:
        opt_block = prompt_text

    for line in opt_block.splitlines():
        m = _OPTION_LINE_RE.match(line)
        if m:
            keys.append(m.group(1).upper())
            continue
        m2 = _OPTION_LINE_NUM_RE.match(line)
        if m2:
            keys.append(m2.group(1))
            continue

    # 去重保持顺序
    seen = set()
    uniq = []
    for k in keys:
        if k not in seen:
            seen.add(k)
            uniq.append(k)

    # 常见情况下应为 4 个
    return uniq


def extract_gold_key_from_labels(tokenizer, input_ids_1d, labels_1d, attn_1d):
    """
    从 labels 的答案区解码得到 gold key（A/B/C/D 或 1/2/3/4 等）。
    解决你遇到的 gold 为空/空格问题：会 strip 并提取首个合法 key。
    """
    valid_pos = (labels_1d != -100) & (attn_1d == 1)
    if not valid_pos.any():
        return None, ""

    ans_ids = labels_1d[valid_pos].tolist()

    # 解码答案区整体文本（而不是只看第一个 token）
    ans_text = tokenizer.decode(ans_ids, skip_special_tokens=False)

    # 做清洗：去掉换行/多空格，提取第一个 A-D 或 1-9
    cleaned = ans_text.replace("\n", " ").strip()

    # 直接匹配整个就是 "A"/" C"/"3" 等
    m = _ANSWER_KEY_RE.match(cleaned)
    if m:
        key = m.group(1).upper() if m.group(1).isalpha() else m.group(1)
        return key, cleaned

    # 否则在字符串里找第一个可用 key（例如 "Answer: C" 这种）
    m2 = re.search(r"\b([A-D])\b", cleaned, re.IGNORECASE)
    if m2:
        return m2.group(1).upper(), cleaned

    m3 = re.search(r"\b([1-9])\b", cleaned)
    if m3:
        return m3.group(1), cleaned

    # 还是找不到就返回 None（该样本不计入 acc）
    return None, cleaned


def score_candidate_sequence_logprob(model_engine, input_ids_prompt, cand_ids, attention_mask_prompt=None):
    """
    计算 log p(cand | prompt) 的序列对数概率（teacher forcing）。
    返回：sum_logprob, mean_logprob
    """
    device = input_ids_prompt.device

    # 拼接
    cand = torch.tensor(cand_ids, device=device, dtype=torch.long)
    all_ids = torch.cat([input_ids_prompt, cand], dim=0)  # [L_prompt + L_cand]

    # attention_mask
    if attention_mask_prompt is None:
        attn = torch.ones_like(all_ids, dtype=torch.long, device=device)
    else:
        attn = torch.cat([attention_mask_prompt, torch.ones_like(cand, dtype=torch.long)], dim=0)

    # 需要批维度
    all_ids_b = all_ids.unsqueeze(0)          # [1, L]
    attn_b = attn.unsqueeze(0)                # [1, L]

    # forward
    out = model_engine(input_ids=all_ids_b, attention_mask=attn_b, use_cache=False)
    logits = out.logits.squeeze(0)            # [L, V]

    # 对 cand 的每个 token，取其对应位置的 logprob：
    # token cand[j] 被预测的位置是 all_ids 的 index = (L_prompt - 1 + j)
    Lp = input_ids_prompt.size(0)
    Lc = len(cand_ids)

    if Lc == 0:
        return -1e9, -1e9

    # 取预测位置 logits: [Lc, V]
    pred_logits = logits[(Lp - 1):(Lp - 1 + Lc), :]  # 对应 cand 的每一步预测
    logprobs = F.log_softmax(pred_logits, dim=-1)    # [Lc, V]
    target = torch.tensor(cand_ids, device=device, dtype=torch.long).unsqueeze(-1)  # [Lc,1]
    token_logprobs = logprobs.gather(-1, target).squeeze(-1)  # [Lc]

    sum_lp = token_logprobs.sum().item()
    mean_lp = (token_logprobs.mean().item()) if Lc > 0 else -1e9
    return sum_lp, mean_lp


def predict_choice_key(model_engine, tokenizer, prompt_ids_1d, option_keys):
    """
    对每个 key（如 A/B/C/D 或 1/2/3/4）计算分数，选择最高者。
    为了兼容空格差异，对每个 key 同时比较 key 与 " "+key，取 max。
    """
    # 注意：prompt_ids_1d 是 1D tensor（只包含 prompt 部分）
    best_key = None
    best_score = -1e18  # 用 mean_logprob 做比较更稳
    best_surface = None

    for k in option_keys:
        surfaces = [k, " " + k]
        key_best = -1e18
        key_best_surface = None

        for s in surfaces:
            cand_ids = tokenizer.encode(s, add_special_tokens=False)
            _, mean_lp = score_candidate_sequence_logprob(model_engine, prompt_ids_1d, cand_ids)
            if mean_lp > key_best:
                key_best = mean_lp
                key_best_surface = s

        if key_best > best_score:
            best_score = key_best
            best_key = k
            best_surface = key_best_surface

    return best_key, best_surface, best_score


# ============================= 评估核心：loss + 多选 accuracy =============================

def run_evaluation(args, model_engine, eval_dataloader, tokenizer):
    """
    - avg_loss：仍在 labels!=-100 的 token 上统计（和你原逻辑一致）
    - classification_accuracy：
        1) 从 prompt 文本自动抽取选项 key（A-D 或 1-9）
        2) 从 labels 解码答案区文本，提取 gold key（修复 gold 为空/空格）
        3) 用“序列 logprob”对每个 key 打分（不要求单 token，不会因 tokenizer 切分而崩）
    """
    device = torch.device("cuda", model_engine.local_rank)
    model_engine.eval()

    rank = dist.get_rank()
    world_size = dist.get_world_size()
    is_main = (rank == 0)

    total_loss = 0.0
    total_answer_tokens = 0

    total_class_correct = 0
    total_class_count = 0   # 只统计可判定 gold 的样本
    total_samples = 0

    total_skipped_no_gold = 0
    total_skipped_no_options = 0

    debug_print_limit = 5
    debug_print_count = 0

    os.makedirs(args.output_dir, exist_ok=True)
    pred_log_path = os.path.join(args.output_dir, f"mmlu_predictions_rank{rank}.txt")
    pred_f = open(pred_log_path, "w", encoding="utf-8")

    if is_main:
        eval_start = time.time()

    with torch.no_grad():
        for step, batch in enumerate(eval_dataloader):
            batch = {k: v.to(device) for k, v in batch.items()}

            input_ids = batch["input_ids"]          # [B, T]
            attention_mask = batch["attention_mask"]# [B, T]
            labels = batch["labels"]                # [B, T]

            # ===== loss 统计（保持你原逻辑）=====
            outputs = model_engine(
                input_ids=input_ids,
                attention_mask=attention_mask,
                labels=labels,
                use_cache=False,
            )
            logits = outputs.logits
            batch_loss_mean = outputs.loss

            valid_mask = labels.ne(-100) & attention_mask.eq(1)
            num_answer_tokens = valid_mask.sum().item()
            batch_loss_sum = batch_loss_mean.item() * max(num_answer_tokens, 1)

            total_loss += batch_loss_sum
            total_answer_tokens += num_answer_tokens

            B = input_ids.size(0)
            total_samples += B

            # ===== per-sample 多选评估 =====
            for i in range(B):
                # prompt mask：labels == -100 的有效部分
                prompt_token_mask = (labels[i] == -100) & (attention_mask[i] == 1)
                prompt_ids = input_ids[i][prompt_token_mask]
                if prompt_ids.numel() == 0:
                    continue

                prompt_text = tokenizer.decode(prompt_ids.tolist(), skip_special_tokens=False)

                # 1) 抽取 options keys
                option_keys = extract_option_keys_from_prompt(prompt_text)
                if len(option_keys) == 0:
                    # fallback：默认 A-D
                    option_keys = ["A", "B", "C", "D"]
                    total_skipped_no_options += 1

                # 2) 从 labels 解码提取 gold key（修复 gold 为空/空格）
                gold_key, gold_raw = extract_gold_key_from_labels(
                    tokenizer, input_ids[i], labels[i], attention_mask[i]
                )

                if gold_key is None:
                    total_skipped_no_gold += 1
                    # 仍写日志，便于你定位
                    safe_prompt = prompt_text.replace("\n", "\\n")
                    pred_f.write(f"SKIP_NO_GOLD\t{gold_raw}\t\t{safe_prompt}\n")
                    continue

                # 如果 gold_key 不在 option_keys 里（例如 gold 是 "C"，但 options 是 1-4），也跳过
                if gold_key not in option_keys:
                    # 有些数据会出现：options 是 1-4，gold 是 A-D 或反过来，这里直接跳过防止误判
                    total_skipped_no_gold += 1
                    safe_prompt = prompt_text.replace("\n", "\\n")
                    pred_f.write(f"SKIP_GOLD_NOT_IN_OPTIONS\t{gold_key}\t\t{safe_prompt}\n")
                    continue

                # 3) 预测：对每个 key 做序列 logprob 评分
                pred_key, pred_surface, pred_score = predict_choice_key(
                    model_engine, tokenizer, prompt_ids, option_keys
                )

                is_correct = int(pred_key == gold_key)
                total_class_correct += is_correct
                total_class_count += 1

                # debug 打印
                if is_main and debug_print_count < debug_print_limit:
                    print("=" * 80)
                    print(f"[Sample #{debug_print_count}]")
                    print("Prompt / Question (含选项和 'Answer:'):")
                    print(prompt_text)
                    print(f"Option keys detected: {option_keys}")
                    print(f"Gold key: {gold_key!r} (raw decoded answer={gold_raw!r})")
                    print(f"Pred key: {pred_key!r} (best surface={pred_surface!r}, mean_logprob={pred_score:.6f})")
                    print(f"Correct? {bool(is_correct)}")
                    debug_print_count += 1

                safe_prompt = prompt_text.replace("\n", "\\n")
                pred_f.write(f"{is_correct}\t{gold_key}\t{pred_key}\t{safe_prompt}\n")

            # rank0 log
            if is_main and (step + 1) % args.logging_steps == 0:
                elapsed = time.time() - eval_start
                avg_step_time = elapsed / (step + 1)
                est_total_steps = len(eval_dataloader)
                remaining_steps = max(est_total_steps - (step + 1), 0)
                eta_seconds = remaining_steps * avg_step_time

                acc_so_far = (total_class_correct / max(total_class_count, 1))
                print(
                    f"[Eval] step {step+1}/{est_total_steps} | "
                    f"batch_loss(mean): {batch_loss_mean.item():.4f} | "
                    f"answer_tokens_in_batch: {num_answer_tokens} | "
                    f"cls_acc_so_far: {acc_so_far:.4%} (count={total_class_count}) | "
                    f"avg_step_time {avg_step_time:.3f}s | ETA {format_seconds(eta_seconds)}"
                )

    pred_f.close()

    # ===== 多卡汇总 =====
    stats = torch.tensor(
        [
            total_loss,
            total_answer_tokens,
            total_class_correct,
            total_class_count,
            total_samples,
            total_skipped_no_gold,
            total_skipped_no_options,
        ],
        dtype=torch.float64,
        device=device,
    )
    dist.all_reduce(stats, op=dist.ReduceOp.SUM)
    (
        total_loss,
        total_answer_tokens,
        total_class_correct,
        total_class_count,
        total_samples,
        total_skipped_no_gold,
        total_skipped_no_options,
    ) = stats.tolist()

    total_answer_tokens = max(total_answer_tokens, 1.0)

    avg_loss = total_loss / total_answer_tokens
    classification_accuracy = (total_class_correct / max(total_class_count, 1.0))

    extra = {
        "total_samples_seen": int(total_samples),
        "total_samples_used_for_accuracy": int(total_class_count),
        "skipped_no_gold_or_gold_not_in_options": int(total_skipped_no_gold),
        "fallback_no_options_detected": int(total_skipped_no_options),
    }

    return avg_loss, classification_accuracy, extra


# ============================= 主函数 =============================

def main():
    torch.backends.cudnn.enabled = False

    args = parse_args()
    setup_distributed()
    rank = dist.get_rank()
    world_size = dist.get_world_size()
    is_main = (rank == 0)

    if is_main:
        os.makedirs(args.output_dir, exist_ok=True)
        print("====== Eval args ======")
        for k, v in vars(args).items():
            print(f"{k}: {v}")
        print("=======================")

    set_seed(args.seed + rank)

    # 1. 加载模型 + tokenizer
    if is_main:
        print(f"[Rank 0] Loading model from {args.model_name_or_path}")
    model, tokenizer = load_model_and_tokenizer(args.model_name_or_path)

    # 2. 加载 DeepSpeed 配置
    with open(args.deepspeed_config, "r") as f:
        ds_config = json.load(f)

    per_device_bs = args.per_device_eval_batch_size
    grad_accum = args.gradient_accumulation_steps

    ds_config["train_micro_batch_size_per_gpu"] = int(per_device_bs)
    ds_config["gradient_accumulation_steps"] = int(grad_accum)
    ds_config["train_batch_size"] = int(per_device_bs * grad_accum * world_size)
    ds_config["steps_per_print"] = ds_config.get("steps_per_print", 10_000_000)

    if "zero_optimization" in ds_config:
        ds_config["zero_optimization"]["stage"] = 0

    # 评估不需要 optimizer / scheduler，避免构建 fused optim
    if "optimizer" in ds_config:
        if is_main:
            print("[Rank 0] Remove 'optimizer' from ds_config for evaluation")
        ds_config.pop("optimizer")
    if "scheduler" in ds_config:
        if is_main:
            print("[Rank 0] Remove 'scheduler' from ds_config for evaluation")
        ds_config.pop("scheduler")

    # 避免 DeepSpeed 再从 args 读 config
    if hasattr(args, "deepspeed"):
        args.deepspeed = None
    if hasattr(args, "deepspeed_config"):
        args.deepspeed_config = None

    # 3. DataLoader
    if is_main:
        print("[Rank 0] Loading Eval Datasets.")
    eval_dataloader, dataset_size = build_eval_dataloader(args, tokenizer)
    if is_main:
        print(f"World size: {world_size}")
        print(f"Eval dataset size: {dataset_size}")
        print(f"Eval steps: {len(eval_dataloader)}")

    # 4. DeepSpeed initialize
    model_engine, _, _, _ = deepspeed.initialize(
        model=model,
        model_parameters=model.parameters(),
        args=args,
        config_params=ds_config,
    )

    # 5. Evaluate
    if is_main:
        print("[Rank 0] Start evaluation...")
    eval_start = time.time()
    avg_loss, cls_acc, extra = run_evaluation(args, model_engine, eval_dataloader, tokenizer)
    total_time = time.time() - eval_start

    # 6. Save
    if is_main:
        log_str = (
            f"[Eval Finished] avg_loss: {avg_loss:.4f} | "
            f"classification_accuracy: {cls_acc:.4%} | "
            f"used_samples_for_acc: {extra['total_samples_used_for_accuracy']} | "
            f"skipped: {extra['skipped_no_gold_or_gold_not_in_options']} | "
            f"total_time: {format_seconds(total_time)}"
        )
        print(log_str)

        os.makedirs(args.output_dir, exist_ok=True)
        with open(os.path.join(args.output_dir, "eval_log.txt"), "a", encoding="utf-8") as f:
            f.write(log_str + "\n")

        result_json = {
            "avg_loss": float(avg_loss),
            "classification_accuracy": float(cls_acc),
            "total_time_seconds": float(total_time),
            **extra
        }
        with open(os.path.join(args.output_dir, "eval_metrics.json"), "w", encoding="utf-8") as f:
            json.dump(result_json, f, indent=2, ensure_ascii=False)

    dist.barrier()


if __name__ == "__main__":
    main()
