import os
import pytorch_lightning as pl
from torch.utils.data import DataLoader as TorchDataLoader
from torch_geometric.loader import DataLoader as GeoDataLoader
from datasets.unimodal_dataset import Smiles3DGraphDataset, SmilesGraphDataset, PeptideGraphDataset
from datasets.multimodal_dataset import FusionDataset


class DataModule(pl.LightningDataModule):
    """DataModule for SMILES graph datasets."""
    def __init__(self, mode: str, data_dir: str, batch_size: int, num_workers: int = 4, global_mean_override: float = None, vocab: dict | None = None):
        super().__init__()
        self.data_dir = data_dir
        self.batch_size = batch_size
        self.num_workers = num_workers
        self.mode = mode  # or 'smiles', set based on your dataset type
        self.vocab = vocab
        

    def setup(self, stage=None):
        train_csv = os.path.join(self.data_dir, 'train.csv')
        val_csv   = os.path.join(self.data_dir, 'valid.csv')
        test_csv  = os.path.join(self.data_dir, 'test.csv')

        if self.mode == 'SMILES':
            self.train_ds = SmilesGraphDataset(train_csv)
            self.val_ds   = SmilesGraphDataset(val_csv)
            self.test_ds  = SmilesGraphDataset(test_csv)

            all_targets = [d.y.item() for d in self.train_ds]

        elif self.mode == 'PEPTIDE':
            self.train_ds = PeptideGraphDataset(train_csv)
            self.vocab    = self.train_ds.vocab
            self.val_ds   = PeptideGraphDataset(val_csv, vocab=self.vocab)
            self.test_ds  = PeptideGraphDataset(test_csv, vocab=self.vocab)

            all_targets = [d.y.item() for d in self.train_ds]

        elif self.mode == 'GEOMETRY':
            self.train_ds = Smiles3DGraphDataset(train_csv)
            self.val_ds   = Smiles3DGraphDataset(val_csv)
            self.test_ds  = Smiles3DGraphDataset(test_csv)

            all_targets = [d.y.item() for d in self.train_ds]

        elif self.mode == 'FUSION':
            self.train_ds = FusionDataset(train_csv)
            self.vocab = self.train_ds.ds_peptide.vocab
            self.val_ds = FusionDataset(val_csv, vocab_peptide=self.vocab)
            self.test_ds = FusionDataset(test_csv, vocab_peptide=self.vocab)

            self.vocab_peptide_inv   = {i: tok for tok, i in self.vocab.items()}

            all_targets = [d["y"].item() for d in self.train_ds]

        else:

            raise ValueError(f"Unsupported mode: {self.mode}")

        


    def _make_dataloader(self, dataset, shuffle=False, bs=1):
        """
        Choose GeoDataLoader for graph data (SMILES, PEPTIDE),
        TorchDataLoader + custom collate for SEQUENCE and FUSION.
        """
        if self.mode in ('SMILES', 'PEPTIDE', 'GEOMETRY'):
            return GeoDataLoader(
                dataset,
                batch_size=bs,
                shuffle=shuffle,
                num_workers=self.num_workers,
                pin_memory=True,
            )

        # SEQUENCE or FUSION
        collate = FusionDataset.collate_fn
        return TorchDataLoader(
            dataset,
            batch_size=bs,
            shuffle=shuffle,
            num_workers=self.num_workers,
            pin_memory=True,
            collate_fn=collate,
        )

    def train_dataloader(self):
        return self._make_dataloader(self.train_ds, shuffle=True, bs = self.batch_size)

    def val_dataloader(self):
        return self._make_dataloader(self.val_ds, shuffle=False, bs = self.batch_size)

    def test_dataloader(self):
        return self._make_dataloader(self.test_ds, shuffle=False, bs = self.batch_size)


def build_datamodule(mode, args):
    """Create and set up a ``DataModule`` for *mode*."""
    dm = DataModule(
        mode=mode,
        data_dir=args.data_dir,
        batch_size=args.batch_size,
        num_workers=args.num_workers
    )
    dm.setup()
    return dm
