# 文件名: evaluation_mmau.py (已修改模型加载方式以匹配 spatialQA)

import argparse
import json
import logging
import os
import re
from typing import Set

import torch
import torch.nn.functional as F
import torchaudio
from torch.utils.data import Dataset, DataLoader
from tqdm import tqdm
from transformers import BitsAndBytesConfig
from peft import LoraConfig, TaskType, get_peft_model, PeftModel

# 导入您项目中的 TWNM 模型定义
try:
    from twnm.models.twnm_pretrained_model import TWNM, TWNMConfig
except ImportError as e:
    print(f"错误: 无法导入 TWNM 模型: {e}")
    print("请确保脚本位于可以访问 'models' 包的项目根目录中，或已将项目路径添加到 PYTHONPATH。")
    exit(1)

# =====================================================================================
# 1. MMAU 数据集类 (保持不变)
# =====================================================================================
class MMAUDataset(Dataset):
    def __init__(self, json_path: str, audio_path: str, sample_rate: int = 44100, skip_ids: Set[str] = None):
        super().__init__()
        self.audio_base_path = audio_path
        self.sample_rate = sample_rate
        self.data_list = []
        # --- 新增：定义目标长度 (30秒) ---
        self.target_length_samples = 10 * self.sample_rate

        logging.info(f"正在从 {json_path} 加载 MMAU 数据...")
        with open(json_path, 'r', encoding='utf8') as fin:
            all_data = json.load(fin)

        if skip_ids:
            original_count = len(all_data)
            self.data_list = [item for item in all_data if item.get("id") not in skip_ids]
            skipped_count = original_count - len(self.data_list)
            logging.info(f"检测到断点。已跳过 {skipped_count} 条已完成的记录。")
        else:
            self.data_list = all_data

        logging.info(f"成功加载 {len(self.data_list)} 条待评测数据。")

    def __len__(self):
        return len(self.data_list)

    def __getitem__(self, index):
        json_obj = self.data_list[index]
        was_truncated = False  # 初始化截断标记

        # 1. 处理音频
        audio_filename = os.path.basename(json_obj["audio_id"])
        audio_full_path = os.path.join(self.audio_base_path, audio_filename)
        
        try:
            waveform, original_sr = torchaudio.load(audio_full_path)
            if original_sr != self.sample_rate:
                resampler = torchaudio.transforms.Resample(orig_freq=original_sr, new_freq=self.sample_rate)
                waveform = resampler(waveform)
            
            if waveform.shape[0] == 1:
                waveform = waveform.repeat(2, 1)
            elif waveform.shape[0] > 2:
                waveform = waveform[:2, :]

            # --- 音频长度标准化 (5秒) 与截断标记 ---
            current_length_samples = waveform.shape[1]
            if current_length_samples > self.target_length_samples:
                waveform = waveform[:, :self.target_length_samples]
                was_truncated = True # 标记此音频被截断
            elif current_length_samples < self.target_length_samples:
                padding_needed = self.target_length_samples - current_length_samples
                waveform = F.pad(waveform, (0, padding_needed))

        except Exception as e:
            logging.error(f"加载或处理音频文件失败: {audio_full_path}, 错误: {e}")
            waveform = torch.zeros((2, self.target_length_samples)) # 返回5秒静音

        # 2. 处理问题、选项和答案
        question = json_obj["question"]
        choices = json_obj["choices"]
        ground_truth_text = json_obj["answer"]
        
        options_map = {chr(ord('A') + i): choice for i, choice in enumerate(choices)}
        choices_str = "\n".join([f"{key}: {value}" for key, value in options_map.items()])
        task_prompt = f"{question}. Please choose the answer from the following options: {choices_str}"
        
        reverse_options_map = {v: k for k, v in options_map.items()}
        solution_letter = reverse_options_map.get(ground_truth_text, "N/A")
        if solution_letter == "N/A":
             logging.warning(f"在样本 {json_obj['id']} 中未找到答案 '{ground_truth_text}' 对应的选项。")

        return {
            "audio": waveform,
            "task": task_prompt,
            "solution": solution_letter,
            "scene_id": json_obj.get("id"),
            "task_type": json_obj.get("task", "unknown"),
            "ground_truth_text": ground_truth_text,
            "was_truncated": was_truncated
        }

