{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Generate Spurious-Motif Dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "from BA3_loc import *\n",
    "from tqdm import tqdm\n",
    "import os.path as osp\n",
    "import warnings\n",
    "warnings.filterwarnings(\"ignore\")\n",
    "import random\n",
    "global_b = '0.9' # Set bias degree here\n",
    "data_dir = f'../data/SPMotif-{global_b}/raw/'\n",
    "os.makedirs(data_dir, exist_ok=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "#def create_motif_type_star(size=100, branches=10, node_feature_mean=[1.5,2.0,1.2,1.3,1.8], std = [1.5,2.0,1.2,1.3,1.8]):\n",
    "def create_motif_type_branch(size, branches, node_feature_mean, std,motif_type):\n",
    "    # 添加不同形状的节点\n",
    "    #星形\n",
    "    if motif_type == 'star':\n",
    "       G = nx.Graph() # 创建一个空图\n",
    "       role_id = []\n",
    "    # 添加节点\n",
    "       for i in range(size):\n",
    "          if i == 0:\n",
    "            features = np.random.normal(loc=node_feature_mean, scale=std)\n",
    "            G.add_node(i, features=features, is_center=True)\n",
    "            role_id.append(0)\n",
    "          else:\n",
    "            features = np.random.normal(loc=node_feature_mean, scale=std)\n",
    "            G.add_node(i, features=features, is_center=False)\n",
    "            role_id.append(1)\n",
    "    # 添加边\n",
    "       for i in range(1, size):\n",
    "              G.add_edge(0, i)\n",
    "\n",
    "    #路径形状\n",
    "    elif motif_type == 'path':\n",
    "        G = nx.Graph()        # 创建一个空图\n",
    "        role_id = []\n",
    "        # 添加路径节点，并为其添加特征\n",
    "        for i in range(size):\n",
    "            node_id = i\n",
    "            node_features = np.random.normal(node_feature_mean, std)\n",
    "            G.add_node(node_id, features=node_features.tolist())\n",
    "            role_id.append(1)\n",
    "            if i > 0:\n",
    "                G.add_edge(i-1, i)\n",
    "    #扇形\n",
    "    elif motif_type=='fan':\n",
    "        G = nx.Graph()        # 创建一个空图\n",
    "        role_id = []\n",
    "    # 添加节点\n",
    "        for i in range(size):\n",
    "          if i == 0:\n",
    "           features = np.random.normal(loc=node_feature_mean, scale=std)\n",
    "           G.add_node(i, features=features, is_center=True)\n",
    "           role_id.append(0)\n",
    "          else:\n",
    "           features = np.random.normal(loc=node_feature_mean, scale=std)\n",
    "           G.add_node(i, features=features, is_center=False)\n",
    "           role_id.append(1)\n",
    "    # 添加边\n",
    "        for i in range(1, size):\n",
    "          G.add_edge(0, i)\n",
    "        for i in range(1, branches):\n",
    "          for j in range(1, size // branches):\n",
    "           G.add_edge(i, j * branches + i)\n",
    "       #尖角多边形：cuspedPolygon\n",
    "    elif motif_type=='cuspedPolygon':\n",
    "        assert size % 2 == 0, \"Size of motif should be even\"\n",
    "        assert size >= branches * 2, \"Size of motif should be at least twice the number of branches\"\n",
    "        G = nx.Graph()        # 创建一个空图\n",
    "        # Add nodes\n",
    "        nodes = range(size)\n",
    "        features = np.random.normal(loc=node_feature_mean, scale=std)\n",
    "        G.add_nodes_from(nodes)\n",
    "    # Add edges\n",
    "        for i in range(size // 2):\n",
    "           node_a = i\n",
    "           node_b = (i + size // 2) % size\n",
    "        for b in range(branches):\n",
    "            next_node = (i * branches + b + 1) % (size // 2) + size // 2\n",
    "            G.add_edge(node_a, next_node)\n",
    "            G.add_edge(node_b, next_node)\n",
    "            G.add_edge(size // 2 - 1, size // 2)\n",
    "    #随机二分图：一个包含两个等大部分的随机二分图\n",
    "    elif motif_type=='random_bipartite':\n",
    "        G = nx.Graph()        # 创建一个空图\n",
    "    # define the node attributes\n",
    "        nodes = []\n",
    "        for i in range(size):\n",
    "           nodes.append((i, {'feat': np.random.normal(node_feature_mean, std)}))\n",
    "        G.add_nodes_from(nodes)\n",
    "    # create the bipartite structure\n",
    "        num_nodes_per_part = size // 2\n",
    "        part1 = list(range(num_nodes_per_part))\n",
    "        part2 = list(range(num_nodes_per_part, size))\n",
    "        for i in part1:\n",
    "           for j in part2:\n",
    "              if np.random.rand() < 0.5:\n",
    "                  G.add_edge(i, j)\n",
    "    return G,role_id\n",
    "def create_motif_type(size,node_feature_mean,std,motif_type):\n",
    "   if motif_type == 'completeGraph':\n",
    "    G = nx.complete_graph(size)\n",
    "    for i in range(size):\n",
    "        G.nodes[i][\"feature\"] = np.random.normal(node_feature_mean[i % len(node_feature_mean)], std[i % len(std)])\n",
    "   elif motif_type == 'netShape':\n",
    "    G = nx.Graph()\n",
    "    role_id = []\n",
    "    n = int(np.sqrt(size))\n",
    "    for i in range(n):\n",
    "        for j in range(n):\n",
    "            node_idx = i * n + j\n",
    "            node_feature = np.random.normal(node_feature_mean, std)\n",
    "            G.add_node(node_idx, feature=node_feature)\n",
    "            role_id.append(i)\n",
    "            if i > 0:\n",
    "                upper_node_idx = (i - 1) * n + j\n",
    "                G.add_edge(node_idx, upper_node_idx)\n",
    "            if j > 0:\n",
    "                left_node_idx = i * n + (j - 1)\n",
    "                G.add_edge(node_idx, left_node_idx)\n",
    "   elif motif_type == 'dircycle':\n",
    "    G = nx.Graph()\n",
    "    role_id = []\n",
    "    for i in range(size):\n",
    "        features = np.random.normal(loc=node_feature_mean, scale=std)\n",
    "        G.add_node(i, features=features)\n",
    "        role_id.append(i)\n",
    "    for i in range(size):\n",
    "        G.add_edge(i, (i + 1) % size)\n",
    "   elif motif_type == 'dualRing':\n",
    "    G = nx.Graph()\n",
    "    role_id = []\n",
    "    n1 = size//2\n",
    "    n2 = size-n1\n",
    "    G = nx.empty_graph(n=n1+n2)\n",
    "    for i in range(size):\n",
    "        features = np.random.normal(loc=node_feature_mean, scale=std)\n",
    "        G.add_node(i, features=features)\n",
    "        role_id.append(i)\n",
    "    G.add_edges_from([(i,(i+1)%n1) for i in range(n1)] + [(i+n1,(i+1)%n2+n1) for i in range(n2)] + [(i,i+n1) for i in range(n1)])\n",
    "   elif motif_type=='triangle':\n",
    "    # 创建一个空图\n",
    "    G = nx.Graph()\n",
    "    role_id = []\n",
    "    n = int(np.sqrt(2 * size))\n",
    "    for i in range(n):\n",
    "        for j in range(i + 1):\n",
    "            node_idx = (i * (i + 1) // 2) + j\n",
    "            node_feature = np.random.normal(node_feature_mean, std)\n",
    "            G.add_node(node_idx, feature=node_feature)\n",
    "            role_id.append(i)\n",
    "            if i > 0:\n",
    "                upper_node_idx = (i * (i - 1) // 2) + j\n",
    "                G.add_edge(node_idx, upper_node_idx)\n",
    "                if j > 0:\n",
    "                    left_upper_node_idx = upper_node_idx - 1\n",
    "                    G.add_edge(node_idx, left_upper_node_idx)\n",
    "                if j < i:\n",
    "                    right_upper_node_idx = upper_node_idx\n",
    "                    G.add_edge(node_idx, right_upper_node_idx)\n",
    "            if j > 0:\n",
    "                left_node_idx = (i * (i + 1) // 2) + (j - 1)\n",
    "                G.add_edge(node_idx, left_node_idx)\n",
    "   return G,role_id\n",
    "#包含，两个motif中顶点较少的完全被另一个所包含\n",
    "def include_smaller_graph(G1, G2):\n",
    "    \"\"\"\n",
    "    将顶点较少的图 G1 完全包含在顶点较多的图 G2 中。\n",
    "\n",
    "    Args:\n",
    "        G1: networkx.Graph, 第一个图。\n",
    "        G2: networkx.Graph, 第二个图。\n",
    "\n",
    "    Returns:\n",
    "        G: networkx.Graph, 新的图，包含 G1 和 G2 中的所有节点和边。\n",
    "    \"\"\"\n",
    "    if len(G1.nodes()) > len(G2.nodes()):\n",
    "        G1, G2 = G2, G1  # 确保 G1 是顶点较少的图\n",
    "    # 在 G2 中添加 G1 中的节点\n",
    "    G = nx.Graph()\n",
    "    G.add_nodes_from(G2.nodes())\n",
    "    G.add_nodes_from(G1.nodes())\n",
    "\n",
    "    # 将 G1 中的边添加到 G 中，同时保证 G1 的所有节点都在 G2 中\n",
    "    for u, v in G1.edges():\n",
    "        if u in G2.nodes() and v in G2.nodes():\n",
    "            G.add_edge(u, v)\n",
    "    # 添加 G2 中除 G1 外的边\n",
    "    for u, v in G2.edges():\n",
    "        if u in G.nodes() and v in G.nodes():\n",
    "            G.add_edge(u, v)\n",
    "    # 生成与节点数相同的 role_id 列表\n",
    "    role_id = [i for i in range(G.number_of_nodes())]\n",
    "    return G,role_id\n",
    "#纠缠，即两个motif通过多条边连接在一起，边数可配置\n",
    "def multiple_edges_connection(G1, G2, common_edges):\n",
    "    # 随机选择一些边作为公用边\n",
    "    common_edges1 = random.sample(G1.edges(), common_edges)\n",
    "    common_edges2 = random.sample(G2.edges(), common_edges)\n",
    "    # 将公用边从第二个图中删除\n",
    "    G2.remove_edges_from(common_edges2)\n",
    "    # 创建新的图并将两个图连接起来\n",
    "    G = nx.disjoint_union(G1, G2)\n",
    "    for u, v in common_edges1:\n",
    "        G.add_edge(u, v + G1.number_of_nodes())\n",
    "    # 为孤立节点添加随机边\n",
    "    isolated_nodes = [n for n in G.nodes() if G.degree(n) == 0]\n",
    "    for u in isolated_nodes:\n",
    "        v = random.choice(list(G.nodes()))\n",
    "        while G.has_edge(u, v) or u == v:\n",
    "            v = random.choice(list(G.nodes()))\n",
    "        G.add_edge(u, v)\n",
    "    # 生成与节点数相同的 role_id 列表\n",
    "    role_id = [i for i in range(G.number_of_nodes())]\n",
    "    # 删除节点编号大于图的总节点数的节点\n",
    "    max_node_id = G.number_of_nodes() - 1\n",
    "    G.remove_nodes_from([n for n in G.nodes() if n > max_node_id])\n",
    "    return G, role_id\n",
    "#交叉，两个motif间公用一些顶点，公用顶点数可配置\n",
    "def multiple_nodes_connection(G1, G2,common_nodes):\n",
    "# 随机选择一些节点作为公用节点\n",
    "  common_nodes1 = set(random.sample(G1.nodes(), common_nodes))\n",
    "  common_nodes2 = set(random.sample(G2.nodes(), common_nodes))\n",
    "# 将公用节点从第二个图中删除\n",
    "  G2.remove_nodes_from(common_nodes2)\n",
    "# 创建新的图并将两个图连接起来\n",
    "  G = nx.disjoint_union(G1, G2)\n",
    "  while G.number_of_nodes() > G1.number_of_nodes() + G2.number_of_nodes():\n",
    "      G.remove_node(random.choice(list(G.nodes())))\n",
    "  for node in common_nodes1:\n",
    "      if node + G1.number_of_nodes() < G.number_of_nodes():\n",
    "        G.add_edge(node, node + G1.number_of_nodes())\n",
    "   # 生成与节点数相同的 role_id 列表\n",
    "  role_id = [i for i in range(G.number_of_nodes())]\n",
    "  return G, role_id\n",
    "#相邻，即两个motif通过一条边连接在一起\n",
    "def adjacent_connection(G1,G2):\n",
    "    nodes1 = set(G1.nodes())\n",
    "    nodes2 = set(G2.nodes())\n",
    "    common_nodes = nodes1.intersection(nodes2)\n",
    "    # 随机选择两个图中的一条边作为相邻拼接的公共边\n",
    "    edge1 = random.choice(list(G1.edges()))\n",
    "    edge2 = random.choice(list(G2.edges()))\n",
    "    # 为相邻拼接的公共边创建新节点\n",
    "    new_node1 = max(nodes1.union(nodes2)) + 1\n",
    "    new_node2 = max(nodes1.union(nodes2)) + 2\n",
    "    # 在两个图中替换相邻拼接的公共边\n",
    "    G1.remove_edge(*edge1)\n",
    "    G1.add_edge(edge1[0], new_node1)\n",
    "    G1.add_edge(new_node1, edge1[1])\n",
    "    G2.remove_edge(*edge2)\n",
    "    G2.add_edge(edge2[0], new_node2)\n",
    "    G2.add_edge(new_node2, edge2[1])\n",
    "    # 在两个图中加入新节点\n",
    "    G1.add_node(new_node2)\n",
    "    G2.add_node(new_node1)\n",
    "    # 合并两个图\n",
    "    new_G = nx.compose(G1, G2)\n",
    "    # 添加相邻拼接的边\n",
    "    new_G.add_edge(new_node1, new_node2)\n",
    "    # 确保在新图中添加两个原图中所有节点，包括共享的节点\n",
    "    for node in common_nodes:\n",
    "        new_G.add_node(node, role_id=np.random.randint(low=1, high=len(common_nodes) + 3))\n",
    "    # 生成与节点数相同的 role_id 列表\n",
    "    role_id = [i for i in range(new_G.number_of_nodes())]\n",
    "    return new_G, role_id\n",
    "def get_diamond(basis_type, nb_shapes=80, width_basis=8, feature_generator=None, m=3, draw=True):\n",
    "    \"\"\" Synthetic Graph:\n",
    "    Start with a tree and attach DIAMOND-shaped subgraphs.\n",
    "    \"\"\"\n",
    "    list_shapes = [[\"diamond\"]] * nb_shapes #diamond\n",
    "    if draw:\n",
    "        plt.figure(figsize=figsize)\n",
    "    G, role_id, _ = synthetic_structsim.build_graph(\n",
    "        width_basis, basis_type, list_shapes, start=0, rdm_basis_plugins=True\n",
    "    )\n",
    "    G = perturb([G], 0.00, id=role_id)[0]#扰动函数\n",
    "    #特征生成器\n",
    "    if feature_generator is None:\n",
    "        feature_generator = featgen.ConstFeatureGen(1)\n",
    "    feature_generator.gen_node_features(G)\n",
    "    name = basis_type + \"_\" + str(width_basis) + \"_\" + str(nb_shapes)\n",
    "    return G, role_id, name\n",
    "def get_crossgrid(basis_type, nb_shapes=80, width_basis=8, feature_generator=None, m=3, draw=True):\n",
    "    \"\"\" Synthetic Graph:\n",
    "    Start with a tree and attach CROSSGRID-shaped subgraphs.\n",
    "    \"\"\"\n",
    "    list_shapes = [[\"crossgrid\"]] * nb_shapes # crossgrid\n",
    "    if draw:\n",
    "        plt.figure(figsize=figsize)\n",
    "    G, role_id, _ = synthetic_structsim.build_graph(\n",
    "        width_basis, basis_type, list_shapes, start=0, rdm_basis_plugins=True\n",
    "    )\n",
    "    G = perturb([G], 0.00, id=role_id)[0]#扰动函数\n",
    "    #特征生成器\n",
    "    if feature_generator is None:\n",
    "        feature_generator = featgen.ConstFeatureGen(1)#？？什么意思\n",
    "    feature_generator.gen_node_features(G)\n",
    "    name = basis_type + \"_\" + str(width_basis) + \"_\" + str(nb_shapes)\n",
    "    return G, role_id, name\n",
    "def get_dircycle(basis_type, nb_shapes=80, width_basis=8, feature_generator=None, m=3, draw=True):\n",
    "    \"\"\" Synthetic Graph:\n",
    "    Start with a tree and attach DIRCYCLE-shaped subgraphs.\n",
    "    \"\"\"\n",
    "    list_shapes = [[\"dircycle\"]] * nb_shapes # dircycle\n",
    "    if draw:\n",
    "        plt.figure(figsize=figsize)\n",
    "    G, role_id, _ = synthetic_structsim.build_graph(\n",
    "        width_basis, basis_type, list_shapes, start=0, rdm_basis_plugins=True\n",
    "    )\n",
    "    G = perturb([G], 0.00, id=role_id)[0]#扰动函数\n",
    "\n",
    "    #特征生成器\n",
    "    if feature_generator is None:\n",
    "        feature_generator = featgen.ConstFeatureGen(1)\n",
    "    feature_generator.gen_node_features(G)\n",
    "    name = basis_type + \"_\" + str(width_basis) + \"_\" + str(nb_shapes)\n",
    "    return G, role_id, name\n",
    "def get_varcycle(basis_type, nb_shapes=80, width_basis=8, feature_generator=None, m=3, draw=True):\n",
    "    \"\"\" Synthetic Graph:\n",
    "    Start with a tree and attach VARCYCLE-shaped subgraphs.\n",
    "    \"\"\"\n",
    "    list_shapes = [[\"varcycle\"]] * nb_shapes # varcycle\n",
    "    if draw:\n",
    "        plt.figure(figsize=figsize)\n",
    "    G, role_id, _ = synthetic_structsim.build_graph(\n",
    "        width_basis, basis_type, list_shapes, start=0, rdm_basis_plugins=True\n",
    "    )\n",
    "    G = perturb([G], 0.00, id=role_id)[0]#扰动函数\n",
    "\n",
    "    #特征生成器\n",
    "    if feature_generator is None:\n",
    "        feature_generator = featgen.ConstFeatureGen(1)\n",
    "    feature_generator.gen_node_features(G)\n",
    "    name = basis_type + \"_\" + str(width_basis) + \"_\" + str(nb_shapes)\n",
    "    return G, role_id, name\n",
    "def get_house(basis_type, nb_shapes=80, width_basis=8, feature_generator=None, m=3, draw=True):\n",
    "    \"\"\" Synthetic Graph:\n",
    "    Start with a tree and attach HOUSE-shaped subgraphs.\n",
    "    \"\"\"\n",
    "    list_shapes = [[\"house\"]] * nb_shapes # house\n",
    "    if draw:\n",
    "        plt.figure(figsize=figsize)\n",
    "    G, role_id, _ = synthetic_structsim.build_graph(\n",
    "        width_basis, basis_type, list_shapes, start=0, rdm_basis_plugins=True\n",
    "    )\n",
    "    G = perturb([G], 0.00, id=role_id)[0]#扰动函数\n",
    "    if feature_generator is None:\n",
    "        feature_generator = featgen.ConstFeatureGen(1)#？？什么意思\n",
    "    feature_generator.gen_node_features(G)\n",
    "    name = basis_type + \"_\" + str(width_basis) + \"_\" + str(nb_shapes)\n",
    "    return G, role_id, name\n",
    "def get_cycle(basis_type, nb_shapes=80, width_basis=8, feature_generator=None, m=3, draw=True):\n",
    "    \"\"\" Synthetic Graph:\n",
    "    Start with a tree and attach cycle-shaped (directed edges) subgraphs.\n",
    "    \"\"\"\n",
    "    list_shapes = [[\"dircycle\"]] * nb_shapes\n",
    "\n",
    "    if draw:\n",
    "        plt.figure(figsize=figsize)\n",
    "    G, role_id, _ = synthetic_structsim.build_graph(\n",
    "        width_basis, basis_type, list_shapes, start=0, rdm_basis_plugins=True\n",
    "    )\n",
    "    G = perturb([G], 0.00, id=role_id)[0]       # 0.05 original\n",
    "    if feature_generator is None:\n",
    "        feature_generator = featgen.ConstFeatureGen(1)\n",
    "    feature_generator.gen_node_features(G)\n",
    "    name = basis_type + \"_\" + str(width_basis) + \"_\" + str(nb_shapes)\n",
    "    return G, role_id, name\n",
    "def get_crane(basis_type, nb_shapes=80, width_basis=8, feature_generator=None, m=3, draw=True):\n",
    "    \"\"\" Synthetic Graph:\n",
    "\n",
    "    Start with a tree and attach crane-shaped subgraphs.\n",
    "    \"\"\"\n",
    "    list_shapes = [[\"varcycle\"]] * nb_shapes   # crane\n",
    "    if draw:\n",
    "        plt.figure(figsize=figsize)\n",
    "\n",
    "    G, role_id, _ = synthetic_structsim.build_graph(\n",
    "        width_basis, basis_type, list_shapes, start=0, rdm_basis_plugins=True\n",
    "    )\n",
    "    G = perturb([G], 0.00, id=role_id)[0]\n",
    "\n",
    "    if feature_generator is None:\n",
    "        feature_generator = featgen.ConstFeatureGen(1)\n",
    "    feature_generator.gen_node_features(G)\n",
    "\n",
    "    name = basis_type + \"_\" + str(width_basis) + \"_\" + str(nb_shapes)\n",
    "\n",
    "    return G, role_id, name\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Training Dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|█████████████████████████████████████████████████████████████████████████████| 1000/1000 [00:09<00:00, 106.50it/s]\n",
      "  1%|▊                                                                              | 11/1000 [00:00<00:09, 100.61it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "#Graphs: 1000    #Nodes: 19.00    #Edges: 20.94 \n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|█████████████████████████████████████████████████████████████████████████████| 1000/1000 [00:08<00:00, 112.57it/s]\n",
      "  2%|█▋                                                                             | 21/1000 [00:00<00:05, 190.08it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "#Graphs: 2000    #Nodes: 22.49    #Edges: 29.85 \n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|█████████████████████████████████████████████████████████████████████████████| 1000/1000 [00:07<00:00, 135.35it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "#Graphs: 3000    #Nodes: 12.00    #Edges: 26.48 \n"
     ]
    }
   ],
   "source": [
    "edge_index_list, label_list = [], []\n",
    "ground_truth_list, role_id_list, pos_list = [], [], []\n",
    "bias = float(global_b)\n",
    "\n",
    "def graph_stats(base_num):\n",
    "    if base_num == 1:\n",
    "        base = 'tree'\n",
    "        width_basis=np.random.choice(range(3))\n",
    "    if base_num == 2:\n",
    "        base = 'ladder'\n",
    "        width_basis=np.random.choice(range(8,12))\n",
    "    if base_num == 3:\n",
    "        base = 'wheel'\n",
    "        width_basis=np.random.choice(range(15,20))\n",
    "    return base, width_basis\n",
    "\n",
    "e_mean, n_mean = [], []\n",
    "for _ in tqdm(range(1000)):\n",
    "    base_num = np.random.choice([1,2,3], p=[bias,(1-bias)/2,(1-bias)/2])\n",
    "    base, width_basis = graph_stats(base_num)\n",
    "    #G, role_id, name = get_cycle(basis_type=base, nb_shapes=1, \n",
    "    #                                width_basis=width_basis, feature_generator=None, m=3, draw=False)\n",
    "    G1,role_id1=create_motif_type_branch(size=10,branches=4, node_feature_mean=[1.5,2.0,1.2,1.3,1.8], std = [1.5,2.0,1.2,1.3,1.8],motif_type='star')\n",
    "    G2,role_id2=create_motif_type_branch(size=8,branches=4, node_feature_mean=[1,2,1,1,1], std = [1,2,1,1,1],motif_type='star')\n",
    "    G3,role_id3=adjacent_connection(G1,G2)\n",
    "    G4,role_id4=create_motif_type_branch(size=10,branches=4, node_feature_mean=[1.5,2.0,1.2,1.3,1.8], std = [1.5,2.0,1.2,1.3,1.8],motif_type='fan')\n",
    "    G,role_id=multiple_nodes_connection(G3,G4,3)\n",
    "    label_list.append(0)#在列表末尾添加元素\n",
    "    e_mean.append(len(G.edges))\n",
    "    n_mean.append(len(G.nodes))\n",
    "    role_id = np.array(role_id)\n",
    "    edge_index = np.array(G.edges, dtype=np.int).T\n",
    "    role_id_list.append(role_id)\n",
    "    edge_index_list.append(edge_index)\n",
    "    pos_list.append(np.array(list(nx.spring_layout(G).values())))\n",
    "    ground_truth_list.append(find_gd(edge_index, role_id))\n",
    "print(\"#Graphs: %d    #Nodes: %.2f    #Edges: %.2f \" % (len(ground_truth_list), np.mean(n_mean), np.mean(e_mean)))\n",
    "\n",
    "e_mean, n_mean = [], []\n",
    "for _ in tqdm(range(1000)):\n",
    "    base_num = np.random.choice([1,2,3], p=[(1-bias)/2,bias,(1-bias)/2])\n",
    "    base, width_basis = graph_stats(base_num)\n",
    "\n",
    "    #G, role_id, name = get_house(basis_type=base, nb_shapes=1, \n",
    "    #                                width_basis=width_basis, feature_generator=None, m=3, draw=False)\n",
    "    G1,role_id1=create_motif_type(size=10,node_feature_mean=[1.5,2.0,1.2,1.3,1.8], std = [1.5,2.0,1.2,1.3,1.8],motif_type='dircycle')\n",
    "    G2,role_id2=create_motif_type(size=6,node_feature_mean=[1.5,2.0,1.2,1.3,1.8], std = [1.5,2.0,1.2,1.3,1.8],motif_type='dircycle')\n",
    "    G3,role_id3=adjacent_connection(G1,G2)\n",
    "    G4,role_id4=create_motif_type(size=10,node_feature_mean=[1.5,2.0,1.2,1.3,1.8], std = [1.5,2.0,1.2,1.3,1.8],motif_type='dualRing')\n",
    "    G,role_id=multiple_edges_connection(G3,G4,3)\n",
    "    label_list.append(1)\n",
    "    e_mean.append(len(G.edges))\n",
    "    n_mean.append(len(G.nodes))\n",
    "    role_id = np.array(role_id)\n",
    "    edge_index = np.array(G.edges, dtype=np.int).T\n",
    "    role_id_list.append(role_id)\n",
    "    edge_index_list.append(edge_index)\n",
    "    pos_list.append(np.array(list(nx.spring_layout(G).values())))\n",
    "    ground_truth_list.append(find_gd(edge_index, role_id))\n",
    "print(\"#Graphs: %d    #Nodes: %.2f    #Edges: %.2f \" % (len(ground_truth_list), np.mean(n_mean), np.mean(e_mean)))\n",
    "e_mean, n_mean = [], []\n",
    "for _ in tqdm(range(1000)):\n",
    "    base_num = np.random.choice([1,2,3], p=[(1-bias)/2,(1-bias)/2,bias])\n",
    "    base, width_basis = graph_stats(base_num)\n",
    "    #G, role_id, name = get_crane(basis_type=base, nb_shapes=1, \n",
    "    #                                width_basis=width_basis, feature_generator=None, m=3, draw=False)\n",
    "    G1,role_id1=create_motif_type(size=10,node_feature_mean=[1.5,2.0,1.2,1.3,1.8], std = [1.5,2.0,1.2,1.3,1.8],motif_type='netShape')\n",
    "    G2,role_id2=create_motif_type(size=6,node_feature_mean=[1.5,2.0,1.2,1.3,1.8], std = [1.5,2.0,1.2,1.3,1.8],motif_type='netShape')\n",
    "    G3,role_id3=include_smaller_graph(G1,G2)\n",
    "    G4,role_id4=create_motif_type(size=10,node_feature_mean=[1.5,2.0,1.2,1.3,1.8], std = [1.5,2.0,1.2,1.3,1.8],motif_type='triangle')\n",
    "    G,role_id=adjacent_connection(G3,G4)\n",
    "    #nx.draw_networkx(G)\n",
    "    #plt.show()\n",
    "    label_list.append(2)\n",
    "    e_mean.append(len(G.edges))\n",
    "    n_mean.append(len(G.nodes))\n",
    "    role_id = np.array(role_id)\n",
    "    edge_index = np.array(G.edges, dtype=np.int).T\n",
    "    role_id_list.append(role_id)\n",
    "    edge_index_list.append(edge_index)\n",
    "    pos_list.append(np.array(list(nx.spring_layout(G).values())))\n",
    "    ground_truth_list.append(find_gd(edge_index, role_id))\n",
    "\n",
    "print(\"#Graphs: %d    #Nodes: %.2f    #Edges: %.2f \" % (len(ground_truth_list), np.mean(n_mean), np.mean(e_mean)))\n",
    "np.save(osp.join(data_dir, 'train.npy'), (edge_index_list, label_list, ground_truth_list, role_id_list, pos_list))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Val Dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|█████████████████████████████████████████████████████████████████████████████| 1000/1000 [00:07<00:00, 138.27it/s]\n",
      "  1%|█                                                                              | 13/1000 [00:00<00:08, 116.17it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "#Graphs: 1000    #Nodes: 19.00    #Edges: 20.99 \n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████████████████████████████████████████████████████████████████████████| 1000/1000 [00:14<00:00, 71.20it/s]\n",
      "  1%|▋                                                                                | 8/1000 [00:00<00:13, 73.72it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "#Graphs: 2000    #Nodes: 22.52    #Edges: 29.82 \n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████████████████████████████████████████████████████████████████████████| 1000/1000 [00:12<00:00, 81.55it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "# Graphs: 3000    # Nodes: 12.00    # Edges: 26.48 \n"
     ]
    }
   ],
   "source": [
    "edge_index_list, label_list = [], []\n",
    "ground_truth_list, role_id_list, pos_list = [], [], []\n",
    "bias = float(global_b)\n",
    "\n",
    "def graph_stats(base_num):\n",
    "    if base_num == 1:\n",
    "        base = 'tree'\n",
    "        width_basis=np.random.choice(range(3))\n",
    "    if base_num == 2:\n",
    "        base = 'ladder'\n",
    "        width_basis=np.random.choice(range(8,12))\n",
    "    if base_num == 3:\n",
    "        base = 'wheel'\n",
    "        width_basis=np.random.choice(range(15,20))\n",
    "    return base, width_basis\n",
    "\n",
    "e_mean, n_mean = [], []\n",
    "for _ in tqdm(range(1000)):\n",
    "    base_num = np.random.choice([1,2,3])\n",
    "    base, width_basis = graph_stats(base_num)\n",
    "    G1,role_id1=create_motif_type_branch(size=10,branches=4, node_feature_mean=[1.5,2.0,1.2,1.3,1.8], std = [1.5,2.0,1.2,1.3,1.8],motif_type='star')\n",
    "    G2,role_id2=create_motif_type_branch(size=8,branches=4, node_feature_mean=[1,2,1,1,1], std = [1,2,1,1,1],motif_type='star')\n",
    "    G3,role_id3=adjacent_connection(G1,G2)\n",
    "    G4,role_id4=create_motif_type_branch(size=10,branches=4, node_feature_mean=[1.5,2.0,1.2,1.3,1.8], std = [1.5,2.0,1.2,1.3,1.8],motif_type='fan')\n",
    "    G,role_id=multiple_nodes_connection(G3,G4,3)\n",
    "    #G, role_id, name = get_cycle(basis_type=base, nb_shapes=1, \n",
    "    #                               width_basis=width_basis, feature_generator=None, m=3, draw=False)\n",
    "    label_list.append(0)\n",
    "    e_mean.append(len(G.edges))\n",
    "    n_mean.append(len(G.nodes))\n",
    "\n",
    "    role_id = np.array(role_id)\n",
    "    edge_index = np.array(G.edges, dtype=np.int).T\n",
    "\n",
    "    role_id_list.append(role_id)\n",
    "    edge_index_list.append(edge_index)\n",
    "    pos_list.append(np.array(list(nx.spring_layout(G).values())))\n",
    "    ground_truth_list.append(find_gd(edge_index, role_id))\n",
    "\n",
    "print(\"#Graphs: %d    #Nodes: %.2f    #Edges: %.2f \" % (len(ground_truth_list), np.mean(n_mean), np.mean(e_mean)))\n",
    "\n",
    "e_mean, n_mean = [], []\n",
    "for _ in tqdm(range(1000)):\n",
    "    base_num = np.random.choice([1,2,3])\n",
    "    base, width_basis = graph_stats(base_num)\n",
    "    G1,role_id1=create_motif_type(size=10,node_feature_mean=[1.5,2.0,1.2,1.3,1.8], std = [1.5,2.0,1.2,1.3,1.8],motif_type='dircycle')\n",
    "    G2,role_id2=create_motif_type(size=6,node_feature_mean=[1.5,2.0,1.2,1.3,1.8], std = [1.5,2.0,1.2,1.3,1.8],motif_type='dircycle')\n",
    "    G3,role_id3=adjacent_connection(G1,G2)\n",
    "    G4,role_id4=create_motif_type(size=10,node_feature_mean=[1.5,2.0,1.2,1.3,1.8], std = [1.5,2.0,1.2,1.3,1.8],motif_type='dualRing')\n",
    "    G,role_id=multiple_edges_connection(G3,G4,3)\n",
    "    #G, role_id, name = get_house(basis_type=base, nb_shapes=1, \n",
    "    #                                width_basis=width_basis, feature_generator=None, m=3, draw=False)\n",
    "    label_list.append(1)\n",
    "    e_mean.append(len(G.edges))\n",
    "    n_mean.append(len(G.nodes))\n",
    "\n",
    "    role_id = np.array(role_id)\n",
    "    edge_index = np.array(G.edges, dtype=np.int).T\n",
    "\n",
    "    role_id_list.append(role_id)\n",
    "    edge_index_list.append(edge_index)\n",
    "    pos_list.append(np.array(list(nx.spring_layout(G).values())))\n",
    "    ground_truth_list.append(find_gd(edge_index, role_id))\n",
    "\n",
    "print(\"#Graphs: %d    #Nodes: %.2f    #Edges: %.2f \" % (len(ground_truth_list), np.mean(n_mean), np.mean(e_mean)))\n",
    "\n",
    "e_mean, n_mean = [], []\n",
    "for _ in tqdm(range(1000)):\n",
    "    base_num = np.random.choice([1,2,3])\n",
    "    base, width_basis = graph_stats(base_num)\n",
    "    G1,role_id1=create_motif_type(size=10,node_feature_mean=[1.5,2.0,1.2,1.3,1.8], std = [1.5,2.0,1.2,1.3,1.8],motif_type='netShape')\n",
    "    G2,role_id2=create_motif_type(size=6,node_feature_mean=[1.5,2.0,1.2,1.3,1.8], std = [1.5,2.0,1.2,1.3,1.8],motif_type='netShape')\n",
    "    G3,role_id3=include_smaller_graph(G1,G2)\n",
    "    G4,role_id4=create_motif_type(size=10,node_feature_mean=[1.5,2.0,1.2,1.3,1.8], std = [1.5,2.0,1.2,1.3,1.8],motif_type='triangle')\n",
    "    G,role_id=adjacent_connection(G3,G4)\n",
    "    #G, role_id, name = get_crane(basis_type=base, nb_shapes=1, \n",
    "    #                                width_basis=width_basis, feature_generator=None, m=3, draw=False)\n",
    "    label_list.append(2)\n",
    "    e_mean.append(len(G.edges))\n",
    "    n_mean.append(len(G.nodes))\n",
    "\n",
    "    role_id = np.array(role_id)\n",
    "    edge_index = np.array(G.edges, dtype=np.int).T\n",
    "\n",
    "    role_id_list.append(role_id)\n",
    "    edge_index_list.append(edge_index)\n",
    "    pos_list.append(np.array(list(nx.spring_layout(G).values())))\n",
    "    ground_truth_list.append(find_gd(edge_index, role_id))\n",
    "\n",
    "print(\"# Graphs: %d    # Nodes: %.2f    # Edges: %.2f \" % (len(ground_truth_list), np.mean(n_mean), np.mean(e_mean)))\n",
    "np.save(osp.join(data_dir, 'val.npy'), (edge_index_list, label_list, ground_truth_list, role_id_list, pos_list))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Testing Dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████████████████████████████████████████████████████████████████████████| 2000/2000 [00:26<00:00, 76.81it/s]\n",
      "  0%|▎                                                                                | 8/2000 [00:00<00:26, 73.84it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "#Graphs: 2000    #Nodes: 19.00    #Edges: 21.20 \n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████████████████████████████████████████████████████████████████████████| 2000/2000 [00:27<00:00, 72.16it/s]\n",
      "  0%|▎                                                                                | 8/2000 [00:00<00:29, 66.50it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "#Graphs: 4000    #Nodes: 22.51    #Edges: 29.81 \n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████████████████████████████████████████████████████████████████████████| 2000/2000 [00:21<00:00, 94.68it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "#Graphs: 6000    #Nodes: 12.00    #Edges: 26.46 \n"
     ]
    }
   ],
   "source": [
    "# no bias for test dataset\n",
    "edge_index_list, label_list = [], []\n",
    "ground_truth_list, role_id_list, pos_list = [], [], []\n",
    "\n",
    "def graph_stats_large(base_num):\n",
    "    if base_num == 1:\n",
    "        base = 'tree'\n",
    "        width_basis=np.random.choice(range(3,6))\n",
    "    if base_num == 2:\n",
    "        base = 'ladder'\n",
    "        width_basis=np.random.choice(range(30,50))\n",
    "    if base_num == 3:\n",
    "        base = 'wheel'\n",
    "        width_basis=np.random.choice(range(60,80))\n",
    "    return base, width_basis\n",
    "\n",
    "e_mean, n_mean = [], []\n",
    "for _ in tqdm(range(2000)):\n",
    "    base_num = np.random.choice([1,2,3]) # uniform\n",
    "    base, width_basis = graph_stats_large(base_num)\n",
    "    #G, role_id, name = get_cycle(basis_type=base, nb_shapes=1, \n",
    "    #                                width_basis=width_basis, feature_generator=None, m=3, draw=False)\n",
    "    G1,role_id1=create_motif_type_branch(size=10,branches=4, node_feature_mean=[1.5,2.0,1.2,1.3,1.8], std = [1.5,2.0,1.2,1.3,1.8],motif_type='star')\n",
    "    G2,role_id2=create_motif_type_branch(size=8,branches=4, node_feature_mean=[1,2,1,1,1], std = [1,2,1,1,1],motif_type='star')\n",
    "    G3,role_id3=adjacent_connection(G1,G2)\n",
    "    G4,role_id4=create_motif_type_branch(size=10,branches=4, node_feature_mean=[1.5,2.0,1.2,1.3,1.8], std = [1.5,2.0,1.2,1.3,1.8],motif_type='fan')\n",
    "    G,role_id=multiple_nodes_connection(G3,G4,3)\n",
    "    label_list.append(0)\n",
    "    e_mean.append(len(G.edges))\n",
    "    n_mean.append(len(G.nodes))\n",
    "    role_id = np.array(role_id)\n",
    "    edge_index = np.array(G.edges, dtype=np.int).T\n",
    "\n",
    "    role_id_list.append(role_id)\n",
    "    edge_index_list.append(edge_index)\n",
    "    pos_list.append(np.array(list(nx.spring_layout(G).values())))\n",
    "    ground_truth_list.append(find_gd(edge_index, role_id))\n",
    "\n",
    "print(\"#Graphs: %d    #Nodes: %.2f    #Edges: %.2f \" % (len(ground_truth_list), np.mean(n_mean), np.mean(e_mean)))\n",
    "\n",
    "e_mean, n_mean = [], []\n",
    "for _ in tqdm(range(2000)):\n",
    "    base_num = np.random.choice([1,2,3])\n",
    "    base, width_basis = graph_stats_large(base_num)\n",
    "\n",
    "    #G, role_id, name = get_house(basis_type=base, nb_shapes=1, \n",
    "    #                                width_basis=width_basis, feature_generator=None, m=3, draw=False)\n",
    "    G1,role_id1=create_motif_type(size=10,node_feature_mean=[1.5,2.0,1.2,1.3,1.8], std = [1.5,2.0,1.2,1.3,1.8],motif_type='dircycle')\n",
    "    G2,role_id2=create_motif_type(size=6,node_feature_mean=[1.5,2.0,1.2,1.3,1.8], std = [1.5,2.0,1.2,1.3,1.8],motif_type='dircycle')\n",
    "    G3,role_id3=adjacent_connection(G1,G2)\n",
    "    G4,role_id4=create_motif_type(size=10,node_feature_mean=[1.5,2.0,1.2,1.3,1.8], std = [1.5,2.0,1.2,1.3,1.8],motif_type='dualRing')\n",
    "    G,role_id=multiple_edges_connection(G3,G4,3)\n",
    "    label_list.append(1)\n",
    "    e_mean.append(len(G.edges))\n",
    "    n_mean.append(len(G.nodes))\n",
    "\n",
    "    role_id = np.array(role_id)\n",
    "    edge_index = np.array(G.edges, dtype=np.int).T\n",
    "\n",
    "    role_id_list.append(role_id)\n",
    "    edge_index_list.append(edge_index)\n",
    "    pos_list.append(np.array(list(nx.spring_layout(G).values())))\n",
    "    ground_truth_list.append(find_gd(edge_index, role_id))\n",
    "\n",
    "print(\"#Graphs: %d    #Nodes: %.2f    #Edges: %.2f \" % (len(ground_truth_list), np.mean(n_mean), np.mean(e_mean)))\n",
    "\n",
    "e_mean, n_mean = [], []\n",
    "for _ in tqdm(range(2000)):\n",
    "    base_num = np.random.choice([1,2,3])\n",
    "    base, width_basis = graph_stats_large(base_num)\n",
    "    #G, role_id, name = get_crane(basis_type=base, nb_shapes=1, \n",
    "    #                                width_basis=width_basis, feature_generator=None, m=3, draw=False)\n",
    "    G1,role_id1=create_motif_type(size=10,node_feature_mean=[1.5,2.0,1.2,1.3,1.8], std = [1.5,2.0,1.2,1.3,1.8],motif_type='netShape')\n",
    "    G2,role_id2=create_motif_type(size=6,node_feature_mean=[1.5,2.0,1.2,1.3,1.8], std = [1.5,2.0,1.2,1.3,1.8],motif_type='netShape')\n",
    "    G3,role_id3=include_smaller_graph(G1,G2)\n",
    "    G4,role_id4=create_motif_type(size=10,node_feature_mean=[1.5,2.0,1.2,1.3,1.8], std = [1.5,2.0,1.2,1.3,1.8],motif_type='triangle')\n",
    "    G,role_id=adjacent_connection(G3,G4)\n",
    "    label_list.append(2)\n",
    "    e_mean.append(len(G.edges))\n",
    "    n_mean.append(len(G.nodes))\n",
    "    role_id = np.array(role_id)\n",
    "    edge_index = np.array(G.edges, dtype=np.int).T\n",
    "    role_id_list.append(role_id)\n",
    "    edge_index_list.append(edge_index)\n",
    "    pos_list.append(np.array(list(nx.spring_layout(G).values())))\n",
    "    ground_truth_list.append(find_gd(edge_index, role_id))\n",
    "\n",
    "print(\"#Graphs: %d    #Nodes: %.2f    #Edges: %.2f \" % (len(ground_truth_list), np.mean(n_mean), np.mean(e_mean)))\n",
    "np.save(osp.join(data_dir, 'test.npy'), (edge_index_list, label_list, ground_truth_list, role_id_list, pos_list))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.5"
  },
  "vscode": {
   "interpreter": {
    "hash": "c71b0b87ea436ae79e2503ec051639fc2420e91bd742cb356b7debceb9d5ed19"
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
