from torch.utils.data import Dataset


class NodeDataset(Dataset):
    def __init__(self, node_level_data_dict):
        self.node_level_data_dict = node_level_data_dict

    def __len__(self):
        return len(self.node_level_data_dict["node_label"])

    def __getitem__(self, idx):
        # Return node idx and label at index idx
        node_id = self.node_level_data_dict["node_id"][idx]
        node_label = self.node_level_data_dict["node_label"][idx]
        node_dataset_name = self.node_level_data_dict["node_dataset_name"]

        node_dataset_task_name = self.node_level_data_dict["dataset_task_name"]

        dataset_name = self.node_level_data_dict["dataset_name"]
        return dict(
            node_id=node_id,
            node_label=node_label,
            node_dataset_name=node_dataset_name,
            node_dataset_task_name=node_dataset_task_name,
            dataset_name=dataset_name,
        )