# (辅助函数 parse_answer 保持不变)
def parse_answer(text: str) -> str:
    match = re.search(r'\|\<answer\>\|(.*?)\|\</answer\>\|', text, re.DOTALL)
    if match:
        return match.group(1).strip()
    return ""

def main():
    parser = argparse.ArgumentParser(description="TWNM 模型在 MMAU 上的评测脚本 (支持断点续评)")
    # --- [核心修改] 修改参数，移除 policy_adapter_path ---
    parser.add_argument("--sft_checkpoint_path", type=str, required=True, help="SFT阶段训练好的模型检查点文件路径 (*.bin)")
    parser.add_argument("--mmau_json_path", type=str, required=True, help="评测用的 MMAU json 数据文件")
    parser.add_argument("--mmau_audio_path", type=str, required=True, help="MMAU 音频文件所在的根目录")
    parser.add_argument("--output_file", type=str, required=True, help="保存详细评测结果的jsonl文件路径")
    parser.add_argument("--model_mode", type=str, required=True, choices=['policy', 'sft'], help="评测模式: 'policy' 使用GRPO适配器, 'sft' 使用SFT基座模型")
    parser.add_argument("--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu", help="运行设备")
    parser.add_argument("--batch_size", type=int, default=4, help="批处理大小。")
    args = parser.parse_args()
    logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")

    # (断点续评逻辑不变)
    completed_ids = set()
    if os.path.exists(args.output_file):
        try:
            with open(args.output_file, 'r', encoding='utf-8') as f_in:
                for line in f_in:
                    try: completed_ids.add(json.loads(line)['scene_id'])
                    except (json.JSONDecodeError, KeyError): pass
            logging.info(f"检测到已存在的输出文件，已加载 {len(completed_ids)} 条已完成的记录。")
        except Exception as e:
            logging.warning(f"读取已有结果文件失败: {e}。将从头开始评测。")

    # --- [核心修改] 采用 evaluation_spatialQA.py 的模型加载方式 ---
    logging.info("--- 1. 开始加载和组装模型 (采用 spatialQA 方式) ---")
    quantization_config = BitsAndBytesConfig(
        load_in_4bit=True,
        bnb_4bit_quant_type="nf4",
        bnb_4bit_compute_dtype=torch.bfloat16,
        bnb_4bit_use_double_quant=True,
    )
    temp_lora_config = LoraConfig(
        target_modules=["q_proj", "v_proj"],
        task_type=TaskType.CAUSAL_LM,
        r=8,
        lora_alpha=32,
        lora_dropout=0.1,
    )

    twnm_config = TWNMConfig()
    model = TWNM(
        config=twnm_config,
        peft_config=temp_lora_config,
        quantization_config=quantization_config,
    )
    logging.info("模型基础结构初始化完成。")
    logging.info(f"正在从 {args.sft_checkpoint_path} 加载SFT权重...")

    state_dict = torch.load(args.sft_checkpoint_path, map_location="cpu")
    new_state_dict = {
        k[7:] if k.startswith("module.") else k: v for k, v in state_dict.items()
    }
    model.load_state_dict(new_state_dict, strict=False)

    # 固化SFT权重
    model.decoder = model.decoder.merge_and_unload()
    logging.info("SFT基座已固化 (merge_and_unload)。")

    # 加载policy适配器
    policy_adapter_path = "assets/checkpoints/grpo_checkpoint-1/policy/policy"
    model.decoder = PeftModel.from_pretrained(
        model.decoder,
        policy_adapter_path,
        adapter_name="policy",
    )
    logging.info(f"已从 {policy_adapter_path} 加载训练好的 'policy' 适配器权重。")

    print("当前激活的适配器:", model.decoder.active_adapters)

    model.to(args.device)
    model.eval()
    tokenizer = model.tokenizer
    logging.info("--- 模型加载和组装全部完成 ---")
    
    # (数据加载逻辑不变)
    logging.info("--- 2. 加载 MMAU 评测数据集 ---")
    eval_dataset = MMAUDataset(json_path=args.mmau_json_path, audio_path=args.mmau_audio_path, skip_ids=completed_ids)
    
    if not eval_dataset:
        logging.info("所有样本均已评测完成！")
    else:
        eval_dataloader = DataLoader(eval_dataset, batch_size=args.batch_size, shuffle=False)
        logging.info(f"数据集加载完成，剩余 {len(eval_dataset)} 条数据待评测。")
    
        # (评测循环主体逻辑不变)
        logging.info(f"--- 3. 设置模型模式为: {args.model_mode.upper()} ---")
        if args.model_mode == 'policy':
            model.decoder.set_adapter("policy")
            logging.info("已激活 'policy' 适配器。")
        elif args.model_mode == 'sft':
            # 对于已经加载了适配器的模型，需要显式禁用它来测试SFT基座
            model.decoder.disable_adapter()
            logging.info("已禁用所有适配器，使用SFT基座模型。")

        logging.info("--- 4. 开始评测循环 ---")
        with open(args.output_file, 'a', encoding='utf-8') as f_out:
            for batch in tqdm(eval_dataloader, desc=f"Evaluating in '{args.model_mode}' mode"):
                audios = batch['audio'].to(args.device)
                tasks = batch['task']
                input_ids = tokenizer(tasks, return_tensors="pt", padding=True, truncation=True).input_ids.to(args.device)

                with torch.no_grad():
                    generated_ids = model.generate(
                        input_ids=input_ids, 
                        audio=audios, 
                        max_new_tokens=768, 
                        do_sample=False)

                decoded_outputs = tokenizer.batch_decode(generated_ids, add_special_tokens=True)

                for i in range(len(decoded_outputs)):
                    parsed_answer = parse_answer(decoded_outputs[i])
                    ground_truth_option = batch["solution"][i]
                    is_correct = (parsed_answer == ground_truth_option)
                    
                    result_item = {
                        "scene_id": batch["scene_id"][i],
                        "task_type": batch["task_type"][i],
                        "task": batch["task"][i],
                        "ground_truth_option": ground_truth_option,
                        "ground_truth_text": batch["ground_truth_text"][i],
                        "model_mode": args.model_mode,
                        "parsed_answer": parsed_answer,
                        "was_truncated": bool(batch["was_truncated"][i].item()),
                        "is_correct": is_correct,
                        "raw_output": decoded_outputs[i]
                    }
                    f_out.write(json.dumps(result_item, ensure_ascii=False) + '\n')

    # (最终报告逻辑不变)
    logging.info("--- 5. 评测完成，生成最终报告 ---")
    final_correct = 0
    final_total = 0
    if os.path.exists(args.output_file):
        with open(args.output_file, 'r', encoding='utf-8') as f_final:
            for line in f_final:
                try:
                    record = json.loads(line)
                    # 仅统计当前运行模式下的结果
                    if record.get('model_mode') == args.model_mode:
                        final_total += 1
                        if record.get('is_correct'): 
                            final_correct += 1
                except (json.JSONDecodeError, KeyError): 
                    pass
    
    accuracy = (final_correct / final_total) * 100 if final_total > 0 else 0
    print("\n" + "=" * 50)
    print("           MMAU 评 测 报 告")
    print("=" * 50)
    print(f"  模型模式 (Model Mode): {args.model_mode.upper()}")
    print(f"  测试数据文件: {args.mmau_json_path}")
    print(f"  详细结果已保存至: {args.output_file}")
    print("-" * 50)
    print(f"  总评测样本数 (当前模式): {final_total}")
    print(f"  正确预测数 (当前模式): {final_correct}")
    print(f"  => 准确率 (Accuracy): {accuracy:.2f}%")
    print("=" * 50)

if __name__ == "__main__":
    main()

"""
CUDA_VISIBLE_DEVICES=5 python evaluation_mmau_grpo.py \
  --sft_checkpoint_path assets/checkpoints/sft2_checkpoint-2502/pytorch_model.bin \
  --mmau_json_path datasets/mmau/mmau-test-mini-short.json \
  --mmau_audio_path /data2/wl/test-mini-audios-short \
  --output_file results_mmau_grpo_10s.jsonl \
  --model_mode policy \
  --batch_size 6
"""