# evaluate_resumable.py

import argparse
import json
import logging
import os
import re
from typing import List, Dict, Any, Set

import torch
import torchaudio
from torch.utils.data import Dataset, DataLoader
from tqdm import tqdm
from transformers import BitsAndBytesConfig
from peft import LoraConfig, TaskType, get_peft_model, PeftModel

from safetensors.torch import load_file


try:
    from twnm.models.twnm_pretrained_model import TWNM, TWNMConfig
except ImportError:
    print("错误：无法导入 'TWNM' 和 'TWNMConfig'。")
    print("请确保脚本位于可以访问 'models' 包的项目目录中。")
    exit(1)


# # =====================================================================================
# 1. 数据集类 (修改 __init__ 以支持跳过已完成的ID)
# =====================================================================================
class AudioDataset(Dataset):
    def __init__(
        self, data_file, sample_rate=44100, skip_ids: Set[str] = None
    ):  # [断点续评] 增加 skip_ids 参数
        super().__init__()
        self.data_list = []
        self.data_dir = os.path.dirname(data_file)

        all_data = []
        with open(data_file, "r", encoding="utf8") as fin:
            for line in fin:
                all_data.append(json.loads(line))

        # [断点续评] 如果提供了 skip_ids，则过滤数据
        if skip_ids:
            original_count = len(all_data)
            self.data_list = [
                item for item in all_data if item.get("scene_id") not in skip_ids
            ]
            skipped_count = original_count - len(self.data_list)
            logging.info(f"已跳过 {skipped_count} 条已完成的记录。")
        else:
            self.data_list = all_data

        self.sample_rate = sample_rate
        logging.info(f"成功加载 {len(self.data_list)} 条待评测数据来自 {data_file}")

    def __len__(self):
        return len(self.data_list)

    def __getitem__(self, index):
        # ... __getitem__ 内部逻辑完全不变 ...
        json_obj = self.data_list[index]
        audio_path = json_obj["source_metadata"]["audio_path"]
        if not os.path.isabs(audio_path):
            audio_path = os.path.join(os.path.dirname(self.data_dir), audio_path)
        try:
            waveform, original_sr = torchaudio.load(audio_path)
            if original_sr != self.sample_rate:
                resampler = torchaudio.transforms.Resample(
                    orig_freq=original_sr, new_freq=self.sample_rate
                )
                waveform = resampler(waveform)
            waveform = waveform[:2, :]
        except Exception as e:
            logging.error(f"加载或处理音频文件失败: {audio_path}, 错误: {e}")
            waveform = torch.zeros((1, self.sample_rate * 5))

        question_data = json_obj["question_data"]
        question = question_data["question"]
        options = question_data["options"]
        answer_key = question_data["answer"]
        choices_str = "\n".join([f"{key}: {value}" for key, value in options.items()])
        task_prompt = f"{question}. Please choose the answer from the following options: {choices_str}."
        solution = answer_key
        return {
            "audio": waveform,
            "task": task_prompt,
            "solution": solution,
            "scene_id": json_obj.get("scene_id", f"item_{index}"),
            "task_type": json_obj.get("task_type", "unknown"),
            "question": question,
            "options": options,
        }


# (parse_answer 函数保持不变)
def parse_answer(text: str) -> str:
    match = re.search(r"\|\<answer\>\|(.*?)\|\</answer\>\|", text, re.DOTALL)
    if match:
        return match.group(1).strip()
    else:
        return "N/A"


def parse_answer_fuzzy(text: str, options: Dict[str, str]) -> str:
    processed_text = text.upper()
    processed_text = re.sub(
        r"<S>|<\/S>|<PAD>|<UNK>|<s>|<\/s>", "", processed_text
    ).strip()

    option_keys = list(options.keys())
    option_keys_str = "|".join(option_keys)

    # 规则一: 显式答案声明
    keywords = [
        "ANSWER IS",
        "THE ANSWER IS",
        "OPTION IS",
        "CORRECT ANSWER IS",
        "I CHOOSE",
        "MY CHOICE IS",
        "FINAL ANSWER",
        "答案是",
        "我选",
        "选择是",
        "选项是",
    ]
    keywords_str = "|".join(keywords)
    pattern1 = re.compile(rf"(?:{keywords_str})\s*:?\s*\(?({option_keys_str})\)?")
    match1 = pattern1.search(processed_text)
    if match1:
        return match1.group(1)

    # 规则二: 独立的选项字母
    pattern2 = re.compile(rf"\b({option_keys_str})\b")
    matches2 = pattern2.findall(processed_text)
    if matches2:
        return matches2[-1]

    # 规则三: 选项内容匹配
    found_matches = {}
    for key, value in options.items():
        processed_option_text = value.upper().strip()
        if processed_option_text and processed_option_text in processed_text:
            found_matches[key] = processed_option_text

    if len(found_matches) == 1:
        return list(found_matches.keys())[0]
    elif len(found_matches) > 1:
        longest_key = max(found_matches, key=lambda k: len(found_matches[k]))
        return longest_key

    return "N/A"


