from torch.utils.data import Dataset
import re


class NodeDataset(Dataset):
    def __init__(self, node_level_data_dict, mask="train_mask"):
        self.node_level_data_dict = node_level_data_dict
        self.node_dataset_name = self.node_level_data_dict["node_dataset_name"]
        self.dataset_task_name = self.node_level_data_dict["dataset_task_name"]
        self.main_graph_dict_key = self.node_level_data_dict["main_graph_dict_key"]
        self.num_nodes = self.node_level_data_dict["num_nodes"]
        self.graph_ratio_in_epoch = self.node_level_data_dict["graph_ratio_in_epoch"]
        self.dataset_type = "NC"
        self.mask = mask
        if mask == "train_mask":
            self.split = "train"
        elif mask == "val_mask":
            self.split = "val"
        elif mask == "test_mask":
            self.split = "test"
        if self.mask is None:
            self.split = "train"

    def __len__(self):
        return len(self.node_level_data_dict["y"])

    def __getitem__(self, idx):
        # Return node idx and label at index idx
        if self.mask is not None:
            node_id = self.node_level_data_dict["node_id"][idx]
            node_label = self.node_level_data_dict["y"][idx]
        else:
            node_id = self.node_level_data_dict["node_id"][idx]
            node_label = self.node_level_data_dict["y"][idx]
        node_dataset_name = self.node_dataset_name

        node_dataset_task_name = self.dataset_task_name
        # Regex pattern to find 'graph_classification' or 'node_classification'
        pattern = re.compile(r"(graph_classification|node_classification)")
        task_type = pattern.findall(node_dataset_task_name)[0]
        node_dataset_main_graph_dict_key = self.main_graph_dict_key
        return dict(
            node_id=node_id,
            node_label=node_label,
            node_dataset_name=node_dataset_name,
            node_dataset_task_name=node_dataset_task_name,
            node_dataset_main_graph_dict_key=node_dataset_main_graph_dict_key,
            dataset_id_name=f"{node_dataset_main_graph_dict_key}",
            task_type=task_type,
            num_nodes=self.num_nodes,
        )
