from typing import Optional

import pytorch_lightning as pl
from torch.utils.data import DataLoader, Dataset

class CombinedDataset(Dataset):
    def __init__(self, datasets):
        self.datasets = datasets

    def __len__(self):
        return sum(len(dataset) for dataset in self.datasets)

    def __getitem__(self, index):
        dataset_index = 0
        while index >= len(self.datasets[dataset_index]):
            index -= len(self.datasets[dataset_index])
            dataset_index += 1
        return self.datasets[dataset_index][index]

class all_data_module(pl.LightningDataModule):

    def __init__(self, train_datasets, val_datasets, test_datasets=None, cfg=None, name="", **kwargs):
        super().__init__()
        self.train_datasets = train_datasets
        self.val_datasets = val_datasets
        self.test_datasets = test_datasets
        self.name = name
        self.cfg = cfg

    def setup(self, stage: Optional[str] = None):

        if stage == "fit" or stage is None:
            self.train_dataset = CombinedDataset(self.train_datasets)
            self.val_dataset = CombinedDataset(self.val_datasets)

        elif stage == "validate":
            self.val_dataset = CombinedDataset(self.val_datasets)
        
        elif stage == "test":
            self.test_dataset = CombinedDataset(self.val_datasets)

    def train_dataloader(self):
        return DataLoader(
            self.train_dataset,
            batch_size=self.cfg.batch_size,
            shuffle=True,
            num_workers=self.cfg.num_workers,
            drop_last=True,
            pin_memory=True,
        )

    def val_dataloader(self):

        return DataLoader(
            self.val_dataset,
            batch_size=self.cfg.batch_size,
            shuffle=False,
            num_workers=self.cfg.num_workers,
            drop_last=False,
            pin_memory=True,
        )

    def test_dataloader(self):
        return DataLoader(
            self.test_dataset,
            batch_size=self.cfg.batch_size,
            shuffle=False,
            num_workers=self.cfg.num_workers,
            drop_last=False,
            pin_memory=True,
        )
