from typing import Optional
import pytorch_lightning as pl
from torch.utils.data import DataLoader

class EMGPretrainDataModule(pl.LightningDataModule):
    """
    用于加载 EMGPretrainDataset 的 DataModule，适合无标签（纯预训练）场景。
    """

    def __init__(self, 
                 train_dataset, 
                 val_dataset, 
                 test_dataset=None, 
                 cfg=None, 
                 name="", 
                 **kwargs):
        """
        Args:
            train_dataset (EMGPretrainDataset): 训练数据集（无标签）
            val_dataset (EMGPretrainDataset): 验证数据集（无标签）
            test_dataset (EMGPretrainDataset, optional): 测试数据集（无标签），可选
            cfg (object, optional): 配置对象或字典，一般包含 batch_size、num_workers 等信息
            name (str, optional): 数据模块名称（可用于日志/调试）
        """
        super().__init__()
        self.train = train_dataset
        self.val = val_dataset
        self.test = test_dataset
        self.cfg = cfg
        self.name = name

    def setup(self, stage: Optional[str] = None):
        """
        按照 Lightning 的流程，会在 'fit'、'validate'、'test' 等阶段调用。
        这里将 Dataset 分配给相应的属性，供后续 DataLoader 使用。
        """
        if stage == "fit" or stage is None:
            self.train_dataset = self.train
            self.val_dataset = self.val

        elif stage == "validate":
            self.val_dataset = self.val

        elif stage == "test":
            self.test_dataset = self.test

    def train_dataloader(self):
        """
        返回一个 DataLoader，用于训练阶段。
        """
        return DataLoader(
            self.train_dataset,
            batch_size=self.cfg.batch_size,
            shuffle=True,
            num_workers=self.cfg.num_workers,
            drop_last=False,
            pin_memory=True,
        )

    def val_dataloader(self):
        """
        返回一个 DataLoader，用于验证阶段。
        """
        return DataLoader(
            self.val_dataset,
            batch_size=self.cfg.batch_size,
            shuffle=False,
            num_workers=self.cfg.num_workers,
            drop_last=False,
            pin_memory=True,
        )

    def test_dataloader(self):
        """
        返回一个 DataLoader，用于测试阶段（若提供了测试集）。
        """
        if self.test_dataset is None:
            return None

        return DataLoader(
            self.test_dataset,
            batch_size=self.cfg.batch_size,
            shuffle=False,
            num_workers=self.cfg.num_workers,
            drop_last=False,
            pin_memory=True,
        )
