import json
import logging
import os

import torch
import torchaudio
from torch.utils.data import Dataset

class AudioDataset(Dataset):
    """
    一个用于处理 GRPO 训练所需音频问答数据的 PyTorch 数据集类。

    该类负责：
    1. 读取包含音频路径和问答对的 jsonl 文件。
    2. 加载并对音频波形进行重采样。
    3. 根据指定的格式将问题和选项拼接成一个 prompt (task)。
    4. 提取正确的答案选项作为 solution。
    """
    def __init__(self, data_file, sample_rate=44100):
        """
        初始化数据集。

        Args:
            data_file (str): jsonl 数据文件的路径。
            sample_rate (int, optional): 目标音频采样率。默认为 16000。
        """
        super().__init__()
        self.data_list = []
        # 获取数据文件所在的目录，用于拼接相对音频路径
        self.data_dir = os.path.dirname(data_file)
        
        with open(data_file, 'r', encoding='utf8') as fin:
            for line in fin:
                self.data_list.append(json.loads(line))

        self.sample_rate = sample_rate
        logging.info(f"成功加载 {len(self.data_list)} 条数据来自 {data_file}")

    def __len__(self):
        """返回数据集中的样本总数。"""
        return len(self.data_list)

    def __getitem__(self, index):
        """
        获取并处理数据集中的一个样本。

        Args:
            index (int): 样本的索引。

        Returns:
            dict: 一个包含 'audio', 'task', 和 'solution' 的字典。
        """
        # 1. 获取原始数据项
        json_obj = self.data_list[index]

        # 2. 加载和处理音频
        # 注意：这里的 audio_path 可能是相对路径，需要与 data_dir 拼接
        audio_path = json_obj["source_metadata"]["audio_path"]
        if not os.path.isabs(audio_path):
             audio_path = os.path.join(os.path.dirname(self.data_dir), audio_path)
        
        try:
            waveform, original_sr = torchaudio.load(audio_path)
            if original_sr != self.sample_rate:
                resampler = torchaudio.transforms.Resample(orig_freq=original_sr, new_freq=self.sample_rate)
                waveform = resampler(waveform)
            # 我们通常使用单声道音频进行处理
            audio_tensor = waveform
        except Exception as e:
            logging.error(f"加载或处理音频文件失败: {audio_path}, 错误: {e}")
            # 返回一个静音张量或跳过此样本（这里选择返回静音）
            audio_tensor = torch.zeros(self.sample_rate * 10) # 假设10秒的静音

        # 3. 解析问题、选项和答案
        question_data = json_obj["question_data"]
        question = question_data["question"]
        options = question_data["options"]
        answer_key = question_data["answer"]

        # 4. 按指定格式构建 task (prompt)
        # 将选项字典转换为 "A: ...\nB: ...\n" 格式的字符串
        choices_str = "\n".join([f"{key}: {value}" for key, value in options.items()])
        
        # 拼接最终的 prompt
        task_prompt = f"{question}. Please choose the answer from the following options: {choices_str}. "

        # 5. 提取 solution
        # solution 就是正确的答案选项，例如 "B"
        solution = answer_key

        # 6. 返回处理好的数据
        return {
            "audio": audio_tensor,
            "task": task_prompt,
            "solution": solution,
        }

if __name__ == '__main__':
    # === 用于测试数据集类的示例代码 ===
    logging.basicConfig(level=logging.INFO)

    dataset = AudioDataset("/data2/wl/RL_benchmark/output/benchmark/benchmark_questions.jsonl")
    
    # 3. 获取第一个样本并打印
    if len(dataset) > 0:
        sample = dataset[0]
        print("\n成功获取第一个样本:")
        print(f"  - Audio Tensor Shape: {sample['audio'].shape}")
        print(f"  - Task (Prompt): \n---\n{sample['task']}\n---")
        print(f"  - Solution: {sample['solution']}")
    else:
        print("数据集为空，无法获取样本。")
