import os
import h5py
import numpy as np
import torch
import torch.nn.functional as F
from torch.utils.data import Dataset


class EMGPretrainDataset(Dataset):
    """
    只处理单个文件夹。假设该文件夹内包含若干 .h5 文件，每个文件具有 (N, C, T) 的 'data' 数据。
    本类会将它们全部读入并合并成一个大数组 (N_total, C, T)，用于无标签预训练。
    """

    def __init__(
        self,
        data_path: str,         # 这里只接收文件夹路径
        transform=True,
        squeeze=False,
        zero_pad_chans=None,
        zero_pad_toks=None
    ):
        """
        Args:
            data_path (str): 指向一个包含若干 .h5 文件的文件夹
            transform (callable, optional): 对每个样本进行的可选转换函数
            squeeze (bool, optional): 是否在 0 维上插入额外维度 (如: (C, T)->(1, C, T))
            zero_pad_chans (int, optional): 在 C 维上补零
            zero_pad_toks (int, optional): 在 T 维上补零
        """
        super().__init__()
        self.transform = transform
        self.squeeze = squeeze
        self.zero_pad_chans = zero_pad_chans
        self.zero_pad_toks = zero_pad_toks

        # 判断 data_path 是否存在、是否为目录
        if not os.path.exists(data_path):
            raise FileNotFoundError(f"{data_path} does not exist.")
        if not os.path.isdir(data_path):
            raise NotADirectoryError(f"{data_path} is not a directory.")

        # 遍历该目录下所有 .h5 文件，并按需排序
        self.file_paths = []
        for fn in os.listdir(data_path):
            if fn.endswith(".h5"):
                self.file_paths.append(os.path.join(data_path, fn))
        self.file_paths.sort()

        if not self.file_paths:
            raise ValueError(f"No .h5 files found in directory: {data_path}")

        # 读取并合并所有文件 => (N_total, C, T)
        data_list = []
        for fp in self.file_paths:
            with h5py.File(fp, "r") as h5f:
                data_ = h5f["data"][:]  # => shape (N, C, T)
                data_list.append(data_)

        self._data = np.concatenate(data_list, axis=0)
        self._num_samples = self._data.shape[0]

    def __len__(self):
        return self._num_samples

    def __getitem__(self, idx: int):
        # 取第 idx 个样本 => shape (C, T)
        x_np = self._data[idx]
        x = torch.tensor(x_np, dtype=torch.float32)

        if self.squeeze:
            x = x.unsqueeze(0)

        # 可选 transform：例如 min-max 到 [-1, 1]
        if self.transform:
            max_val = x.max()
            min_val = x.min()
            x = (x - min_val) / (max_val - min_val + 1e-10)
            x = (x - 0.5) * 2  # => [-1, 1]

        # 若需要在通道维 (C) 上 zero padding
        if self.zero_pad_chans is not None:
            x = x.transpose(0, 1)  
            x = F.pad(x, (0, self.zero_pad_chans), value=0.0)
            x = x.transpose(0, 1)

        # 若需要在时间维 (T) 上 zero padding
        if self.zero_pad_toks is not None:
            x = F.pad(x, (0, self.zero_pad_toks), value=0.0)

        return {"input": x}
