{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "a32dd142",
   "metadata": {},
   "source": [
    "### Generate Spurious-Motif Dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "142774b6",
   "metadata": {},
   "outputs": [],
   "source": [
    "from BA3_loc import *\n",
    "from tqdm import tqdm\n",
    "import os.path as osp\n",
    "import warnings\n",
    "warnings.filterwarnings(\"ignore\")\n",
    "import random\n",
    "import torch\n",
    "import copy\n",
    "import itertools\n",
    "data_dir = f'./data/CRCG-MOTIF/raw/'\n",
    "os.makedirs(data_dir, exist_ok=True)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "742338c9",
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "def create_motif_star_branch(size, branches, node_feature_mean, std):\n",
    "\n",
    "\n",
    "       G = nx.Graph()\n",
    "       role_id = []\n",
    "\n",
    "       for i in range(size):\n",
    "          if i == 0:\n",
    "            features = np.random.normal(loc=node_feature_mean, scale=std)\n",
    "            G.add_node(i, features=features, is_center=True)\n",
    "            G.nodes[i]['feature'] = features\n",
    "            role_id.append(0)\n",
    "          else:\n",
    "            features = np.random.normal(loc=node_feature_mean, scale=std)\n",
    "            G.add_node(i, features=features, is_center=False)\n",
    "            G.nodes[i]['feature'] = features\n",
    "            role_id.append(1)\n",
    "\n",
    "       for i in range(1, size):\n",
    "              G.add_edge(0, i)\n",
    "       return G,role_id\n",
    "def create_motif_path_branch(size, branches, node_feature_mean, std):\n",
    "\n",
    "        G = nx.Graph()\n",
    "        role_id = []\n",
    "\n",
    "        for i in range(size):\n",
    "            node_id = i\n",
    "            node_features = np.random.normal(node_feature_mean, std)\n",
    "            G.add_node(node_id, features=node_features.tolist())\n",
    "            G.nodes[node_id]['feature'] = node_features\n",
    "            role_id.append(1)\n",
    "            if i > 0:\n",
    "                G.add_edge(i-1, i)\n",
    "        return G,role_id\n",
    "def create_motif_fan_branch(size, branches, node_feature_mean, std):\n",
    "\n",
    "        G = nx.Graph()\n",
    "        role_id = []\n",
    "\n",
    "        for i in range(size):\n",
    "          if i == 0:\n",
    "           features = np.random.normal(loc=node_feature_mean, scale=std)\n",
    "           G.add_node(i, features=features, is_center=True)\n",
    "           G.nodes[i]['feature'] = features\n",
    "           role_id.append(0)\n",
    "          else:\n",
    "           features = np.random.normal(loc=node_feature_mean, scale=std)\n",
    "           G.add_node(i, features=features, is_center=False)\n",
    "           G.nodes[i]['feature'] = features\n",
    "           role_id.append(1)\n",
    "\n",
    "        for i in range(1, size):\n",
    "          G.add_edge(0, i)\n",
    "        for i in range(1, branches):\n",
    "          for j in range(1, size // branches):\n",
    "           G.add_edge(i, j * branches + i)\n",
    "        return G,role_id\n",
    "def create_motif_cuspedPolygon_branch(size, branches, node_feature_mean, std):\n",
    "\n",
    "\n",
    "\n",
    "        G = nx.Graph()\n",
    "        role_id = []\n",
    "\n",
    "        nodes = range(size)\n",
    "        features = np.random.normal(loc=node_feature_mean, scale=std)\n",
    "        role_id.append(1)\n",
    "        G.add_nodes_from(nodes)\n",
    "\n",
    "        for i in range(size // 2):\n",
    "           node_a = i\n",
    "           node_b = (i + size // 2) % size\n",
    "        for b in range(branches):\n",
    "            next_node = (i * branches + b + 1) % (size // 2) + size // 2\n",
    "            G.add_edge(node_a, next_node)\n",
    "            G.add_edge(node_b, next_node)\n",
    "            G.add_edge(size // 2 - 1, size // 2)\n",
    "\n",
    "        if len(G.edges()) == 0:\n",
    "           node_a = random.choice(list(G.nodes()))\n",
    "           node_b = random.choice(list(G.nodes()))\n",
    "           G.add_edge(node_a, node_b)\n",
    "        return G,role_id\n",
    "def create_motif_random_bipartite_branch(size, branches, node_feature_mean, std):    \n",
    "\n",
    "        G = nx.Graph()\n",
    "        role_id=[]\n",
    "\n",
    "        nodes = []\n",
    "        for i in range(size):\n",
    "           nodes.append((i, {'feat': np.random.normal(node_feature_mean, std)}))\n",
    "           role_id.append(i)\n",
    "        G.add_nodes_from(nodes)\n",
    "\n",
    "        num_nodes_per_part = size // 2\n",
    "        part1 = list(range(num_nodes_per_part))\n",
    "        part2 = list(range(num_nodes_per_part, size))\n",
    "        for i in part1:\n",
    "           for j in part2:\n",
    "              if np.random.rand() < 0.5:\n",
    "                  G.add_edge(i, j)\n",
    "\n",
    "        if len(G.edges()) == 0:\n",
    "           node_a = random.choice(list(G.nodes()))\n",
    "           node_b = random.choice(list(G.nodes()))\n",
    "           G.add_edge(node_a, node_b)\n",
    "        return G,role_id\n",
    "def create_motif_tree_branch(size, branches, node_feature_mean, std): \n",
    "        assert branches > 0, \"Number of branches must be greater than 0 for tree motif.\"\n",
    "        G = nx.Graph()\n",
    "        node_id = 0\n",
    "        role_id=[]\n",
    "        for i in range(size):\n",
    "            G.add_node(node_id, feature=np.random.normal(node_feature_mean, std))\n",
    "            role_id.append(i)\n",
    "            if i > 0:\n",
    "               parent_id = random.randint(max(0, node_id - branches), node_id - 1)\n",
    "               G.add_edge(parent_id, node_id)\n",
    "            node_id += 1 \n",
    "        return G,role_id  \n",
    "def create_motif_trident_branch(size, branches, node_feature_mean, std):      \n",
    "\n",
    "\n",
    "      G = nx.Graph()\n",
    "      role_id=[]\n",
    "      for i in range(size):\n",
    "         feature = np.random.normal(node_feature_mean, std)\n",
    "         G.add_node(i, feature=feature)\n",
    "         role_id.append(i)\n",
    "      for i in range(branches):\n",
    "          start_node = i * 3\n",
    "          end_node_1 = start_node + 1\n",
    "          end_node_2 = start_node + 2\n",
    "          G.add_edge(start_node, end_node_1)\n",
    "          G.add_edge(start_node, end_node_2)\n",
    "          if i > 0:\n",
    "             last_branch_end_1 = (i - 1) * 3 + 1\n",
    "             last_branch_end_2 = (i - 1) * 3 + 2\n",
    "             G.add_edge(last_branch_end_1, end_node_1)\n",
    "             G.add_edge(last_branch_end_2, end_node_2)\n",
    "          return G,role_id\n",
    "def create_motif_conicalConnection_branch(size, branches, node_feature_mean, std):  \n",
    "\n",
    "\n",
    "\n",
    "      G = nx.Graph()\n",
    "      role_id = []\n",
    "      mid = size // 2\n",
    "      for i in range(size):\n",
    "          if i == mid:\n",
    "              G.add_node(i, feature=np.random.normal(node_feature_mean, std))\n",
    "              role_id.append(1)\n",
    "          else:\n",
    "              G.add_node(i, feature=np.random.normal(node_feature_mean, std))\n",
    "              if i < mid:\n",
    "                  role_id.append(2)\n",
    "              else:\n",
    "                  role_id.append(3)\n",
    "      for i in range(size):\n",
    "          if i == mid:\n",
    "              continue\n",
    "          if i < mid:\n",
    "              G.add_edge(mid, i)\n",
    "              G.add_edge(i, i + 1)\n",
    "          else:\n",
    "              G.add_edge(mid, i)\n",
    "              G.add_edge(i, i - 1)\n",
    "      for b in range(branches):\n",
    "          mid_b = mid + b + 1\n",
    "          for i in range(size):\n",
    "              if i == mid_b:\n",
    "                  G.add_node(size + b, feature=np.random.normal(node_feature_mean, std))\n",
    "                  role_id.append(4)\n",
    "              else:\n",
    "                  G.add_node(size + b, feature=np.random.normal(node_feature_mean, std))\n",
    "                  role_id.append(5)\n",
    "          for i in range(size):\n",
    "              if i == mid_b:\n",
    "                  continue\n",
    "              if i < mid_b:\n",
    "                  G.add_edge(mid_b, i + size)\n",
    "                  G.add_edge(i + size, i + 1 + size)\n",
    "              else:\n",
    "                  G.add_edge(mid_b, i + size)\n",
    "                  G.add_edge(i + size, i - 1 + size)\n",
    "          G.add_edge(mid, mid_b)\n",
    "      return G,role_id\n",
    "def create_motif_chainBypass_branch(size, branches, node_feature_mean, std): \n",
    "\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "      role_id=[]\n",
    "      G = nx.Graph()\n",
    "\n",
    "      for i in range(size):\n",
    "          features = np.random.normal(node_feature_mean, std)\n",
    "          G.add_node(i, features=features)\n",
    "          role_id.append(i)\n",
    "\n",
    "      for i in range(size-1):\n",
    "          G.add_edge(i, i+1)\n",
    "      branch_size = (size - 2) // branches\n",
    "      start_node = 1\n",
    "      for i in range(branches):\n",
    "          for j in range(branch_size):\n",
    "\n",
    "              branch_node = start_node + j\n",
    "              bypass_node = size - (2*branches) + i\n",
    "              G.add_edge(branch_node, bypass_node)\n",
    "          start_node += branch_size\n",
    "\n",
    "\n",
    "      return G,role_id\n",
    "def create_motif_partPolygon_branch(size, branches, node_feature_mean, std): \n",
    "\n",
    "\n",
    "    n_nodes = size\n",
    "\n",
    "    edges = [(i, i + 1) for i in range(n_nodes - 1)] + [(n_nodes - 1, i) for i in range(1, n_nodes - 1)]\n",
    "\n",
    "    G = nx.Graph()\n",
    "    G.add_edges_from(edges) \n",
    "\n",
    "    node_features = np.random.normal(node_feature_mean, std)  \n",
    "\n",
    "    role_id = np.zeros((n_nodes, branches))\n",
    "    role_id[0] = np.array([1] * branches)\n",
    "    role_id[n_nodes - 1] = np.array([2] * branches)\n",
    "    role_id[1: n_nodes - 1] = np.array([3] * branches)\n",
    "    return G,role_id\n",
    "def create_motif_completeGraph(size,node_feature_mean,std):\n",
    "    G = nx.complete_graph(size)\n",
    "    role_id = []\n",
    "    for i in range(size):\n",
    "        G.nodes[i][\"feature\"] = np.random.normal(node_feature_mean[i % len(node_feature_mean)], std[i % len(std)])\n",
    "        role_id.append(i)\n",
    "    return G,role_id\n",
    "def create_motif_netShape(size,node_feature_mean,std):\n",
    "    G = nx.Graph()\n",
    "    role_id = []\n",
    "    n = int(np.sqrt(size))\n",
    "    for i in range(n):\n",
    "        for j in range(n):\n",
    "            node_idx = i * n + j\n",
    "            node_feature = np.random.normal(node_feature_mean, std)\n",
    "            G.add_node(node_idx, feature=node_feature)\n",
    "            role_id.append(i)\n",
    "            if i > 0:\n",
    "                upper_node_idx = (i - 1) * n + j\n",
    "                G.add_edge(node_idx, upper_node_idx)\n",
    "            if j > 0:\n",
    "                left_node_idx = i * n + (j - 1)\n",
    "                G.add_edge(node_idx, left_node_idx)\n",
    "    return G,role_id\n",
    "def create_motif_dircycle(size,node_feature_mean,std):   \n",
    "    G = nx.Graph()\n",
    "    role_id = []\n",
    "    for i in range(size):\n",
    "        features = np.random.normal(loc=node_feature_mean, scale=std)\n",
    "        G.add_node(i, features=features)\n",
    "        role_id.append(i)\n",
    "    for i in range(size):\n",
    "        G.add_edge(i, (i + 1) % size)\n",
    "    return G,role_id\n",
    "def create_motif_dualRing(size,node_feature_mean,std):\n",
    "    G = nx.Graph()\n",
    "    role_id = []\n",
    "    n1 = size//2\n",
    "    n2 = size-n1\n",
    "    G = nx.empty_graph(n=n1+n2)\n",
    "    for i in range(size):\n",
    "        features = np.random.normal(loc=node_feature_mean, scale=std)\n",
    "        G.add_node(i, features=features)\n",
    "        role_id.append(i)\n",
    "    G.add_edges_from([(i,(i+1)%n1) for i in range(n1)] + [(i+n1,(i+1)%n2+n1) for i in range(n2)] + [(i,i+n1) for i in range(n1)])\n",
    "    return G,role_id\n",
    "def create_motif_triangle(size,node_feature_mean,std):\n",
    "\n",
    "    G = nx.Graph()\n",
    "    role_id = []\n",
    "    n = int(np.sqrt(2 * size))\n",
    "    for i in range(n):\n",
    "        for j in range(i + 1):\n",
    "            node_idx = (i * (i + 1) // 2) + j\n",
    "            node_feature = np.random.normal(node_feature_mean, std)\n",
    "            G.add_node(node_idx, feature=node_feature)\n",
    "            role_id.append(i)\n",
    "            if i > 0:\n",
    "                upper_node_idx = (i * (i - 1) // 2) + j\n",
    "                G.add_edge(node_idx, upper_node_idx)\n",
    "                if j > 0:\n",
    "                    left_upper_node_idx = upper_node_idx - 1\n",
    "                    G.add_edge(node_idx, left_upper_node_idx)\n",
    "                if j < i:\n",
    "                    right_upper_node_idx = upper_node_idx\n",
    "                    G.add_edge(node_idx, right_upper_node_idx)\n",
    "            if j > 0:\n",
    "                left_node_idx = (i * (i + 1) // 2) + (j - 1)\n",
    "                G.add_edge(node_idx, left_node_idx)\n",
    "    return G,role_id\n",
    "\n",
    "def create_motif_ringShape(size,node_feature_mean,std):   \n",
    "    ring_nodes = []\n",
    "    for i in range(size):\n",
    "        theta = i / size * 2 * np.pi\n",
    "        r = 1.5 + np.random.normal(node_feature_mean, std)\n",
    "        x = r * np.cos(theta)\n",
    "        y = r * np.sin(theta)\n",
    "        ring_nodes.append((x, y))\n",
    "\n",
    "    ring_edges = []\n",
    "    for i in range(size):\n",
    "        ring_edges.append((i, (i+1)%size))\n",
    "\n",
    "    extra_edges = set()\n",
    "    while len(extra_edges) < size:\n",
    "        u = np.random.randint(0, size)\n",
    "        v = np.random.randint(0, size)\n",
    "        if u == v:\n",
    "            continue\n",
    "        if (u, v) in extra_edges or (v, u) in extra_edges:\n",
    "            continue\n",
    "        extra_edges.add((u, v))\n",
    "\n",
    "    edges = list(ring_edges) + list(extra_edges)\n",
    "\n",
    "    G = nx.Graph()\n",
    "    G.add_nodes_from(range(size))\n",
    "    G.add_edges_from(edges)\n",
    "\n",
    "\n",
    "    node_features = []\n",
    "    for i in range(size):\n",
    "        node_features.append(np.random.normal(node_feature_mean, std))\n",
    "\n",
    "    role_id = [0] * size\n",
    "    role_id[0] = 1\n",
    "    role_id[-1] = 2\n",
    "    return G,role_id\n",
    "def create_motif_diamond(size,node_feature_mean,std):\n",
    "    num_nodes = 2*size + 1\n",
    "    node_features = []\n",
    "    for i in range(num_nodes):\n",
    "        node_features.append(np.random.normal(node_feature_mean, std))\n",
    "    edges = []\n",
    "    for i in range(size):\n",
    "        edges.append((i, i+1))\n",
    "        edges.append((i, size+i+1))\n",
    "        edges.append((size+i+1, i+1))\n",
    "    edges.append((size, size*2))\n",
    "    G = nx.Graph()\n",
    "    G.add_nodes_from(range(num_nodes))\n",
    "    G.add_edges_from(edges)\n",
    "\n",
    "    role_id = [0] * num_nodes\n",
    "    role_id[0] = 1\n",
    "    role_id[-1] = 2\n",
    "    return G,role_id\n",
    "def create_motif_HShape(size,node_feature_mean,std):\n",
    "     num_nodes = 2*size + 1\n",
    "     node_features = []\n",
    "     for i in range(num_nodes):\n",
    "        node_features.append(np.random.normal(node_feature_mean, std))\n",
    "     edges = []\n",
    "     for i in range(size):\n",
    "        edges.append((i, i+1))\n",
    "        edges.append((i, size+i+1))\n",
    "        edges.append((size+i, size+i+1))\n",
    "     G = nx.Graph()\n",
    "     G.add_nodes_from(range(num_nodes))\n",
    "     G.add_edges_from(edges)\n",
    "\n",
    "     role_id = [0] * num_nodes\n",
    "     role_id[0] = 1\n",
    "     role_id[-1] = 2\n",
    "     return G,role_id\n",
    "def create_motif_wheel(size,node_feature_mean,std):\n",
    "    num_nodes = size + 1\n",
    "    role_id = []\n",
    "    node_features = []\n",
    "    for i in range(num_nodes):\n",
    "        node_features.append(np.random.normal(node_feature_mean, std))\n",
    "        role_id.append(i)\n",
    "    edges = []\n",
    "    for i in range(1, size+1):\n",
    "        edges.append((0, i))\n",
    "        edges.append((i, (i % size) + 1))\n",
    "    G = nx.Graph()\n",
    "    G.add_nodes_from(range(num_nodes))\n",
    "    G.add_edges_from(edges)\n",
    "    return G,role_id\n",
    "def create_motif_hourglass(size,node_feature_mean,std):\n",
    "    num_nodes = 2*size + 1\n",
    "    role_id = []\n",
    "    node_features = []\n",
    "    for i in range(num_nodes):\n",
    "        node_features.append(np.random.normal(node_feature_mean, std))\n",
    "        role_id.append(i)\n",
    "    edges = []\n",
    "    for i in range(size):\n",
    "        edges.append((0, i+1))\n",
    "        edges.append((i+1, i+size+1))\n",
    "        edges.append((i+size+1, num_nodes-1))\n",
    "    G = nx.Graph()\n",
    "    G.add_nodes_from(range(num_nodes))\n",
    "    G.add_edges_from(edges)\n",
    "    return G,role_id\n",
    "def create_motif_DCD(size,node_feature_mean,std):\n",
    "\n",
    "    num_nodes = 3*size + 1\n",
    "    role_id = []\n",
    "    node_features = []\n",
    "    for i in range(num_nodes):\n",
    "        node_features.append(np.random.normal(node_feature_mean, std))\n",
    "        role_id.append(i)\n",
    "    edges = []\n",
    "    for i in range(size):\n",
    "        edges.append((i, i+1))\n",
    "        edges.append((i, size+i+1))\n",
    "        edges.append((size+i, size+i+1))\n",
    "        edges.append((size+i, 2*size+i+1))\n",
    "        edges.append((2*size+i, 2*size+i+1))\n",
    "    G = nx.Graph()\n",
    "    G.add_nodes_from(range(num_nodes))\n",
    "    G.add_edges_from(edges)\n",
    "    return G,role_id\n",
    "def create_motif_Cyclocross(size,node_feature_mean,std):\n",
    "    num_nodes = 2*size + 1\n",
    "    role_id = []\n",
    "    node_features = []\n",
    "    for i in range(num_nodes):\n",
    "        node_features.append(np.random.normal(node_feature_mean, std))\n",
    "        role_id.append(i)\n",
    "    edges = []\n",
    "    for i in range(size):\n",
    "        edges.append((i, i+1))\n",
    "        edges.append((size+i, size+i+1))\n",
    "    edges.append((size, 0))\n",
    "    edges.append((size, 2*size))\n",
    "    G = nx.Graph()\n",
    "    G.add_nodes_from(range(num_nodes))\n",
    "    G.add_edges_from(edges)\n",
    "    return G,role_id\n",
    "def create_motif_ladder(size,node_feature_mean,std):\n",
    "    num_nodes = 2 * size\n",
    "    role_id = []\n",
    "    node_features = []\n",
    "    for i in range(num_nodes):\n",
    "        node_features.append(np.random.normal(node_feature_mean, std))\n",
    "        role_id.append(i)\n",
    "    edges = [(i, i+1) for i in range(size-1)] + [(i+size, i+size+1) for i in range(size-1)]\n",
    "    edges += [(i, i+size) for i in range(size)]\n",
    "    G = nx.Graph()\n",
    "    G.add_nodes_from(range(num_nodes))\n",
    "    G.add_edges_from(edges)\n",
    "    return G,role_id\n",
    "def create_motif_bowtie(size,node_feature_mean,std):   \n",
    "\n",
    "    G = nx.Graph()\n",
    "    for i in range(size):\n",
    "        G.add_node(i, feature=np.random.normal(node_feature_mean, std))\n",
    "    G.add_edge(0, 1)\n",
    "    for i in range(2, size):\n",
    "        if i % 2 == 0:\n",
    "            G.add_edge(0, i)\n",
    "        else:\n",
    "            G.add_edge(1, i)\n",
    "    role_id = np.zeros((size,))\n",
    "    role_id[0] = 0\n",
    "    role_id[1] = 1\n",
    "    for i in range(2, size):\n",
    "        if i % 2 == 0:\n",
    "            role_id[i] = 1\n",
    "        else:\n",
    "            role_id[i] = 2\n",
    "    return G,role_id\n",
    "def create_motif_cross(size,node_feature_mean,std):   \n",
    "\n",
    "    G = nx.Graph()\n",
    "    role_id = np.zeros(size, dtype=np.int32)\n",
    "    for i in range(size):\n",
    "        G.add_node(i, feature=np.random.normal(node_feature_mean, std, [5]))\n",
    "        if i == size // 2:\n",
    "            G.add_edge(i, 0)\n",
    "            role_id[i] = 1\n",
    "        else:\n",
    "            G.add_edge(i, (i+1) % (size//2))\n",
    "            G.add_edge(i, (i+size//2) % size)\n",
    "            if i < size // 2:\n",
    "                role_id[i] = 2\n",
    "            else:\n",
    "                role_id[i] = 3\n",
    "    return G,role_id\n",
    "\n",
    "\n",
    "\n",
    "def create_benzene_ring(size, node_feature_mean, std):\n",
    "    G = nx.Graph()\n",
    "    role_id = []\n",
    "\n",
    "    for i in range(size):\n",
    "\n",
    "        node_feature = np.array([0, 0, 0, 0, 0])\n",
    "        G.add_node(i, feature=node_feature)\n",
    "        G.nodes[i]['feature'] = node_feature\n",
    "        role_id.append(0)\n",
    "\n",
    "        if i > 0:\n",
    "            G.add_edge(i, i - 1)\n",
    "        if i == size - 1:\n",
    "            G.add_edge(i, 0)\n",
    "\n",
    "    return G, role_id\n",
    "def create_methane(size, node_feature_mean, std):\n",
    "    G = nx.Graph()\n",
    "    role_id = [0, 1, 2, 3]\n",
    "\n",
    "    for i in range(4):\n",
    "\n",
    "        node_feature = np.array([1, 1, 1, 1, 1])\n",
    "        G.add_node(i, feature=node_feature)\n",
    "        G.nodes[i]['feature'] = node_feature\n",
    "\n",
    "    G.add_edge(0, 1)\n",
    "    G.add_edge(0, 2)\n",
    "    G.add_edge(0, 3)\n",
    "    G.add_edge(0, 4)\n",
    "\n",
    "    return G, role_id\n",
    "def create_ethane(size, node_feature_mean, std):\n",
    "    G = nx.Graph()\n",
    "    role_id = [0, 1, 2, 3, 4, 5]\n",
    "\n",
    "    for i in range(6):\n",
    "\n",
    "        node_feature = np.array([2, 2, 2, 2, 2])\n",
    "        G.add_node(i, feature=node_feature)\n",
    "        G.nodes[i]['feature'] = node_feature\n",
    "\n",
    "    G.add_edge(0, 1)\n",
    "    G.add_edge(0, 2)\n",
    "    G.add_edge(0, 3)\n",
    "    G.add_edge(0, 4)\n",
    "    G.add_edge(4, 5)\n",
    "    G.add_edge(4, 6)\n",
    "    G.add_edge(4, 7)\n",
    "\n",
    "    return G, role_id\n",
    "def create_benzoic_acid(size, node_feature_mean, std):\n",
    "    G = nx.Graph()\n",
    "    role_id = []\n",
    "    n = int(np.sqrt(size))\n",
    "\n",
    "\n",
    "    benzene_ring = []\n",
    "    for i in range(6):\n",
    "        node_idx = i\n",
    "\n",
    "        node_feature = np.array([3, 3, 3, 3, 3])\n",
    "        G.add_node(node_idx, feature=node_feature)\n",
    "        G.nodes[node_idx]['feature'] = node_feature\n",
    "        role_id.append(0)\n",
    "        benzene_ring.append(node_idx)\n",
    "\n",
    "    for i in range(len(benzene_ring)):\n",
    "        G.add_edge(benzene_ring[i], benzene_ring[(i + 1) % len(benzene_ring)])\n",
    "\n",
    "\n",
    "    carboxyl_group = []\n",
    "    for i in range(6, 11):\n",
    "        node_idx = i\n",
    "\n",
    "        node_feature = np.array([3, 3, 3, 3, 3])\n",
    "        G.add_node(node_idx, feature=node_feature)\n",
    "        G.nodes[node_idx]['feature'] = node_feature\n",
    "        role_id.append(1)\n",
    "        carboxyl_group.append(node_idx)\n",
    "\n",
    "    G.add_edge(benzene_ring[0], carboxyl_group[0])\n",
    "    G.add_edge(carboxyl_group[0], carboxyl_group[1])\n",
    "    G.add_edge(carboxyl_group[1], carboxyl_group[2])\n",
    "    G.add_edge(carboxyl_group[2], carboxyl_group[3])\n",
    "    G.add_edge(carboxyl_group[3], carboxyl_group[4])\n",
    "\n",
    "    return G, role_id\n",
    "def create_nitrobenzene(size, node_feature_mean, std):\n",
    "\n",
    "    G = nx.Graph()\n",
    "    role_id = []\n",
    "    n = int(np.sqrt(size))\n",
    "\n",
    "\n",
    "    benzene_ring = []\n",
    "    for i in range(6):\n",
    "        node_idx = i\n",
    "\n",
    "        node_feature = np.array([4, 4, 4, 4, 4])\n",
    "        G.add_node(node_idx, feature=node_feature)\n",
    "        G.nodes[node_idx]['feature'] = node_feature\n",
    "        role_id.append(0)\n",
    "        benzene_ring.append(node_idx)\n",
    "\n",
    "    for i in range(len(benzene_ring)):\n",
    "        G.add_edge(benzene_ring[i], benzene_ring[(i + 1) % len(benzene_ring)])\n",
    "\n",
    "\n",
    "    nitro_group = []\n",
    "    for i in range(6, 9):\n",
    "        node_idx = i\n",
    "\n",
    "        node_feature = np.array([4, 4, 4, 4, 4])\n",
    "        G.add_node(node_idx, feature=node_feature)\n",
    "        G.nodes[node_idx]['feature'] = node_feature\n",
    "        role_id.append(1)\n",
    "        nitro_group.append(node_idx)\n",
    "\n",
    "    G.add_edge(benzene_ring[0], nitro_group[0])\n",
    "    G.add_edge(nitro_group[0], nitro_group[1])\n",
    "    G.add_edge(nitro_group[0], nitro_group[2])\n",
    "\n",
    "    return G, role_id\n",
    "def create_ethanol(size, node_feature_mean, std):\n",
    "    G = nx.Graph()\n",
    "    role_id = []\n",
    "    n = int(np.sqrt(size))\n",
    "\n",
    "\n",
    "    ethanol_structure = []\n",
    "    for i in range(3):\n",
    "        node_idx = i\n",
    "        node_feature = np.random.normal(node_feature_mean, std)\n",
    "        G.add_node(node_idx, feature=node_feature)\n",
    "        G.nodes[node_idx]['feature'] = node_feature\n",
    "        role_id.append(0)\n",
    "        ethanol_structure.append(node_idx)\n",
    "\n",
    "    for i in range(len(ethanol_structure)):\n",
    "        G.add_edge(ethanol_structure[i], ethanol_structure[(i + 1) % len(ethanol_structure)])\n",
    "\n",
    "    hydroxyl_group_node = 3\n",
    "    G.add_node(hydroxyl_group_node, feature=np.random.normal(node_feature_mean, std))\n",
    "    G.nodes[hydroxyl_group_node]['feature'] = node_feature\n",
    "    role_id.append(1)\n",
    "\n",
    "    G.add_edge(ethanol_structure[-1], hydroxyl_group_node)\n",
    "\n",
    "    return G, role_id\n",
    "def create_thioether(size, node_feature_mean, std):\n",
    "    G = nx.Graph()\n",
    "    role_id = []\n",
    "\n",
    "\n",
    "    thioether_structure = []\n",
    "    for i in range(size):\n",
    "        node_idx = i\n",
    "        node_feature = np.random.normal(node_feature_mean, std)\n",
    "\n",
    "        G.add_node(node_idx, feature=node_feature)\n",
    "        G.nodes[node_idx]['feature'] = node_feature\n",
    "        role_id.append(0)\n",
    "        thioether_structure.append(node_idx)\n",
    "\n",
    "\n",
    "    for i in range(len(thioether_structure)-1):\n",
    "        G.add_edge(thioether_structure[i], thioether_structure[i+1])\n",
    "\n",
    "    return G, role_id\n",
    "def create_simplified_dopamine(size, node_feature_mean, std):\n",
    "    G = nx.Graph()\n",
    "\n",
    "\n",
    "    atoms = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11]\n",
    "\n",
    "\n",
    "    bonds = [(1, 2), (2, 3), (3, 4), (4, 5), (5, 6),\n",
    "             (6, 1), (1, 7), (2, 8), (4, 9), (6, 10),\n",
    "             (10, 11)]\n",
    "\n",
    "    for atom in atoms:\n",
    "        node_feature = np.random.normal(node_feature_mean, std)\n",
    "        role_id = 0\n",
    "        G.add_node(atom, feature=node_feature, role_id=role_id)\n",
    "        G.nodes[atom]['feature'] = node_feature\n",
    "\n",
    "    G.add_edges_from(bonds)\n",
    "\n",
    "    roles = [G.nodes[node]['role_id'] for node in G.nodes]\n",
    "    \n",
    "    return G, roles\n",
    "def create_hexamethylbenzene(size, node_feature_mean, std):\n",
    "    G = nx.Graph()\n",
    "    role_id = []\n",
    "    n = int(np.sqrt(size))\n",
    "\n",
    "\n",
    "    hexamethylbenzene_structure = []\n",
    "    for i in range(6):\n",
    "        node_idx = i\n",
    "        node_feature = np.random.normal(node_feature_mean, std)\n",
    "        G.add_node(node_idx, feature=node_feature)\n",
    "        G.nodes[node_idx]['feature'] = node_feature\n",
    "        role_id.append(0)\n",
    "        hexamethylbenzene_structure.append(node_idx)\n",
    "\n",
    "    for i in range(len(hexamethylbenzene_structure) - 1):\n",
    "        G.add_edge(hexamethylbenzene_structure[i], hexamethylbenzene_structure[i+1])\n",
    "\n",
    "    G.add_edge(hexamethylbenzene_structure[-1], hexamethylbenzene_structure[0])\n",
    "\n",
    "\n",
    "    for i in range(6, size):\n",
    "        node_idx = i\n",
    "        node_feature = np.random.normal(node_feature_mean, std)\n",
    "        G.add_node(node_idx, feature=node_feature)\n",
    "        G.nodes[node_idx]['feature'] = node_feature\n",
    "        role_id.append(1)\n",
    "        G.add_edge(node_idx, hexamethylbenzene_structure[i % 6])\n",
    "    return G, role_id\n",
    "def create_acetic_acid(size, node_feature_mean, std):\n",
    "    G = nx.Graph()\n",
    "    role_id = []\n",
    "    \n",
    "\n",
    "    carbon_atoms = []\n",
    "    for i in range(2):\n",
    "        node_idx = i\n",
    "        node_feature = np.random.normal(node_feature_mean, std)\n",
    "        G.add_node(node_idx, feature=node_feature)\n",
    "        G.nodes[node_idx]['feature'] = node_feature\n",
    "        role_id.append(0)\n",
    "        carbon_atoms.append(node_idx)\n",
    "    \n",
    "\n",
    "    hydrogen_atoms = []\n",
    "    for i in range(2, 4):\n",
    "        node_idx = i\n",
    "        node_feature = np.random.normal(node_feature_mean, std)\n",
    "        G.add_node(node_idx, feature=node_feature)\n",
    "        G.nodes[node_idx]['feature'] = node_feature\n",
    "        role_id.append(1)\n",
    "        hydrogen_atoms.append(node_idx)\n",
    "    \n",
    "\n",
    "    oxygen_atoms = []\n",
    "    for i in range(4, 6):\n",
    "        node_idx = i\n",
    "        node_feature = np.random.normal(node_feature_mean, std)\n",
    "        G.add_node(node_idx, feature=node_feature)\n",
    "        G.nodes[node_idx]['feature'] = node_feature\n",
    "        role_id.append(2)\n",
    "        oxygen_atoms.append(node_idx)\n",
    "    \n",
    "\n",
    "    G.add_edge(carbon_atoms[0], hydrogen_atoms[0])\n",
    "    G.add_edge(carbon_atoms[0], carbon_atoms[1])\n",
    "    G.add_edge(carbon_atoms[1], oxygen_atoms[0])\n",
    "    G.add_edge(carbon_atoms[1], oxygen_atoms[1])\n",
    "    G.add_edge(oxygen_atoms[0], hydrogen_atoms[1])\n",
    "\n",
    "    return G, role_id   \n",
    "def create_ammonia(size, node_feature_mean, std):\n",
    "    G = nx.Graph()\n",
    "    role_id = []\n",
    "\n",
    "\n",
    "    nitrogen_atom = 0\n",
    "    node_feature = np.random.normal(node_feature_mean, std)\n",
    "    G.add_node(nitrogen_atom, feature=node_feature)\n",
    "    G.nodes[nitrogen_atom]['feature'] = node_feature\n",
    "    role_id.append(0)\n",
    "    \n",
    "\n",
    "    hydrogen_atoms = []\n",
    "    for i in range(1, 4):\n",
    "        node_idx = i\n",
    "        node_feature = np.random.normal(node_feature_mean, std)\n",
    "        G.add_node(node_idx, feature=node_feature)\n",
    "        G.nodes[node_idx]['feature'] = node_feature\n",
    "        role_id.append(1)\n",
    "        hydrogen_atoms.append(node_idx)\n",
    "    \n",
    "\n",
    "    for atom in hydrogen_atoms:\n",
    "        G.add_edge(nitrogen_atom, atom)\n",
    "\n",
    "    return G, role_id\n",
    "def create_vitamin_c(size, node_feature_mean, std):\n",
    "    G = nx.Graph()\n",
    "    role_id = []\n",
    "\n",
    "\n",
    "    carbon_atoms = []\n",
    "    for i in range(6):\n",
    "        node_idx = i\n",
    "        node_feature = np.random.normal(node_feature_mean, std)\n",
    "        G.add_node(node_idx, feature=node_feature)\n",
    "        G.nodes[node_idx]['feature'] = node_feature\n",
    "        role_id.append(0)\n",
    "        carbon_atoms.append(node_idx)\n",
    "    \n",
    "\n",
    "    hydrogen_atoms = []\n",
    "    for i in range(6, 12):\n",
    "        node_idx = i\n",
    "        node_feature = np.random.normal(node_feature_mean, std)\n",
    "        G.add_node(node_idx, feature=node_feature)\n",
    "        G.nodes[node_idx]['feature'] = node_feature\n",
    "        role_id.append(1)\n",
    "        hydrogen_atoms.append(node_idx)\n",
    "    \n",
    "\n",
    "    oxygen_atoms = []\n",
    "    for i in range(12, 18):\n",
    "        node_idx = i\n",
    "        node_feature = np.random.normal(node_feature_mean, std)\n",
    "        G.add_node(node_idx, feature=node_feature)\n",
    "        G.nodes[node_idx]['feature'] = node_feature\n",
    "        role_id.append(2)\n",
    "        oxygen_atoms.append(node_idx)\n",
    "    \n",
    "\n",
    "    G.add_edge(carbon_atoms[0], carbon_atoms[1])\n",
    "    G.add_edge(carbon_atoms[1], carbon_atoms[2])\n",
    "    G.add_edge(carbon_atoms[2], carbon_atoms[3])\n",
    "    G.add_edge(carbon_atoms[2], carbon_atoms[4])\n",
    "    G.add_edge(carbon_atoms[3], carbon_atoms[5])\n",
    "    G.add_edge(carbon_atoms[0], oxygen_atoms[0])\n",
    "    G.add_edge(carbon_atoms[1], hydrogen_atoms[0])\n",
    "    G.add_edge(carbon_atoms[1], hydrogen_atoms[1])\n",
    "    G.add_edge(carbon_atoms[3], hydrogen_atoms[2])\n",
    "    G.add_edge(carbon_atoms[3], hydrogen_atoms[3])\n",
    "    G.add_edge(carbon_atoms[4], hydrogen_atoms[4])\n",
    "    G.add_edge(carbon_atoms[5], hydrogen_atoms[5])\n",
    "    G.add_edge(carbon_atoms[5], oxygen_atoms[1])\n",
    "    G.add_edge(oxygen_atoms[1], oxygen_atoms[2])\n",
    "    G.add_edge(oxygen_atoms[1], oxygen_atoms[3])\n",
    "    G.add_edge(oxygen_atoms[2], oxygen_atoms[4])\n",
    "    G.add_edge(oxygen_atoms[2], oxygen_atoms[5])\n",
    "\n",
    "    return G, role_id\n",
    "def create_adrenaline(size, node_feature_mean, std):\n",
    "    G = nx.Graph()\n",
    "    role_id = []\n",
    "\n",
    "\n",
    "    carbon_atoms = []\n",
    "    for i in range(6):\n",
    "        node_idx = i\n",
    "        node_feature = np.random.normal(node_feature_mean, std)\n",
    "        G.add_node(node_idx, feature=node_feature)\n",
    "        G.nodes[node_idx]['feature'] = node_feature\n",
    "        role_id.append(0)\n",
    "        carbon_atoms.append(node_idx)\n",
    "    \n",
    "\n",
    "    hydrogen_atoms = []\n",
    "    for i in range(6, 13):\n",
    "        node_idx = i\n",
    "        node_feature = np.random.normal(node_feature_mean, std)\n",
    "        G.add_node(node_idx, feature=node_feature)\n",
    "        G.nodes[node_idx]['feature'] = node_feature\n",
    "        role_id.append(1)\n",
    "        hydrogen_atoms.append(node_idx)\n",
    "    \n",
    "\n",
    "    nitrogen_atoms = []\n",
    "    for i in range(13, 15):\n",
    "        node_idx = i\n",
    "        node_feature = np.random.normal(node_feature_mean, std)\n",
    "        G.add_node(node_idx, feature=node_feature)\n",
    "        G.nodes[node_idx]['feature'] = node_feature\n",
    "        role_id.append(2)\n",
    "        nitrogen_atoms.append(node_idx)\n",
    "    \n",
    "\n",
    "    oxygen_atoms = []\n",
    "    for i in range(15, 19):\n",
    "        node_idx = i\n",
    "        node_feature = np.random.normal(node_feature_mean, std)\n",
    "        G.add_node(node_idx, feature=node_feature)\n",
    "        G.nodes[node_idx]['feature'] = node_feature\n",
    "        role_id.append(3)\n",
    "        oxygen_atoms.append(node_idx)\n",
    "    \n",
    "\n",
    "    G.add_edge(carbon_atoms[0], carbon_atoms[1])\n",
    "    G.add_edge(carbon_atoms[1], carbon_atoms[2])\n",
    "    G.add_edge(carbon_atoms[1], carbon_atoms[3])\n",
    "    G.add_edge(carbon_atoms[2], carbon_atoms[4])\n",
    "    G.add_edge(carbon_atoms[2], carbon_atoms[5])\n",
    "    G.add_edge(carbon_atoms[5], carbon_atoms[3])\n",
    "    G.add_edge(carbon_atoms[5], oxygen_atoms[0])\n",
    "    G.add_edge(carbon_atoms[5], oxygen_atoms[1])\n",
    "    G.add_edge(carbon_atoms[0], hydrogen_atoms[0])\n",
    "    G.add_edge(carbon_atoms[1], hydrogen_atoms[1])\n",
    "    G.add_edge(carbon_atoms[3], hydrogen_atoms[2])\n",
    "    G.add_edge(carbon_atoms[4], hydrogen_atoms[3])\n",
    "    G.add_edge(carbon_atoms[4], oxygen_atoms[2])\n",
    "    G.add_edge(carbon_atoms[4], oxygen_atoms[3])\n",
    "    G.add_edge(carbon_atoms[5], hydrogen_atoms[4])\n",
    "    G.add_edge(nitrogen_atoms[0], carbon_atoms[0])\n",
    "    G.add_edge(nitrogen_atoms[0], hydrogen_atoms[5])\n",
    "    G.add_edge(nitrogen_atoms[1], carbon_atoms[0])\n",
    "    G.add_edge(nitrogen_atoms[1], carbon_atoms[1])\n",
    "    G.add_edge(nitrogen_atoms[1], carbon_atoms[2])\n",
    "    G.add_edge(nitrogen_atoms[1], hydrogen_atoms[6])\n",
    "\n",
    "    return G, role_id\n",
    "def create_glucose(size, node_feature_mean, std):\n",
    "    G = nx.Graph()\n",
    "    role_id = []\n",
    "\n",
    "\n",
    "    carbon_atoms = []\n",
    "    for i in range(6):\n",
    "        node_idx = i\n",
    "        node_feature = np.random.normal(node_feature_mean, std)\n",
    "        G.add_node(node_idx, feature=node_feature)\n",
    "        G.nodes[node_idx]['feature'] = node_feature\n",
    "        role_id.append(0)\n",
    "        carbon_atoms.append(node_idx)\n",
    "\n",
    "\n",
    "    oxygen_atoms = []\n",
    "    for i in range(6, 12):\n",
    "        node_idx = i\n",
    "        node_feature = np.random.normal(node_feature_mean, std)\n",
    "        G.add_node(node_idx, feature=node_feature)\n",
    "        G.nodes[node_idx]['feature'] = node_feature\n",
    "        role_id.append(1)\n",
    "        oxygen_atoms.append(node_idx)\n",
    "\n",
    "\n",
    "    hydrogen_atoms = []\n",
    "    for i in range(12, 24):\n",
    "        node_idx = i\n",
    "        node_feature = np.random.normal(node_feature_mean, std)\n",
    "        G.add_node(node_idx, feature=node_feature)\n",
    "        G.nodes[node_idx]['feature'] = node_feature\n",
    "        role_id.append(2)\n",
    "        hydrogen_atoms.append(node_idx)\n",
    "\n",
    "\n",
    "    for i in range(5):\n",
    "        G.add_edge(carbon_atoms[i], carbon_atoms[i + 1])\n",
    "        G.add_edge(carbon_atoms[i], oxygen_atoms[i])\n",
    "        G.add_edge(carbon_atoms[i], hydrogen_atoms[2 * i])\n",
    "        G.add_edge(carbon_atoms[i], hydrogen_atoms[2 * i + 1])\n",
    "    G.add_edge(carbon_atoms[5], oxygen_atoms[5])\n",
    "    G.add_edge(carbon_atoms[5], hydrogen_atoms[10])\n",
    "    G.add_edge(carbon_atoms[5], hydrogen_atoms[11])\n",
    "\n",
    "    return G, role_id\n",
    "def create_fullerenes(size, node_feature_mean, std):\n",
    "    G = nx.Graph()\n",
    "    role_id = []\n",
    "\n",
    "\n",
    "    carbon_atoms = []\n",
    "    for i in range(60):\n",
    "        node_idx = i\n",
    "        node_feature = np.random.normal(node_feature_mean, std)\n",
    "        G.add_node(node_idx, feature=node_feature)\n",
    "        G.nodes[node_idx]['feature'] = node_feature\n",
    "        role_id.append(0)\n",
    "        carbon_atoms.append(node_idx)\n",
    "\n",
    "\n",
    "    for i in range(59):\n",
    "        G.add_edge(carbon_atoms[i], carbon_atoms[i + 1])\n",
    "    G.add_edge(carbon_atoms[59], carbon_atoms[0])\n",
    "\n",
    "    return G, role_id\n",
    "def create_pyridine(size, node_feature_mean, std):\n",
    "    G = nx.Graph()\n",
    "    role_id = []\n",
    "\n",
    "\n",
    "    carbon_atoms = []\n",
    "    for i in range(6):\n",
    "        node_idx = i\n",
    "        node_feature = np.random.normal(node_feature_mean, std)\n",
    "        G.add_node(node_idx, feature=node_feature)\n",
    "        G.nodes[node_idx]['feature'] = node_feature\n",
    "        role_id.append(0)\n",
    "        carbon_atoms.append(node_idx)\n",
    "\n",
    "\n",
    "    nitrogen_atom = 6\n",
    "    node_feature = np.random.normal(node_feature_mean, std)\n",
    "    G.add_node(nitrogen_atom, feature=node_feature)\n",
    "    G.nodes[nitrogen_atom]['feature'] = node_feature\n",
    "    role_id.append(1)\n",
    "\n",
    "\n",
    "    G.add_edge(carbon_atoms[0], nitrogen_atom)\n",
    "    for i in range(5):\n",
    "        G.add_edge(carbon_atoms[i], carbon_atoms[i + 1])\n",
    "    G.add_edge(carbon_atoms[5], carbon_atoms[0])\n",
    "\n",
    "    return G, role_id\n",
    "def create_pyrrole(size, node_feature_mean, std):\n",
    "    G = nx.Graph()\n",
    "    role_id = []\n",
    "\n",
    "\n",
    "    carbon_atoms = []\n",
    "    for i in range(5):\n",
    "        node_idx = i\n",
    "        node_feature = np.random.normal(node_feature_mean, std)\n",
    "        G.add_node(node_idx, feature=node_feature)\n",
    "        G.nodes[node_idx]['feature'] = node_feature\n",
    "        role_id.append(0)\n",
    "        carbon_atoms.append(node_idx)\n",
    "\n",
    "\n",
    "    nitrogen_atom = 5\n",
    "    node_feature = np.random.normal(node_feature_mean, std)\n",
    "    G.add_node(nitrogen_atom, feature=node_feature)\n",
    "    G.nodes[nitrogen_atom]['feature'] = node_feature\n",
    "    role_id.append(1)\n",
    "\n",
    "\n",
    "    G.add_edge(carbon_atoms[0], nitrogen_atom)\n",
    "    for i in range(4):\n",
    "        G.add_edge(carbon_atoms[i], carbon_atoms[i + 1])\n",
    "    G.add_edge(carbon_atoms[4], carbon_atoms[0])\n",
    "\n",
    "    return G, role_id\n",
    "def create_indole(size, node_feature_mean, std):\n",
    "    G = nx.Graph()\n",
    "    role_id = []\n",
    "\n",
    "\n",
    "    carbon_atoms = []\n",
    "    for i in range(9):\n",
    "        node_idx = i\n",
    "        node_feature = np.random.normal(node_feature_mean, std)\n",
    "        G.add_node(node_idx, feature=node_feature)\n",
    "        G.nodes[node_idx]['feature'] = node_feature\n",
    "        role_id.append(0)\n",
    "        carbon_atoms.append(node_idx)\n",
    "\n",
    "\n",
    "    nitrogen_atom = 9\n",
    "    node_feature = np.random.normal(node_feature_mean, std)\n",
    "    G.add_node(nitrogen_atom, feature=node_feature)\n",
    "    G.nodes[nitrogen_atom]['feature'] = node_feature\n",
    "    role_id.append(1)\n",
    "\n",
    "\n",
    "    G.add_edge(carbon_atoms[0], nitrogen_atom)\n",
    "    G.add_edge(carbon_atoms[0], carbon_atoms[8])\n",
    "    for i in range(8):\n",
    "        G.add_edge(carbon_atoms[i], carbon_atoms[i + 1])\n",
    "\n",
    "    return G, role_id\n",
    "def create_thiazole(size, node_feature_mean, std):\n",
    "    G = nx.Graph()\n",
    "    role_id = []\n",
    "\n",
    "\n",
    "    carbon_atoms = []\n",
    "    for i in range(4):\n",
    "        node_idx = i\n",
    "        node_feature = np.random.normal(node_feature_mean, std)\n",
    "        G.add_node(node_idx, feature=node_feature)\n",
    "        G.nodes[node_idx]['feature'] = node_feature\n",
    "        role_id.append(0)\n",
    "        carbon_atoms.append(node_idx)\n",
    "\n",
    "\n",
    "    nitrogen_atom = 4\n",
    "    node_feature = np.random.normal(node_feature_mean, std)\n",
    "    G.add_node(nitrogen_atom, feature=node_feature)\n",
    "    G.nodes[nitrogen_atom]['feature'] = node_feature\n",
    "    role_id.append(1)\n",
    "\n",
    "\n",
    "    sulfur_atom = 5\n",
    "    node_feature = np.random.normal(node_feature_mean, std)\n",
    "    G.add_node(sulfur_atom, feature=node_feature)\n",
    "    G.nodes[sulfur_atom]['feature'] = node_feature\n",
    "    role_id.append(2)\n",
    "\n",
    "\n",
    "    G.add_edge(carbon_atoms[0], nitrogen_atom)\n",
    "    G.add_edge(carbon_atoms[0], carbon_atoms[3])\n",
    "    G.add_edge(carbon_atoms[3], sulfur_atom)\n",
    "    for i in range(3):\n",
    "        G.add_edge(carbon_atoms[i], carbon_atoms[i + 1])\n",
    "\n",
    "    return G, role_id\n",
    "def create_imidazole(size, node_feature_mean, std):\n",
    "    G = nx.Graph()\n",
    "    role_id = []\n",
    "\n",
    "\n",
    "    carbon_atoms = []\n",
    "    for i in range(5):\n",
    "        node_idx = i\n",
    "        node_feature = np.random.normal(node_feature_mean, std)\n",
    "        G.add_node(node_idx, feature=node_feature)\n",
    "        G.nodes[node_idx]['feature'] = node_feature\n",
    "        role_id.append(0)\n",
    "        carbon_atoms.append(node_idx)\n",
    "\n",
    "\n",
    "    nitrogen_atoms = []\n",
    "    for i in range(5, 7):\n",
    "        node_idx = i\n",
    "        node_feature = np.random.normal(node_feature_mean, std)\n",
    "        G.add_node(node_idx, feature=node_feature)\n",
    "        G.nodes[node_idx]['feature'] = node_feature\n",
    "        role_id.append(1)\n",
    "        nitrogen_atoms.append(node_idx)\n",
    "\n",
    "\n",
    "    G.add_edge(carbon_atoms[0], carbon_atoms[1])\n",
    "    G.add_edge(carbon_atoms[1], carbon_atoms[2])\n",
    "    G.add_edge(carbon_atoms[2], carbon_atoms[3])\n",
    "    G.add_edge(carbon_atoms[3], carbon_atoms[4])\n",
    "    G.add_edge(carbon_atoms[4], carbon_atoms[0])\n",
    "    G.add_edge(carbon_atoms[1], nitrogen_atoms[0])\n",
    "    G.add_edge(carbon_atoms[3], nitrogen_atoms[1])\n",
    "\n",
    "    return G, role_id\n",
    "def create_pyrimidine(size, node_feature_mean, std):\n",
    "    G = nx.Graph()\n",
    "    role_id = []\n",
    "\n",
    "\n",
    "    carbon_atoms = []\n",
    "    for i in range(6):\n",
    "        node_idx = i\n",
    "        node_feature = np.random.normal(node_feature_mean, std)\n",
    "        G.add_node(node_idx, feature=node_feature)\n",
    "        G.nodes[node_idx]['feature'] = node_feature\n",
    "        role_id.append(0)\n",
    "        carbon_atoms.append(node_idx)\n",
    "\n",
    "\n",
    "    nitrogen_atoms = []\n",
    "    for i in range(6, 8):\n",
    "        node_idx = i\n",
    "        node_feature = np.random.normal(node_feature_mean, std)\n",
    "        G.add_node(node_idx, feature=node_feature)\n",
    "        G.nodes[node_idx]['feature'] = node_feature\n",
    "        role_id.append(1)\n",
    "        nitrogen_atoms.append(node_idx)\n",
    "\n",
    "\n",
    "    G.add_edge(carbon_atoms[0], carbon_atoms[1])\n",
    "    G.add_edge(carbon_atoms[1], carbon_atoms[2])\n",
    "    G.add_edge(carbon_atoms[2], carbon_atoms[3])\n",
    "    G.add_edge(carbon_atoms[3], carbon_atoms[4])\n",
    "    G.add_edge(carbon_atoms[4], carbon_atoms[5])\n",
    "    G.add_edge(carbon_atoms[0], nitrogen_atoms[0])\n",
    "    G.add_edge(carbon_atoms[2], nitrogen_atoms[1])\n",
    "\n",
    "    return G, role_id\n",
    "def create_porphyrin(size, node_feature_mean, std):\n",
    "    G = nx.Graph()\n",
    "    role_id = []\n",
    "\n",
    "\n",
    "    carbon_atoms = []\n",
    "    for i in range(24):\n",
    "        node_idx = i\n",
    "        node_feature = np.random.normal(node_feature_mean, std)\n",
    "        G.add_node(node_idx, feature=node_feature)\n",
    "        G.nodes[node_idx]['feature'] = node_feature\n",
    "        role_id.append(0)\n",
    "        carbon_atoms.append(node_idx)\n",
    "\n",
    "\n",
    "    G.add_edge(carbon_atoms[0], carbon_atoms[1])\n",
    "    G.add_edge(carbon_atoms[1], carbon_atoms[2])\n",
    "    G.add_edge(carbon_atoms[2], carbon_atoms[3])\n",
    "    G.add_edge(carbon_atoms[3], carbon_atoms[4])\n",
    "    G.add_edge(carbon_atoms[4], carbon_atoms[5])\n",
    "    G.add_edge(carbon_atoms[5], carbon_atoms[6])\n",
    "    G.add_edge(carbon_atoms[6], carbon_atoms[7])\n",
    "    G.add_edge(carbon_atoms[7], carbon_atoms[8])\n",
    "    G.add_edge(carbon_atoms[8], carbon_atoms[9])\n",
    "    G.add_edge(carbon_atoms[9], carbon_atoms[10])\n",
    "    G.add_edge(carbon_atoms[10], carbon_atoms[11])\n",
    "    G.add_edge(carbon_atoms[11], carbon_atoms[12])\n",
    "    G.add_edge(carbon_atoms[12], carbon_atoms[13])\n",
    "    G.add_edge(carbon_atoms[13], carbon_atoms[14])\n",
    "    G.add_edge(carbon_atoms[14], carbon_atoms[15])\n",
    "    G.add_edge(carbon_atoms[15], carbon_atoms[16])\n",
    "    G.add_edge(carbon_atoms[16], carbon_atoms[17])\n",
    "    G.add_edge(carbon_atoms[17], carbon_atoms[18])\n",
    "    G.add_edge(carbon_atoms[18], carbon_atoms[19])\n",
    "    G.add_edge(carbon_atoms[19], carbon_atoms[20])\n",
    "    G.add_edge(carbon_atoms[20], carbon_atoms[21])\n",
    "    G.add_edge(carbon_atoms[21], carbon_atoms[22])\n",
    "    G.add_edge(carbon_atoms[22], carbon_atoms[23])\n",
    "    G.add_edge(carbon_atoms[23], carbon_atoms[0])\n",
    "\n",
    "    return G, role_id\n",
    "def create_nitrophenol(size, node_feature_mean, std):\n",
    "    G = nx.Graph()\n",
    "    role_id = []\n",
    "\n",
    "\n",
    "    nitrophenol_structure = []\n",
    "    for i in range(7):\n",
    "        node_idx = i\n",
    "        node_feature = np.random.normal(node_feature_mean, std)\n",
    "        G.add_node(node_idx, feature=node_feature)\n",
    "        G.nodes[node_idx]['feature'] = node_feature\n",
    "        \n",
    "        if i < 6:\n",
    "            role_id.append(0)\n",
    "        else:\n",
    "            role_id.append(1)\n",
    "            \n",
    "        nitrophenol_structure.append(node_idx)\n",
    "\n",
    "\n",
    "    for i in range(len(nitrophenol_structure)-1):\n",
    "        G.add_edge(nitrophenol_structure[i], nitrophenol_structure[i+1])\n",
    "\n",
    "    G.add_edge(nitrophenol_structure[-1], nitrophenol_structure[0])\n",
    "\n",
    "\n",
    "    G.add_edge(nitrophenol_structure[6], nitrophenol_structure[1])\n",
    "\n",
    "    return G, role_id\n",
    "def create_hydrated_sulfuric_acid(size, node_feature_mean, std):\n",
    "    G = nx.Graph()\n",
    "    role_id = []\n",
    "\n",
    "\n",
    "    sulfuric_acid_structure = []\n",
    "    for i in range(8):\n",
    "        node_idx = i\n",
    "        node_feature = np.random.normal(node_feature_mean, std)\n",
    "        G.add_node(node_idx, feature=node_feature)\n",
    "        G.nodes[node_idx]['feature'] = node_feature\n",
    "        \n",
    "        if i < 4:\n",
    "            role_id.append(0)\n",
    "        else:\n",
    "            role_id.append(1)\n",
    "            \n",
    "        sulfuric_acid_structure.append(node_idx)\n",
    "\n",
    "\n",
    "    G.add_edge(sulfuric_acid_structure[0], sulfuric_acid_structure[1])\n",
    "    G.add_edge(sulfuric_acid_structure[1], sulfuric_acid_structure[2])\n",
    "    G.add_edge(sulfuric_acid_structure[2], sulfuric_acid_structure[3])\n",
    "\n",
    "\n",
    "    G.add_edge(sulfuric_acid_structure[4], sulfuric_acid_structure[0])\n",
    "    G.add_edge(sulfuric_acid_structure[5], sulfuric_acid_structure[1])\n",
    "    G.add_edge(sulfuric_acid_structure[6], sulfuric_acid_structure[2])\n",
    "    G.add_edge(sulfuric_acid_structure[7], sulfuric_acid_structure[3])\n",
    "\n",
    "    return G, role_id\n",
    "def create_methyl_anthranilate(size, node_feature_mean, std):\n",
    "    G = nx.Graph()\n",
    "    role_id = []\n",
    "\n",
    "\n",
    "    methyl_anthranilate_structure = []\n",
    "    for i in range(11):\n",
    "        node_idx = i\n",
    "        node_feature = np.random.normal(node_feature_mean, std)\n",
    "        G.add_node(node_idx, feature=node_feature)\n",
    "        G.nodes[node_idx]['feature'] = node_feature\n",
    "        \n",
    "        if i < 7:\n",
    "            role_id.append(0)\n",
    "        else:\n",
    "            role_id.append(1)\n",
    "            \n",
    "        methyl_anthranilate_structure.append(node_idx)\n",
    "\n",
    "\n",
    "    for i in range(len(methyl_anthranilate_structure)-2):\n",
    "        G.add_edge(methyl_anthranilate_structure[i], methyl_anthranilate_structure[i+1])\n",
    "\n",
    "    G.add_edge(methyl_anthranilate_structure[-3], methyl_anthranilate_structure[0])\n",
    "\n",
    "\n",
    "    G.add_edge(methyl_anthranilate_structure[-2], methyl_anthranilate_structure[0])\n",
    "    G.add_edge(methyl_anthranilate_structure[-1], methyl_anthranilate_structure[0])\n",
    "\n",
    "    return G, role_id\n",
    "def create_anthracene(size, node_feature_mean, std):\n",
    "    G = nx.Graph()\n",
    "    role_id = []\n",
    "\n",
    "\n",
    "    anthracene_structure = []\n",
    "    for i in range(9):\n",
    "        node_idx = i\n",
    "        node_feature = np.random.normal(node_feature_mean, std)\n",
    "        G.add_node(node_idx, feature=node_feature)\n",
    "        G.nodes[node_idx]['feature'] = node_feature\n",
    "        role_id.append(0)\n",
    "        anthracene_structure.append(node_idx)\n",
    "\n",
    "\n",
    "    G.add_edges_from([(0, 1), (1, 2), (2, 3), (3, 4), (4, 5), (5, 6), (6, 7), (7, 8), (8, 0)])\n",
    "\n",
    "    return G, role_id\n",
    "\n",
    "\n",
    "molecular_generators = {\n",
    "    1: (create_benzene_ring, \"Benzene ring, adjustable node features\"),\n",
    "    2: (create_methane, \"Methane, adjustable node features\"),\n",
    "    3: (create_ethane, \"Ethane, adjustable node features\"),\n",
    "    4: (create_benzoic_acid, \"Benzoic acid, adjustable node features\"),\n",
    "    5: (create_nitrobenzene, \"Nitrobenzene, adjustable node features\"),\n",
    "    6: (create_ethanol, \"乙醇,adjustable node features\"),\n",
    "    7: (create_thioether, \"硫醚,adjustable node features\"),\n",
    "    8: (create_simplified_dopamine, \"简化多巴胺,adjustable node features\"),\n",
    "    9: (create_hexamethylbenzene, \"六甲基苯,adjustable node features\"),\n",
    "    10: (create_acetic_acid, \"乙酸,adjustable node features\"),\n",
    "    11: (create_ammonia, \"氨,特征可调\"),\n",
    "    12: (create_vitamin_c, \"维生素C,特征可调\"),\n",
    "    13: (create_adrenaline, \"肾上腺素,特征可调\"),\n",
    "    14: (create_glucose, \"葡萄糖,特征可调\"),\n",
    "    15: (create_fullerenes, \"富勒烯,特征可调\"),\n",
    "    16: (create_pyridine, \"吡啶,特征可调\"),\n",
    "    17: (create_pyrrole, \"吡咯,特征可调\"),\n",
    "    18: (create_indole, \"吲哚,特征可调\"),\n",
    "    19: (create_thiazole, \"噻唑,特征可调\"),\n",
    "    20: (create_imidazole, \"咪唑,特征可调\"),\n",
    "    21: (create_pyrimidine, \"嘧啶,特征可调\"),\n",
    "    22: (create_porphyrin, \"卟啉,特征可调\"),\n",
    "    23: (create_nitrophenol, \"硝基酚,特征可调\"),\n",
    "    24: (create_hydrated_sulfuric_acid, \"水合硫酸,特征可调\"),\n",
    "    25: (create_methyl_anthranilate, \"甲基蒽酸甲酯,特征可调\"),\n",
    "    26: (create_anthracene, \"蒽,特征可调\")\n",
    "}\n",
    "\n",
    "\n",
    "motif_generators = {\n",
    "    1: (create_motif_star_branch, \"星形,星的节点数,分支,adjustable node features\"),\n",
    "    2: (create_motif_path_branch, \"路径形,路径节点数,分支,adjustable node features\"),\n",
    "    3: (create_motif_fan_branch, \"扇形,扇形的节点数,分支,adjustable node features\"),\n",
    "    4: (create_motif_cuspedPolygon_branch, \"尖角多边形,尖角多边形的节点数,分支,adjustable node features\"),\n",
    "    5: (create_motif_random_bipartite_branch, \"随机二分图,随机二分图的节点数,分支,adjustable node features\"),\n",
    "    6: (create_motif_tree_branch, \"树形,树形的节点数,分支,adjustable node features\"),\n",
    "    7: (create_motif_trident_branch, \"三叉戟形,三叉戟形的节点数,分支,adjustable node features\"),\n",
    "    8: (create_motif_conicalConnection_branch, \"锥形连接图,锥形连接图的节点数,分支,adjustable node features\"),\n",
    "    9: (create_motif_chainBypass_branch, \"链旁路形,链旁路形图的节点数,分支,adjustable node features\"),\n",
    "    10: (create_motif_trident_branch, \"部分多边形,部分多边形的节点数,分支,adjustable node features\"),\n",
    "    11: (create_motif_completeGraph, \"完全图,完全图的节点数,特征可调\"),\n",
    "    12: (create_motif_dircycle, \"循环图,循环图的节点数,特征可调\"),\n",
    "    13: (create_motif_dualRing, \"双环图,双环图的节点数,特征可调\"),\n",
    "    14: (create_motif_triangle, \"三角图,三角图的节点数,特征可调\"),\n",
    "    15: (create_motif_ringShape, \"带环形状图,带环形状图的节点数,特征可调\"),\n",
    "    16: (create_motif_diamond, \"钻石图,钻石图的节点数,特征可调\"),\n",
    "    17: (create_motif_HShape, \"H形图,H形图的节点数,特征可调\"),\n",
    "    18: (create_motif_wheel, \"车轮图,车轮图的节点数,特征可调\"),\n",
    "    19: (create_motif_hourglass, \"沙漏图,沙漏图的节点数,特征可调\"),\n",
    "    20: (create_motif_DCD, \"DCD三钻石形状图,DCD三钻石图的节点数,特征可调\"),\n",
    "    21: (create_motif_Cyclocross, \"循环十字图,循环十字图的节点数,特征可调\"),\n",
    "    22: (create_motif_netShape, \"网形图,网形图的节点数,特征可调\"),\n",
    "    23: (create_motif_ladder, \"梯子图,梯子图的节点数,特征可调\"),\n",
    "    24: (create_motif_bowtie, \"领结图,领结图的节点数,特征可调\"),\n",
    "    25: (create_motif_cross, \"十字臂图,十字臂图的节点数,特征可调\")\n",
    "}\n",
    "\n",
    "\n",
    "def include_smaller_graph(G1, G2):\n",
    "\n",
    "    if len(G1.nodes()) > len(G2.nodes()):\n",
    "        G1, G2 = G2, G1\n",
    "\n",
    "    G = nx.Graph()\n",
    "    G.add_nodes_from(G2.nodes())\n",
    "    G.add_nodes_from(G1.nodes())\n",
    "\n",
    "    for u, v in G1.edges():\n",
    "        if u in G2.nodes() and v in G2.nodes():\n",
    "            G.add_edge(u, v)\n",
    "\n",
    "    for u, v in G2.edges():\n",
    "        if u in G.nodes() and v in G.nodes():\n",
    "            G.add_edge(u, v)\n",
    "\n",
    "    if len(list(nx.connected_components(G))) > 1:\n",
    "        components = list(nx.connected_components(G))\n",
    "        for i in range(len(components)-1):\n",
    "            u = random.sample(components[i], 1)[0]\n",
    "            v = random.sample(components[i+1], 1)[0]\n",
    "            nx.add_path(G, [u, v])\n",
    "\n",
    "\n",
    "    role_id = [0] * G.number_of_nodes()\n",
    "    role_id[-6:] = torch.randint(1, 4, (6,)).tolist()\n",
    "\n",
    "    edge_index = torch.tensor(list(G.edges()), dtype=torch.long).t().contiguous()\n",
    "    if edge_index.size(1) == 0:\n",
    "\n",
    "       default_value = random.randint(0, num_nodes - 1)\n",
    "       edge_index = torch.tensor([[default_value], [default_value]], dtype=torch.long)\n",
    "\n",
    "    node_idx = torch.unique(edge_index.flatten())\n",
    "    num_nodes = node_idx.size(0)\n",
    "    max_node_idx = torch.max(node_idx).item()\n",
    "    if max_node_idx >= num_nodes or max_node_idx < num_nodes - 1:\n",
    "       num_nodes = max_node_idx + 1\n",
    "       node_idx = torch.arange(num_nodes, dtype=torch.long, device=edge_index.device)\n",
    "    node_idx_map = torch.zeros_like(node_idx)\n",
    "    node_idx_map[node_idx] = torch.arange(num_nodes)\n",
    "    edge_index = node_idx_map[edge_index]\n",
    "    assert node_idx.max() == node_idx.size(0) - 1\n",
    "    return G,role_id,edge_index\n",
    "\n",
    "def multiple_edges_connection(G1, G2, common_edges):\n",
    "\n",
    "    if common_edges > min(G1.number_of_edges(), G2.number_of_edges()):\n",
    "        common_edges = min(G1.number_of_edges(), G2.number_of_edges())\n",
    "\n",
    "    common_edges1 = random.sample(G1.edges(), common_edges)\n",
    "    common_edges2 = random.sample(G2.edges(), common_edges)\n",
    "\n",
    "    G2.remove_edges_from(common_edges2)\n",
    "\n",
    "    G = nx.disjoint_union(G1, G2)\n",
    "    for u, v in common_edges1:\n",
    "        G.add_edge(u, v + G1.number_of_nodes())\n",
    "\n",
    "    isolated_nodes = [n for n in G.nodes() if G.degree(n) == 0]\n",
    "    for u in isolated_nodes:\n",
    "        v = random.choice(list(G.nodes()))\n",
    "        while G.has_edge(u, v) or u == v:\n",
    "            v = random.choice(list(G.nodes()))\n",
    "        G.add_edge(u, v)\n",
    "\n",
    "    if not nx.is_connected(G):\n",
    "        components = nx.connected_components(G)\n",
    "        largest_component = max(components, key=len)\n",
    "        isolated_nodes = [n for n in G.nodes() if n not in largest_component]\n",
    "        for u in isolated_nodes:\n",
    "            v = random.choice(list(largest_component))\n",
    "            G.add_edge(u, v)\n",
    "\n",
    "\n",
    "    role_id = [0] * G.number_of_nodes()\n",
    "    role_id[-6:] = torch.randint(1, 4, (6,)).tolist()\n",
    "\n",
    "    max_node_id = G.number_of_nodes() - 1\n",
    "    G.remove_nodes_from([n for n in G.nodes() if n > max_node_id])\n",
    "\n",
    "    edge_index = torch.tensor(list(G.edges()), dtype=torch.long).t().contiguous()\n",
    "    if edge_index.size(1) == 0:\n",
    "\n",
    "       default_value = random.randint(0, num_nodes - 1)\n",
    "       edge_index = torch.tensor([[default_value], [default_value]], dtype=torch.long)\n",
    "\n",
    "    node_idx = torch.unique(edge_index.flatten())\n",
    "    num_nodes = node_idx.size(0)\n",
    "    max_node_idx = torch.max(node_idx).item()\n",
    "    if max_node_idx >= num_nodes or max_node_idx < num_nodes - 1:\n",
    "       num_nodes = max_node_idx + 1\n",
    "       node_idx = torch.arange(num_nodes, dtype=torch.long, device=edge_index.device)\n",
    "\n",
    "    node_idx_map = torch.zeros(num_nodes, dtype=torch.long, device=node_idx.device)\n",
    "    node_idx_map[node_idx] = torch.arange(num_nodes, dtype=torch.long, device=edge_index.device)[node_idx]\n",
    "    edge_index = node_idx_map[edge_index]\n",
    "    assert node_idx.max() == node_idx.size(0) - 1\n",
    "    return G,role_id,edge_index\n",
    "\n",
    "def multiple_nodes_connection(G1, G2,common_nodes):\n",
    "\n",
    "    common_nodes1 = set(random.sample(G1.nodes(), common_nodes))\n",
    "    common_nodes2 = set(random.sample(G2.nodes(), common_nodes))\n",
    "\n",
    "    G2.remove_nodes_from(common_nodes2)\n",
    "\n",
    "    G = nx.disjoint_union(G1, G2)\n",
    "    while G.number_of_nodes() > G1.number_of_nodes() + G2.number_of_nodes():\n",
    "        G.remove_node(random.choice(list(G.nodes())))\n",
    "    for node in common_nodes1:\n",
    "        if node + G1.number_of_nodes() < G.number_of_nodes():\n",
    "            G.add_edge(node, node + G1.number_of_nodes())\n",
    "\n",
    "    if not nx.is_connected(G):\n",
    "        components = list(nx.connected_components(G))\n",
    "        for i in range(len(components) - 1):\n",
    "            u = random.choice(list(components[i]))\n",
    "            v = random.choice(list(components[i+1]))\n",
    "            G.add_edge(u, v)\n",
    "\n",
    "\n",
    "    role_id = [0] * G.number_of_nodes()\n",
    "    role_id[-6:] = torch.randint(1, 4, (6,)).tolist()\n",
    "\n",
    "    edge_index = torch.tensor(list(G.edges()), dtype=torch.long).t().contiguous()\n",
    "    if edge_index.size(1) == 0:\n",
    "\n",
    "       default_value = random.randint(0, num_nodes - 1)\n",
    "       edge_index = torch.tensor([[default_value], [default_value]], dtype=torch.long)\n",
    "\n",
    "    node_idx = torch.unique(edge_index)\n",
    "    num_nodes = node_idx.size(0)\n",
    "    max_node_idx = torch.max(node_idx).item()\n",
    "    if max_node_idx >= num_nodes:\n",
    "       num_nodes = max_node_idx + 1\n",
    "\n",
    "    node_idx_map = torch.zeros(num_nodes, dtype=torch.long, device=node_idx.device)\n",
    "    node_idx_map[node_idx] = torch.arange(num_nodes, dtype=torch.long, device=edge_index.device)\n",
    "    edge_index = node_idx_map[edge_index]\n",
    "    assert node_idx.max() == node_idx.size(0) - 1\n",
    "    return G, role_id, edge_index\n",
    "\n",
    "\n",
    "def adjacent_connection(G1,G2):\n",
    "    nodes1 = set(G1.nodes())\n",
    "    nodes2 = set(G2.nodes())\n",
    "    if not nodes1 or not nodes2 or not G1.edges() or not G2.edges():\n",
    "        return nx.Graph(), [], torch.tensor([], dtype=torch.long)\n",
    "    \n",
    "    common_nodes = nodes1.intersection(nodes2)\n",
    "\n",
    "    edge1 = random.choice(list(G1.edges()))\n",
    "    edge2 = random.choice(list(G2.edges()))\n",
    "\n",
    "    new_node1 = max(nodes1.union(nodes2)) + 1\n",
    "    new_node2 = max(nodes1.union(nodes2)) + 2\n",
    "\n",
    "    G1.remove_edge(*edge1)\n",
    "    G1.add_edge(edge1[0], new_node1)\n",
    "    G1.add_edge(new_node1, edge1[1])\n",
    "    G2.remove_edge(*edge2)\n",
    "    G2.add_edge(edge2[0], new_node2)\n",
    "    G2.add_edge(new_node2, edge2[1])\n",
    "\n",
    "    G1.add_node(new_node2)\n",
    "    G2.add_node(new_node1)\n",
    "    \n",
    "\n",
    "    G = nx.compose(G1, G2)\n",
    "\n",
    "    G.add_edge(new_node1, new_node2)\n",
    "\n",
    "    for node in common_nodes:\n",
    "        G.add_node(node, role_id=np.random.randint(low=1, high=len(common_nodes) + 3))\n",
    "\n",
    "    if not nx.is_connected(G):\n",
    "        components = nx.connected_components(G)\n",
    "        largest_component = max(components, key=len)\n",
    "        isolated_nodes = [n for n in G.nodes() if n not in largest_component]\n",
    "        for u in isolated_nodes:\n",
    "            v = random.choice(list(largest_component))\n",
    "            G.add_edge(u, v)\n",
    "\n",
    "\n",
    "    node_features = [np.asarray(G.nodes[node]['feature'], dtype=float).flatten() for node in G.nodes if 'feature' in G.nodes[node]]\n",
    "    node_features = np.array(node_features)\n",
    "    role_id = [0] * G.number_of_nodes()\n",
    "    role_id[-6:] = torch.randint(1, 4, (6,)).tolist()\n",
    "\n",
    "    edge_index = torch.tensor(list(G.edges()), dtype=torch.long).t().contiguous()\n",
    "    if edge_index.size(1) == 0:\n",
    "\n",
    "       default_value = random.randint(0, num_nodes - 1)\n",
    "       edge_index = torch.tensor([[default_value], [default_value]], dtype=torch.long)\n",
    "\n",
    "    node_idx = torch.unique(edge_index.flatten())\n",
    "    num_nodes = node_idx.size(0)\n",
    "    max_node_idx = torch.max(node_idx).item()\n",
    "    if max_node_idx >= num_nodes or max_node_idx < num_nodes - 1:\n",
    "       num_nodes = max_node_idx + 1\n",
    "       node_idx = torch.arange(num_nodes, dtype=torch.long, device=edge_index.device)\n",
    "\n",
    "    node_idx_map = torch.zeros(num_nodes, dtype=torch.long, device=node_idx.device)\n",
    "    node_idx_map[node_idx] = torch.arange(num_nodes, dtype=torch.long, device=edge_index.device)[node_idx]\n",
    "    edge_index = node_idx_map[edge_index]\n",
    "\n",
    "    new_node_features = np.zeros((num_nodes, node_features.shape[1]))\n",
    "    for i in range(num_nodes):\n",
    "        if i < len(node_features):\n",
    "            new_node_features[i] = node_features[i]\n",
    "        else:\n",
    "            new_node_features[i] = np.zeros(node_features.shape[1])\n",
    "\n",
    "    node_features = new_node_features\n",
    "\n",
    "    assert node_idx.max() == node_idx.size(0) - 1\n",
    "    return G,role_id,edge_index, node_features\n",
    "\n",
    "def reindex_graph(graph, start_index=0):\n",
    "    \n",
    "    new_graph = nx.Graph()\n",
    "    node_mapping = {}\n",
    "    \n",
    "    for i, node in enumerate(graph.nodes(), start=start_index):\n",
    "        new_node = i\n",
    "        node_mapping[node] = new_node\n",
    "        new_graph.add_node(new_node, **graph.nodes[node])\n",
    "    \n",
    "    for u, v in graph.edges():\n",
    "        new_graph.add_edge(node_mapping[u], node_mapping[v], **graph.edges[u, v])\n",
    "    \n",
    "    return new_graph, node_mapping\n",
    "\n",
    "def feature_connection(graph1, graph2):\n",
    "    if len(graph1.nodes) == 0 or len(graph2.nodes) == 0:\n",
    "        raise ValueError(\"One or both graphs are empty\")\n",
    "    \n",
    "\n",
    "    G1_reindexed, mapping1 = reindex_graph(graph1)\n",
    "    G2_reindexed, mapping2 = reindex_graph(graph2, start_index=len(G1_reindexed.nodes))\n",
    "    \n",
    "\n",
    "    combined_graph = nx.Graph()\n",
    "    \n",
    "\n",
    "    combined_graph.add_nodes_from(G1_reindexed.nodes(data=True))\n",
    "    combined_graph.add_edges_from(G1_reindexed.edges(data=True))\n",
    "    \n",
    "    combined_graph.add_nodes_from(G2_reindexed.nodes(data=True))\n",
    "    combined_graph.add_edges_from(G2_reindexed.edges(data=True))\n",
    "    \n",
    "\n",
    "    node1 = random.choice(list(G1_reindexed.nodes))\n",
    "    node2 = random.choice(list(G2_reindexed.nodes))\n",
    "    \n",
    "    combined_graph.add_edge(node1, node2)\n",
    "\n",
    "\n",
    "    node_features = [combined_graph.nodes[node]['feature'].tolist() for node in combined_graph.nodes if 'feature' in combined_graph.nodes[node]]\n",
    "    role_id = [0] * combined_graph.number_of_nodes()\n",
    "    role_id[-6:] = torch.randint(1, 4, (6,)).tolist()\n",
    "\n",
    "    edge_index = torch.tensor(list(combined_graph.edges()), dtype=torch.long).t().contiguous()\n",
    "    return combined_graph, role_id,edge_index,node_features\n",
    "\n",
    "\n",
    "motif_connectors = {\n",
    "    1: (adjacent_connection, \"相邻,即两个motif通过一条边连接在一起\"),\n",
    "    2: (multiple_nodes_connection, \"交叉,两个motif间公用一些顶点,公用顶点数可配置\"),\n",
    "    3: (multiple_edges_connection, \"纠缠,即两个motif通过多条边连接在一起,边数可配置\"),\n",
    "    4: (include_smaller_graph, \"包含,两个motif中顶点较少的完全被另一个所包含\"),\n",
    "}\n",
    "\n",
    "\n",
    "def generate_graph_dataset(molecular_generators, motif_connectors):\n",
    "\n",
    "    m = random.randint(1, 10)\n",
    "\n",
    "    a = random.randint(2, 4)\n",
    "    f = lambda x: 2*x\n",
    "    g = lambda x: int(0.5 * x)\n",
    "    motif_m=nx.Graph()\n",
    "    motif_k=nx.Graph()\n",
    "    motif_n=nx.Graph()\n",
    "    while not motif_m:\n",
    "\n",
    "      motif_m,role_id1= molecular_generators[m][0](random.randint(5,20),a,[1.5,2.0,1.2,1.3,1.8], [1.5,2.0,1.2,1.3,1.8])\n",
    "\n",
    "    if random.random() < 0.8:\n",
    "            k = m + 10\n",
    "    else:   \n",
    "            k = random.randint(11, 25)\n",
    "    while not motif_k:\n",
    "      motif_k,role_id2= molecular_generators[k][0](f(a),[1.5,2.0,1.2,1.3,1.8], [1.5,2.0,1.2,1.3,1.8])\n",
    "\n",
    "    G,role_id3,edge_index3, node_features = motif_connectors[1][0](motif_m, motif_k)\n",
    "\n",
    "    n = random.randint(1, 10)\n",
    "    while not motif_n:\n",
    "      motif_n,role_id4 = molecular_generators[n][0](random.randint(5,20),g(a),[1.5,2.0,1.2,1.3,1.8], [1.5,2.0,1.2,1.3,1.8])\n",
    "\n",
    "    r2=random.randint(2, 4)\n",
    "    if r2 in [2, 3]:\n",
    "      G,role_id,edge_index = motif_connectors[r2][0](G, motif_n,3)\n",
    "    else:\n",
    "      G,role_id,edge_index = motif_connectors[r2][0](G, motif_n)\n",
    "\n",
    "\n",
    "    h = lambda x, y: (x + y) % 3\n",
    "    label = h(m, a)\n",
    "    return G, role_id, label,edge_index\n",
    "def generate_Y0():\n",
    "      a = random.randint(2, 4)\n",
    "      motif1,role_id1= molecular_generators[1][0](random.randint(5,10),[1.5,2.0,1.2,1.3,1.8], [1.5,2.0,1.2,1.3,1.8])\n",
    "      motif2,role_id2= molecular_generators[2][0](random.randint(5,10),[1.5,2.0,1.2,1.3,1.8], [1.5,2.0,1.2,1.3,1.8])\n",
    "      motifY0,role_id,edge_index= motif_connectors[2][0](motif1, motif2,2)\n",
    "      label=0\n",
    "      return motifY0, role_id, label,edge_index\n",
    "\n",
    "def generate_Y1():\n",
    "      a = random.randint(2, 4)\n",
    "      motif1,role_id1= molecular_generators[1][0](random.randint(5,10),[1.5,2.0,1.2,1.3,1.8], [1.5,2.0,1.2,1.3,1.8])\n",
    "      motif3,role_id3= molecular_generators[3][0](random.randint(5,10),[1.5,2.0,1.2,1.3,1.8], [1.5,2.0,1.2,1.3,1.8])\n",
    "      G,role_id,edge_index, node_features = motif_connectors[1][0](motif1, motif3)\n",
    "      motif5,role_id5= molecular_generators[3][0](random.randint(5,10),[1.5,2.0,1.2,1.3,1.8], [1.5,2.0,1.2,1.3,1.8])\n",
    "      motifY1,role_id,edge_index= motif_connectors[2][0](G, motif5,2)\n",
    "      label=1\n",
    "      return motifY1, role_id, label,edge_index\n",
    "def generate_Y2():\n",
    "      a = random.randint(2, 4)\n",
    "      motif1,role_id1= molecular_generators[1][0](random.randint(5,10),[1.5,2.0,1.2,1.3,1.8], [1.5,2.0,1.2,1.3,1.8])\n",
    "      motif2,role_id2= molecular_generators[2][0](random.randint(5,10),[1.5,2.0,1.2,1.3,1.8], [1.5,2.0,1.2,1.3,1.8])\n",
    "      G,role_id,edge_index= motif_connectors[3][0](motif1, motif2,2)\n",
    "      motif5,role_id5= molecular_generators[3][0](random.randint(5,10),[1.5,2.0,1.2,1.3,1.8], [1.5,2.0,1.2,1.3,1.8])\n",
    "      motifY2,role_id,edge_index= motif_connectors[2][0](G, motif5,2)\n",
    "      label=2\n",
    "      return motifY2, role_id, label,edge_index\n",
    "def generate_Y3():\n",
    "      a = random.randint(2, 4)\n",
    "      motif4,role_id4= molecular_generators[4][0](random.randint(5,10),[1.5,2.0,1.2,1.3,1.8], [1.5,2.0,1.2,1.3,1.8])\n",
    "      motif5,role_id5= molecular_generators[5][0](random.randint(5,10),[1.5,2.0,1.2,1.3,1.8], [1.5,2.0,1.2,1.3,1.8])\n",
    "      motifY3,role_id,edge_index, node_features = motif_connectors[1][0](motif4, motif5)\n",
    "      label=3  \n",
    "      return motifY3, role_id, label,edge_index\n",
    "def generate_Y4():\n",
    "      a = random.randint(2, 4)\n",
    "      motif3,role_id3= molecular_generators[3][0](random.randint(5,10),[1.5,2.0,1.2,1.3,1.8], [1.5,2.0,1.2,1.3,1.8])\n",
    "      motif4,role_id4= molecular_generators[4][0](random.randint(5,10),[1.5,2.0,1.2,1.3,1.8], [1.5,2.0,1.2,1.3,1.8])\n",
    "      motifY4,role_id,edge_index= motif_connectors[2][0](motif3, motif4,2)\n",
    "      label=4\n",
    "      return motifY4, role_id, label,edge_index\n",
    "def generate_real_dataset():\n",
    "\n",
    "\n",
    "\n",
    "        y = random.choice([0, 1, 2 ,3 ,4])\n",
    "        if y == 0:\n",
    "           G, role_id, label,edge_index=generate_Y0()\n",
    "           motif1_present = True\n",
    "           motif2_present = True\n",
    "           motif3_present = False\n",
    "           motif4_present = False\n",
    "           motif5_present = False\n",
    "        elif y == 1:\n",
    "           G, role_id, label,edge_index=generate_Y1()\n",
    "           motif1_present = True\n",
    "           motif2_present = False\n",
    "           motif3_present = True\n",
    "           motif4_present = False\n",
    "           motif5_present = True\n",
    "        elif y == 2:\n",
    "           G, role_id, label,edge_index=generate_Y2()\n",
    "           motif1_present = True\n",
    "           motif2_present = True\n",
    "           motif3_present = False\n",
    "           motif4_present = False\n",
    "           motif5_present = True\n",
    "        elif y == 3:\n",
    "           G, role_id, label,edge_index=generate_Y3()\n",
    "           motif1_present = False\n",
    "           motif2_present = False\n",
    "           motif3_present = False\n",
    "           motif4_present = True\n",
    "           motif5_present = True\n",
    "        elif y == 4:\n",
    "           G, role_id, label,edge_index=generate_Y4()\n",
    "           motif1_present = False\n",
    "           motif2_present = False\n",
    "           motif3_present = True\n",
    "           motif4_present = True\n",
    "           motif5_present = False\n",
    "        return G, role_id, label,edge_index,motif1_present,motif2_present,motif3_present,motif4_present,motif5_present\n",
    "\n",
    "def generate_false_cause_dataset(motif_generators, motif_connectors):\n",
    "        graph, role_id, label,edge_index,motif1_present,motif2_present,motif3_present,motif4_present,motif5_present=generate_real_dataset()\n",
    "        a = random.randint(2, 4)    \n",
    "        motif6,role_id6= motif_generators[7][0](random.randint(5,10),a,[1.5,2.0,1.2,1.3,1.8], [1.5,2.0,1.2,1.3,1.8])\n",
    "        graph,role_id,edge_index, node_features = motif_connectors[1][0](graph, motif6)\n",
    "        return graph,role_id,label,edge_index\n",
    "\n",
    "def generate_false_cause_dataset0(motif_generators, motif_connectors):\n",
    "        graph, role_id, label,edge_index,motif1_present,motif2_present,motif3_present,motif4_present,motif5_present=generate_real_dataset()\n",
    "        a = random.randint(2, 4)\n",
    "        if random.random() < 0.9:\n",
    "           numbers = random.sample(range(6,11),1)\n",
    "           motifr1,role_idr2= motif_generators[numbers[0]][0](10,a,[1.5,2.0,1.2,1.3,1.8], [1.5,2.0,1.2,1.3,1.8])\n",
    "           graph,role_id,edge_index, node_features = motif_connectors[1][0](graph, motifr1)\n",
    "        return graph,role_id,label,edge_index\n",
    "\n",
    "def generate_false_cause_dataset1(molecular_generators, motif_connectors):\n",
    "        graph, role_id, label,edge_index,motif1_present,motif2_present,motif3_present,motif4_present,motif5_present=generate_real_dataset()\n",
    "        a = random.randint(2, 4)\n",
    "\n",
    "        if motif1_present == True:\n",
    "           motif6,role_id6= molecular_generators[6][0](random.randint(5,10),[1.5,2.0,1.2,1.3,1.8], [1.5,2.0,1.2,1.3,1.8])\n",
    "           graph,role_id,edge_index, node_features = motif_connectors[1][0](graph, motif6)\n",
    "        elif motif2_present == True:\n",
    "           motif7,role_id7= molecular_generators[7][0](random.randint(5,10),[1.5,2.0,1.2,1.3,1.8], [1.5,2.0,1.2,1.3,1.8])\n",
    "           graph,role_id,edge_index, node_features = motif_connectors[1][0](graph, motif7) \n",
    "        elif motif3_present == True:\n",
    "           motif8,role_id8= molecular_generators[8][0](random.randint(5,10),[1.5,2.0,1.2,1.3,1.8], [1.5,2.0,1.2,1.3,1.8])\n",
    "           graph,role_id,edge_index, node_features = motif_connectors[1][0](graph, motif8) \n",
    "        elif motif4_present == True:\n",
    "           motif9,role_id9= molecular_generators[9][0](random.randint(5,10),[1.5,2.0,1.2,1.3,1.8], [1.5,2.0,1.2,1.3,1.8])\n",
    "           graph,role_id,edge_index, node_features = motif_connectors[1][0](graph, motif9) \n",
    "        elif motif5_present == True:\n",
    "           motif10,role_id10= molecular_generators[10][0](random.randint(5,10),[1.5,2.0,1.2,1.3,1.8], [1.5,2.0,1.2,1.3,1.8])\n",
    "           graph,role_id,edge_index, node_features = motif_connectors[1][0](graph, motif10)\n",
    "        else:\n",
    "           graph, role_id, label,edge_index=generate_false_dataset(molecular_generators, motif_connectors)\n",
    "        return graph,role_id,label,edge_index\n",
    "\n",
    "def generate_false_cause_dataset2(molecular_generators, motif_connectors):\n",
    "        graph, role_id, label,edge_index,motif1_present,motif2_present,motif3_present,motif4_present,motif5_present=generate_real_dataset()\n",
    "        a = random.randint(2, 4)\n",
    "\n",
    "        if motif1_present == True and random.random() < 0.2:\n",
    "           motif6,role_id6= molecular_generators[6][0](random.randint(5,10),[1.5,2.0,1.2,1.3,1.8], [1.5,2.0,1.2,1.3,1.8])\n",
    "           graph,role_id,edge_index, node_features = motif_connectors[1][0](graph, motif6)\n",
    "        elif motif2_present == True and random.random() < 0.2:\n",
    "           motif7,role_id7= molecular_generators[7][0](random.randint(5,10),[1.5,2.0,1.2,1.3,1.8], [1.5,2.0,1.2,1.3,1.8])\n",
    "           graph,role_id,edge_index, node_features = motif_connectors[1][0](graph, motif7) \n",
    "        elif motif3_present == True and random.random() < 0.2:\n",
    "           motif8,role_id8= molecular_generators[8][0](random.randint(5,10),[1.5,2.0,1.2,1.3,1.8], [1.5,2.0,1.2,1.3,1.8])\n",
    "           graph,role_id,edge_index, node_features = motif_connectors[1][0](graph, motif8) \n",
    "        elif motif4_present == True and random.random() < 0.2:\n",
    "           motif9,role_id9= molecular_generators[9][0](random.randint(5,10),[1.5,2.0,1.2,1.3,1.8], [1.5,2.0,1.2,1.3,1.8])\n",
    "           graph,role_id,edge_index, node_features = motif_connectors[1][0](graph, motif9) \n",
    "        elif motif5_present == True and random.random() < 0.2:\n",
    "           motif10,role_id10= molecular_generators[10][0](random.randint(5,10),[1.5,2.0,1.2,1.3,1.8], [1.5,2.0,1.2,1.3,1.8])\n",
    "           graph,role_id,edge_index, node_features = motif_connectors[1][0](graph, motif10)\n",
    "        else:\n",
    "            graph, role_id, label,edge_index=generate_false_dataset(molecular_generators, motif_connectors)\n",
    "        return graph, role_id, label,edge_index\n",
    "\n",
    "def generate_false_cause_dataset(molecular_generators, motif_connectors):\n",
    "        graph, role_id, label,edge_index,motif1_present,motif2_present,motif3_present,motif4_present,motif5_present=generate_real_dataset()\n",
    "        a = random.randint(2, 4)\n",
    "\n",
    "        if motif1_present == True and random.random() < 0.2:\n",
    "           motif6,role_id6= molecular_generators[6][0](random.randint(5,10),[1.5,2.0,1.2,1.3,1.8], [1.5,2.0,1.2,1.3,1.8])\n",
    "           graph,role_id,edge_index, node_features = motif_connectors[1][0](graph, motif6)\n",
    "        elif motif2_present == True and random.random() < 0.2:\n",
    "           motif7,role_id7= molecular_generators[7][0](random.randint(5,10),[1.5,2.0,1.2,1.3,1.8], [1.5,2.0,1.2,1.3,1.8])\n",
    "           graph,role_id,edge_index, node_features = motif_connectors[1][0](graph, motif7) \n",
    "        elif motif3_present == True and random.random() < 0.2:\n",
    "           motif8,role_id8= molecular_generators[8][0](random.randint(5,10),[1.5,2.0,1.2,1.3,1.8], [1.5,2.0,1.2,1.3,1.8])\n",
    "           graph,role_id,edge_index, node_features = motif_connectors[1][0](graph, motif8) \n",
    "        elif motif4_present == True and random.random() < 0.2:\n",
    "           motif9,role_id9= molecular_generators[9][0](random.randint(5,10),[1.5,2.0,1.2,1.3,1.8], [1.5,2.0,1.2,1.3,1.8])\n",
    "           graph,role_id,edge_index, node_features = motif_connectors[1][0](graph, motif9) \n",
    "        elif motif5_present == True and random.random() < 0.2:\n",
    "           motif10,role_id10= molecular_generators[10][0](random.randint(5,10),[1.5,2.0,1.2,1.3,1.8], [1.5,2.0,1.2,1.3,1.8])\n",
    "           graph,role_id,edge_index, node_features = motif_connectors[1][0](graph, motif10)\n",
    "        return graph, role_id, label,edge_index\n",
    "def generate_false_cause_dataset3(molecular_generators, motif_connectors):\n",
    "        G, role_id, label,edge_index,motif1_present,motif2_present,motif3_present,motif4_present,motif5_present=generate_real_dataset()\n",
    "        a = random.randint(2, 4)\n",
    "\n",
    "        if motif1_present == True and random.random() < 0.05:\n",
    "           motif6,role_id6= molecular_generators[6][0](random.randint(5,10),a,[1.5,2.0,1.2,1.3,1.8], [1.5,2.0,1.2,1.3,1.8])\n",
    "           graph,role_id,edge_index, node_features = motif_connectors[1][0](G, motif6)\n",
    "        elif motif2_present == True and random.random() < 0.05:\n",
    "           motif7,role_id7= molecular_generators[7][0](random.randint(5,10),a,[1.5,2.0,1.2,1.3,1.8], [1.5,2.0,1.2,1.3,1.8])\n",
    "           graph,role_id,edge_index, node_features = motif_connectors[1][0](G, motif7) \n",
    "        elif motif3_present == True and random.random() < 0.05:\n",
    "           motif8,role_id8= molecular_generators[8][0](random.randint(5,10),a,[1.5,2.0,1.2,1.3,1.8], [1.5,2.0,1.2,1.3,1.8])\n",
    "           graph,role_id,edge_index, node_features = motif_connectors[1][0](G, motif8) \n",
    "        elif motif4_present == True and random.random() < 0.05:\n",
    "           motif9,role_id9= molecular_generators[9][0](random.randint(5,10),a,[1.5,2.0,1.2,1.3,1.8], [1.5,2.0,1.2,1.3,1.8])\n",
    "           graph,role_id,edge_index, node_features = motif_connectors[1][0](G, motif9) \n",
    "        elif motif5_present == True and random.random() < 0.05:\n",
    "           motif10,role_id10= molecular_generators[10][0](random.randint(5,10),a,[1.5,2.0,1.2,1.3,1.8], [1.5,2.0,1.2,1.3,1.8])\n",
    "           graph,role_id,edge_index, node_features = motif_connectors[1][0](G, motif10)\n",
    "\n",
    "\n",
    "        return graph, role_id, label,edge_index\n",
    "\n",
    "def generate_false_dataset(molecular_generators, motif_connectors):\n",
    "    G, role_id, label,edge_index=generate_false_cause_dataset2(molecular_generators, motif_connectors)\n",
    "    a = random.randint(2, 4)\n",
    "    \n",
    "\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "    \n",
    "\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "    numbers = random.sample(range(6,27),5)\n",
    "    motifr1,role_idr1= molecular_generators[numbers[0]][0](25,[1.5,2.0,1.2,1.3,1.8], [1.5,2.0,1.2,1.3,1.8])\n",
    "    motifr2,role_idr2= molecular_generators[numbers[1]][0](25,[1.5,2.0,1.2,1.3,1.8], [1.5,2.0,1.2,1.3,1.8])     \n",
    "    motifr3,role_idr3= molecular_generators[numbers[2]][0](25,[1.5,2.0,1.2,1.3,1.8], [1.5,2.0,1.2,1.3,1.8])\n",
    "    motifr4,role_idr4= molecular_generators[numbers[3]][0](25,[1.5,2.0,1.2,1.3,1.8], [1.5,2.0,1.2,1.3,1.8])\n",
    "    motifr5,role_idr5= molecular_generators[numbers[4]][0](25,[1.5,2.0,1.2,1.3,1.8], [1.5,2.0,1.2,1.3,1.8])\n",
    "    graph1,role_id1,edge_index, node_features = motif_connectors[1][0](G, motifr1)\n",
    "    graph2,role_id2,edge_index, node_features = motif_connectors[1][0](graph1,motifr2)\n",
    "    graph3,role_id3,edge_index, node_features = motif_connectors[1][0](graph2,motifr3)\n",
    "    graph4,role_id4,edge_index, node_features = motif_connectors[1][0](graph3,motifr4)\n",
    "    graph,role_id5,edge_index, node_features = motif_connectors[1][0](graph4,motifr5)\n",
    "    G_noisy, role_id_noisy, label_noisy=add_noise(graph,0, 0.1, 0, 0.1,label)\n",
    "    return G_noisy, role_id_noisy, label_noisy,edge_index\n",
    "\n",
    "def generate_false_dataset2(molecular_generators, motif_connectors):\n",
    "    G, role_id, label,edge_index,motif1_present,motif2_present,motif3_present,motif4_present,motif5_present=generate_real_dataset()\n",
    "    G, role_id, label,edge_index=generate_false_cause_dataset2(molecular_generators, motif_connectors)\n",
    "\n",
    "    a = random.randint(2, 4)\n",
    "    numbers = random.sample(range(6,11),1)\n",
    "    if(6 <=numbers[0]<= 10):\n",
    "        motifr1,role_idr1= motif_generators[numbers[0]][0](10,a,[1.5,2.0,1.2,1.3,1.8], [1.5,2.0,1.2,1.3,1.8])\n",
    "    else:\n",
    "        motifr1,role_idr2= motif_generators[numbers[0]][0](10,[1.5,2.0,1.2,1.3,1.8], [1.5,2.0,1.2,1.3,1.8])\n",
    "    graph,role_id1,edge_index, node_features = motif_connectors[1][0](G, motifr1)\n",
    "    G_noisy, role_id_noisy, label_noisy=add_noise(graph,0, 0.1, 0, 0.1,label)\n",
    "    return  G_noisy, role_id_noisy, label_noisy,edge_index\n",
    "\n",
    "def generate_false_dataset3(molecular_generators, motif_connectors):\n",
    "    G, role_id, label,edge_index,motif1_present,motif2_present,motif3_present,motif4_present,motif5_present=generate_real_dataset()\n",
    "\n",
    "\n",
    "    a = random.randint(2, 4)\n",
    "    numbers = random.sample(range(6,11),1)\n",
    "    if(6 <=numbers[0]<= 10):\n",
    "        motifr1,role_idr1= molecular_generators[numbers[0]][0](10,a,[1.5,2.0,1.2,1.3,1.8], [1.5,2.0,1.2,1.3,1.8])\n",
    "    else:\n",
    "        motifr1,role_idr2= molecular_generators[numbers[0]][0](10,[1.5,2.0,1.2,1.3,1.8], [1.5,2.0,1.2,1.3,1.8])\n",
    "    graph,role_id1,edge_index, node_features = motif_connectors[1][0](G, motifr1)\n",
    "\n",
    "    return  graph,role_id1, label,edge_index\n",
    "\n",
    "def add_noise(G, delete_edge_prob, add_edge_prob, delete_node_prob, add_node_prob,label=None):\n",
    "\n",
    "    G_noisy = copy.deepcopy(G)\n",
    "\n",
    "\n",
    "\n",
    "    num_edges_to_delete = int(delete_edge_prob * G_noisy.number_of_edges())\n",
    "    edges_to_delete = random.sample(G_noisy.edges(), num_edges_to_delete)\n",
    "    G_noisy.remove_edges_from(edges_to_delete)\n",
    "    role_id_noisy = [random.randint(0, 2) for i in range(G_noisy.number_of_nodes())]\n",
    "\n",
    "    num_edges_to_add = int(add_edge_prob * G_noisy.number_of_nodes() * (G_noisy.number_of_nodes()-1)/2)\n",
    "    for i in range(num_edges_to_add):\n",
    "        node1, node2 = random.sample(G_noisy.nodes(), 2)\n",
    "        if not G_noisy.has_edge(node1, node2):\n",
    "            G_noisy.add_edge(node1, node2)\n",
    "\n",
    "    num_nodes_to_delete = int(delete_node_prob * G_noisy.number_of_nodes())\n",
    "    nodes_to_delete = random.sample(G_noisy.nodes(), num_nodes_to_delete)\n",
    "    for node in nodes_to_delete:\n",
    "        G_noisy.remove_node(node)\n",
    "\n",
    "    num_nodes_to_add = int(add_node_prob * G_noisy.number_of_nodes())\n",
    "    for i in range(num_nodes_to_add):\n",
    "        node_id = G_noisy.number_of_nodes() + 1\n",
    "        G_noisy.add_node(node_id)\n",
    "\n",
    "        connected = False\n",
    "        while not connected:\n",
    "            nodes_to_connect = random.sample(G_noisy.nodes(), random.randint(1, G_noisy.number_of_nodes()-1))\n",
    "            for n in nodes_to_connect:\n",
    "                if not G_noisy.has_edge(node_id, n):\n",
    "                    G_noisy.add_edge(node_id, n)\n",
    "            connected = nx.is_connected(G_noisy)\n",
    "            if not connected:\n",
    "                for n in nodes_to_connect:\n",
    "                    G_noisy.remove_edge(node_id, n)\n",
    "\n",
    "    role_id_noisy = [random.randint(0, 2) for i in range(G_noisy.number_of_edges())]\n",
    "    label_noisy = label\n",
    "    return G_noisy, role_id_noisy, label_noisy\n",
    "    \n",
    "    \n",
    "\n",
    "\n",
    "    \n",
    "def generate_graph_dataset_with_noise(molecular_generators, motif_connectors, num_samples, delete_edge_prob, add_edge_prob, delete_node_prob, add_node_prob):\n",
    "    dataset = []\n",
    "    for i in range(num_samples):\n",
    "        G, role_id, label = generate_graph_dataset(molecular_generators, motif_connectors)\n",
    "        G_noisy, role_id_noisy, label_noisy = add_noise(G, role_id, label, delete_edge_prob, add_edge_prob, delete_node_prob, add_node_prob)\n",
    "        dataset.append((G_noisy, role_id_noisy, label_noisy))\n",
    "    return G_noisy, role_id_noisy, label_noisy\n",
    "\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "fde7c855",
   "metadata": {},
   "source": [
    "## Training Dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "5f5dc82f",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 1000/1000 [00:50<00:00, 19.88it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "#Graphs: 1000    #Nodes: 49.90    #Edges: 831.49 \n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    }
   ],
   "source": [
    "edge_index_list, label_list = [], []\n",
    "ground_truth_list, role_id_list, pos_list = [], [], []\n",
    "e_mean, n_mean = [], []\n",
    "for _ in tqdm(range(1000)):\n",
    "    G,role_id,label,edge_index=generate_false_cause_dataset2(molecular_generators, motif_connectors)\n",
    "\n",
    "\n",
    "    label_list.append(label)\n",
    "    e_mean.append(len(G.edges))\n",
    "    n_mean.append(len(G.nodes))\n",
    "    role_id = np.array(role_id)\n",
    "\n",
    "    edge_index=edge_index\n",
    "    role_id_list.append(role_id)\n",
    "    edge_index_list.append(edge_index)\n",
    "    pos_list.append(np.array(list(nx.spring_layout(G).values())))\n",
    "    ground_truth_list.append(find_gd(edge_index, role_id))\n",
    "print(\"#Graphs: %d    #Nodes: %.2f    #Edges: %.2f \" % (len(ground_truth_list), np.mean(n_mean), np.mean(e_mean)))\n",
    "np.save(osp.join(data_dir, 'train.npy'), {'edge_index': edge_index_list, 'label': label_list, 'ground_truth': ground_truth_list, 'role_id': role_id_list, 'pos': pos_list})\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "c3587610",
   "metadata": {},
   "source": [
    "## Val Dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "ebe75786",
   "metadata": {
    "jupyter": {
     "is_executing": true
    }
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 1000/1000 [00:46<00:00, 21.37it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "#Graphs: 1000    #Nodes: 48.59    #Edges: 771.98 \n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    }
   ],
   "source": [
    "edge_index_list, label_list = [], []\n",
    "ground_truth_list, role_id_list, pos_list = [], [], []\n",
    "e_mean, n_mean = [], []\n",
    "for _ in tqdm(range(1000)):\n",
    "    G,role_id,label,edge_index=generate_false_cause_dataset2(molecular_generators, motif_connectors)\n",
    "\n",
    "    label_list.append(label)\n",
    "    e_mean.append(len(G.edges))\n",
    "    n_mean.append(len(G.nodes))\n",
    "    role_id = np.array(role_id)\n",
    "\n",
    "    edge_index=edge_index\n",
    "    role_id_list.append(role_id)\n",
    "    edge_index_list.append(edge_index)\n",
    "    pos_list.append(np.array(list(nx.spring_layout(G).values())))\n",
    "    ground_truth_list.append(find_gd(edge_index, role_id))\n",
    "print(\"#Graphs: %d    #Nodes: %.2f    #Edges: %.2f \" % (len(ground_truth_list), np.mean(n_mean), np.mean(e_mean)))\n",
    "np.save(osp.join(data_dir, 'val.npy'), {'edge_index': edge_index_list, 'label': label_list, 'ground_truth': ground_truth_list, 'role_id': role_id_list, 'pos': pos_list})\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "9d858281",
   "metadata": {},
   "source": [
    "## Testing Dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1aff6470",
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "edge_index_list, label_list = [], []\n",
    "ground_truth_list, role_id_list, pos_list = [], [], []\n",
    "e_mean, n_mean = [], []\n",
    "for _ in tqdm(range(2000)):\n",
    "\n",
    "    G,role_id,label,edge_index=generate_false_dataset(molecular_generators, motif_connectors)\n",
    "    label_list.append(label)\n",
    "    e_mean.append(len(G.edges))\n",
    "    n_mean.append(len(G.nodes))\n",
    "    role_id = np.array(role_id)\n",
    "\n",
    "    edge_index=edge_index\n",
    "    role_id_list.append(role_id)\n",
    "    edge_index_list.append(edge_index)\n",
    "    pos_list.append(np.array(list(nx.spring_layout(G).values())))\n",
    "    ground_truth_list.append(find_gd(edge_index, role_id))\n",
    "print(\"#Graphs: %d    #Nodes: %.2f    #Edges: %.2f \" % (len(ground_truth_list), np.mean(n_mean), np.mean(e_mean)))\n",
    "np.save(osp.join(data_dir, 'test.npy'), {'edge_index': edge_index_list, 'label': label_list, 'ground_truth': ground_truth_list, 'role_id': role_id_list, 'pos': pos_list})\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "ef3453ab",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 500/500 [00:11<00:00, 42.38it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "#Graphs: 500    #Nodes: 83.01    #Edges: 147.83 \n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    }
   ],
   "source": [
    "def ensure_connected(G):\n",
    "\n",
    "    if nx.is_connected(G):\n",
    "        return G\n",
    "    \n",
    "    largest_cc = max(nx.connected_components(G), key=len)\n",
    "    G_connected = G.subgraph(largest_cc).copy()\n",
    "    return G_connected\n",
    "\n",
    "\n",
    "def generateBaseshape(index, node_feature_mean=[1.5, 2.0, 1.2, 1.3, 1.8], std=[1.5, 2.0, 1.2, 1.3, 1.8]):\n",
    "    motif, role_id = molecular_generators[index][0](random.randint(5,10), node_feature_mean, std)\n",
    "    motif = ensure_connected(motif)\n",
    "\n",
    "    nodes_to_remove = [node for node in motif.nodes if 'feature' not in motif.nodes[node]]\n",
    "    motif.remove_nodes_from(nodes_to_remove)\n",
    "    motif, mapping = reindex_graph(motif)\n",
    "\n",
    "    role_id = [0] * motif.number_of_nodes()\n",
    "    role_id[-6:] = torch.randint(1, 4, (6,)).tolist()\n",
    "    edge_index = torch.tensor(list(motif.edges()), dtype=torch.long).t().contiguous()\n",
    "    if edge_index.size(1) == 0:\n",
    "\n",
    "        default_value = random.randint(0, motif.number_of_nodes() - 1)\n",
    "        edge_index = torch.tensor([[default_value], [default_value]], dtype=torch.long)\n",
    "    node_features = [motif.nodes[node]['feature'].tolist() for node in motif.nodes if 'feature' in motif.nodes[node]]\n",
    "    return motif, role_id, edge_index, node_features\n",
    "\n",
    "\n",
    "def addOtherShape(indexs, G, num):\n",
    "    G_other = copy.deepcopy(G)\n",
    "    additional_nodes_count = 0\n",
    "    for _ in range(num):\n",
    "            additional_index = random.choice([i for i in range(1, 27) if i not in indexs])\n",
    "            node_feature_mean = [index] * 5\n",
    "            std = [index] * 5\n",
    "            motif, role_id, edge_index, node_features = generateBaseshape(additional_index, node_feature_mean=node_feature_mean, std=std)\n",
    "            G_other, role_id, edge_index, node_features = feature_connection(G_other, motif)\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "            additional_nodes_count += motif.number_of_nodes()\n",
    "    return G_other, role_id, edge_index, additional_nodes_count, node_features\n",
    "\n",
    "\n",
    "def addNoise(G, G_other, num_nodes_to_add):\n",
    "\n",
    "    G_noisy = copy.deepcopy(G_other)\n",
    "\n",
    "    for _ in range(num_nodes_to_add):\n",
    "        node_id = G_noisy.number_of_nodes()\n",
    "        G_noisy.add_node(node_id)\n",
    "        G_noisy.nodes[node_id]['feature'] = np.random.normal([1.5, 2.0, 1.2, 1.3, 1.8], [1.5, 2.0, 1.2, 1.3, 1.8])\n",
    "\n",
    "        connected = False\n",
    "        while not connected:\n",
    "\n",
    "            nodes_to_connect = random.sample(G_noisy.nodes(), 3)\n",
    "\n",
    "            for n in nodes_to_connect:\n",
    "                if not G.has_node(n) and not G_noisy.has_edge(node_id, n):\n",
    "                    G_noisy.add_edge(node_id, n)\n",
    "            connected = nx.is_connected(G_noisy)\n",
    "            if not connected:\n",
    "                for n in nodes_to_connect:\n",
    "                    if G_noisy.has_edge(node_id, n):\n",
    "                        G_noisy.remove_edge(node_id, n)\n",
    "\n",
    "\n",
    "\n",
    "    node_features = [G_noisy.nodes[node]['feature'].tolist() for node in G_noisy.nodes if 'feature' in G_noisy.nodes[node]]\n",
    "    role_id = [0] * G_noisy.number_of_nodes()\n",
    "    role_id[-6:] = torch.randint(1, 4, (6,)).tolist()\n",
    "\n",
    "    edge_index = torch.tensor(list(G_noisy.edges()), dtype=torch.long).t().contiguous()\n",
    "    return G_noisy, role_id, edge_index, node_features\n",
    "\n",
    "edge_index_list, label_list = [], []\n",
    "ground_truth_list, role_id_list, pos_list = [], [], []\n",
    "e_mean, n_mean = [], []\n",
    "node_features_list = []\n",
    "indexs = [1, 2, 3, 4, 5]\n",
    "for _ in tqdm(range(500)):\n",
    "    index = random.choice(indexs)\n",
    "    node_feature_mean = [index-1] * 5\n",
    "    std = [index-1] * 5\n",
    "    label=index-1\n",
    "\n",
    "    G, role_id, edge_index, node_features = generateBaseshape(index, node_feature_mean=node_feature_mean, std=std)  \n",
    "   \n",
    "    G_other, role_id, edge_index, additional_nodes_count, node_features = addOtherShape(indexs, G, 3)\n",
    "    G, role_id, edge_index, node_features = addNoise(G, G_other, additional_nodes_count)\n",
    "\n",
    "\n",
    "\n",
    "    node_features_list.append(node_features)\n",
    "    label_list.append(label)\n",
    "    e_mean.append(len(G.edges))\n",
    "    n_mean.append(len(G.nodes))\n",
    "    role_id = np.array(role_id)\n",
    "\n",
    "    edge_index=edge_index\n",
    "    role_id_list.append(role_id)\n",
    "    edge_index_list.append(edge_index)\n",
    "    pos_list.append(np.array(list(nx.spring_layout(G).values())))\n",
    "    ground_truth_list.append(find_gd(edge_index, role_id))\n",
    "\n",
    "print(\"#Graphs: %d    #Nodes: %.2f    #Edges: %.2f \" % (len(ground_truth_list), np.mean(n_mean), np.mean(e_mean)))\n",
    "np.save(osp.join(data_dir, 'datasets_3_test.npy'), {'node_features':node_features_list, 'edge_index': edge_index_list, 'label': label_list, 'ground_truth': ground_truth_list, 'role_id': role_id_list, 'pos': pos_list})\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "3c1095e3",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<torch._C.Generator at 0x2190555cbb0>"
      ]
     },
     "execution_count": 8,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "\n",
    "import random\n",
    "import numpy as np\n",
    "import torch\n",
    "import networkx as nx\n",
    "import matplotlib.pyplot as plt\n",
    "from tqdm import tqdm\n",
    "import inspect\n",
    "import os\n",
    "\n",
    "\n",
    "random.seed(42)\n",
    "np.random.seed(42)\n",
    "torch.manual_seed(42)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ec866c54",
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "def standardize_feature(feature, target_length=5):\n",
    "    \n",
    "\n",
    "    if hasattr(feature, 'tolist'):\n",
    "        feature = feature.tolist()\n",
    "    elif not isinstance(feature, list):\n",
    "        feature = [float(feature)]\n",
    "    else:\n",
    "        feature = list(feature)\n",
    "    \n",
    "\n",
    "    if len(feature) < target_length:\n",
    "\n",
    "        feature = feature + [0.0] * (target_length - len(feature))\n",
    "    elif len(feature) > target_length:\n",
    "\n",
    "        feature = feature[:target_length]\n",
    "    \n",
    "    return feature\n",
    "\n",
    "\n",
    "def safe_process_feature(feature, target_length=5):\n",
    "    \n",
    "    try:\n",
    "        return standardize_feature(feature, target_length)\n",
    "    except Exception as e:\n",
    "\n",
    "        return [0.0] * target_length\n",
    "\n",
    "\n",
    "def safe_connect_motifs(G1, G2, feature_length=5):\n",
    "    \n",
    "    G = nx.Graph()\n",
    "    \n",
    "\n",
    "    G.add_nodes_from(G1.nodes(data=True))\n",
    "    G.add_edges_from(G1.edges())\n",
    "    \n",
    "\n",
    "    mapping = {old_id: old_id + G1.number_of_nodes() for old_id in G2.nodes()}\n",
    "    G2_remapped = nx.relabel_nodes(G2, mapping)\n",
    "    \n",
    "\n",
    "    G.add_nodes_from(G2_remapped.nodes(data=True))\n",
    "    G.add_edges_from(G2_remapped.edges())\n",
    "    \n",
    "\n",
    "    node1 = random.choice(list(G1.nodes()))\n",
    "    node2 = random.choice(list(G2_remapped.nodes()))\n",
    "    G.add_edge(node1, node2)\n",
    "    \n",
    "\n",
    "    node_features = []\n",
    "    for node in G.nodes():\n",
    "        if 'feature' in G.nodes[node]:\n",
    "            try:\n",
    "\n",
    "                feature = safe_process_feature(G.nodes[node]['feature'], feature_length)\n",
    "                node_features.append(feature)\n",
    "            except Exception as e:\n",
    "\n",
    "                node_features.append([0.0] * feature_length)\n",
    "        else:\n",
    "\n",
    "            node_features.append([0.0] * feature_length)\n",
    "    \n",
    "\n",
    "    role_id = [0] * G.number_of_nodes()\n",
    "    edge_index = torch.tensor(list(G.edges()), dtype=torch.long).t().contiguous()\n",
    "    \n",
    "    return G, role_id, edge_index, node_features\n",
    "\n",
    "\n",
    "def extract_node_features(G, feature_length=5):\n",
    "    \n",
    "    node_features = []\n",
    "    for node in G.nodes():\n",
    "        if 'feature' in G.nodes[node]:\n",
    "            try:\n",
    "                feature = safe_process_feature(G.nodes[node]['feature'], feature_length)\n",
    "                node_features.append(feature)\n",
    "            except Exception as e:\n",
    "                node_features.append([0.0] * feature_length)\n",
    "        else:\n",
    "            node_features.append([0.0] * feature_length)\n",
    "    return node_features\n",
    "\n",
    "\n",
    "def wrap_original_generator(original_func, feature_length=5):\n",
    "    \n",
    "    def wrapper(*args, **kwargs):\n",
    "        G, role_id, label, edge_index = original_func(*args, **kwargs)\n",
    "        node_features = extract_node_features(G, feature_length)\n",
    "        return G, role_id, label, edge_index, node_features\n",
    "    return wrapper\n",
    "\n",
    "\n",
    "generate_Y0_with_features = wrap_original_generator(generate_Y0)\n",
    "generate_Y1_with_features = wrap_original_generator(generate_Y1)\n",
    "generate_Y2_with_features = wrap_original_generator(generate_Y2)\n",
    "generate_Y3_with_features = wrap_original_generator(generate_Y3)\n",
    "generate_Y4_with_features = wrap_original_generator(generate_Y4)\n",
    "\n",
    "def add_confounding_motif(graph, label, motif_generators, motif_connectors, is_train=True, feature_length=5):\n",
    "\n",
    "    G = graph.copy()\n",
    "    a = random.randint(2, 4)\n",
    "    role_id = None\n",
    "    edge_index = None\n",
    "    node_features = None\n",
    "    \n",
    "\n",
    "    if is_train:\n",
    "\n",
    "        if random.random() < 0.9:\n",
    "            confound_motif_id = 11 + label\n",
    "            \n",
    "\n",
    "            try:\n",
    "\n",
    "                sig = inspect.signature(motif_generators[confound_motif_id][0])\n",
    "                param_count = len(sig.parameters)\n",
    "                \n",
    "\n",
    "                if param_count == 2:\n",
    "                    confound_motif, role_id_conf = motif_generators[confound_motif_id][0](\n",
    "                        random.randint(5, 10), [1.5, 2.0, 1.2, 1.3, 1.8]\n",
    "                    )\n",
    "                elif param_count == 3:\n",
    "                    confound_motif, role_id_conf = motif_generators[confound_motif_id][0](\n",
    "                        random.randint(5, 10), [1.5, 2.0, 1.2, 1.3, 1.8], [1.5, 2.0, 1.2, 1.3, 1.8]\n",
    "                    )\n",
    "                elif param_count == 4:\n",
    "                    confound_motif, role_id_conf = motif_generators[confound_motif_id][0](\n",
    "                        random.randint(5, 10), a, [1.5, 2.0, 1.2, 1.3, 1.8], [1.5, 2.0, 1.2, 1.3, 1.8]\n",
    "                    )\n",
    "                else:\n",
    "\n",
    "                    confound_motif, role_id_conf = motif_generators[confound_motif_id][0](\n",
    "                        random.randint(5, 10), [1.5, 2.0, 1.2, 1.3, 1.8], [1.5, 2.0, 1.2, 1.3, 1.8]\n",
    "                    )\n",
    "            except Exception as e:\n",
    "                print(f\"Error calling motif generator {confound_motif_id}: {e}\")\n",
    "\n",
    "                role_id = [0] * G.number_of_nodes()\n",
    "                edge_index = torch.tensor(list(G.edges()), dtype=torch.long).t().contiguous()\n",
    "\n",
    "                node_features = extract_node_features(G, feature_length)\n",
    "                return G, role_id, label, edge_index, node_features\n",
    "            \n",
    "\n",
    "            G, role_id, edge_index, node_features = safe_connect_motifs(G, confound_motif, feature_length)\n",
    "\n",
    "    else:\n",
    "\n",
    "        confound_types = list(range(11, 16))\n",
    "        if random.random() < 1/len(confound_types):\n",
    "            confound_motif_id = random.choice(confound_types)\n",
    "            \n",
    "\n",
    "            try:\n",
    "\n",
    "                sig = inspect.signature(motif_generators[confound_motif_id][0])\n",
    "                param_count = len(sig.parameters)\n",
    "                \n",
    "\n",
    "                if param_count == 2:\n",
    "                    confound_motif, role_id_conf = motif_generators[confound_motif_id][0](\n",
    "                        random.randint(5, 10), [1.5, 2.0, 1.2, 1.3, 1.8]\n",
    "                    )\n",
    "                elif param_count == 3:\n",
    "                    confound_motif, role_id_conf = motif_generators[confound_motif_id][0](\n",
    "                        random.randint(5, 10), [1.5, 2.0, 1.2, 1.3, 1.8], [1.5, 2.0, 1.2, 1.3, 1.8]\n",
    "                    )\n",
    "                elif param_count == 4:\n",
    "                    confound_motif, role_id_conf = motif_generators[confound_motif_id][0](\n",
    "                        random.randint(5, 10), a, [1.5, 2.0, 1.2, 1.3, 1.8], [1.5, 2.0, 1.2, 1.3, 1.8]\n",
    "                    )\n",
    "                else:\n",
    "\n",
    "                    confound_motif, role_id_conf = motif_generators[confound_motif_id][0](\n",
    "                        random.randint(5, 10), [1.5, 2.0, 1.2, 1.3, 1.8], [1.5, 2.0, 1.2, 1.3, 1.8]\n",
    "                    )\n",
    "            except Exception as e:\n",
    "                print(f\"Error calling motif generator {confound_motif_id}: {e}\")\n",
    "\n",
    "                role_id = [0] * G.number_of_nodes()\n",
    "                edge_index = torch.tensor(list(G.edges()), dtype=torch.long).t().contiguous()\n",
    "\n",
    "                node_features = extract_node_features(G, feature_length)\n",
    "                return G, role_id, label, edge_index, node_features\n",
    "            \n",
    "\n",
    "            G, role_id, edge_index, node_features = safe_connect_motifs(G, confound_motif, feature_length)\n",
    "    \n",
    "\n",
    "    if role_id is None:\n",
    "\n",
    "        role_id = [0] * G.number_of_nodes()\n",
    "        edge_index = torch.tensor(list(G.edges()), dtype=torch.long).t().contiguous()\n",
    "\n",
    "        node_features = extract_node_features(G, feature_length)\n",
    "    \n",
    "    return G, role_id, label, edge_index, node_features\n",
    "\n",
    "def add_confounding_molecular(graph, label, molecular_generators, motif_connectors, is_train=True, feature_length=5):\n",
    "\n",
    "    G = graph.copy()\n",
    "    role_id = None\n",
    "    edge_index = None\n",
    "    node_features = None\n",
    "    \n",
    "\n",
    "    if is_train:\n",
    "\n",
    "        if random.random() < 0.9:\n",
    "            confound_mol_id = 6 + label\n",
    "            \n",
    "\n",
    "            try:\n",
    "\n",
    "                sig = inspect.signature(molecular_generators[confound_mol_id][0])\n",
    "                param_count = len(sig.parameters)\n",
    "                \n",
    "\n",
    "                if param_count == 2:\n",
    "                    confound_mol, role_id_conf = molecular_generators[confound_mol_id][0](\n",
    "                        random.randint(5, 10), [1.5, 2.0, 1.2, 1.3, 1.8]\n",
    "                    )\n",
    "                elif param_count == 3:\n",
    "                    confound_mol, role_id_conf = molecular_generators[confound_mol_id][0](\n",
    "                        random.randint(5, 10), [1.5, 2.0, 1.2, 1.3, 1.8], [1.5, 2.0, 1.2, 1.3, 1.8]\n",
    "                    )\n",
    "                else:\n",
    "\n",
    "                    confound_mol, role_id_conf = molecular_generators[confound_mol_id][0](\n",
    "                        random.randint(5, 10), [1.5, 2.0, 1.2, 1.3, 1.8], [1.5, 2.0, 1.2, 1.3, 1.8]\n",
    "                    )\n",
    "            except Exception as e:\n",
    "                print(f\"Error calling molecular generator {confound_mol_id}: {e}\")\n",
    "\n",
    "                role_id = [0] * G.number_of_nodes()\n",
    "                edge_index = torch.tensor(list(G.edges()), dtype=torch.long).t().contiguous()\n",
    "\n",
    "                node_features = extract_node_features(G, feature_length)\n",
    "                return G, role_id, label, edge_index, node_features\n",
    "            \n",
    "\n",
    "            G, role_id, edge_index, node_features = safe_connect_motifs(G, confound_mol, feature_length)\n",
    "\n",
    "    else:\n",
    "\n",
    "        confound_types = list(range(6, 11))\n",
    "        if random.random() < 1/len(confound_types):\n",
    "            confound_mol_id = random.choice(confound_types)\n",
    "            \n",
    "\n",
    "            try:\n",
    "\n",
    "                sig = inspect.signature(molecular_generators[confound_mol_id][0])\n",
    "                param_count = len(sig.parameters)\n",
    "                \n",
    "\n",
    "                if param_count == 2:\n",
    "                    confound_mol, role_id_conf = molecular_generators[confound_mol_id][0](\n",
    "                        random.randint(5, 10), [1.5, 2.0, 1.2, 1.3, 1.8]\n",
    "                    )\n",
    "                elif param_count == 3:\n",
    "                    confound_mol, role_id_conf = molecular_generators[confound_mol_id][0](\n",
    "                        random.randint(5, 10), [1.5, 2.0, 1.2, 1.3, 1.8], [1.5, 2.0, 1.2, 1.3, 1.8]\n",
    "                    )\n",
    "                else:\n",
    "\n",
    "                    confound_mol, role_id_conf = molecular_generators[confound_mol_id][0](\n",
    "                        random.randint(5, 10), [1.5, 2.0, 1.2, 1.3, 1.8], [1.5, 2.0, 1.2, 1.3, 1.8]\n",
    "                    )\n",
    "            except Exception as e:\n",
    "                print(f\"Error calling molecular generator {confound_mol_id}: {e}\")\n",
    "\n",
    "                role_id = [0] * G.number_of_nodes()\n",
    "                edge_index = torch.tensor(list(G.edges()), dtype=torch.long).t().contiguous()\n",
    "\n",
    "                node_features = extract_node_features(G, feature_length)\n",
    "                return G, role_id, label, edge_index, node_features\n",
    "            \n",
    "\n",
    "            G, role_id, edge_index, node_features = safe_connect_motifs(G, confound_mol, feature_length)\n",
    "    \n",
    "\n",
    "    if role_id is None:\n",
    "\n",
    "        role_id = [0] * G.number_of_nodes()\n",
    "        edge_index = torch.tensor(list(G.edges()), dtype=torch.long).t().contiguous()\n",
    "\n",
    "        node_features = extract_node_features(G, feature_length)\n",
    "    \n",
    "    return G, role_id, label, edge_index, node_features\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7898f893",
   "metadata": {},
   "outputs": [],
   "source": [
    "def ensure_valid_graph(G, edge_index):\n",
    "    \n",
    "\n",
    "    mapping = {old_id: new_id for new_id, old_id in enumerate(G.nodes())}\n",
    "    G = nx.relabel_nodes(G, mapping)\n",
    "    \n",
    "\n",
    "    if isinstance(edge_index, torch.Tensor):\n",
    "\n",
    "        new_edge_index = torch.tensor([[mapping[int(edge_index[0, i])], mapping[int(edge_index[1, i])]] \n",
    "                                      for i in range(edge_index.size(1))], dtype=torch.long).t()\n",
    "    else:\n",
    "\n",
    "        new_edge_index = np.array([[mapping[edge_index[0][i]], mapping[edge_index[1][i]]] \n",
    "                                  for i in range(len(edge_index[0]))]).T\n",
    "    \n",
    "\n",
    "    max_node_idx = len(G.nodes()) - 1\n",
    "    if isinstance(new_edge_index, torch.Tensor):\n",
    "        assert new_edge_index.max() <= max_node_idx, f\"Out of range: {new_edge_index.max()} > {max_node_idx}\"\n",
    "    else:\n",
    "        assert np.max(new_edge_index) <= max_node_idx, f\"Out of range: {np.max(new_edge_index)} > {max_node_idx}\"\n",
    "    \n",
    "    return G, new_edge_index\n",
    "\n",
    "\n",
    "\n",
    "def generate_base_dataset(num_samples=1000, train_ratio=0.7, val_ratio=0.15, save_dir=None):\n",
    "    train_size = int(num_samples * train_ratio)\n",
    "    val_size = int(num_samples * val_ratio)\n",
    "    test_size = num_samples - train_size - val_size\n",
    "\n",
    "    train_data = []\n",
    "    test_data = []\n",
    "    train_size = int(num_samples * train_ratio)\n",
    "    \n",
    "\n",
    "\n",
    "    train_data, val_data, test_data = [], [], []\n",
    "    train_edge_index_list, val_edge_index_list, test_edge_index_list = [], [], []\n",
    "    train_label_list, val_label_list, test_label_list = [], [], []\n",
    "    train_ground_truth_list, val_ground_truth_list, test_ground_truth_list = [], [], []\n",
    "    train_role_id_list, val_role_id_list, test_role_id_list = [], [], []\n",
    "    train_pos_list, val_pos_list, test_pos_list = [], [], []\n",
    "    train_node_feature_list, val_node_feature_list, test_node_feature_list = [], [], []\n",
    "    \n",
    "    for i in tqdm(range(num_samples)):\n",
    "\n",
    "        y = random.choice([0, 1, 2, 3, 4])\n",
    "        \n",
    "\n",
    "        if y == 0:\n",
    "            G, role_id, label, edge_index, node_features = generate_Y0_with_features()\n",
    "        elif y == 1:\n",
    "            G, role_id, label, edge_index, node_features = generate_Y1_with_features()\n",
    "        elif y == 2:\n",
    "            G, role_id, label, edge_index, node_features = generate_Y2_with_features()\n",
    "        elif y == 3:\n",
    "            G, role_id, label, edge_index, node_features = generate_Y3_with_features()\n",
    "        elif y == 4:\n",
    "            G, role_id, label, edge_index, node_features = generate_Y4_with_features()\n",
    "        \n",
    "        is_train = i < train_size\n",
    "        is_val = train_size <= i < (train_size + val_size)\n",
    "        is_test = i >= (train_size + val_size)\n",
    "\n",
    "        processed_features = [np.asarray(feature, dtype=float).flatten() for feature in node_features]\n",
    "\n",
    "        if is_train:\n",
    "            train_data.append((G, role_id, label, edge_index, node_features))\n",
    "\n",
    "            train_edge_index_list.append(edge_index)\n",
    "            train_label_list.append(label)\n",
    "            train_role_id_list.append(np.array(role_id))\n",
    "            train_pos_list.append(np.array(list(nx.spring_layout(G).values())))\n",
    "            train_node_feature_list.append(np.array(processed_features))\n",
    "        elif is_val:\n",
    "            val_data.append((G, role_id, label, edge_index, node_features))\n",
    "\n",
    "            val_edge_index_list.append(edge_index)\n",
    "            val_label_list.append(label)\n",
    "            val_role_id_list.append(np.array(role_id))\n",
    "            val_pos_list.append(np.array(list(nx.spring_layout(G).values())))\n",
    "            val_node_feature_list.append(np.array(processed_features))\n",
    "        else:\n",
    "            test_data.append((G, role_id, label, edge_index, node_features))\n",
    "\n",
    "            test_edge_index_list.append(edge_index)\n",
    "            test_label_list.append(label)\n",
    "            test_role_id_list.append(np.array(role_id))\n",
    "            test_pos_list.append(np.array(list(nx.spring_layout(G).values())))\n",
    "            test_node_feature_list.append(np.array(processed_features))\n",
    "    \n",
    "\n",
    "    if save_dir:\n",
    "        os.makedirs(save_dir, exist_ok=True)\n",
    "\n",
    "        train_data_dict = {\n",
    "            'node_features': train_node_feature_list,\n",
    "            'edge_index': train_edge_index_list, \n",
    "            'label': train_label_list, \n",
    "            'role_id': train_role_id_list, \n",
    "            'pos': train_pos_list\n",
    "        }\n",
    "        np.save(os.path.join(save_dir, 'base_train.npy'), train_data_dict)\n",
    "        \n",
    "\n",
    "        val_data_dict = {\n",
    "            'node_features': val_node_feature_list,\n",
    "            'edge_index': val_edge_index_list, \n",
    "            'label': val_label_list, \n",
    "            'role_id': val_role_id_list, \n",
    "            'pos': val_pos_list\n",
    "        }\n",
    "        np.save(os.path.join(save_dir, 'base_val.npy'), val_data_dict)\n",
    "\n",
    "\n",
    "        test_data_dict = {\n",
    "            'node_features': test_node_feature_list,\n",
    "            'edge_index': test_edge_index_list, \n",
    "            'label': test_label_list, \n",
    "            'role_id': test_role_id_list, \n",
    "            'pos': test_pos_list\n",
    "        }\n",
    "        np.save(os.path.join(save_dir, 'base_test.npy'), test_data_dict)\n",
    "    \n",
    "    return train_data, val_data, test_data\n",
    "\n",
    "\n",
    "def generate_motif_confounded_dataset(num_samples=1000, train_ratio=0.7, val_ratio=0.15, save_dir=None):\n",
    "    train_size = int(num_samples * train_ratio)\n",
    "    val_size = int(num_samples * val_ratio)\n",
    "    test_size = num_samples - train_size - val_size\n",
    "\n",
    "    train_data = []\n",
    "    test_data = []\n",
    "    train_size = int(num_samples * train_ratio)\n",
    "    \n",
    "\n",
    "\n",
    "    train_data, val_data, test_data = [], [], []\n",
    "    train_edge_index_list, val_edge_index_list, test_edge_index_list = [], [], []\n",
    "    train_label_list, val_label_list, test_label_list = [], [], []\n",
    "    train_ground_truth_list, val_ground_truth_list, test_ground_truth_list = [], [], []\n",
    "    train_role_id_list, val_role_id_list, test_role_id_list = [], [], []\n",
    "    train_pos_list, val_pos_list, test_pos_list = [], [], []\n",
    "    train_node_feature_list, val_node_feature_list, test_node_feature_list = [], [], []\n",
    "\n",
    "    \n",
    "    for i in tqdm(range(num_samples)):\n",
    "\n",
    "        y = random.choice([0, 1, 2, 3, 4])\n",
    "        \n",
    "\n",
    "        if y == 0:\n",
    "            G, role_id, label, edge_index, node_features = generate_Y0_with_features()\n",
    "        elif y == 1:\n",
    "            G, role_id, label, edge_index, node_features = generate_Y1_with_features()\n",
    "        elif y == 2:\n",
    "            G, role_id, label, edge_index, node_features = generate_Y2_with_features()\n",
    "        elif y == 3:\n",
    "            G, role_id, label, edge_index, node_features = generate_Y3_with_features()\n",
    "        elif y == 4:\n",
    "            G, role_id, label, edge_index, node_features = generate_Y4_with_features()\n",
    "        \n",
    "\n",
    "        is_train = i < train_size\n",
    "        is_val = train_size <= i < (train_size + val_size)\n",
    "        is_test = i >= (train_size + val_size)\n",
    "        \n",
    "        G_conf, role_id_conf, label, edge_index, node_features = add_confounding_motif(\n",
    "            G, label, motif_generators, motif_connectors, is_train=is_train\n",
    "        )\n",
    "        \n",
    "\n",
    "        processed_features = [np.asarray(feature, dtype=float).flatten() for feature in node_features]\n",
    "    \n",
    "        if is_train:\n",
    "            train_data.append((G_conf, role_id_conf, label, edge_index, processed_features))\n",
    "            train_edge_index_list.append(edge_index)\n",
    "            train_label_list.append(label)\n",
    "            train_role_id_list.append(np.array(role_id_conf))\n",
    "            train_pos_list.append(np.array(list(nx.spring_layout(G_conf).values())))\n",
    "            train_node_feature_list.append(np.array(processed_features))\n",
    "        elif is_val:\n",
    "            val_data.append((G_conf, role_id_conf, label, edge_index, processed_features))\n",
    "            val_edge_index_list.append(edge_index)\n",
    "            val_label_list.append(label)\n",
    "            val_role_id_list.append(np.array(role_id_conf))\n",
    "            val_pos_list.append(np.array(list(nx.spring_layout(G_conf).values())))\n",
    "            val_node_feature_list.append(np.array(processed_features))\n",
    "        else:\n",
    "            test_data.append((G_conf, role_id_conf, label, edge_index, processed_features))\n",
    "            test_edge_index_list.append(edge_index)\n",
    "            test_label_list.append(label)\n",
    "            test_role_id_list.append(np.array(role_id_conf))\n",
    "            test_pos_list.append(np.array(list(nx.spring_layout(G_conf).values())))\n",
    "            test_node_feature_list.append(np.array(processed_features))\n",
    "    \n",
    "\n",
    "    if save_dir:\n",
    "        os.makedirs(save_dir, exist_ok=True)\n",
    "\n",
    "        train_data_dict = {\n",
    "            'node_features': train_node_feature_list,\n",
    "            'edge_index': train_edge_index_list, \n",
    "            'label': train_label_list, \n",
    "            'role_id': train_role_id_list, \n",
    "            'pos': train_pos_list\n",
    "        }\n",
    "        np.save(os.path.join(save_dir, 'motif_conf_train.npy'), train_data_dict)\n",
    "        \n",
    "\n",
    "        val_data_dict = {\n",
    "            'node_features': val_node_feature_list,\n",
    "            'edge_index': val_edge_index_list, \n",
    "            'label': val_label_list, \n",
    "            'role_id': val_role_id_list, \n",
    "            'pos': val_pos_list\n",
    "        }\n",
    "        np.save(os.path.join(save_dir, 'motif_conf_val.npy'), val_data_dict)\n",
    "        \n",
    "\n",
    "        test_data_dict = {\n",
    "            'node_features': test_node_feature_list,\n",
    "            'edge_index': test_edge_index_list, \n",
    "            'label': test_label_list, \n",
    "            'role_id': test_role_id_list, \n",
    "            'pos': test_pos_list\n",
    "        }\n",
    "        np.save(os.path.join(save_dir, 'motif_conf_test.npy'), test_data_dict)\n",
    "    \n",
    "    return train_data, val_data, test_data\n",
    "\n",
    "\n",
    "def generate_molecular_confounded_dataset(num_samples=1000, train_ratio=0.7, val_ratio=0.15, save_dir=None):\n",
    "    train_size = int(num_samples * train_ratio)\n",
    "    val_size = int(num_samples * val_ratio)\n",
    "    test_size = num_samples - train_size - val_size\n",
    "    \n",
    "\n",
    "\n",
    "    train_data, val_data, test_data = [], [], []\n",
    "    train_edge_index_list, val_edge_index_list, test_edge_index_list = [], [], []\n",
    "    train_label_list, val_label_list, test_label_list = [], [], []\n",
    "    train_ground_truth_list, val_ground_truth_list, test_ground_truth_list = [], [], []\n",
    "    train_role_id_list, val_role_id_list, test_role_id_list = [], [], []\n",
    "    train_pos_list, val_pos_list, test_pos_list = [], [], []\n",
    "    train_node_feature_list, val_node_feature_list, test_node_feature_list = [], [], []\n",
    "\n",
    "    for i in tqdm(range(num_samples)):\n",
    "\n",
    "        y = random.choice([0, 1, 2, 3, 4])\n",
    "        \n",
    "\n",
    "        if y == 0:\n",
    "            G, role_id, label, edge_index, node_features = generate_Y0_with_features()\n",
    "        elif y == 1:\n",
    "            G, role_id, label, edge_index, node_features = generate_Y1_with_features()\n",
    "        elif y == 2:\n",
    "            G, role_id, label, edge_index, node_features = generate_Y2_with_features()\n",
    "        elif y == 3:\n",
    "            G, role_id, label, edge_index, node_features = generate_Y3_with_features()\n",
    "        elif y == 4:\n",
    "            G, role_id, label, edge_index, node_features = generate_Y4_with_features()\n",
    "        \n",
    "\n",
    "        is_train = i < train_size\n",
    "        is_val = train_size <= i < (train_size + val_size)\n",
    "        is_test = i >= (train_size + val_size)\n",
    "        \n",
    "        G_conf, role_id_conf, label, edge_index, node_features = add_confounding_molecular(\n",
    "            G, label, molecular_generators, motif_connectors, is_train=is_train\n",
    "        )\n",
    "        \n",
    "\n",
    "        processed_features = [np.asarray(feature, dtype=float).flatten() for feature in node_features]\n",
    "        G_conf, edge_index = ensure_valid_graph(G_conf, edge_index)\n",
    "        if is_train:\n",
    "            train_data.append((G_conf, role_id_conf, label, edge_index, processed_features))\n",
    "            train_edge_index_list.append(edge_index)\n",
    "            train_label_list.append(label)\n",
    "            train_role_id_list.append(np.array(role_id_conf))\n",
    "            train_pos_list.append(np.array(list(nx.spring_layout(G_conf).values())))\n",
    "            train_node_feature_list.append(np.array(processed_features))\n",
    "        elif is_val:\n",
    "            val_data.append((G_conf, role_id_conf, label, edge_index, processed_features))\n",
    "            val_edge_index_list.append(edge_index)\n",
    "            val_label_list.append(label)\n",
    "            val_role_id_list.append(np.array(role_id_conf))\n",
    "            val_pos_list.append(np.array(list(nx.spring_layout(G_conf).values())))\n",
    "            val_node_feature_list.append(np.array(processed_features))\n",
    "        else:\n",
    "            test_data.append((G_conf, role_id_conf, label, edge_index, processed_features))\n",
    "            test_edge_index_list.append(edge_index)\n",
    "            test_label_list.append(label)\n",
    "            test_role_id_list.append(np.array(role_id_conf))\n",
    "            test_pos_list.append(np.array(list(nx.spring_layout(G_conf).values())))\n",
    "            test_node_feature_list.append(np.array(processed_features))\n",
    "    \n",
    "\n",
    "    if save_dir:\n",
    "        os.makedirs(save_dir, exist_ok=True)\n",
    "\n",
    "        train_data_dict = {\n",
    "            'node_features': train_node_feature_list,\n",
    "            'edge_index': train_edge_index_list, \n",
    "            'label': train_label_list, \n",
    "            'role_id': train_role_id_list, \n",
    "            'pos': train_pos_list\n",
    "        }\n",
    "        np.save(os.path.join(save_dir, 'mol_conf_train.npy'), train_data_dict)\n",
    "        \n",
    "\n",
    "        val_data_dict = {\n",
    "            'node_features': val_node_feature_list,\n",
    "            'edge_index': val_edge_index_list, \n",
    "            'label': val_label_list, \n",
    "            'role_id': val_role_id_list, \n",
    "            'pos': val_pos_list\n",
    "        }\n",
    "        np.save(os.path.join(save_dir, 'mol_conf_val.npy'), val_data_dict)\n",
    "        \n",
    "\n",
    "        test_data_dict = {\n",
    "            'node_features': test_node_feature_list,\n",
    "            'edge_index': test_edge_index_list, \n",
    "            'label': test_label_list, \n",
    "            'role_id': test_role_id_list, \n",
    "            'pos': test_pos_list\n",
    "        }\n",
    "        np.save(os.path.join(save_dir, 'mol_conf_test.npy'), test_data_dict)\n",
    "    \n",
    "    return train_data, val_data, test_data\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 76,
   "id": "d7b4987b",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "生成不带混杂特征的基础数据集...\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 1000/1000 [00:03<00:00, 276.23it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "基础数据集已保存到 ./data/CRCG-CONFOUND/raw/\n",
      "基础数据集生成完成，训练集大小: 700，验证集大小: 150，测试集大小: 150\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    }
   ],
   "source": [
    "\n",
    "confound_data_dir = f'./data/CRCG-CONFOUND/raw/'\n",
    "os.makedirs(confound_data_dir, exist_ok=True)\n",
    "\n",
    "\n",
    "print(\"生成不带混杂特征的基础数据集...\")\n",
    "base_train_data, base_val_data, base_test_data = generate_base_dataset(num_samples=1000, save_dir=confound_data_dir)\n",
    "print(f\"基础数据集生成完成，训练集大小: {len(base_train_data)}，验证集大小: {len(base_val_data)}，测试集大小: {len(base_test_data)}\")\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 77,
   "id": "e5fdc846",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "生成带有基元混杂特征的数据集...\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 1000/1000 [00:04<00:00, 219.88it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "带基元混杂特征的数据集已保存到 ./data/CRCG-CONFOUND/raw/\n",
      "带基元混杂特征的数据集生成完成，训练集大小: 700，验证集大小: 150，测试集大小: 150\n"
     ]
    }
   ],
   "source": [
    "\n",
    "print(\"生成带有基元混杂特征的数据集...\")\n",
    "motif_conf_train_data, motif_conf_val_data, motif_conf_test_data = generate_motif_confounded_dataset(num_samples=1000, save_dir=confound_data_dir)\n",
    "print(f\"带基元混杂特征的数据集生成完成，训练集大小: {len(motif_conf_train_data)}，验证集大小: {len(motif_conf_val_data)}，测试集大小: {len(motif_conf_test_data)}\")\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 73,
   "id": "34515f40",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "生成带有分子结构混杂特征的数据集...\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 1000/1000 [00:05<00:00, 195.35it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "带分子结构混杂特征的数据集已保存到 ./data/CRCG-CONFOUND/raw/\n",
      "带分子结构混杂特征的数据集生成完成，训练集大小: 700，验证集大小: 150，测试集大小: 150\n"
     ]
    }
   ],
   "source": [
    "\n",
    "print(\"生成带有分子结构混杂特征的数据集...\")\n",
    "mol_conf_train_data, mol_conf_val_data, mol_conf_test_data = generate_molecular_confounded_dataset(num_samples=1000, save_dir=confound_data_dir)\n",
    "print(f\"带分子结构混杂特征的数据集生成完成，训练集大小: {len(mol_conf_train_data)}，验证集大小: {len(mol_conf_val_data)}，测试集大小: {len(mol_conf_test_data)}\")\n"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "CRCG",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.16"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
