from typing import Literal

import pytorch_lightning as pl
from torch.utils.data.dataset import Dataset


DatasetStage = Literal["train", "val", "test"]

class DatasetAccessor:

    def setup_accessors(self) -> None:
        self.datasets = {}

    def set_dataset(self, stage: DatasetStage, dataset: Dataset) -> None:
        self.datasets[stage] = dataset

    def get_dataset(self, stage: DatasetStage) -> Dataset:
        return self.datasets[stage]

class LightningDataAccessor(pl.LightningDataModule, DatasetAccessor):

    def __init__(self) -> None:
        super().__init__()
        self.setup_accessors()
