from typing import Optional

from hydra.utils import instantiate
from lightning import LightningDataModule
from omegaconf import DictConfig
from torch.utils.data import DataLoader

from data.dataset import MyDataset
from data.dataset_noise import DatasetNoise
from data.preprocessing import Pipeline


class DDMDataModule(LightningDataModule):

    def __init__(
        self, batch_size: int, num_workers: int, pin_memory: bool, pipeline: DictConfig, dataset: DictConfig
    ) -> None:
        super().__init__()

        self.batch_size = batch_size
        self.num_workers = num_workers
        self.pin_memory = pin_memory

        self.pipeline: Optional[Pipeline] = instantiate(pipeline) if pipeline else None

        self.dataset_train: MyDataset = instantiate(dataset.train, pipeline=self.pipeline)
        self.dataset_val: MyDataset = instantiate(dataset.val, pipeline=self.pipeline)
        self.dataset_predict: DatasetNoise = instantiate(dataset.predict)

    def train_dataloader(self) -> DataLoader:
        return DataLoader(
            self.dataset_train, shuffle=True, batch_size=self.batch_size, num_workers=self.num_workers,
            pin_memory=self.pin_memory, drop_last=True,
        )

    def val_dataloader(self) -> DataLoader:
        return DataLoader(
            self.dataset_val, shuffle=False, batch_size=self.batch_size, num_workers=self.num_workers,
            pin_memory=self.pin_memory, drop_last=True,
        )

    def predict_dataloader(self) -> DataLoader:
        return DataLoader(
            self.dataset_predict, batch_size=self.batch_size, num_workers=self.num_workers, pin_memory=self.pin_memory,
            drop_last=True
        )

    def __repr__(self) -> str:
        return (
            f"{self.__class__.__name__}\n"
            f"\tbatch_size={self.batch_size}\n"
            f"\tpipeline={self.pipeline}\n"
            f"\tdataset_train={self.dataset_train}\n"
            f"\tdataset_val={self.dataset_val}\n"
            f"\tdataset_predict={self.dataset_predict}"
        )
