from typing import Optional

from jamtorch.data import get_batch
from omegaconf import OmegaConf
from pytorch_lightning import LightningDataModule
from torch.utils.data import DataLoader, Dataset

from src.datamodules.datasets.point_cloud import MVP_CP

# pylint: disable=abstract-method


class PointCloudModule(LightningDataModule):
    def __init__(self, **kwargs):
        super().__init__()
        cfg = OmegaConf.create(kwargs)
        self.cfg = cfg

        self.data_train: Optional[Dataset] = None
        self.data_test: Optional[Dataset] = None

    def setup(self, stage: Optional[str] = None):
        self.data_train = MVP_CP(prefix="train", path=self.cfg.ds_path)
        self.data_test = MVP_CP(prefix="test", path=self.cfg.ds_path)

    def train_dataloader(self):
        return DataLoader(
            dataset=self.data_train,
            batch_size=self.cfg.batch_size,
        )

    def val_dataloader(self):
        return DataLoader(
            dataset=self.data_test,
            batch_size=self.cfg.batch_size,
        )

    def get_test_samples(self, batch_size=100):
        test_dl = DataLoader(
            dataset=self.data_test,
            batch_size=batch_size,
            shuffle=True,
        )
        return get_batch(test_dl)
