import torch
from torch.utils.data import Dataset
from torch_geometric.data import Batch
from utils import SequenceBatch
from datasets.unimodal_dataset import SmilesGraphDataset, PeptideGraphDataset, Smiles3DGraphDataset


class FusionDataset(Dataset):
    """
    Combine three modalities (molecule graph, peptide graph, sequence tokens)
    into a single sample for multi-modal regression.
    """
    def __init__(
        self,
        csv_path,
        vocab_peptide=None
    ):
        # build individual datasets
        self.ds_smiles = SmilesGraphDataset(csv_path)
        self.ds_geometry= Smiles3DGraphDataset(csv_path)
        self.ds_peptide = PeptideGraphDataset(csv_path, vocab=vocab_peptide)
        assert len(self.ds_smiles) == len(self.ds_peptide) == len(self.ds_geometry), \
            "CSV files must have the same number of rows (one-to-one alignment)."

    def __len__(self):
        return len(self.ds_smiles)

    def __getitem__(self, idx):
        mol_graph = self.ds_smiles[idx]          # PyG Data (contains .y)
        pep_graph = self.ds_peptide[idx]         # PyG Data (contains .y)
        geo_graph = self.ds_geometry[idx]               # PyG Data (contains .y)
        # seq_ids, label_seq = self.ds_sequence[idx]

        label = mol_graph.y


        return {
            "mol": mol_graph,
            "pep": pep_graph,
            "geo": geo_graph,
            "y":   label
        }

    @staticmethod
    def collate_fn(batch):
        """
        Merge list of samples into one mini-batch:
          - mol, pep: PyG Batch
          - seq      : SequenceBatch(input_ids, y)
          - labels   : FloatTensor (B, 1)
        """
        mol_batch = Batch.from_data_list([b["mol"] for b in batch])
        pep_batch = Batch.from_data_list([b["pep"] for b in batch])
        geo_batch = Batch.from_data_list([b["geo"] for b in batch])

        # seqs   = torch.stack([b["seq"] for b in batch], dim=0)  # (B, L)
        labels = torch.stack([b["y"]   for b in batch], dim=0)  # (B, 1)
        return (mol_batch, pep_batch, geo_batch), labels
