{
 "cells": [
  {
   "attachments": {},
   "cell_type": "markdown",
   "id": "a32dd142",
   "metadata": {},
   "source": [
    "### Generate Spurious-Motif Dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "142774b6",
   "metadata": {},
   "outputs": [],
   "source": [
    "from BA3_loc import *\n",
    "from tqdm import tqdm\n",
    "import os.path as osp\n",
    "import warnings\n",
    "warnings.filterwarnings(\"ignore\")\n",
    "import random\n",
    "import torch\n",
    "import copy\n",
    "import itertools\n",
    "data_dir = f'../data/CRCG-MOTIF/raw/'\n",
    "os.makedirs(data_dir, exist_ok=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "742338c9",
   "metadata": {},
   "outputs": [],
   "source": [
    "#motif类型\n",
    "def create_motif_star_branch(size, branches, node_feature_mean, std):\n",
    "    # 添加不同形状的节点\n",
    "    #星形\n",
    "       G = nx.Graph() # 创建一个空图\n",
    "       role_id = []\n",
    "    # 添加节点\n",
    "       for i in range(size):\n",
    "          if i == 0:\n",
    "            features = np.random.normal(loc=node_feature_mean, scale=std)\n",
    "            G.add_node(i, features=features, is_center=True)\n",
    "            role_id.append(0)\n",
    "          else:\n",
    "            features = np.random.normal(loc=node_feature_mean, scale=std)\n",
    "            G.add_node(i, features=features, is_center=False)\n",
    "            role_id.append(1)\n",
    "    # 添加边\n",
    "       for i in range(1, size):\n",
    "              G.add_edge(0, i)\n",
    "       return G,role_id\n",
    "def create_motif_path_branch(size, branches, node_feature_mean, std):\n",
    "    #路径形状\n",
    "        G = nx.Graph()        # 创建一个空图\n",
    "        role_id = []\n",
    "        # 添加路径节点，并为其添加特征\n",
    "        for i in range(size):\n",
    "            node_id = i\n",
    "            node_features = np.random.normal(node_feature_mean, std)\n",
    "            G.add_node(node_id, features=node_features.tolist())\n",
    "            role_id.append(1)\n",
    "            if i > 0:\n",
    "                G.add_edge(i-1, i)\n",
    "        return G,role_id\n",
    "def create_motif_fan_branch(size, branches, node_feature_mean, std):\n",
    "    #扇形\n",
    "        G = nx.Graph()        # 创建一个空图\n",
    "        role_id = []\n",
    "    # 添加节点\n",
    "        for i in range(size):\n",
    "          if i == 0:\n",
    "           features = np.random.normal(loc=node_feature_mean, scale=std)\n",
    "           G.add_node(i, features=features, is_center=True)\n",
    "           role_id.append(0)\n",
    "          else:\n",
    "           features = np.random.normal(loc=node_feature_mean, scale=std)\n",
    "           G.add_node(i, features=features, is_center=False)\n",
    "           role_id.append(1)\n",
    "    # 添加边\n",
    "        for i in range(1, size):\n",
    "          G.add_edge(0, i)\n",
    "        for i in range(1, branches):\n",
    "          for j in range(1, size // branches):\n",
    "           G.add_edge(i, j * branches + i)\n",
    "        return G,role_id\n",
    "def create_motif_cuspedPolygon_branch(size, branches, node_feature_mean, std):\n",
    "       #尖角多边形：cuspedPolygon\n",
    "        #assert size % 2 == 0, \"Size of motif should be even\"\n",
    "        #assert size >= branches * 2, \"Size of motif should be at least twice the number of branches\"\n",
    "        G = nx.Graph()        # 创建一个空图\n",
    "        role_id = []\n",
    "        # Add nodes\n",
    "        nodes = range(size)\n",
    "        features = np.random.normal(loc=node_feature_mean, scale=std)\n",
    "        role_id.append(1)\n",
    "        G.add_nodes_from(nodes)\n",
    "    # Add edges\n",
    "        for i in range(size // 2):\n",
    "           node_a = i\n",
    "           node_b = (i + size // 2) % size\n",
    "        for b in range(branches):\n",
    "            next_node = (i * branches + b + 1) % (size // 2) + size // 2\n",
    "            G.add_edge(node_a, next_node)\n",
    "            G.add_edge(node_b, next_node)\n",
    "            G.add_edge(size // 2 - 1, size // 2)\n",
    "        # Check if G has edges, if not, add a random edge\n",
    "        if len(G.edges()) == 0:\n",
    "           node_a = random.choice(list(G.nodes()))\n",
    "           node_b = random.choice(list(G.nodes()))\n",
    "           G.add_edge(node_a, node_b)\n",
    "        return G,role_id\n",
    "def create_motif_random_bipartite_branch(size, branches, node_feature_mean, std):    \n",
    "    #随机二分图：一个包含两个等大部分的随机二分图\n",
    "        G = nx.Graph()        # 创建一个空图\n",
    "        role_id=[]\n",
    "    # define the node attributes\n",
    "        nodes = []\n",
    "        for i in range(size):\n",
    "           nodes.append((i, {'feat': np.random.normal(node_feature_mean, std)}))\n",
    "           role_id.append(i)\n",
    "        G.add_nodes_from(nodes)\n",
    "    # create the bipartite structure\n",
    "        num_nodes_per_part = size // 2\n",
    "        part1 = list(range(num_nodes_per_part))\n",
    "        part2 = list(range(num_nodes_per_part, size))\n",
    "        for i in part1:\n",
    "           for j in part2:\n",
    "              if np.random.rand() < 0.5:\n",
    "                  G.add_edge(i, j)\n",
    "        # Check if G has edges, if not, add a random edge\n",
    "        if len(G.edges()) == 0:\n",
    "           node_a = random.choice(list(G.nodes()))\n",
    "           node_b = random.choice(list(G.nodes()))\n",
    "           G.add_edge(node_a, node_b)\n",
    "        return G,role_id\n",
    "def create_motif_tree_branch(size, branches, node_feature_mean, std): \n",
    "        assert branches > 0, \"Number of branches must be greater than 0 for tree motif.\"\n",
    "        G = nx.Graph()\n",
    "        node_id = 0\n",
    "        role_id=[]\n",
    "        for i in range(size):\n",
    "            G.add_node(node_id, feature=np.random.normal(node_feature_mean, std))\n",
    "            role_id.append(i)\n",
    "            if i > 0:\n",
    "               parent_id = random.randint(max(0, node_id - branches), node_id - 1)\n",
    "               G.add_edge(parent_id, node_id)\n",
    "            node_id += 1 \n",
    "        return G,role_id  \n",
    "def create_motif_trident_branch(size, branches, node_feature_mean, std):      \n",
    "      #assert size >= 3 * branches, \"size must be greater than or equal to 3 * branches\"\n",
    "      #assert branches >= 1, \"branches must be greater than or equal to 1\"\n",
    "      G = nx.Graph()\n",
    "      role_id=[]\n",
    "      for i in range(size):\n",
    "         feature = np.random.normal(node_feature_mean, std)\n",
    "         G.add_node(i, feature=feature)\n",
    "         role_id.append(i)\n",
    "      for i in range(branches):\n",
    "          start_node = i * 3\n",
    "          end_node_1 = start_node + 1\n",
    "          end_node_2 = start_node + 2\n",
    "          G.add_edge(start_node, end_node_1)\n",
    "          G.add_edge(start_node, end_node_2)\n",
    "          if i > 0:\n",
    "             last_branch_end_1 = (i - 1) * 3 + 1\n",
    "             last_branch_end_2 = (i - 1) * 3 + 2\n",
    "             G.add_edge(last_branch_end_1, end_node_1)\n",
    "             G.add_edge(last_branch_end_2, end_node_2)\n",
    "          return G,role_id\n",
    "def create_motif_conicalConnection_branch(size, branches, node_feature_mean, std):  \n",
    "      #这里生成的形状图包括一个主干和若干分支，其中主干是一个大小为size的线性图，分为左右两个部分，分别与沙漏的两个锥形连接。\n",
    "      #沙漏上的每个分支是一个大小为size的线性图，也分为左右两个部分，分别与沙漏上的两个锥形连接。\n",
    "      #主干的中心节点与所有分支的中心节点都连接一条边，用于连接主干和分支。\n",
    "      G = nx.Graph()\n",
    "      role_id = []\n",
    "      mid = size // 2\n",
    "      for i in range(size):\n",
    "          if i == mid:\n",
    "              G.add_node(i, feature=np.random.normal(node_feature_mean, std))\n",
    "              role_id.append(1)\n",
    "          else:\n",
    "              G.add_node(i, feature=np.random.normal(node_feature_mean, std))\n",
    "              if i < mid:\n",
    "                  role_id.append(2)\n",
    "              else:\n",
    "                  role_id.append(3)\n",
    "      for i in range(size):\n",
    "          if i == mid:\n",
    "              continue\n",
    "          if i < mid:\n",
    "              G.add_edge(mid, i)\n",
    "              G.add_edge(i, i + 1)\n",
    "          else:\n",
    "              G.add_edge(mid, i)\n",
    "              G.add_edge(i, i - 1)\n",
    "      for b in range(branches):\n",
    "          mid_b = mid + b + 1\n",
    "          for i in range(size):\n",
    "              if i == mid_b:\n",
    "                  G.add_node(size + b, feature=np.random.normal(node_feature_mean, std))\n",
    "                  role_id.append(4)\n",
    "              else:\n",
    "                  G.add_node(size + b, feature=np.random.normal(node_feature_mean, std))\n",
    "                  role_id.append(5)\n",
    "          for i in range(size):\n",
    "              if i == mid_b:\n",
    "                  continue\n",
    "              if i < mid_b:\n",
    "                  G.add_edge(mid_b, i + size)\n",
    "                  G.add_edge(i + size, i + 1 + size)\n",
    "              else:\n",
    "                  G.add_edge(mid_b, i + size)\n",
    "                  G.add_edge(i + size, i - 1 + size)\n",
    "          G.add_edge(mid, mid_b)\n",
    "      return G,role_id\n",
    "def create_motif_chainBypass_branch(size, branches, node_feature_mean, std): \n",
    "      #该函数生成的形状图是一条链（size-2个节点）加上若干个大小为branch_size的分支。\n",
    "      #每个分支都与链中一个节点连接，并形成bypass\n",
    "      #assert size >= 4, \"Size should be at least 4 for this motif type.\"\n",
    "      #assert branches >= 1, \"Number of branches should be at least 1 for this motif type.\"\n",
    "      # Create empty graph\n",
    "      role_id=[]\n",
    "      G = nx.Graph()\n",
    "      # Add nodes to the graph with random features\n",
    "      for i in range(size):\n",
    "          features = np.random.normal(node_feature_mean, std)\n",
    "          G.add_node(i, features=features)\n",
    "          role_id.append(i)\n",
    "      # Add edges to the graph to form chain with bypass motif\n",
    "      for i in range(size-1):\n",
    "          G.add_edge(i, i+1)\n",
    "      branch_size = (size - 2) // branches  # size of each branch\n",
    "      start_node = 1\n",
    "      for i in range(branches):\n",
    "          for j in range(branch_size):\n",
    "              # Connect branch to bypass\n",
    "              branch_node = start_node + j\n",
    "              bypass_node = size - (2*branches) + i\n",
    "              G.add_edge(branch_node, bypass_node)\n",
    "          start_node += branch_size\n",
    "      # Assign role IDs to nodes\n",
    "      '''\n",
    "      role_id = np.zeros(size, dtype=np.int32)\n",
    "      role_id[0] = 1\n",
    "      role_id[size-1] = 2\n",
    "      for i in range(1, size-1):\n",
    "          if i % branch_size == 0:\n",
    "              role_id[i] = 3\n",
    "          else:\n",
    "              role_id[i] = 4\n",
    "      '''\n",
    "      return G,role_id\n",
    "def create_motif_partPolygon_branch(size, branches, node_feature_mean, std): \n",
    "    # assert size >= 2\n",
    "    # Define the number of nodes in the motif\n",
    "    n_nodes = size\n",
    "    # Define the edges for the motif\n",
    "    edges = [(i, i + 1) for i in range(n_nodes - 1)] + [(n_nodes - 1, i) for i in range(1, n_nodes - 1)]\n",
    "    # Create a directed graph\n",
    "    G = nx.Graph()\n",
    "    G.add_edges_from(edges) \n",
    "    # Define the node feature matrix\n",
    "    node_features = np.random.normal(node_feature_mean, std)  \n",
    "    # Define the role ids\n",
    "    role_id = np.zeros((n_nodes, branches))\n",
    "    role_id[0] = np.array([1] * branches)  # Input nodes\n",
    "    role_id[n_nodes - 1] = np.array([2] * branches)  # Output nodes\n",
    "    role_id[1: n_nodes - 1] = np.array([3] * branches)  # Hidden nodes\n",
    "    return G,role_id\n",
    "def create_motif_completeGraph(size,node_feature_mean,std):\n",
    "    G = nx.complete_graph(size)\n",
    "    role_id = []\n",
    "    for i in range(size):\n",
    "        G.nodes[i][\"feature\"] = np.random.normal(node_feature_mean[i % len(node_feature_mean)], std[i % len(std)])\n",
    "        role_id.append(i)\n",
    "    return G,role_id\n",
    "def create_motif_netShape(size,node_feature_mean,std):\n",
    "    G = nx.Graph()\n",
    "    role_id = []\n",
    "    n = int(np.sqrt(size))\n",
    "    for i in range(n):\n",
    "        for j in range(n):\n",
    "            node_idx = i * n + j\n",
    "            node_feature = np.random.normal(node_feature_mean, std)\n",
    "            G.add_node(node_idx, feature=node_feature)\n",
    "            role_id.append(i)\n",
    "            if i > 0:\n",
    "                upper_node_idx = (i - 1) * n + j\n",
    "                G.add_edge(node_idx, upper_node_idx)\n",
    "            if j > 0:\n",
    "                left_node_idx = i * n + (j - 1)\n",
    "                G.add_edge(node_idx, left_node_idx)\n",
    "    return G,role_id\n",
    "def create_motif_dircycle(size,node_feature_mean,std):   \n",
    "    G = nx.Graph()\n",
    "    role_id = []\n",
    "    for i in range(size):\n",
    "        features = np.random.normal(loc=node_feature_mean, scale=std)\n",
    "        G.add_node(i, features=features)\n",
    "        role_id.append(i)\n",
    "    for i in range(size):\n",
    "        G.add_edge(i, (i + 1) % size)\n",
    "    return G,role_id\n",
    "def create_motif_dualRing(size,node_feature_mean,std):\n",
    "    G = nx.Graph()\n",
    "    role_id = []\n",
    "    n1 = size//2\n",
    "    n2 = size-n1\n",
    "    G = nx.empty_graph(n=n1+n2)\n",
    "    for i in range(size):\n",
    "        features = np.random.normal(loc=node_feature_mean, scale=std)\n",
    "        G.add_node(i, features=features)\n",
    "        role_id.append(i)\n",
    "    G.add_edges_from([(i,(i+1)%n1) for i in range(n1)] + [(i+n1,(i+1)%n2+n1) for i in range(n2)] + [(i,i+n1) for i in range(n1)])\n",
    "    return G,role_id\n",
    "def create_motif_triangle(size,node_feature_mean,std):\n",
    "    # 创建一个空图\n",
    "    G = nx.Graph()\n",
    "    role_id = []\n",
    "    n = int(np.sqrt(2 * size))\n",
    "    for i in range(n):\n",
    "        for j in range(i + 1):\n",
    "            node_idx = (i * (i + 1) // 2) + j\n",
    "            node_feature = np.random.normal(node_feature_mean, std)\n",
    "            G.add_node(node_idx, feature=node_feature)\n",
    "            role_id.append(i)\n",
    "            if i > 0:\n",
    "                upper_node_idx = (i * (i - 1) // 2) + j\n",
    "                G.add_edge(node_idx, upper_node_idx)\n",
    "                if j > 0:\n",
    "                    left_upper_node_idx = upper_node_idx - 1\n",
    "                    G.add_edge(node_idx, left_upper_node_idx)\n",
    "                if j < i:\n",
    "                    right_upper_node_idx = upper_node_idx\n",
    "                    G.add_edge(node_idx, right_upper_node_idx)\n",
    "            if j > 0:\n",
    "                left_node_idx = (i * (i + 1) // 2) + (j - 1)\n",
    "                G.add_edge(node_idx, left_node_idx)\n",
    "    return G,role_id\n",
    "   #带环形状，具有一个环以及从环中的每个节点到其他节点的随机边\n",
    "def create_motif_ringShape(size,node_feature_mean,std):   \n",
    "    ring_nodes = []\n",
    "    for i in range(size):\n",
    "        theta = i / size * 2 * np.pi\n",
    "        r = 1.5 + np.random.normal(node_feature_mean, std)\n",
    "        x = r * np.cos(theta)\n",
    "        y = r * np.sin(theta)\n",
    "        ring_nodes.append((x, y))\n",
    "    # Connect the ring nodes to form a ring\n",
    "    ring_edges = []\n",
    "    for i in range(size):\n",
    "        ring_edges.append((i, (i+1)%size))\n",
    "    # Add random edges to the ring to form a connected graph\n",
    "    extra_edges = set()\n",
    "    while len(extra_edges) < size:\n",
    "        u = np.random.randint(0, size)\n",
    "        v = np.random.randint(0, size)\n",
    "        if u == v:\n",
    "            continue\n",
    "        if (u, v) in extra_edges or (v, u) in extra_edges:\n",
    "            continue\n",
    "        extra_edges.add((u, v))\n",
    "    # Combine all edges\n",
    "    edges = list(ring_edges) + list(extra_edges)\n",
    "    # Construct graph\n",
    "    G = nx.Graph()\n",
    "    G.add_nodes_from(range(size))\n",
    "    G.add_edges_from(edges)\n",
    "\n",
    "    # Assign node features\n",
    "    node_features = []\n",
    "    for i in range(size):\n",
    "        node_features.append(np.random.normal(node_feature_mean, std))\n",
    "    # Assign role ID\n",
    "    role_id = [0] * size\n",
    "    role_id[0] = 1\n",
    "    role_id[-1] = 2\n",
    "    return G,role_id\n",
    "def create_motif_diamond(size,node_feature_mean,std):\n",
    "    num_nodes = 2*size + 1\n",
    "    node_features = []\n",
    "    for i in range(num_nodes):\n",
    "        node_features.append(np.random.normal(node_feature_mean, std))\n",
    "    edges = []\n",
    "    for i in range(size):\n",
    "        edges.append((i, i+1))\n",
    "        edges.append((i, size+i+1))\n",
    "        edges.append((size+i+1, i+1))\n",
    "    edges.append((size, size*2))\n",
    "    G = nx.Graph()\n",
    "    G.add_nodes_from(range(num_nodes))\n",
    "    G.add_edges_from(edges)\n",
    "    # Assign role ID\n",
    "    role_id = [0] * num_nodes\n",
    "    role_id[0] = 1\n",
    "    role_id[-1] = 2\n",
    "    return G,role_id\n",
    "def create_motif_HShape(size,node_feature_mean,std):\n",
    "     num_nodes = 2*size + 1\n",
    "     node_features = []\n",
    "     for i in range(num_nodes):\n",
    "        node_features.append(np.random.normal(node_feature_mean, std))\n",
    "     edges = []\n",
    "     for i in range(size):\n",
    "        edges.append((i, i+1))\n",
    "        edges.append((i, size+i+1))\n",
    "        edges.append((size+i, size+i+1))\n",
    "     G = nx.Graph()\n",
    "     G.add_nodes_from(range(num_nodes))\n",
    "     G.add_edges_from(edges)\n",
    "     # Assign role ID\n",
    "     role_id = [0] * num_nodes\n",
    "     role_id[0] = 1\n",
    "     role_id[-1] = 2\n",
    "     return G,role_id\n",
    "def create_motif_wheel(size,node_feature_mean,std):\n",
    "    num_nodes = size + 1\n",
    "    role_id = []\n",
    "    node_features = []\n",
    "    for i in range(num_nodes):\n",
    "        node_features.append(np.random.normal(node_feature_mean, std))\n",
    "        role_id.append(i)\n",
    "    edges = []\n",
    "    for i in range(1, size+1):\n",
    "        edges.append((0, i))\n",
    "        edges.append((i, (i % size) + 1))\n",
    "    G = nx.Graph()\n",
    "    G.add_nodes_from(range(num_nodes))\n",
    "    G.add_edges_from(edges)\n",
    "    return G,role_id\n",
    "def create_motif_hourglass(size,node_feature_mean,std):\n",
    "    num_nodes = 2*size + 1\n",
    "    role_id = []\n",
    "    node_features = []\n",
    "    for i in range(num_nodes):\n",
    "        node_features.append(np.random.normal(node_feature_mean, std))\n",
    "        role_id.append(i)\n",
    "    edges = []\n",
    "    for i in range(size):\n",
    "        edges.append((0, i+1))\n",
    "        edges.append((i+1, i+size+1))\n",
    "        edges.append((i+size+1, num_nodes-1))\n",
    "    G = nx.Graph()\n",
    "    G.add_nodes_from(range(num_nodes))\n",
    "    G.add_edges_from(edges)\n",
    "    return G,role_id\n",
    "def create_motif_DCD(size,node_feature_mean,std):\n",
    "    #DCD 形状图是一组由三个钻石形状组成的图形，其中一个中心钻石形状被两个周围钻石形状围绕\n",
    "    num_nodes = 3*size + 1\n",
    "    role_id = []\n",
    "    node_features = []\n",
    "    for i in range(num_nodes):\n",
    "        node_features.append(np.random.normal(node_feature_mean, std))\n",
    "        role_id.append(i)\n",
    "    edges = []\n",
    "    for i in range(size):\n",
    "        edges.append((i, i+1))\n",
    "        edges.append((i, size+i+1))\n",
    "        edges.append((size+i, size+i+1))\n",
    "        edges.append((size+i, 2*size+i+1))\n",
    "        edges.append((2*size+i, 2*size+i+1))\n",
    "    G = nx.Graph()\n",
    "    G.add_nodes_from(range(num_nodes))\n",
    "    G.add_edges_from(edges)\n",
    "    return G,role_id\n",
    "def create_motif_Cyclocross(size,node_feature_mean,std):\n",
    "    num_nodes = 2*size + 1\n",
    "    role_id = []\n",
    "    node_features = []\n",
    "    for i in range(num_nodes):\n",
    "        node_features.append(np.random.normal(node_feature_mean, std))\n",
    "        role_id.append(i)\n",
    "    edges = []\n",
    "    for i in range(size):\n",
    "        edges.append((i, i+1))\n",
    "        edges.append((size+i, size+i+1))\n",
    "    edges.append((size, 0))\n",
    "    edges.append((size, 2*size))\n",
    "    G = nx.Graph()\n",
    "    G.add_nodes_from(range(num_nodes))\n",
    "    G.add_edges_from(edges)\n",
    "    return G,role_id\n",
    "def create_motif_ladder(size,node_feature_mean,std):\n",
    "    num_nodes = 2 * size\n",
    "    role_id = []\n",
    "    node_features = []\n",
    "    for i in range(num_nodes):\n",
    "        node_features.append(np.random.normal(node_feature_mean, std))\n",
    "        role_id.append(i)\n",
    "    edges = [(i, i+1) for i in range(size-1)] + [(i+size, i+size+1) for i in range(size-1)]\n",
    "    edges += [(i, i+size) for i in range(size)]\n",
    "    G = nx.Graph()\n",
    "    G.add_nodes_from(range(num_nodes))\n",
    "    G.add_edges_from(edges)\n",
    "    return G,role_id\n",
    "def create_motif_bowtie(size,node_feature_mean,std):   \n",
    "    #\"bowtie\"形状图的特点是一个中心节点连接两个对称的亚图，亚图中每个节点与中心节点相连，因此中心节点的度为2，亚图中每个节点的度为1\n",
    "    G = nx.Graph()\n",
    "    for i in range(size):\n",
    "        G.add_node(i, feature=np.random.normal(node_feature_mean, std))\n",
    "    G.add_edge(0, 1)\n",
    "    for i in range(2, size):\n",
    "        if i % 2 == 0:\n",
    "            G.add_edge(0, i)\n",
    "        else:\n",
    "            G.add_edge(1, i)\n",
    "    role_id = np.zeros((size,))\n",
    "    role_id[0] = 0  # center node\n",
    "    role_id[1] = 1  # node connected to center node\n",
    "    for i in range(2, size):\n",
    "        if i % 2 == 0:\n",
    "            role_id[i] = 1\n",
    "        else:\n",
    "            role_id[i] = 2\n",
    "    return G,role_id\n",
    "def create_motif_cross(size,node_feature_mean,std):   \n",
    "    #每个节点都与中心节点相连，并且连接着一个水平和一个垂直臂\n",
    "    G = nx.Graph()\n",
    "    role_id = np.zeros(size, dtype=np.int32)\n",
    "    for i in range(size):\n",
    "        G.add_node(i, feature=np.random.normal(node_feature_mean, std, [5]))\n",
    "        if i == size // 2: # Connect to the central node\n",
    "            G.add_edge(i, 0)\n",
    "            role_id[i] = 1 # Central node\n",
    "        else: # Connect to the horizontal and vertical arms\n",
    "            G.add_edge(i, (i+1) % (size//2))\n",
    "            G.add_edge(i, (i+size//2) % size)\n",
    "            if i < size // 2:\n",
    "                role_id[i] = 2 # Horizontal arm\n",
    "            else:\n",
    "                role_id[i] = 3 # Vertical arm\n",
    "    return G,role_id\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "'''\n",
    "维护一个列表,简单描述各个motif生成方法(图分类人工合成数据集)\n",
    "'''\n",
    "motif_generators = {\n",
    "    1: (create_motif_star_branch, \"星形,星的节点数,分支,节点特征可调\"),\n",
    "    2: (create_motif_path_branch, \"路径形,路径节点数,分支,节点特征可调\"),\n",
    "    3: (create_motif_fan_branch, \"扇形,扇形的节点数,分支,节点特征可调\"),\n",
    "    4: (create_motif_cuspedPolygon_branch, \"尖角多边形,尖角多边形的节点数,分支,节点特征可调\"),\n",
    "    5: (create_motif_random_bipartite_branch, \"随机二分图,随机二分图的节点数,分支,节点特征可调\"),\n",
    "    6: (create_motif_tree_branch, \"树形,树形的节点数,分支,节点特征可调\"),\n",
    "    7: (create_motif_trident_branch, \"三叉戟形,三叉戟形的节点数,分支,节点特征可调\"),\n",
    "    8: (create_motif_conicalConnection_branch, \"锥形连接图,锥形连接图的节点数,分支,节点特征可调\"),\n",
    "    9: (create_motif_chainBypass_branch, \"链旁路形,链旁路形图的节点数,分支,节点特征可调\"),\n",
    "    10: (create_motif_trident_branch, \"部分多边形,部分多边形的节点数,分支,节点特征可调\"),\n",
    "    11: (create_motif_completeGraph, \"完全图,完全图的节点数,特征可调\"),\n",
    "    12: (create_motif_dircycle, \"循环图,循环图的节点数,特征可调\"),\n",
    "    13: (create_motif_dualRing, \"双环图,双环图的节点数,特征可调\"),\n",
    "    14: (create_motif_triangle, \"三角图,三角图的节点数,特征可调\"),\n",
    "    15: (create_motif_ringShape, \"带环形状图,带环形状图的节点数,特征可调\"),\n",
    "    16: (create_motif_diamond, \"钻石图,钻石图的节点数,特征可调\"),\n",
    "    17: (create_motif_HShape, \"H形图,H形图的节点数,特征可调\"),\n",
    "    18: (create_motif_wheel, \"车轮图,车轮图的节点数,特征可调\"),\n",
    "    19: (create_motif_hourglass, \"沙漏图,沙漏图的节点数,特征可调\"),\n",
    "    20: (create_motif_DCD, \"DCD三钻石形状图,DCD三钻石图的节点数,特征可调\"),\n",
    "    21: (create_motif_Cyclocross, \"循环十字图,循环十字图的节点数,特征可调\"),\n",
    "    22: (create_motif_netShape, \"网形图,网形图的节点数,特征可调\"),\n",
    "    23: (create_motif_ladder, \"梯子图,梯子图的节点数,特征可调\"),\n",
    "    24: (create_motif_bowtie, \"领结图,领结图的节点数,特征可调\"),\n",
    "    25: (create_motif_cross, \"十字臂图,十字臂图的节点数,特征可调\")\n",
    "}\n",
    "#relation类型\n",
    "#包含，两个motif中顶点较少的完全被另一个所包含\n",
    "def include_smaller_graph(G1, G2):\n",
    "    \"\"\"\n",
    "    将顶点较少的图 G1 完全包含在顶点较多的图 G2 中。\n",
    "    Args:\n",
    "        G1: networkx.Graph, 第一个图。\n",
    "        G2: networkx.Graph, 第二个图。\n",
    "    Returns:\n",
    "        G: networkx.Graph, 新的图，包含 G1 和 G2 中的所有节点和边。\n",
    "    \"\"\"\n",
    "    if len(G1.nodes()) > len(G2.nodes()):\n",
    "        G1, G2 = G2, G1  # 确保 G1 是顶点较少的图\n",
    "    # 在 G2 中添加 G1 中的节点\n",
    "    G = nx.Graph()\n",
    "    G.add_nodes_from(G2.nodes())\n",
    "    G.add_nodes_from(G1.nodes())\n",
    "    # 将 G1 中的边添加到 G 中，同时保证 G1 的所有节点都在 G2 中\n",
    "    for u, v in G1.edges():\n",
    "        if u in G2.nodes() and v in G2.nodes():\n",
    "            G.add_edge(u, v)\n",
    "    # 添加 G2 中除 G1 外的边\n",
    "    for u, v in G2.edges():\n",
    "        if u in G.nodes() and v in G.nodes():\n",
    "            G.add_edge(u, v)\n",
    "    # 判断新图是否连通，如果不连通则添加路径连接各个联通子图\n",
    "    if len(list(nx.connected_components(G))) > 1:\n",
    "        components = list(nx.connected_components(G))\n",
    "        for i in range(len(components)-1):\n",
    "            u = random.sample(components[i], 1)[0]\n",
    "            v = random.sample(components[i+1], 1)[0]\n",
    "            nx.add_path(G, [u, v])\n",
    "    # 生成与节点数相同的 role_id 列表\n",
    "    #role_id = [i for i in range(G.number_of_nodes())]\n",
    "    role_id = [0] * G.number_of_nodes()  # 创建一个初始全零的 role_id 列表\n",
    "    role_id[-6:] = torch.randint(1, 4, (6,)).tolist()  # 将后六位设置为 1、2 或 3 中的随机数\n",
    "    # Convert graph to edge index representation\n",
    "    edge_index = torch.tensor(list(G.edges()), dtype=torch.long).t().contiguous()\n",
    "    if edge_index.size(1) == 0:\n",
    "    # Handle the case where edge_index is empty\n",
    "       default_value = random.randint(0, num_nodes - 1)  # Generate a random default value\n",
    "       edge_index = torch.tensor([[default_value], [default_value]], dtype=torch.long)\n",
    "# Ensure node indices are consecutive and satisfy the assertion\n",
    "    node_idx = torch.unique(edge_index.flatten())\n",
    "    num_nodes = node_idx.size(0)\n",
    "    max_node_idx = torch.max(node_idx).item()\n",
    "    if max_node_idx >= num_nodes or max_node_idx < num_nodes - 1:\n",
    "       num_nodes = max_node_idx + 1\n",
    "       node_idx = torch.arange(num_nodes, dtype=torch.long, device=edge_index.device)\n",
    "    node_idx_map = torch.zeros_like(node_idx)\n",
    "    node_idx_map[node_idx] = torch.arange(num_nodes)\n",
    "    edge_index = node_idx_map[edge_index]\n",
    "    assert node_idx.max() == node_idx.size(0) - 1\n",
    "    return G,role_id,edge_index\n",
    "#纠缠，即两个motif通过多条边连接在一起，边数可配置\n",
    "def multiple_edges_connection(G1, G2, common_edges):\n",
    "    # 随机选择一些边作为公用边\n",
    "    if common_edges > min(G1.number_of_edges(), G2.number_of_edges()):\n",
    "        common_edges = min(G1.number_of_edges(), G2.number_of_edges())\n",
    "    # 随机选择一些边作为公用边\n",
    "    common_edges1 = random.sample(G1.edges(), common_edges)\n",
    "    common_edges2 = random.sample(G2.edges(), common_edges)\n",
    "    # 将公用边从第二个图中删除\n",
    "    G2.remove_edges_from(common_edges2)\n",
    "    # 创建新的图并将两个图连接起来\n",
    "    G = nx.disjoint_union(G1, G2)\n",
    "    for u, v in common_edges1:\n",
    "        G.add_edge(u, v + G1.number_of_nodes())\n",
    "    # 为孤立节点添加随机边\n",
    "    isolated_nodes = [n for n in G.nodes() if G.degree(n) == 0]\n",
    "    for u in isolated_nodes:\n",
    "        v = random.choice(list(G.nodes()))\n",
    "        while G.has_edge(u, v) or u == v:\n",
    "            v = random.choice(list(G.nodes()))\n",
    "        G.add_edge(u, v)\n",
    "    # 判断图是否连通，如果不连通就添加边\n",
    "    if not nx.is_connected(G):\n",
    "        components = nx.connected_components(G)\n",
    "        largest_component = max(components, key=len)\n",
    "        isolated_nodes = [n for n in G.nodes() if n not in largest_component]\n",
    "        for u in isolated_nodes:\n",
    "            v = random.choice(list(largest_component))\n",
    "            G.add_edge(u, v)\n",
    "    # 生成与节点数相同的 role_id 列表\n",
    "    #role_id = [i for i in range(G.number_of_nodes())]\n",
    "    role_id = [0] * G.number_of_nodes()  # 创建一个初始全零的 role_id 列表\n",
    "    role_id[-6:] = torch.randint(1, 4, (6,)).tolist()  # 将后六位设置为 1、2 或 3 中的随机数\n",
    "    # 删除节点编号大于图的总节点数的节点\n",
    "    max_node_id = G.number_of_nodes() - 1\n",
    "    G.remove_nodes_from([n for n in G.nodes() if n > max_node_id])\n",
    "    # Convert graph to edge index representation\n",
    "    edge_index = torch.tensor(list(G.edges()), dtype=torch.long).t().contiguous()\n",
    "    if edge_index.size(1) == 0:\n",
    "    # Handle the case where edge_index is empty\n",
    "       default_value = random.randint(0, num_nodes - 1)  # Generate a random default value\n",
    "       edge_index = torch.tensor([[default_value], [default_value]], dtype=torch.long)\n",
    "# Ensure node indices are consecutive and satisfy the assertion\n",
    "    node_idx = torch.unique(edge_index.flatten())\n",
    "    num_nodes = node_idx.size(0)\n",
    "    max_node_idx = torch.max(node_idx).item()\n",
    "    if max_node_idx >= num_nodes or max_node_idx < num_nodes - 1:\n",
    "       num_nodes = max_node_idx + 1\n",
    "       node_idx = torch.arange(num_nodes, dtype=torch.long, device=edge_index.device)\n",
    "    # Adjust the size of node_idx_map to match the updated num_nodes\n",
    "    node_idx_map = torch.zeros(num_nodes, dtype=torch.long, device=node_idx.device)\n",
    "    node_idx_map[node_idx] = torch.arange(num_nodes, dtype=torch.long, device=edge_index.device)[node_idx]\n",
    "    edge_index = node_idx_map[edge_index]\n",
    "    assert node_idx.max() == node_idx.size(0) - 1\n",
    "    return G,role_id,edge_index\n",
    "#交叉，两个motif间公用一些顶点，公用顶点数可配置\n",
    "def multiple_nodes_connection(G1, G2,common_nodes):\n",
    "    # 随机选择一些节点作为公用节点\n",
    "    common_nodes1 = set(random.sample(G1.nodes(), common_nodes))\n",
    "    common_nodes2 = set(random.sample(G2.nodes(), common_nodes))\n",
    "    # 将公用节点从第二个图中删除\n",
    "    G2.remove_nodes_from(common_nodes2)\n",
    "    # 创建新的图并将两个图连接起来\n",
    "    G = nx.disjoint_union(G1, G2)\n",
    "    while G.number_of_nodes() > G1.number_of_nodes() + G2.number_of_nodes():\n",
    "        G.remove_node(random.choice(list(G.nodes())))\n",
    "    for node in common_nodes1:\n",
    "        if node + G1.number_of_nodes() < G.number_of_nodes():\n",
    "            G.add_edge(node, node + G1.number_of_nodes())\n",
    "    # 检查图是否是连通的，如果不是，添加边直到图变成连通的\n",
    "    if not nx.is_connected(G):\n",
    "        components = list(nx.connected_components(G))\n",
    "        for i in range(len(components) - 1):\n",
    "            u = random.choice(list(components[i]))\n",
    "            v = random.choice(list(components[i+1]))\n",
    "            G.add_edge(u, v)\n",
    "   # 生成与节点数相同的 role_id 列表\n",
    "   #role_id = [i for i in range(G.number_of_nodes())]\n",
    "    role_id = [0] * G.number_of_nodes()  # 创建一个初始全零的 role_id 列表\n",
    "    role_id[-6:] = torch.randint(1, 4, (6,)).tolist()  # 将后六位设置为 1、2 或 3 中的随机数\n",
    "    # Convert graph to edge index representation\n",
    "    edge_index = torch.tensor(list(G.edges()), dtype=torch.long).t().contiguous()\n",
    "    if edge_index.size(1) == 0:\n",
    "    # Handle the case where edge_index is empty\n",
    "       default_value = random.randint(0, num_nodes - 1)  # Generate a random default value\n",
    "       edge_index = torch.tensor([[default_value], [default_value]], dtype=torch.long)\n",
    "    # Ensure node indices are consecutive and satisfy the assertion\n",
    "    node_idx = torch.unique(edge_index)\n",
    "    num_nodes = node_idx.size(0)\n",
    "    max_node_idx = torch.max(node_idx).item()\n",
    "    if max_node_idx >= num_nodes:\n",
    "       num_nodes = max_node_idx + 1\n",
    "    # Adjust the size of node_idx_map to match the updated num_nodes\n",
    "    node_idx_map = torch.zeros(num_nodes, dtype=torch.long, device=node_idx.device)\n",
    "    node_idx_map[node_idx] = torch.arange(num_nodes, dtype=torch.long, device=edge_index.device)\n",
    "    edge_index = node_idx_map[edge_index]\n",
    "    assert node_idx.max() == node_idx.size(0) - 1\n",
    "    return G, role_id, edge_index\n",
    "\n",
    "#相邻，即两个motif通过一条边连接在一起\n",
    "def adjacent_connection(G1,G2):\n",
    "    nodes1 = set(G1.nodes())\n",
    "    nodes2 = set(G2.nodes())\n",
    "    if not nodes1 or not nodes2 or not G1.edges() or not G2.edges():\n",
    "        return nx.Graph(), [], torch.tensor([], dtype=torch.long)  # 返回空的图形、空的 role_id 列表和空的 edge_index\n",
    "    common_nodes = nodes1.intersection(nodes2)\n",
    "    # 随机选择两个图中的一条边作为相邻拼接的公共边\n",
    "    edge1 = random.choice(list(G1.edges()))\n",
    "    edge2 = random.choice(list(G2.edges()))\n",
    "    # 为相邻拼接的公共边创建新节点\n",
    "    new_node1 = max(nodes1.union(nodes2)) + 1\n",
    "    new_node2 = max(nodes1.union(nodes2)) + 2\n",
    "    # 在两个图中替换相邻拼接的公共边\n",
    "    G1.remove_edge(*edge1)\n",
    "    G1.add_edge(edge1[0], new_node1)\n",
    "    G1.add_edge(new_node1, edge1[1])\n",
    "    G2.remove_edge(*edge2)\n",
    "    G2.add_edge(edge2[0], new_node2)\n",
    "    G2.add_edge(new_node2, edge2[1])\n",
    "    # 在两个图中加入新节点\n",
    "    G1.add_node(new_node2)\n",
    "    G2.add_node(new_node1)\n",
    "    # 合并两个图\n",
    "    G = nx.compose(G1, G2)\n",
    "    # 添加相邻拼接的边\n",
    "    G.add_edge(new_node1, new_node2)\n",
    "    # 确保在新图中添加两个原图中所有节点，包括共享的节点\n",
    "    for node in common_nodes:\n",
    "        G.add_node(node, role_id=np.random.randint(low=1, high=len(common_nodes) + 3))\n",
    "    # 判断图是否连通，如果不连通就添加边\n",
    "    if not nx.is_connected(G):\n",
    "        components = nx.connected_components(G)\n",
    "        largest_component = max(components, key=len)\n",
    "        isolated_nodes = [n for n in G.nodes() if n not in largest_component]\n",
    "        for u in isolated_nodes:\n",
    "            v = random.choice(list(largest_component))\n",
    "            G.add_edge(u, v)\n",
    "    # 生成与节点数相同的 role_id 列表\n",
    "    #role_id = [i for i in range(G.number_of_nodes())]\n",
    "    role_id = [0] * G.number_of_nodes()  # 创建一个初始全零的 role_id 列表\n",
    "    role_id[-6:] = torch.randint(1, 4, (6,)).tolist()  # 将后六位设置为 1、2 或 3 中的随机数\n",
    "    # Convert graph to edge index representation\n",
    "    edge_index = torch.tensor(list(G.edges()), dtype=torch.long).t().contiguous()\n",
    "    if edge_index.size(1) == 0:\n",
    "    # Handle the case where edge_index is empty\n",
    "       default_value = random.randint(0, num_nodes - 1)  # Generate a random default value\n",
    "       edge_index = torch.tensor([[default_value], [default_value]], dtype=torch.long)\n",
    "# Ensure node indices are consecutive and satisfy the assertion\n",
    "    node_idx = torch.unique(edge_index.flatten())\n",
    "    num_nodes = node_idx.size(0)\n",
    "    max_node_idx = torch.max(node_idx).item()\n",
    "    if max_node_idx >= num_nodes or max_node_idx < num_nodes - 1:\n",
    "       num_nodes = max_node_idx + 1\n",
    "       node_idx = torch.arange(num_nodes, dtype=torch.long, device=edge_index.device)\n",
    "    # Adjust the size of node_idx_map to match the updated num_nodes\n",
    "    node_idx_map = torch.zeros(num_nodes, dtype=torch.long, device=node_idx.device)\n",
    "    node_idx_map[node_idx] = torch.arange(num_nodes, dtype=torch.long, device=edge_index.device)[node_idx]\n",
    "    edge_index = node_idx_map[edge_index]\n",
    "    assert node_idx.max() == node_idx.size(0) - 1\n",
    "    return G,role_id,edge_index\n",
    "'''\n",
    "维护一个列表,简单描述各个motif连接方法(图分类人工合成数据集)\n",
    "'''\n",
    "motif_connectors = {\n",
    "    1: (adjacent_connection, \"相邻,即两个motif通过一条边连接在一起\"),\n",
    "    2: (multiple_nodes_connection, \"交叉,两个motif间公用一些顶点,公用顶点数可配置\"),\n",
    "    3: (multiple_edges_connection, \"纠缠,即两个motif通过多条边连接在一起,边数可配置\"),\n",
    "    4: (include_smaller_graph, \"包含,两个motif中顶点较少的完全被另一个所包含\"),\n",
    "}\n",
    "#G,role_id=motif_generators[3][0](10,4,[1.5,2.0,1.2,1.3,1.8], [1.5,2.0,1.2,1.3,1.8])\n",
    "#基于组合模式构建数据集\n",
    "def generate_graph_dataset(motif_generators, motif_connectors):\n",
    "    # generate motif m\n",
    "    m = random.randint(1, 10)\n",
    "    # set parameters for motif_m and motif_k using functions f and g\n",
    "    a = random.randint(2, 4)\n",
    "    f = lambda x: 2*x\n",
    "    g = lambda x: int(0.5 * x)\n",
    "    motif_m=nx.Graph()\n",
    "    motif_k=nx.Graph()\n",
    "    motif_n=nx.Graph()\n",
    "    while not motif_m:\n",
    "      #节点纬度为5\n",
    "      motif_m,role_id1= motif_generators[m][0](random.randint(5,20),a,[1.5,2.0,1.2,1.3,1.8], [1.5,2.0,1.2,1.3,1.8])\n",
    "    # generate motif k\n",
    "    if random.random() < 0.8:\n",
    "            k = m + 10\n",
    "    else:   \n",
    "            k = random.randint(11, 25)\n",
    "    while not motif_k:\n",
    "      motif_k,role_id2= motif_generators[k][0](f(a),[1.5,2.0,1.2,1.3,1.8], [1.5,2.0,1.2,1.3,1.8])\n",
    "    # combine motifs using relation r1(有点问题)\n",
    "    G,role_id3,edge_index3= motif_generators[1][0](motif_m, motif_k)\n",
    "    # add more motifs\n",
    "    n = random.randint(1, 10)\n",
    "    while not motif_n:\n",
    "      motif_n,role_id4 = motif_generators[n][0](random.randint(5,20),g(a),[1.5,2.0,1.2,1.3,1.8], [1.5,2.0,1.2,1.3,1.8])\n",
    "    # combine all motifs using relation r2\n",
    "    r2=random.randint(2, 4)\n",
    "    if r2 in [2, 3]:\n",
    "      G,role_id,edge_index = motif_generators[r2][0](G, motif_n,3)\n",
    "    else:\n",
    "      G,role_id,edge_index = motif_generators[r2][0](G, motif_n)\n",
    "    '''\n",
    "    #节点特征处理\n",
    "    feature_generator=None\n",
    "    if feature_generator is None:\n",
    "        feature_generator = featgen.ConstFeatureGen(1)\n",
    "    feature_generator.gen_node_features(G)\n",
    "    '''\n",
    "    # set label for sample based on h(m, a)\n",
    "    h = lambda x, y: (x + y) % 3\n",
    "    label = h(m, a)\n",
    "    return G, role_id, label,edge_index\n",
    "def generate_Y0():\n",
    "      a = random.randint(2, 4)\n",
    "      motif1,role_id1= motif_generators[1][0](random.randint(5,10),a,[1.5,2.0,1.2,1.3,1.8], [1.5,2.0,1.2,1.3,1.8])\n",
    "      motif2,role_id2= motif_generators[2][0](random.randint(5,10),a,[1.5,2.0,1.2,1.3,1.8], [1.5,2.0,1.2,1.3,1.8])\n",
    "      motifY0,role_id,edge_index= motif_connectors[2][0](motif1, motif2,2)\n",
    "      label=0\n",
    "      return motifY0, role_id, label,edge_index\n",
    "def generate_Y1():\n",
    "      a = random.randint(2, 4)\n",
    "      motif1,role_id1= motif_generators[1][0](random.randint(5,10),a,[1.5,2.0,1.2,1.3,1.8], [1.5,2.0,1.2,1.3,1.8])\n",
    "      motif3,role_id3= motif_generators[3][0](random.randint(5,10),a,[1.5,2.0,1.2,1.3,1.8], [1.5,2.0,1.2,1.3,1.8])\n",
    "      G,role_id,edge_index= motif_connectors[1][0](motif1, motif3)\n",
    "      motif5,role_id5= motif_generators[3][0](random.randint(5,10),a,[1.5,2.0,1.2,1.3,1.8], [1.5,2.0,1.2,1.3,1.8])\n",
    "      motifY1,role_id,edge_index= motif_connectors[2][0](G, motif5,2)\n",
    "      label=1\n",
    "      return motifY1, role_id, label,edge_index\n",
    "def generate_Y2():\n",
    "      a = random.randint(2, 4)\n",
    "      motif1,role_id1= motif_generators[1][0](random.randint(5,10),a,[1.5,2.0,1.2,1.3,1.8], [1.5,2.0,1.2,1.3,1.8])\n",
    "      motif2,role_id2= motif_generators[2][0](random.randint(5,10),a,[1.5,2.0,1.2,1.3,1.8], [1.5,2.0,1.2,1.3,1.8])\n",
    "      G,role_id,edge_index= motif_connectors[3][0](motif1, motif2,2)\n",
    "      motif5,role_id5= motif_generators[3][0](random.randint(5,10),a,[1.5,2.0,1.2,1.3,1.8], [1.5,2.0,1.2,1.3,1.8])\n",
    "      motifY2,role_id,edge_index= motif_connectors[2][0](G, motif5,2)\n",
    "      label=2\n",
    "      return motifY2, role_id, label,edge_index\n",
    "def generate_Y3():\n",
    "      a = random.randint(2, 4)\n",
    "      motif4,role_id4= motif_generators[4][0](random.randint(5,10),a,[1.5,2.0,1.2,1.3,1.8], [1.5,2.0,1.2,1.3,1.8])\n",
    "      motif5,role_id5= motif_generators[5][0](random.randint(5,10),a,[1.5,2.0,1.2,1.3,1.8], [1.5,2.0,1.2,1.3,1.8])\n",
    "      motifY3,role_id,edge_index= motif_connectors[1][0](motif4, motif5)\n",
    "      label=3  \n",
    "      return motifY3, role_id, label,edge_index\n",
    "def generate_Y4():\n",
    "      a = random.randint(2, 4)\n",
    "      motif3,role_id3= motif_generators[3][0](random.randint(5,10),a,[1.5,2.0,1.2,1.3,1.8], [1.5,2.0,1.2,1.3,1.8])\n",
    "      motif4,role_id4= motif_generators[4][0](random.randint(5,10),a,[1.5,2.0,1.2,1.3,1.8], [1.5,2.0,1.2,1.3,1.8])\n",
    "      motifY4,role_id,edge_index= motif_connectors[2][0](motif3, motif4,2)\n",
    "      label=4\n",
    "      return motifY4, role_id, label,edge_index\n",
    "def generate_real_dataset():\n",
    "      #生成图数据集\n",
    "      #与标签具有真实因果关系的图数据\n",
    "     #for _ in range(num_samples):\n",
    "        y = random.choice([0, 1, 2 ,3 ,4])  # 随机选择Y的取值\n",
    "        if y == 0:\n",
    "           G, role_id, label,edge_index=generate_Y0()\n",
    "           motif1_present = True\n",
    "           motif2_present = True\n",
    "           motif3_present = False\n",
    "           motif4_present = False\n",
    "           motif5_present = False\n",
    "        elif y == 1:\n",
    "           G, role_id, label,edge_index=generate_Y1()\n",
    "           motif1_present = True\n",
    "           motif2_present = False\n",
    "           motif3_present = True\n",
    "           motif4_present = False\n",
    "           motif5_present = True\n",
    "        elif y == 2:\n",
    "           G, role_id, label,edge_index=generate_Y2()\n",
    "           motif1_present = True\n",
    "           motif2_present = True\n",
    "           motif3_present = False\n",
    "           motif4_present = False\n",
    "           motif5_present = True\n",
    "        elif y == 3:\n",
    "           G, role_id, label,edge_index=generate_Y3()\n",
    "           motif1_present = False\n",
    "           motif2_present = False\n",
    "           motif3_present = False\n",
    "           motif4_present = True\n",
    "           motif5_present = True\n",
    "        elif y == 4:\n",
    "           G, role_id, label,edge_index=generate_Y4()\n",
    "           motif1_present = False\n",
    "           motif2_present = False\n",
    "           motif3_present = True\n",
    "           motif4_present = True\n",
    "           motif5_present = False\n",
    "        return G, role_id, label,edge_index,motif1_present,motif2_present,motif3_present,motif4_present,motif5_present\n",
    "\n",
    "def generate_false_cause_dataset1(motif_generators, motif_connectors):\n",
    "        graph, role_id, label,edge_index,motif1_present,motif2_present,motif3_present,motif4_present,motif5_present=generate_real_dataset()\n",
    "        a = random.randint(2, 4)\n",
    "        # 根据规则判断是否与motif6-10相关联\n",
    "        if motif1_present == True:\n",
    "           motif6,role_id6= motif_generators[6][0](random.randint(5,10),a,[1.5,2.0,1.2,1.3,1.8], [1.5,2.0,1.2,1.3,1.8])\n",
    "           graph,role_id,edge_index= motif_connectors[1][0](graph, motif6)\n",
    "        elif motif2_present == True:\n",
    "           motif7,role_id7= motif_generators[7][0](random.randint(5,10),a,[1.5,2.0,1.2,1.3,1.8], [1.5,2.0,1.2,1.3,1.8])\n",
    "           graph,role_id,edge_index= motif_connectors[1][0](graph, motif7) \n",
    "        elif motif3_present == True:\n",
    "           motif8,role_id8= motif_generators[8][0](random.randint(5,10),a,[1.5,2.0,1.2,1.3,1.8], [1.5,2.0,1.2,1.3,1.8])\n",
    "           graph,role_id,edge_index= motif_connectors[1][0](graph, motif8) \n",
    "        elif motif4_present == True:\n",
    "           motif9,role_id9= motif_generators[9][0](random.randint(5,10),a,[1.5,2.0,1.2,1.3,1.8], [1.5,2.0,1.2,1.3,1.8])\n",
    "           graph,role_id,edge_index= motif_connectors[1][0](graph, motif9) \n",
    "        elif motif5_present == True:\n",
    "           motif10,role_id10= motif_generators[10][0](random.randint(5,10),a,[1.5,2.0,1.2,1.3,1.8], [1.5,2.0,1.2,1.3,1.8])\n",
    "           graph,role_id,edge_index= motif_connectors[1][0](graph, motif10)\n",
    "        else:\n",
    "           graph, role_id, label,edge_index=generate_false_dataset(motif_generators, motif_connectors)\n",
    "        return graph,role_id,label,edge_index\n",
    "\n",
    "def generate_false_cause_dataset2(motif_generators, motif_connectors):\n",
    "        graph, role_id, label,edge_index,motif1_present,motif2_present,motif3_present,motif4_present,motif5_present=generate_real_dataset()\n",
    "        a = random.randint(2, 4)\n",
    "        # 根据规则判断是否与motif6-10相关联\n",
    "        if motif1_present == True and random.random() < 0.8:\n",
    "           motif6,role_id6= motif_generators[6][0](random.randint(5,10),a,[1.5,2.0,1.2,1.3,1.8], [1.5,2.0,1.2,1.3,1.8])\n",
    "           graph,role_id,edge_index= motif_connectors[1][0](graph, motif6)\n",
    "        elif motif2_present == True and random.random() < 0.8:\n",
    "           motif7,role_id7= motif_generators[7][0](random.randint(5,10),a,[1.5,2.0,1.2,1.3,1.8], [1.5,2.0,1.2,1.3,1.8])\n",
    "           graph,role_id,edge_index= motif_connectors[1][0](graph, motif7) \n",
    "        elif motif3_present == True and random.random() < 0.8:\n",
    "           motif8,role_id8= motif_generators[8][0](random.randint(5,10),a,[1.5,2.0,1.2,1.3,1.8], [1.5,2.0,1.2,1.3,1.8])\n",
    "           graph,role_id,edge_index= motif_connectors[1][0](graph, motif8) \n",
    "        elif motif4_present == True and random.random() < 0.8:\n",
    "           motif9,role_id9= motif_generators[9][0](random.randint(5,10),a,[1.5,2.0,1.2,1.3,1.8], [1.5,2.0,1.2,1.3,1.8])\n",
    "           graph,role_id,edge_index= motif_connectors[1][0](graph, motif9) \n",
    "        elif motif5_present == True and random.random() < 0.8:\n",
    "           motif10,role_id10= motif_generators[10][0](random.randint(5,10),a,[1.5,2.0,1.2,1.3,1.8], [1.5,2.0,1.2,1.3,1.8])\n",
    "           graph,role_id,edge_index= motif_connectors[1][0](graph, motif10)\n",
    "        else:\n",
    "            graph, role_id, label,edge_index=generate_false_dataset(motif_generators, motif_connectors)\n",
    "        return graph, role_id, label,edge_index\n",
    "\n",
    "\n",
    "def generate_false_cause_dataset3(motif_generators, motif_connectors):\n",
    "        G, role_id, label,edge_index,motif1_present,motif2_present,motif3_present,motif4_present,motif5_present=generate_real_dataset()\n",
    "        a = random.randint(2, 4)\n",
    "        # 根据规则判断是否与motif6-10相关联\n",
    "        if motif1_present == True and random.random() < 0.05:\n",
    "           motif6,role_id6= motif_generators[6][0](random.randint(5,10),a,[1.5,2.0,1.2,1.3,1.8], [1.5,2.0,1.2,1.3,1.8])\n",
    "           graph,role_id,edge_index= motif_connectors[1][0](G, motif6)\n",
    "        elif motif2_present == True and random.random() < 0.05:\n",
    "           motif7,role_id7= motif_generators[7][0](random.randint(5,10),a,[1.5,2.0,1.2,1.3,1.8], [1.5,2.0,1.2,1.3,1.8])\n",
    "           graph,role_id,edge_index= motif_connectors[1][0](G, motif7) \n",
    "        elif motif3_present == True and random.random() < 0.05:\n",
    "           motif8,role_id8= motif_generators[8][0](random.randint(5,10),a,[1.5,2.0,1.2,1.3,1.8], [1.5,2.0,1.2,1.3,1.8])\n",
    "           graph,role_id,edge_index= motif_connectors[1][0](G, motif8) \n",
    "        elif motif4_present == True and random.random() < 0.05:\n",
    "           motif9,role_id9= motif_generators[9][0](random.randint(5,10),a,[1.5,2.0,1.2,1.3,1.8], [1.5,2.0,1.2,1.3,1.8])\n",
    "           graph,role_id,edge_index= motif_connectors[1][0](G, motif9) \n",
    "        elif motif5_present == True and random.random() < 0.05:\n",
    "           motif10,role_id10= motif_generators[10][0](random.randint(5,10),a,[1.5,2.0,1.2,1.3,1.8], [1.5,2.0,1.2,1.3,1.8])\n",
    "           graph,role_id,edge_index= motif_connectors[1][0](G, motif10)\n",
    "        else:\n",
    "           graph, role_id, label,edge_index=generate_false_dataset(motif_generators, motif_connectors)\n",
    "        return graph, role_id, label,edge_index\n",
    "\n",
    "def generate_false_dataset(motif_generators, motif_connectors):\n",
    "    G, role_id, label,edge_index=generate_false_cause_dataset2(motif_generators, motif_connectors)\n",
    "    a = random.randint(2, 4)\n",
    "    \n",
    "    # numbers = random.sample(range(6,26),2)\n",
    "    # motifr1,role_idr2= motif_generators[numbers[0]][0](10,[1.5,2.0,1.2,1.3,1.8], [1.5,2.0,1.2,1.3,1.8])\n",
    "    # motifr2,role_idr2= motif_generators[numbers[0]][0](10,[1.5,2.0,1.2,1.3,1.8], [1.5,2.0,1.2,1.3,1.8])  \n",
    "    # graph1,role_id1,edge_index= motif_connectors[1][0](G, motifr1)\n",
    "    # graph,role_id2,edge_index= motif_connectors[1][0](graph1,motifr2)\n",
    "    \n",
    "    #与标签无关系的图数据\n",
    "    numbers = random.sample(range(11,26),1)\n",
    "    motifr1,role_idr2= motif_generators[numbers[0]][0](10,[1.5,2.0,1.2,1.3,1.8], [1.5,2.0,1.2,1.3,1.8])\n",
    "    graph,role_id1,edge_index= motif_connectors[1][0](G,motifr1)\n",
    "    '''\n",
    "    motifr3,role_idr3= motif_generators[numbers[2]][0](25,[1.5,2.0,1.2,1.3,1.8], [1.5,2.0,1.2,1.3,1.8])\n",
    "    motifr4,role_idr4= motif_generators[numbers[3]][0](25,[1.5,2.0,1.2,1.3,1.8], [1.5,2.0,1.2,1.3,1.8])\n",
    "    motifr5,role_idr5= motif_generators[numbers[4]][0](25,[1.5,2.0,1.2,1.3,1.8], [1.5,2.0,1.2,1.3,1.8])\n",
    "    graph3,role_id3,edge_index= motif_connectors[1][0](graph2,motifr3)\n",
    "    graph4,role_id4,edge_index= motif_connectors[1][0](graph3,motifr4)\n",
    "    graph,role_id5,edge_index= motif_connectors[1][0](graph4,motifr5)\n",
    "    '''\n",
    "    # numbers = random.sample(range(6,26),5)\n",
    "    # motifr1,role_idr1= motif_generators[numbers[0]][0](25,[1.5,2.0,1.2,1.3,1.8], [1.5,2.0,1.2,1.3,1.8])\n",
    "    # motifr2,role_idr2= motif_generators[numbers[1]][0](25,[1.5,2.0,1.2,1.3,1.8], [1.5,2.0,1.2,1.3,1.8])     \n",
    "    # motifr3,role_idr3= motif_generators[numbers[2]][0](25,[1.5,2.0,1.2,1.3,1.8], [1.5,2.0,1.2,1.3,1.8])\n",
    "    # motifr4,role_idr4= motif_generators[numbers[3]][0](25,[1.5,2.0,1.2,1.3,1.8], [1.5,2.0,1.2,1.3,1.8])\n",
    "    # motifr5,role_idr5= motif_generators[numbers[4]][0](25,[1.5,2.0,1.2,1.3,1.8], [1.5,2.0,1.2,1.3,1.8])\n",
    "    # graph1,role_id1,edge_index= motif_connectors[1][0](G, motifr1)\n",
    "    # graph2,role_id2,edge_index= motif_connectors[1][0](graph1,motifr2)\n",
    "    # graph3,role_id3,edge_index= motif_connectors[1][0](graph2,motifr3)\n",
    "    # graph4,role_id4,edge_index= motif_connectors[1][0](graph3,motifr4)\n",
    "    # graph,role_id5,edge_index= motif_connectors[1][0](graph4,motifr5)\n",
    "    G_noisy, role_id_noisy, label_noisy=add_noise(graph,0, 0.1, 0, 0.1,label)\n",
    "    return G_noisy, role_id_noisy, label_noisy,edge_index\n",
    "\n",
    "def generate_false_dataset2(motif_generators, motif_connectors):\n",
    "    G, role_id, label,edge_index,motif1_present,motif2_present,motif3_present,motif4_present,motif5_present=generate_real_dataset()\n",
    "    G, role_id, label,edge_index=generate_false_cause_dataset2(motif_generators, motif_connectors)\n",
    "    #与标签无关系的图数据\n",
    "    a = random.randint(2, 4)\n",
    "    numbers = random.sample(range(6,11),1)\n",
    "    if(6 <=numbers[0]<= 10):\n",
    "        motifr1,role_idr1= motif_generators[numbers[0]][0](10,a,[1.5,2.0,1.2,1.3,1.8], [1.5,2.0,1.2,1.3,1.8])\n",
    "    else:\n",
    "        motifr1,role_idr2= motif_generators[numbers[0]][0](10,[1.5,2.0,1.2,1.3,1.8], [1.5,2.0,1.2,1.3,1.8])\n",
    "    graph,role_id1,edge_index= motif_connectors[1][0](G, motifr1)\n",
    "    G_noisy, role_id_noisy, label_noisy=add_noise(graph,0, 0.1, 0, 0.1,label)\n",
    "    return  G_noisy, role_id_noisy, label_noisy,edge_index\n",
    "\n",
    "def generate_false_dataset3(motif_generators, motif_connectors):\n",
    "    G, role_id, label,edge_index,motif1_present,motif2_present,motif3_present,motif4_present,motif5_present=generate_real_dataset()\n",
    "    #G, role_id, label,edge_index=generate_false_cause_dataset2(motif_generators, motif_connectors)\n",
    "    #与标签无关系的图数据\n",
    "    a = random.randint(2, 4)\n",
    "    numbers = random.sample(range(6,11),1)\n",
    "    if(6 <=numbers[0]<= 10):\n",
    "        motifr1,role_idr1= motif_generators[numbers[0]][0](10,a,[1.5,2.0,1.2,1.3,1.8], [1.5,2.0,1.2,1.3,1.8])\n",
    "    else:\n",
    "        motifr1,role_idr2= motif_generators[numbers[0]][0](10,[1.5,2.0,1.2,1.3,1.8], [1.5,2.0,1.2,1.3,1.8])\n",
    "    graph,role_id1,edge_index= motif_connectors[1][0](G, motifr1)\n",
    "    #G_noisy, role_id_noisy, label_noisy=add_noise(graph,0, 0.1, 0, 0.1,label)\n",
    "    return  graph,role_id1, label,edge_index\n",
    "'''\n",
    "数据集中还要额外包含添加随机噪声数据的功能：\n",
    "1.随机删除或者创建固定数量边(百分比)\n",
    "2.随机删除一定数量的节点（百分比）\n",
    "3.随机创建一定数量的节点，并且将这些节点与已经存在的图随机相连（百分比）\n",
    "'''\n",
    "def add_noise(G, delete_edge_prob, add_edge_prob, delete_node_prob, add_node_prob,label=None):\n",
    "    # create a copy of original graph to add noise\n",
    "    G_noisy = copy.deepcopy(G)\n",
    "    # check if graph is connected\n",
    "    '''\n",
    "    if not nx.is_connected(G_noisy):\n",
    "        # add edges to make it connected\n",
    "        for c in nx.connected_components(G_noisy):\n",
    "            node1, node2 = random.sample(c, 2)\n",
    "            G_noisy.add_edge(node1, node2)\n",
    "    '''\n",
    "    # randomly delete edges\n",
    "    num_edges_to_delete = int(delete_edge_prob * G_noisy.number_of_edges())\n",
    "    edges_to_delete = random.sample(G_noisy.edges(), num_edges_to_delete)\n",
    "    G_noisy.remove_edges_from(edges_to_delete)\n",
    "    role_id_noisy = [random.randint(0, 2) for i in range(G_noisy.number_of_nodes())]\n",
    "    # randomly add edges\n",
    "    num_edges_to_add = int(add_edge_prob * G_noisy.number_of_nodes() * (G_noisy.number_of_nodes()-1)/2)\n",
    "    for i in range(num_edges_to_add):\n",
    "        node1, node2 = random.sample(G_noisy.nodes(), 2)\n",
    "        if not G_noisy.has_edge(node1, node2):\n",
    "            G_noisy.add_edge(node1, node2)\n",
    "    # randomly delete nodes\n",
    "    num_nodes_to_delete = int(delete_node_prob * G_noisy.number_of_nodes())\n",
    "    nodes_to_delete = random.sample(G_noisy.nodes(), num_nodes_to_delete)\n",
    "    for node in nodes_to_delete:\n",
    "        G_noisy.remove_node(node)\n",
    "    # randomly add nodes\n",
    "    num_nodes_to_add = int(add_node_prob * G_noisy.number_of_nodes())\n",
    "    for i in range(num_nodes_to_add):\n",
    "        node_id = G_noisy.number_of_nodes() + 1\n",
    "        G_noisy.add_node(node_id)\n",
    "        # randomly connect new node to existing nodes\n",
    "        connected = False\n",
    "        while not connected:\n",
    "            nodes_to_connect = random.sample(G_noisy.nodes(), random.randint(1, G_noisy.number_of_nodes()-1))\n",
    "            for n in nodes_to_connect:\n",
    "                if not G_noisy.has_edge(node_id, n):\n",
    "                    G_noisy.add_edge(node_id, n)\n",
    "            connected = nx.is_connected(G_noisy)\n",
    "            if not connected:\n",
    "                for n in nodes_to_connect:\n",
    "                    G_noisy.remove_edge(node_id, n)\n",
    "    # set new role_id and label for noisy graph\n",
    "    role_id_noisy = [random.randint(0, 2) for i in range(G_noisy.number_of_edges())]\n",
    "    label_noisy = label\n",
    "    # return noisy graph and corresponding role_id and label\n",
    "    return G_noisy, role_id_noisy, label_noisy\n",
    "def generate_graph_dataset_with_noise(motif_generators, motif_connectors, num_samples, delete_edge_prob, add_edge_prob, delete_node_prob, add_node_prob):\n",
    "    dataset = []\n",
    "    for i in range(num_samples):\n",
    "        G, role_id, label = generate_graph_dataset(motif_generators, motif_connectors)\n",
    "        G_noisy, role_id_noisy, label_noisy = add_noise(G, role_id, label, delete_edge_prob, add_edge_prob, delete_node_prob, add_node_prob)\n",
    "        dataset.append((G_noisy, role_id_noisy, label_noisy))\n",
    "    return G_noisy, role_id_noisy, label_noisy\n"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "id": "fde7c855",
   "metadata": {},
   "source": [
    "## Training Dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "5f5dc82f",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "  0%|          | 0/1000 [00:00<?, ?it/s]"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 1000/1000 [00:09<00:00, 107.81it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "#Graphs: 1000    #Nodes: 17.45    #Edges: 32.69 \n"
     ]
    }
   ],
   "source": [
    "edge_index_list, label_list = [], []\n",
    "ground_truth_list, role_id_list, pos_list = [], [], []\n",
    "e_mean, n_mean = [], []\n",
    "for _ in tqdm(range(1000)):\n",
    "    G,role_id,label,edge_index=generate_false_cause_dataset2(motif_generators, motif_connectors)\n",
    "    #显示图形\n",
    "    label_list.append(label)#在列表末尾添加元素\n",
    "    e_mean.append(len(G.edges))\n",
    "    n_mean.append(len(G.nodes))\n",
    "    role_id = np.array(role_id)\n",
    "    #edge_index = np.array(G.edges, dtype=np.int).T\n",
    "    edge_index=edge_index\n",
    "    role_id_list.append(role_id)\n",
    "    edge_index_list.append(edge_index)\n",
    "    pos_list.append(np.array(list(nx.spring_layout(G).values())))\n",
    "    ground_truth_list.append(find_gd(edge_index, role_id))\n",
    "print(\"#Graphs: %d    #Nodes: %.2f    #Edges: %.2f \" % (len(ground_truth_list), np.mean(n_mean), np.mean(e_mean)))\n",
    "np.save(osp.join(data_dir, 'train.npy'), {'edge_index': edge_index_list, 'label': label_list, 'ground_truth': ground_truth_list, 'role_id': role_id_list, 'pos': pos_list})"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "id": "c3587610",
   "metadata": {},
   "source": [
    "## Val Dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "ebe75786",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 1000/1000 [00:09<00:00, 108.08it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "#Graphs: 1000    #Nodes: 17.52    #Edges: 32.28 \n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    }
   ],
   "source": [
    "edge_index_list, label_list = [], []\n",
    "ground_truth_list, role_id_list, pos_list = [], [], []\n",
    "e_mean, n_mean = [], []\n",
    "for _ in tqdm(range(1000)):\n",
    "    G,role_id,label,edge_index=generate_false_cause_dataset2(motif_generators, motif_connectors)\n",
    "    label_list.append(label)\n",
    "    e_mean.append(len(G.edges))\n",
    "    n_mean.append(len(G.nodes))\n",
    "    role_id = np.array(role_id)\n",
    "    #edge_index = np.array(G.edges, dtype=np.int).T\n",
    "    edge_index=edge_index\n",
    "    role_id_list.append(role_id)\n",
    "    edge_index_list.append(edge_index)\n",
    "    pos_list.append(np.array(list(nx.spring_layout(G).values())))\n",
    "    ground_truth_list.append(find_gd(edge_index, role_id))\n",
    "print(\"#Graphs: %d    #Nodes: %.2f    #Edges: %.2f \" % (len(ground_truth_list), np.mean(n_mean), np.mean(e_mean)))\n",
    "np.save(osp.join(data_dir, 'val.npy'), {'edge_index': edge_index_list, 'label': label_list, 'ground_truth': ground_truth_list, 'role_id': role_id_list, 'pos': pos_list})"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "id": "9d858281",
   "metadata": {},
   "source": [
    "## Testing Dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "1aff6470",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 2000/2000 [00:24<00:00, 83.04it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "#Graphs: 2000    #Nodes: 23.36    #Edges: 90.51 \n"
     ]
    }
   ],
   "source": [
    "# no bias for test dataset\n",
    "edge_index_list, label_list = [], []\n",
    "ground_truth_list, role_id_list, pos_list = [], [], []\n",
    "e_mean, n_mean = [], []\n",
    "for _ in tqdm(range(2000)):\n",
    "    G,role_id,label,edge_index=generate_false_dataset(motif_generators, motif_connectors)\n",
    "    label_list.append(label)\n",
    "    e_mean.append(len(G.edges))\n",
    "    n_mean.append(len(G.nodes))\n",
    "    role_id = np.array(role_id)\n",
    "    #edge_index = np.array(G.edges, dtype=np.int).T\n",
    "    edge_index=edge_index\n",
    "    role_id_list.append(role_id)\n",
    "    edge_index_list.append(edge_index)\n",
    "    pos_list.append(np.array(list(nx.spring_layout(G).values())))\n",
    "    ground_truth_list.append(find_gd(edge_index, role_id))\n",
    "print(\"#Graphs: %d    #Nodes: %.2f    #Edges: %.2f \" % (len(ground_truth_list), np.mean(n_mean), np.mean(e_mean)))\n",
    "np.save(osp.join(data_dir, 'test.npy'), {'edge_index': edge_index_list, 'label': label_list, 'ground_truth': ground_truth_list, 'role_id': role_id_list, 'pos': pos_list})"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3.9.7 ('base': conda)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.5"
  },
  "vscode": {
   "interpreter": {
    "hash": "c71b0b87ea436ae79e2503ec051639fc2420e91bd742cb356b7debceb9d5ed19"
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
