import os

from torch.utils.data import Dataset

from .data_collator import data_collator_seed_eval


class AudioDatasetSeedEval(Dataset):
    def __init__(
        self,
        meta_file,
        dataset_path,
    ):
        self.dataset_path = dataset_path

        self.data = {}
        self.data_names = []
        for line in open(meta_file).readlines():
            line = line.strip()
            if len(line.strip().split("|")) == 5:
                utt, prompt_text, prompt_wav, infer_text, infer_wav = (
                    line.strip().split("|")
                )
                audio_path = os.path.join(self.dataset_path, infer_wav)
            elif len(line.strip().split("|")) == 4:
                utt, prompt_text, prompt_wav, infer_text = line.strip().split("|")
                audio_path = None
            else:
                raise ValueError("Invalid line format")

            # Audio path
            prompt_audio_path = os.path.join(self.dataset_path, prompt_wav)

            self.data[utt] = {
                "text": infer_text,
                "prompt_text": prompt_text,
                "audio": audio_path,
                "prompt_audio": prompt_audio_path,
            }
            self.data_names.append(utt)

        self.data_collator = data_collator_seed_eval

    def __len__(self):
        """
        返回混合数据集的长度。
        """
        return len(self.data)

    def __getitem__(self, idx):
        sample = self.data[self.data_names[idx]]

        data = {
            "utt": self.data_names[idx],
            "text": sample["text"],
            "prompt_text": sample["prompt_text"],
            "audio_path": sample["audio"],
            "prompt_audio_path": sample["prompt_audio"],
        }

        return data
