# 文件名: evaluate_qwen2audio_on_mmau.py

import argparse
import json
import logging
import os
import re
from typing import List, Dict, Set

import torch
import librosa
from torch.utils.data import Dataset, DataLoader
from tqdm import tqdm
from transformers import AutoProcessor, Qwen2AudioForConditionalGeneration, BitsAndBytesConfig
import numpy as np

# =====================================================================================
# 1. MMAU 数据集类 (参考 evaluation_mmau.py)
# =====================================================================================
class MMAUDataset(Dataset):
    """
    用于加载MMAU数据集的Dataset类。
    - 读取JSON文件获取问题和元数据。
    - 加载音频并重采样到模型处理所需的采样率。
    - ***新增功能：将所有音频统一填充或截断到10秒。***
    - 格式化prompt。
    """
    def __init__(self, json_path: str, audio_path: str, processor, skip_ids: Set[str] = None):
        super().__init__()
        self.audio_base_path = audio_path
        self.processor = processor
        self.sample_rate = processor.feature_extractor.sampling_rate
        
        # --- 新增：定义音频处理的目标时长 ---
        self.target_duration = 10  # 目标时长，单位：秒
        self.target_length = self.target_duration * self.sample_rate # 根据采样率计算目标长度
        # --- 功能结束 ---

        self.data_list = []

        logging.info(f"正在从 {json_path} 加载 MMAU 数据...")
        with open(json_path, 'r', encoding='utf8') as fin:
            all_data = json.load(fin)

        if skip_ids:
            original_count = len(all_data)
            self.data_list = [item for item in all_data if item.get("id") not in skip_ids]
            skipped_count = original_count - len(self.data_list)
            logging.info(f"检测到断点。已跳过 {skipped_count} 条已完成的记录。")
        else:
            self.data_list = all_data

        logging.info(f"成功加载 {len(self.data_list)} 条待评测数据。")

    def __len__(self):
        return len(self.data_list)

    def __getitem__(self, index):
        json_obj = self.data_list[index]

        # 1. 加载和处理音频
        audio_filename = os.path.basename(json_obj["audio_id"])
        audio_full_path = os.path.join(self.audio_base_path, audio_filename)
        try:
            # 使用 librosa 加载并重采样
            audio, _ = librosa.load(audio_full_path, sr=self.sample_rate, mono=False)
            if audio.ndim == 1: # 确保是双声道
                 audio = audio[None, :]

            # --- 新增功能：将音频填充或截断到 10 秒 ---
            current_length = audio.shape[1]
            if current_length > self.target_length:
                # 如果音频长于目标长度，则截断
                audio = audio[:, :self.target_length]
            elif current_length < self.target_length:
                # 如果音频短于目标长度，则用0填充
                pad_width = self.target_length - current_length
                # 在第二个维度（时间轴）的末尾进行填充
                audio = np.pad(audio, pad_width=((0, 0), (0, pad_width)), mode='constant')
            # --- 功能结束 ---

        except Exception as e:
            logging.error(f"加载或处理音频文件失败: {audio_full_path}, 错误: {e}")
            audio = None

        # 2. 构建问题、选项和答案
        question = json_obj["question"]
        choices = json_obj["choices"]
        ground_truth_text = json_obj["answer"]
        
        options_map = {chr(ord('A') + i): choice for i, choice in enumerate(choices)}
        choices_str = "\n".join([f"{key}: {value}" for key, value in options_map.items()])
        
        # 为 Qwen2-Audio 构建特定格式的 Prompt
        task_prompt = (
            f"<|audio_bos|><|AUDIO|><|audio_eos|>"
            f"Based on the audio, answer the following question.\n"
            f"Question: {question}\n"
            f"Options:\n{choices_str}\n"
            f"The correct option is:"
        )
        
        reverse_options_map = {v: k for k, v in options_map.items()}
        solution_letter = reverse_options_map.get(ground_truth_text, "N/A")

        return {
            "audio": audio,
            "prompt": task_prompt,
            "solution": solution_letter,
            "options": options_map,
            "scene_id": json_obj.get("id"),
        }

