from torch.utils.data import Dataset


class NodeDataset(Dataset):
    def __init__(self, node_level_data_dict, mask="train_mask"):
        self.node_level_data_dict = node_level_data_dict
        self.mask = mask
        if mask == "train_mask":
            self.split = "train"
        elif mask == "val_mask":
            self.split = "val"
        elif mask == "test_mask":
            self.split = "test"
        if self.mask is None:
            self.split = "train"

    def __len__(self):
        return len(self.node_level_data_dict["y"])

    def __getitem__(self, idx):
        # Return node idx and label at index idx
        if self.mask is not None:
            node_id = self.node_level_data_dict["node_id"][idx]
            node_label = self.node_level_data_dict["y"][idx]
        else:
            node_id = self.node_level_data_dict["node_id"][idx]
            node_label = self.node_level_data_dict["y"][idx]
        node_dataset_name = self.node_level_data_dict["node_dataset_name"]

        node_dataset_task_name = self.node_level_data_dict["dataset_task_name"]

        node_dataset_main_graph_dict_key = self.node_level_data_dict[
            "main_graph_dict_key"
        ]
        return dict(
            node_id=node_id,
            node_label=node_label,
            node_dataset_name=node_dataset_name,
            node_dataset_task_name=node_dataset_task_name,
            node_dataset_main_graph_dict_key=node_dataset_main_graph_dict_key,
        )
