# file: prism/data/base_datamodule.py
from abc import ABC, abstractmethod

import pytorch_lightning as pl
from torch.utils.data import DataLoader
from torchvision import transforms

class BaseDataModule(pl.LightningDataModule, ABC):
    def __init__(self, config):
        super().__init__()
        self.config = config
        self.data_config = config.data
        self.run_config = config.run

        self.data_dir = self.data_config.dir
        self.batch_size = self.data_config.batch_size
        self.num_workers = self.run_config.num_workers

        self.train_ds = None
        self.val_ds = None
        self.test_ds = None

        self.transform = self._get_default_transform()

    def _get_default_transform(self):
        return transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.5,), (0.5,))
        ])

    @abstractmethod
    def prepare_data(self):
        raise NotImplementedError

    @abstractmethod
    def setup(self, stage=None):
        raise NotImplementedError

    @property
    @abstractmethod
    def style_feature_map(self):
        raise NotImplementedError

    def train_dataloader(self):
        return DataLoader(
            self.train_ds,
            batch_size=self.batch_size,
            shuffle=True,
            num_workers=self.num_workers,
            pin_memory=True,
            drop_last=True,
            persistent_workers=True if self.num_workers > 0 else False
        )

    def val_dataloader(self):
        return DataLoader(
            self.val_ds,
            batch_size=self.batch_size,
            shuffle=False,
            num_workers=self.num_workers,
            pin_memory=True,
            persistent_workers=True if self.num_workers > 0 else False
        )

    def test_dataloader(self):
        if self.test_ds is None:
            return None

        return DataLoader(
            self.test_ds,
            batch_size=self.batch_size,
            shuffle=False,
            num_workers=self.num_workers,
            pin_memory=True,
            drop_last=True,
            persistent_workers=True if self.num_workers > 0 else False
        )