# evaluate_resumable.py

import argparse
import json
import logging
import os
import re
from typing import List, Dict, Any, Set

import torch
import torchaudio
from torch.utils.data import Dataset, DataLoader
from tqdm import tqdm
from transformers import BitsAndBytesConfig
from peft import LoraConfig, TaskType, get_peft_model

try:
    from twnm.models.twnm_pretrained_model import TWNM, TWNMConfig
except ImportError:
    print("错误：无法导入 'TWNM' 和 'TWNMConfig'。")
    print("请确保脚本位于可以访问 'models' 包的项目目录中。")
    exit(1)


# # =====================================================================================
# 1. 数据集类 (修改 __init__ 以支持跳过已完成的ID)
# =====================================================================================
class AudioDataset(Dataset):
    def __init__(
        self, data_file, sample_rate=44100, skip_ids: Set[str] = None
    ):  # [断点续评] 增加 skip_ids 参数
        super().__init__()
        self.data_list = []
        self.data_dir = os.path.dirname(data_file)

        all_data = []
        with open(data_file, "r", encoding="utf8") as fin:
            for line in fin:
                all_data.append(json.loads(line))

        # [断点续评] 如果提供了 skip_ids，则过滤数据
        if skip_ids:
            original_count = len(all_data)
            self.data_list = [
                item for item in all_data if item.get("scene_id") not in skip_ids
            ]
            skipped_count = original_count - len(self.data_list)
            logging.info(f"已跳过 {skipped_count} 条已完成的记录。")
        else:
            self.data_list = all_data

        self.sample_rate = sample_rate
        logging.info(f"成功加载 {len(self.data_list)} 条待评测数据来自 {data_file}")

    def __len__(self):
        return len(self.data_list)

    def __getitem__(self, index):
        # ... __getitem__ 内部逻辑完全不变 ...
        json_obj = self.data_list[index]
        audio_path = json_obj["source_metadata"]["audio_path"]
        if not os.path.isabs(audio_path):
            audio_path = os.path.join(os.path.dirname(self.data_dir), audio_path)
        try:
            waveform, original_sr = torchaudio.load(audio_path)
            if original_sr != self.sample_rate:
                resampler = torchaudio.transforms.Resample(
                    orig_freq=original_sr, new_freq=self.sample_rate
                )
                waveform = resampler(waveform)
            waveform = waveform[:2, :]
        except Exception as e:
            logging.error(f"加载或处理音频文件失败: {audio_path}, 错误: {e}")
            waveform = torch.zeros((1, self.sample_rate * 5))

        question_data = json_obj["question_data"]
        question = question_data["question"]
        options = question_data["options"]
        answer_key = question_data["answer"]
        choices_str = "\n".join([f"{key}: {value}" for key, value in options.items()])
        task_prompt = f"{question}. Please choose the answer from the following options: {choices_str}"
        solution = answer_key
        return {
            "audio": waveform,
            "task": task_prompt,
            "solution": solution,
            "scene_id": json_obj.get("scene_id", f"item_{index}"),
            "task_type": json_obj.get("task_type", "unknown"),
            "question": question,
            "options": options,
        }


# (parse_answer 函数保持不变)
def parse_answer(text: str) -> str:
    match = re.search(r"\|\<answer\>\|(.*?)\|\</answer\>\|", text, re.DOTALL)
    if match:
        return match.group(1).strip()
    else:
        return "N/A"


def parse_answer_fuzzy(text: str, options: Dict[str, str]) -> str:
    """
    使用多级模糊匹配规则从模型输出中解析答案。
    Args:
        text: 模型的原始文本输出。
        options: 一个包含选项键和文本的字典, e.g., {"A": "猫叫", "B": "狗叫"}。
    Returns:
        解析出的答案选项 (e.g., "A") 或 "N/A"。
    """
    # --- 规则〇: 预处理 ---
    # 移除可能存在的特殊token，转换为大写以便不区分大小写匹配
    processed_text = text.upper()
    processed_text = re.sub(
        r"<S>|<\/S>|<PAD>|<UNK>|<s>|<\/s>", "", processed_text
    ).strip()

    option_keys = list(options.keys())
    option_keys_str = "|".join(option_keys)

    # --- 规则一: 最高优先级 - 显式答案声明 ---
    # 匹配 "ANSWER IS A", "答案是: B", "I CHOOSE (C)" 等模式
    keywords = [
        "ANSWER IS",
        "THE ANSWER IS",
        "OPTION IS",
        "CORRECT ANSWER IS",
        "I CHOOSE",
        "MY CHOICE IS",
        "FINAL ANSWER",
        "答案是",
        "我选",
        "选择是",
        "选项是",
    ]
    keywords_str = "|".join(keywords)
    pattern1 = re.compile(rf"(?:{keywords_str})\s*:?\s*\(?({option_keys_str})\)?")
    match1 = pattern1.search(processed_text)
    if match1:
        return match1.group(1)

    # --- 规则二: 次高优先级 - 独立的选项字母 ---
    # 匹配独立的 'A', 'B' 等, 倾向于选择最后出现的那一个
    # \b 确保是独立的单词边界, 防止匹配 "A CAT" 中的 'A'
    pattern2 = re.compile(rf"\b({option_keys_str})\b")
    matches2 = pattern2.findall(processed_text)
    if matches2:
        return matches2[-1]  # 返回最后一个匹配项，通常是最终决定

    # --- 规则三: 中等优先级 - 选项内容匹配 ---
    # 如果模型复述了选项的文本内容
    found_matches = {}
    for key, value in options.items():
        # 对选项文本也进行预处理
        processed_option_text = value.upper().strip()
        if processed_option_text and processed_option_text in processed_text:
            found_matches[key] = processed_option_text

    if len(found_matches) == 1:
        return list(found_matches.keys())[0]
    elif len(found_matches) > 1:
        # 歧义处理：选择匹配到的最长的文本对应的选项
        # 这可以解决选项文本存在包含关系的问题 (e.g., "猫" vs "一只黑猫")
        longest_key = max(found_matches, key=lambda k: len(found_matches[k]))
        return longest_key

    # --- 所有规则都失败 ---
    return "N/A"


