# import torch
import random
from typing import Optional, Sequence
import torch
from torch.utils.data import DataLoader, Dataset
from torch.utils.data import random_split
import lightning as L


class Dorschky2024DataModule(L.LightningDataModule):
    def __init__(
        self,
        train_dataset: Dataset,
        batch_size: int = 32,
        num_workers: int = 0,
        persistent_workers: bool = False,
        train_test_split: Sequence = [0.9, 0.1],
        test_val_split: Sequence = [0.5, 0.5],
        val_dataset: Optional[Dataset] = None,
        test_dataset: Optional[Dataset] = None,
        predict_dataset: Optional[Dataset] = None,
        dataset_variables: Optional[dict] = None,
    ):
        super().__init__()

        self.save_hyperparameters(logger=False)

        self.data_train: Dataset = train_dataset
        self.data_val: Optional[Dataset] = val_dataset
        self.data_test: Optional[Dataset] = test_dataset
        self.data_pred: Optional[Dataset] = predict_dataset

    def prepare_data(self):
        """Download data if needed"""
        pass

    def setup(self, stage: Optional[str] = None):
        """Load data. Set variables: `self.data_train`, `self.data_val`, `self.data_test`."""
        if not self.data_test:
            self.data_train, self.data_test = random_split(
                dataset=self.hparams.train_dataset,  # type: ignore
                lengths=self.hparams.train_test_split,  # type: ignore
                generator=torch.Generator().manual_seed(42),
            )
        if not self.data_val:
            self.data_test, self.data_val = random_split(
                dataset=self.data_test,  # type: ignore
                lengths=self.hparams.test_val_split,  # type: ignore
                generator=torch.Generator().manual_seed(42),
            )

    def train_dataloader(self):
        return DataLoader(
            dataset=self.data_train,  # type: ignore
            batch_size=self.hparams.batch_size,  # type: ignore
            num_workers=self.hparams.num_workers,  # type: ignore
            persistent_workers=self.hparams.persistent_workers,  # type: ignore
            shuffle=True,
        )

    def val_dataloader(self):
        return DataLoader(
            dataset=self.data_val,  # type: ignore
            batch_size=self.hparams.batch_size,  # type: ignore
            num_workers=self.hparams.num_workers,  # type: ignore
            persistent_workers=self.hparams.persistent_workers,  # type: ignore
            shuffle=False,
        )

    def test_dataloader(self):
        return DataLoader(
            dataset=self.data_test,  # type: ignore
            batch_size=self.hparams.batch_size,  # type: ignore
            num_workers=self.hparams.num_workers,  # type: ignore
            persistent_workers=self.hparams.persistent_workers,  # type: ignore
            shuffle=False,
        )

    def predict_dataloader(self):
        assert (
            self.data_pred is not None
        ), "Make sure to set predict_dataset when running prediction"
        return DataLoader(
            dataset=self.data_pred,  # type: ignore
            num_workers=self.hparams.num_workers,  # type: ignore
            persistent_workers=self.hparams.persistent_workers,  # type: ignore
            shuffle=False,
        )


if __name__ == "__main__":
    import hydra
    import omegaconf
    import pyrootutils

    root = pyrootutils.setup_root(__file__, pythonpath=True)
    cfg = omegaconf.OmegaConf.load(root / "configs" / "datamodule" / "weiss2022.yaml")
    movementEst_datamodule = hydra.utils.instantiate(cfg)

    train_dl = movementEst_datamodule.train_dataloader()
    for batch in train_dl:
        print("Training data shape:")
        for key, value in batch.items():
            print(key, value.shape)
        break

    test_dl = movementEst_datamodule.test_dataloader()
    for batch in test_dl:
        print("Testing data shape:")
        for key, value in batch.items():
            print(key, value.shape)
        break

    predict_dl = movementEst_datamodule.predict_dataloader()
    for batch in predict_dl:
        print("Prediction data shape:")
        for key, value in batch.items():
            print(key, value.shape)
        break