# =====================================================================================
# 2. 答案解析函数 (鲁棒的模糊匹配)
# =====================================================================================
def parse_answer_fuzzy(text: str, options: Dict[str, str]) -> str:
    # 预处理模型输出文本
    processed_text = text.upper().strip()
    
    option_keys = list(options.keys())
    option_keys_str = "|".join(option_keys)

    # 规则一: 寻找 "The answer is A" 或 "Option: A" 等显式声明
    keywords = ["ANSWER IS", "THE ANSWER IS", "OPTION IS", "CORRECT OPTION IS", "CORRECT ANSWER IS", "I CHOOSE", "FINAL ANSWER IS", "答案是", "选项是"]
    keywords_str = "|".join(keywords)
    pattern1 = re.compile(rf"(?:{keywords_str})\s*:?\s*\(?({option_keys_str})\)?")
    match1 = pattern1.search(processed_text)
    if match1:
        return match1.group(1)

    # 规则二: 寻找被非字母数字字符包围的单个选项字母 (例如 "..., A, ...")
    pattern2 = re.compile(rf"\b({option_keys_str})\b")
    matches2 = pattern2.findall(processed_text)
    if matches2:
        return matches2[-1] # 返回最后一个匹配项，通常更可能是最终答案

    # 规则三: 匹配选项的完整文本内容
    found_matches = {}
    for key, value in options.items():
        processed_option_text = value.upper().strip()
        if processed_option_text and processed_option_text in processed_text:
            found_matches[key] = processed_option_text
    
    if len(found_matches) == 1:
        return list(found_matches.keys())[0]
    elif len(found_matches) > 1: # 如果匹配到多个，返回最长的那个，避免子串问题
        longest_key = max(found_matches, key=lambda k: len(found_matches[k]))
        return longest_key

    return "N/A" # 如果所有规则都失败

