import pytorch_lightning as pl
from typing import Optional
import torchvision.transforms as T
import os
import sys
sys.path.insert(0, os.path.join(os.path.dirname(os.path.abspath(__file__)), '..', '..', 'src'))


class GenericDataModule(pl.LightningDataModule):

    def __init__(self, *args, **data_args):
        
        super().__init__()

        # this line allows to access init params with 'self.hparams' attribute
        self.save_hyperparameters(logger=False)

    def prepare_data(self, *args, **kwargs):
        pass

    def setup(self, stage: Optional[str] = None):
        if self.hparams.name == "synthetic":
            from data import synthetic
            self.train_dataset, self.val_dataset, self.train_loader, self.val_loader = synthetic.build(**self.hparams)
        elif self.hparams.name == "shapes3D":
            from data import shapes3D
            self.train_dataset, self.val_dataset, self.train_loader, self.val_loader = shapes3D.build(**self.hparams)
        elif self.hparams.name == "shapes3DGen":
            from data import shapes3DGen
            self.train_dataset, self.val_dataset, self.train_loader, self.val_loader = shapes3DGen.build(**self.hparams)
        elif self.hparams.name == "syntheticGen":
            from data import syntheticGen
            self.train_dataset, self.val_dataset, self.train_loader, self.val_loader = syntheticGen.build(**self.hparams)
        
        else:
            raise NotImplementedError

    def train_dataloader(self):
        return self.train_loader


    def val_dataloader(self):
        return self.val_loader

    
    def test_dataloader(self):
        return self.val_loader
    