import os
import numpy as np
import random
from torch.utils.data import Dataset, get_worker_info
from datasets import load_from_disk
from .data_collator import data_collator_token
from concurrent.futures import ThreadPoolExecutor, TimeoutError


class AudioDatasetHF(Dataset):
    def __init__(
        self,
        dataset_paths,
        dataset_ratios,
        splits,
        audio_token_dir,
        len_factor=1,
        timeout=5.0,  # 添加超时参数，默认为 5 秒
    ):
        """
        初始化混合数据集。
        :param dataset_paths: 数据集路径列表。
        :param dataset_ratios: 每个数据集的采样比例（权重）。
        :param splits: 数据集的分割类型（如 'train', 'val' 等）。
        :param audio_token_dir: 音频 token 目录。
        :param len_factor: 数据集长度因子。
        :param interleave_seed: 用于 interleave_datasets 的随机种子。
        :param timeout: __getitem__ 方法的最大执行时间（秒）。
        """
        self.data_collator = data_collator_token
        self.audio_token_dir = audio_token_dir
        self.timeout = timeout

        # 计算每个数据集的采样权重
        self.dataset_paths = dataset_paths
        self.dataset_ratios = dataset_ratios
        self.dataset_weights = [ratio / sum(dataset_ratios) for ratio in dataset_ratios]

        # 加载数据集
        self.datasets = []
        for i, dataset_path in enumerate(dataset_paths):
            dataset_path = os.path.join(dataset_path, splits[i])
            dataset = load_from_disk(dataset_path)
            self.datasets.append(dataset)

        # For sampling diversity
        self.arnold_id = int(os.environ.get("ARNOLD_ID", 0)) + 1
        self.dataset_total = sum(len(d) for d in self.datasets)

        print(f"{splits[-1]} Dataset number: {self.dataset_total}")

        # TODO: 因为 epoch卡顿的问题，临时增加数据集长度
        self.dataset_total *= len_factor

    def __len__(self):
        """
        返回混合数据集的长度。
        """
        return self.dataset_total

    def get_sample_distribution(self, num_samples=1000):
        """
        检查采样分布是否符合预期比例。
        :param num_samples: 采样数量。
        :return: 每个数据集的采样次数。
        """
        from collections import defaultdict

        distribution = defaultdict(int)
        for _ in range(num_samples):
            selected_dataset = random.choices(
                self.datasets, weights=self.dataset_weights, k=1
            )[0]
            dataset_idx = self.datasets.index(selected_dataset)
            distribution[dataset_idx] += 1
        return distribution

    def _get_audio_token_path(self, audio_path):
        audio_tokens_path = audio_path.replace(
            "/audio/", f"/{self.audio_token_dir}/"
        ).replace(".wav", ".npy")
        return audio_tokens_path

    def _load_audio_tokens(self, audio_path):
        audio_token_path = self._get_audio_token_path(audio_path)
        if not os.path.exists(audio_token_path):
            return None
        try:
            audio_data = np.load(audio_token_path, allow_pickle=False)
        except:
            return None
        return {
            "audio_tokens": audio_data,
            "audio_lengths": len(audio_data),
        }

    def _getitem_impl(self, idx):
        """
        根据权重随机选择一个数据集，然后从该数据集中随机抽取一个样本。
        :param idx: 索引（未直接使用，因为采样是随机的）。
        """
        # Get worker information for unique seed per worker
        worker_info = get_worker_info()
        if worker_info is not None:
            # Seed random functions differently per worker
            random.seed(worker_info.id * self.arnold_id + idx)

        # 根据权重随机选择一个数据集
        selected_dataset_index = random.choices(
            range(len(self.datasets)), weights=self.dataset_weights, k=1
        )[0]
        selected_dataset = self.datasets[selected_dataset_index]

        # 从选中的数据集中随机抽取一个样本
        random_idx = random.randint(0, len(selected_dataset) - 1)
        sample = selected_dataset[random_idx]

        data = {
            "instruction": "",
            "response": sample["text"],
        }

        # Load audio tokens
        relative_audio_path = sample["audio"]
        audio_path = os.path.join(
            self.dataset_paths[selected_dataset_index], relative_audio_path
        )
        audio_dict = self._load_audio_tokens(audio_path)

        if audio_dict is None:
            print(f"Failed to load audio tokens for {audio_path}. Returning None.")
            return None
        else:
            data.update(audio_dict)
            return data

    def __getitem__(self, idx):
        """
        根据权重随机选择一个数据集，然后从该数据集中随机抽取一个样本。
        :param idx: 索引（未直接使用，因为采样是随机的）。
        """
        with ThreadPoolExecutor(max_workers=1) as executor:
            future = executor.submit(self._getitem_impl, idx)
            try:
                return future.result(timeout=self.timeout)
            except TimeoutError:
                print(f"Timeout occurred for index {idx}. Returning None.")
                return None