# =====================================================================================
# 3. 主函数
# =====================================================================================
def main():
    parser = argparse.ArgumentParser(description="使用 Qwen2-Audio 在 MMAU 数据集上进行评测的脚本")
    parser.add_argument("--model_id", type=str, default="Qwen/Qwen2-Audio-7B", help="要评测的 Hugging Face 模型ID")
    parser.add_argument("--mmau_json_path", type=str, required=True, help="评测用的 MMAU json 数据文件")
    parser.add_argument("--mmau_audio_path", type=str, required=True, help="MMAU 音频文件所在的根目录")
    parser.add_argument("--output_file", type=str, required=True, help="保存详细评测结果的jsonl文件路径")
    parser.add_argument("--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu", help="运行设备")
    parser.add_argument("--batch_size", type=int, default=4, help="批处理大小")
    args = parser.parse_args()
    logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")

    # --- 1. 断点续评：加载已完成的记录 ---
    completed_ids = set()
    if os.path.exists(args.output_file):
        try:
            with open(args.output_file, 'r', encoding='utf-8') as f_in:
                for line in f_in:
                    try: completed_ids.add(json.loads(line)['scene_id'])
                    except (json.JSONDecodeError, KeyError): pass
            logging.info(f"检测到已存在的输出文件，已加载 {len(completed_ids)} 条已完成的记录。")
        except Exception as e:
            logging.warning(f"读取已有结果文件失败: {e}。将从头开始评测。")

    # --- 2. 加载模型和处理器 (4-bit量化) ---
    logging.info(f"--- 正在加载模型: {args.model_id} ---")
    quantization_config = BitsAndBytesConfig(
        load_in_4bit=True,
        bnb_4bit_compute_dtype=torch.bfloat16
    )
    model = Qwen2AudioForConditionalGeneration.from_pretrained(
        args.model_id,
        trust_remote_code=True,
        quantization_config=quantization_config
    )
    processor = AutoProcessor.from_pretrained(args.model_id, trust_remote_code=True)
    model.eval()
    logging.info("--- 模型和处理器加载完成 ---")

    # --- 3. 加载数据集 ---
    eval_dataset = MMAUDataset(
        json_path=args.mmau_json_path,
        audio_path=args.mmau_audio_path,
        processor=processor,
        skip_ids=completed_ids
    )

    # 自定义 collate_fn 用于批量处理
    def collate_fn(batch: List[Dict]):
        prompts = [item['prompt'] for item in batch]
        audios = [item['audio'] for item in batch if item['audio'] is not None]
        # 保留其他元数据
        solutions = [item['solution'] for item in batch]
        options_list = [item['options'] for item in batch]
        scene_ids = [item['scene_id'] for item in batch]
        return {
            "prompts": prompts,
            "audios": audios,
            "solutions": solutions,
            "options": options_list,
            "scene_ids": scene_ids
        }

    eval_dataloader = DataLoader(
        eval_dataset,
        batch_size=args.batch_size,
        shuffle=False,
        collate_fn=collate_fn
    )

    # --- 4. 评测循环 ---
    logging.info("--- 开始评测循环 ---")
    with open(args.output_file, 'a', encoding='utf-8') as f_out:
        for batch in tqdm(eval_dataloader, desc="Evaluating"):
            # 如果批次中的音频全部加载失败，则跳过
            if not batch["audios"]:
                logging.warning(f"跳过一个批次，因为所有音频都无法加载。IDs: {batch['scene_ids']}")
                continue

            inputs = processor(
                text=batch["prompts"],
                audios=batch["audios"],
                return_tensors="pt",
                padding=True
            ).to(args.device)

            with torch.no_grad():
                generated_ids = model.generate(**inputs, max_new_tokens=256)
            
            # 从生成结果中移除 prompt 部分
            prompt_input_ids_length = inputs.input_ids.size(1)
            response_ids = generated_ids[:, prompt_input_ids_length:]

            decoded_responses = processor.batch_decode(
                response_ids,
                skip_special_tokens=True,
                clean_up_tokenization_spaces=False
            )

            for i, response in enumerate(decoded_responses):
                parsed_answer = parse_answer_fuzzy(response, batch["options"][i])
                is_correct = (parsed_answer == batch["solutions"][i])
                
                result_item = {
                    "scene_id": batch["scene_ids"][i],
                    "prompt": batch["prompts"][i],
                    "ground_truth": batch["solutions"][i],
                    "parsed_answer": parsed_answer,
                    "is_correct": is_correct,
                    "raw_output": response,
                }
                f_out.write(json.dumps(result_item, ensure_ascii=False) + '\n')

    # --- 5. 最终报告 ---
    logging.info("--- 评测完成，生成最终报告 ---")
    final_correct = 0
    final_total = 0
    if os.path.exists(args.output_file):
        with open(args.output_file, 'r', encoding='utf-8') as f_final:
            for line in f_final:
                try:
                    final_total += 1
                    if json.loads(line)["is_correct"]:
                        final_correct += 1
                except (json.JSONDecodeError, KeyError): pass

    accuracy = (final_correct / final_total) * 100 if final_total > 0 else 0
    print("\n" + "=" * 60)
    print("           Qwen2-Audio on MMAU 评 测 报 告")
    print("=" * 60)
    print(f"  模型 (Model): {args.model_id}")
    print(f"  测试数据 (Test Data): {args.mmau_json_path}")
    print(f"  详细结果已保存至 (Output File): {args.output_file}")
    print("-" * 60)
    print(f"  总样本数 (Total Samples): {final_total}")
    print(f"  正确预测数 (Correct Predictions): {final_correct}")
    print(f"  => 准确率 (Accuracy): {accuracy:.2f}%")
    print("=" * 60)


if __name__ == "__main__":
    main()

"""
CUDA_VISIBLE_DEVICES=3 python evaluation_mmau_qwen2audio.py \
  --mmau_json_path datasets/mmau/mmau-test-mini-short.json \
  --mmau_audio_path /data2/wl/test-mini-audios-short \
  --output_file mmau_results_base_10s.jsonl \
  --batch_size 8
"""