from typing import Optional
import pytorch_lightning as pl
from torch.utils.data import DataLoader, ConcatDataset
import torch
from torchvision import transforms


class ConcatenatedEEGDataModule(pl.LightningDataModule):
    def __init__(self, train_dict, test=None, val=None, cfg=None, name="", train_val_split_ratio=0.8):
        super().__init__()
        self.train_dict = train_dict
        self.test = test
        self.val = val
        self.cfg = cfg
        self.name = name
        self.train_val_split_ratio = train_val_split_ratio

    def setup(self, stage: Optional[str] = None):
        if stage in ("fit", None):
            # Construct datasets here at runtime
            datasets = [self.train_dict[k] for k in self.train_dict if self.train_dict[k] is not None]
            combined = ConcatDataset(datasets)
            n_train = int(self.train_val_split_ratio * len(combined))
            n_val = len(combined) - n_train
            self.train_dataset, self.val_dataset = torch.utils.data.random_split(combined, [n_train, n_val])

        if stage == "test":
            self.test_dataset = self.test

    def train_dataloader(self):
        return DataLoader(
            self.train_dataset,
            batch_size=self.cfg.batch_size,
            shuffle=True,
            num_workers=self.cfg.num_workers,
            pin_memory=True,
            # Optionally add worker_init_fn here if needed
        )

    def val_dataloader(self):
        return DataLoader(
            self.val_dataset,
            batch_size=self.cfg.batch_size,
            shuffle=False,
            num_workers=self.cfg.num_workers,
            pin_memory=True,
        )