from typing import Optional
import pytorch_lightning as pl
from torch.utils.data import DataLoader, ConcatDataset
import torch
from torchvision import transforms

class ConcatenatedEEGDataModule(pl.LightningDataModule):
    def __init__(self, train:[torch.utils.data.Dataset], test=None, val=None, cfg=None, name="", train_val_split_ratio=0.8, **kwargs):
        super().__init__()
        
        # Concatenate multiple datasets for training
        datasets_list = [train[dataset_name] for dataset_name in train if train[dataset_name] is not None]
        print('datasets list:', datasets_list)
        combined_datasets = ConcatDataset(datasets_list)
        print('len of data', len(combined_datasets))
        train_size = int(train_val_split_ratio * len(combined_datasets))
        val_size = len(combined_datasets) - train_size
        self.train, self.val = torch.utils.data.random_split(combined_datasets, [train_size, val_size])
        self.test = test
        self.name = name
        self.cfg = cfg
        self.batch_size = self.cfg.batch_size
    

    def setup(self, stage: Optional[str] = None):
        
        # Assign train/val datasets for use in dataloaders
        if stage == "fit" or stage is None:
            self.train_dataset = self.train
            self.val_dataset = self.val
        elif stage == "validate":
            self.val_dataset = self.val
        elif stage == "test":
            self.test_dataset = self.val
    

    def train_dataloader(self):
        if not hasattr(self, 'train_dataset'):
            raise ValueError("Setup method must be called before accessing train_dataloader.")
        return DataLoader(
            self.train_dataset,
            batch_size=self.batch_size,
            shuffle=True,
            num_workers=self.cfg.num_workers,
            drop_last=False,
            pin_memory=True,
        )

    def val_dataloader(self):
        if not hasattr(self, 'val_dataset'):
            raise ValueError("Setup method must be called before accessing val_dataloader.")
        return DataLoader(
            self.val_dataset,
            batch_size=self.batch_size,
            shuffle=False,
            num_workers=self.cfg.num_workers,
            drop_last=False,
            pin_memory=True,
        )
    
