import pytorch_lightning as pl
import hydra
from torch.utils.data import DataLoader
from typing import Optional


class FinetuneModule(pl.LightningDataModule):
    """
    一个管理 FinetuneDataset 的 DataModule：
      - train, val, test 在 config 中都是尚未 instantiate 的 dict (带 _target_)
      - 在 setup() 里调用 hydra.utils.instantiate
      - 用 cfg.batch_size, cfg.num_workers 构建 DataLoader
    """
    def __init__(self, train, val, test=None, cfg=None, name="", **kwargs):
        super().__init__()
        self.train_cfg = train
        self.val_cfg   = val
        self.test_cfg  = test
        self.name = name
        self.cfg = cfg

        # 这三个Dataset会在setup()里赋值
        self.train_dataset = None
        self.val_dataset = None
        self.test_dataset = None

    def setup(self, stage: Optional[str] = None):
        # 在 fit 时(或不指定stage时) 设置 train & val ds
        if stage == "fit" or stage is None:
            self.train_dataset = hydra.utils.instantiate(self.train_cfg)
            self.val_dataset   = hydra.utils.instantiate(self.val_cfg)

        # 在 test 时(或不指定stage时) 设置 test ds
        if stage == "test" or stage is None:
            if self.test_cfg is not None:
                self.test_dataset = hydra.utils.instantiate(self.test_cfg)

        # 如果还有 validate stage，也可以同理:
        elif stage == "validate": self.val_dataset = hydra.utils.instantiate(self.val_cfg)

    def train_dataloader(self):
        return DataLoader(
            self.train_dataset,
            batch_size=self.cfg.batch_size,
            shuffle=True,
            num_workers=self.cfg.num_workers,
            drop_last=True,
            pin_memory=True,
        )

    def val_dataloader(self):
        return DataLoader(
            self.val_dataset,
            batch_size=self.cfg.batch_size,
            shuffle=False,
            num_workers=self.cfg.num_workers,
            drop_last=False,
            pin_memory=True,
        )

    def test_dataloader(self):
        if self.test_dataset is None:
            return None
        return DataLoader(
            self.test_dataset,
            batch_size=self.cfg.batch_size,
            shuffle=False,
            num_workers=self.cfg.num_workers,
            drop_last=False,
            pin_memory=True,
        )