def parse_answer_combined(text: str, options: Dict[str, str]) -> str:
    """
    组合解析策略：优先严格匹配，失败后回退到模糊匹配。
    """
    # 1. 首先尝试严格匹配
    strict_result = parse_answer(text)
    # return strict_result
    if strict_result != "N/A":
        return strict_result

    # 2. 如果严格匹配失败，则使用模糊匹配
    return parse_answer_fuzzy(text, options)


def main():
    # (参数解析部分保持不变)
    parser = argparse.ArgumentParser(description="GRPO 模型评测脚本（支持断点续评）")
    parser.add_argument(
        "--sft_checkpoint_path",
        type=str,
        required=True,
        help="SFT阶段训练好的模型检查点文件路径 (*.bin)",
    )
    # parser.add_argument(
    #     "--policy_adapter_path",
    #     type=str,
    #     required=True,
    #     help="GRPO训练产出的LoRA适配器目录",
    # )
    parser.add_argument(
        "--test_data_file", type=str, required=True, help="评测用的jsonl数据文件"
    )
    parser.add_argument(
        "--output_file", type=str, required=True, help="保存详细评测结果的jsonl文件路径"
    )
    parser.add_argument(
        "--model_mode",
        type=str,
        required=True,
        choices=["policy", "sft"],
        help="评测模式",
    )
    parser.add_argument(
        "--device",
        type=str,
        default="cuda" if torch.cuda.is_available() else "cpu",
        help="运行设备",
    )
    parser.add_argument("--batch_size", type=int, default=1, help="批处理大小。")
    args = parser.parse_args()

    logging.basicConfig(
        level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s"
    )

    # [断点续评] 检查并加载已完成的记录
    completed_ids = set()
    if os.path.exists(args.output_file):
        try:
            with open(args.output_file, "r", encoding="utf-8") as f_in:
                for line in f_in:
                    try:
                        completed_ids.add(json.loads(line)["scene_id"])
                    except (json.JSONDecodeError, KeyError):
                        # 忽略损坏的行
                        pass
            logging.info(
                f"检测到已存在的输出文件，已加载 {len(completed_ids)} 条已完成的记录。"
            )
        except Exception as e:
            logging.warning(f"读取已有结果文件失败: {e}。将从头开始评测。")

    # (模型加载部分保持不变)
    logging.info("--- 1. 开始加载和组装模型 ---")
    quantization_config = BitsAndBytesConfig(
        load_in_4bit=True,
        bnb_4bit_quant_type="nf4",
        bnb_4bit_compute_dtype=torch.bfloat16,
        bnb_4bit_use_double_quant=True,
    )
    temp_lora_config = LoraConfig(
        target_modules=["q_proj", "v_proj"],
        task_type=TaskType.CAUSAL_LM,
        r=8,
        lora_alpha=32,
        lora_dropout=0.1,
    )

    twnm_config = TWNMConfig()
    model = TWNM(
        config=twnm_config,
        peft_config=temp_lora_config,
        quantization_config=quantization_config,
    )
    logging.info("模型基础结构初始化完成。")
    logging.info(f"正在从 {args.sft_checkpoint_path} 加载SFT权重...")

    state_dict = torch.load(args.sft_checkpoint_path, map_location="cpu")
    new_state_dict = {
        k[7:] if k.startswith("module.") else k: v for k, v in state_dict.items()
    }
    model.load_state_dict(new_state_dict, strict=False)

    model.decoder = model.decoder.merge_and_unload()

    model.decoder = PeftModel.from_pretrained(
        model.decoder,
        "assets/checkpoints/grpo_checkpoint-1/policy/policy",
        adapter_name="policy",
    )

    logging.info("SFT权重加载完成。")

    # model.decoder = model.decoder.merge_and_unload()
    # logging.info("SFT基座已固化 (merge_and_unload)。")
    # policy_lora_config = LoraConfig(target_modules=["q_proj", "v_proj"], task_type=TaskType.CAUSAL_LM, r=8, lora_alpha=32, lora_dropout=0.1, inference_mode=True)
    # model.decoder = get_peft_model(model.decoder, policy_lora_config, adapter_name="policy")
    # logging.info("已为decoder附加新的 'policy' 适配器插槽。")
    # model.decoder.load_adapter(args.policy_adapter_path, adapter_name="policy")
    # logging.info(f"已从 {args.policy_adapter_path} 加载训练好的 'policy' 适配器权重。")

    print(model.decoder.active_adapters)

    model.to(args.device)
    model.eval()
    tokenizer = model.tokenizer
    logging.info("--- 模型加载和组装全部完成 ---")

    # --- 数据加载 ---
    logging.info("--- 2. 加载评测数据集 ---")
    # [断点续评] 将已完成的ID传给数据集进行过滤
    eval_dataset = AudioDataset(args.test_data_file, skip_ids=completed_ids)

    # 如果所有任务都已完成，则直接报告并退出
    if not eval_dataset:
        logging.info("所有样本均已评测完成！")
    else:
        eval_dataloader = DataLoader(
            eval_dataset, batch_size=args.batch_size, shuffle=False
        )
        logging.info(f"数据集加载完成，剩余 {len(eval_dataset)} 条数据待评测。")

        # --- 模型状态设置 ---
        logging.info(f"--- 3. 设置模型模式为: {args.model_mode} ---")
        if args.model_mode == "policy":
            model.decoder.set_adapter("policy")
            logging.info("已激活 'policy' 适配器。")
        elif args.model_mode == "sft":
            model.decoder.disable_adapter()
            logging.info("已禁用所有适配器，使用SFT基座模型。")

        # --- 评测循环 ---
        logging.info("--- 4. 开始评测循环 ---")
        # [断点续评] 以追加模式 'a' 打开文件
        with open(args.output_file, "a", encoding="utf-8") as f_out:
            for batch in tqdm(
                eval_dataloader, desc=f"Evaluating in '{args.model_mode}' mode"
            ):
                audios = batch["audio"].to(args.device)
                tasks = batch["task"]
                solutions = batch["solution"]
                options_list = batch["options"]
                inputs = tokenizer(
                    tasks, 
                    return_tensors="pt", 
                    padding=True
                )
                input_ids = inputs.input_ids.to(args.device)

                with torch.no_grad():
                    generated_ids = model.generate(
                        input_ids=input_ids,
                        audio=audios,
                        max_new_tokens=1024,
                        do_sample=False,
                        # pad_token_id=tokenizer.eos_token_id,
                    )

                decoded_outputs = tokenizer.batch_decode(
                    generated_ids, add_special_tokens=True
                )
                decoded_outputs = [
                    out.split(tokenizer.eos_token)[0].strip() for out in decoded_outputs
                ]

                for i in range(len(decoded_outputs)):
                    current_options = {
                        key: options_list[key][i] for key in options_list
                    }
                    parsed = parse_answer_combined(decoded_outputs[i], current_options)

                    is_correct = parsed == solutions[i]
                    result_item = {
                        "scene_id": batch["scene_id"][i],
                        "task_type": batch["task_type"][i],
                        "task": batch["task"][i],
                        "ground_truth": solutions[i],
                        "model_mode": args.model_mode,
                        "parsed_answer": parsed,
                        "is_correct": is_correct,
                        "raw_output": decoded_outputs[i],
                    }
                    f_out.write(json.dumps(result_item, ensure_ascii=False) + "\n")

    # --- 最终报告 ---
    logging.info("--- 5. 评测完成，生成最终报告 ---")
    # [断点续评] 从完整的输出文件中读取所有结果来计算总准确率
    final_correct = 0
    final_total = 0
    if os.path.exists(args.output_file):
        with open(args.output_file, "r", encoding="utf-8") as f_final:
            for line in f_final:
                try:
                    final_total += 1
                    if json.loads(line)["is_correct"]:
                        final_correct += 1
                except (json.JSONDecodeError, KeyError):
                    pass

    accuracy = (final_correct / final_total) * 100 if final_total > 0 else 0
    print("\n" + "=" * 50)
    print("           评 测 报 告 (最终)")
    print("=" * 50)
    print(f"  模型模式 (Model Mode): {args.model_mode.upper()}")
    print(f"  测试数据文件: {args.test_data_file}")
    print(f"  详细结果已保存至: {args.output_file}")
    print("-" * 50)
    print(f"  总样本数: {final_total}")
    print(f"  正确预测数: {final_correct}")
    print(f"  => 准确率 (Accuracy): {accuracy:.2f}%")
    print("=" * 50)


if __name__ == "__main__":
    main()

"""
CUDA_VISIBLE_DEVICES=3 python evaluation_spatialQA.py \
  --sft_checkpoint_path assets/checkpoints/sft2_checkpoint-2502/pytorch_model.bin \
  --test_data_file /data2/wl/RL_train/output/benchmark/benchmark_questions.jsonl \
  --output_file results_grpo_spatial.jsonl \
  --model_mode policy \
  --batch_size 4
"""
