import torch
from torch.utils.data import Dataset, DataLoader
from tqdm import tqdm

class FakeNewsTwitterDataset(Dataset):
    """
    Each sample is a tuple (s_dict, label) for one graph.
    You can optionally restrict to a subset of root_ids via `roots`.
    """
    def __init__(self, data, roots=None):

        # If a subset of roots specified, filter
        if roots is not None:
            data = {r: data[r] for r in roots if r in data}
        self.samples = []  # list of (s_dict, label)
        for root, graph in tqdm(data.items(), desc="Building samples", total=len(data)):
            # graph: dict node_id -> (x, dist, lbl)
            node_dict = {nid: (x, dist) for nid, (x, dist, _) in graph.items()}
            label = next(iter(graph.values()))[2]
            # wrap under single key 'propagate'
            s_dict = {'propagate': node_dict}
            self.samples.append((s_dict, label))

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        return self.samples[idx]