def main():
    # (参数解析部分保持不变)
    parser = argparse.ArgumentParser(description="GRPO 模型评测脚本（支持断点续评）")
    parser.add_argument(
        "--sft_checkpoint_path",
        type=str,
        required=True,
        help="SFT阶段训练好的模型检查点文件路径 (*.bin)",
    )
    parser.add_argument(
        "--policy_adapter_path",
        type=str,
        required=True,
        help="GRPO训练产出的LoRA适配器目录",
    )
    parser.add_argument(
        "--test_data_file", type=str, required=True, help="评测用的jsonl数据文件"
    )
    parser.add_argument(
        "--output_file", type=str, required=True, help="保存详细评测结果的jsonl文件路径"
    )
    parser.add_argument(
        "--model_mode",
        type=str,
        required=True,
        choices=["policy", "sft"],
        help="评测模式",
    )
    parser.add_argument(
        "--device",
        type=str,
        default="cuda" if torch.cuda.is_available() else "cpu",
        help="运行设备",
    )
    parser.add_argument("--batch_size", type=int, default=1, help="批处理大小。")
    args = parser.parse_args()

    logging.basicConfig(
        level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s"
    )

    # [断点续评] 检查并加载已完成的记录
    completed_ids = set()
    if os.path.exists(args.output_file):
        try:
            with open(args.output_file, "r", encoding="utf-8") as f_in:
                for line in f_in:
                    try:
                        completed_ids.add(json.loads(line)["scene_id"])
                    except (json.JSONDecodeError, KeyError):
                        # 忽略损坏的行
                        pass
            logging.info(
                f"检测到已存在的输出文件，已加载 {len(completed_ids)} 条已完成的记录。"
            )
        except Exception as e:
            logging.warning(f"读取已有结果文件失败: {e}。将从头开始评测。")

    # (模型加载部分保持不变)
    logging.info("--- 1. 开始加载和组装模型 ---")
    quantization_config = BitsAndBytesConfig(
        load_in_4bit=True,
        bnb_4bit_quant_type="nf4",
        bnb_4bit_compute_dtype=torch.bfloat16,
        bnb_4bit_use_double_quant=True,
    )
    temp_lora_config = LoraConfig(
        target_modules=["q_proj", "v_proj"],
        task_type=TaskType.CAUSAL_LM,
        r=8,
        lora_alpha=32,
        lora_dropout=0.1,
    )
    twnm_config = TWNMConfig()
    model = TWNM(
        config=twnm_config,
        peft_config=temp_lora_config,
        quantization_config=quantization_config,
    )
    logging.info("模型基础结构初始化完成。")
    logging.info(f"正在从 {args.sft_checkpoint_path} 加载SFT权重...")
    state_dict = torch.load(args.sft_checkpoint_path, map_location="cpu")
    new_state_dict = {
        k[7:] if k.startswith("module.") else k: v for k, v in state_dict.items()
    }
    model.load_state_dict(new_state_dict, strict=False)

    logging.info("SFT权重加载完成。")

    # model.decoder = model.decoder.merge_and_unload()
    # logging.info("SFT基座已固化 (merge_and_unload)。")
    # policy_lora_config = LoraConfig(target_modules=["q_proj", "v_proj"], task_type=TaskType.CAUSAL_LM, r=8, lora_alpha=32, lora_dropout=0.1, inference_mode=True)
    # model.decoder = get_peft_model(model.decoder, policy_lora_config, adapter_name="policy")
    # logging.info("已为decoder附加新的 'policy' 适配器插槽。")
    # model.decoder.load_adapter(args.policy_adapter_path, adapter_name="policy")
    # logging.info(f"已从 {args.policy_adapter_path} 加载训练好的 'policy' 适配器权重。")

    model.to(args.device)
    model.eval()
    tokenizer = model.tokenizer
    logging.info("--- 模型加载和组装全部完成 ---")

    # --- 数据加载 ---
    logging.info("--- 2. 加载评测数据集 ---")
    # [断点续评] 将已完成的ID传给数据集进行过滤
    eval_dataset = AudioDataset(args.test_data_file, skip_ids=completed_ids)

    # 如果所有任务都已完成，则直接报告并退出
    if not eval_dataset:
        logging.info("所有样本均已评测完成！")
    else:
        eval_dataloader = DataLoader(
            eval_dataset, batch_size=args.batch_size, shuffle=False
        )
        logging.info(f"数据集加载完成，剩余 {len(eval_dataset)} 条数据待评测。")

        # --- 模型状态设置 ---
        logging.info(f"--- 3. 设置模型模式为: {args.model_mode} ---")
        if args.model_mode == "policy":
            model.decoder.set_adapter("policy")
            logging.info("已激活 'policy' 适配器。")
        elif args.model_mode == "sft":
            model.decoder.disable_adapter()
            logging.info("已禁用所有适配器，使用SFT基座模型。")

        # --- 评测循环 ---
        logging.info("--- 4. 开始评测循环 ---")
        # [断点续评] 以追加模式 'a' 打开文件
        with open(args.output_file, "a", encoding="utf-8") as f_out:
            for batch in tqdm(
                eval_dataloader, desc=f"Evaluating in '{args.model_mode}' mode"
            ):
                audios = batch["audio"].to(args.device)
                tasks = batch["task"]
                solutions = batch["solution"]
                options_list = batch["options"]
                inputs = tokenizer(
                    tasks, return_tensors="pt", padding=True, truncation=True
                )
                input_ids = inputs.input_ids.to(args.device)

                with torch.no_grad():
                    generated_ids = model.generate(
                        input_ids=input_ids,
                        audio=audios,
                        max_new_tokens=1024,
                        do_sample=False,
                        # pad_token_id=tokenizer.eos_token_id,
                    )

                decoded_outputs = tokenizer.batch_decode(
                    generated_ids, add_special_tokens=True
                )
                decoded_outputs = [
                    out.split(tokenizer.eos_token)[0].strip() for out in decoded_outputs
                ]

                for i in range(len(decoded_outputs)):
                    current_options = {
                        key: options_list[key][i] for key in options_list
                    }
                    parsed = parse_answer_fuzzy(decoded_outputs[i], current_options)

                    is_correct = parsed == solutions[i]
                    result_item = {
                        "scene_id": batch["scene_id"][i],
                        "task_type": batch["task_type"][i],
                        "task": batch["task"][i],
                        "ground_truth": solutions[i],
                        "model_mode": args.model_mode,
                        "parsed_answer": parsed,
                        "is_correct": is_correct,
                        "raw_output": decoded_outputs[i],
                    }
                    f_out.write(json.dumps(result_item, ensure_ascii=False) + "\n")

    # --- 最终报告 ---
    logging.info("--- 5. 评测完成，生成最终报告 ---")
    # [断点续评] 从完整的输出文件中读取所有结果来计算总准确率
    final_correct = 0
    final_total = 0
    if os.path.exists(args.output_file):
        with open(args.output_file, "r", encoding="utf-8") as f_final:
            for line in f_final:
                try:
                    final_total += 1
                    if json.loads(line)["is_correct"]:
                        final_correct += 1
                except (json.JSONDecodeError, KeyError):
                    pass

    accuracy = (final_correct / final_total) * 100 if final_total > 0 else 0
    print("\n" + "=" * 50)
    print("           评 测 报 告 (最终)")
    print("=" * 50)
    print(f"  模型模式 (Model Mode): {args.model_mode.upper()}")
    print(f"  测试数据文件: {args.test_data_file}")
    print(f"  详细结果已保存至: {args.output_file}")
    print("-" * 50)
    print(f"  总样本数: {final_total}")
    print(f"  正确预测数: {final_correct}")
    print(f"  => 准确率 (Accuracy): {accuracy:.2f}%")
    print("=" * 50)


if __name__ == "__main__":
    main()

"""
CUDA_VISIBLE_DEVICES=0 python evaluation_spatialQA.py \
  --sft_checkpoint_path <PATH_TO_TWNM>/exp/SFT/checkpoint-117534/pytorch_model.bin \
  --policy_adapter_path assets/checkpoints/grpo_checkpoint-1/policy/policy \
  --test_data_file /data2/wl/RL_train/output/benchmark/benchmark_questions.jsonl \
  --output_file results_sft1.jsonl \
  --model_mode sft \
  --batch_size 4
"""
