{
 "cells": [
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Generate Spurious-Motif Dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from BA3_loc import *\n",
    "from tqdm import tqdm\n",
    "import os.path as osp\n",
    "import warnings\n",
    "warnings.filterwarnings(\"ignore\")\n",
    "\n",
    "global_b = '0.9' # Set bias degree here\n",
    "data_dir = f'../data/SPMotif-{global_b}/raw/'\n",
    "os.makedirs(data_dir, exist_ok=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_house(basis_type, nb_shapes=80, width_basis=8, feature_generator=None, m=3, draw=True):\n",
    "    \"\"\" Synthetic Graph:\n",
    "\n",
    "    Start with a tree and attach HOUSE-shaped subgraphs.\n",
    "    \"\"\"\n",
    "    list_shapes = [[\"house\"]] * nb_shapes # house\n",
    "\n",
    "    if draw:\n",
    "        plt.figure(figsize=figsize)\n",
    "\n",
    "    G, role_id, _ = synthetic_structsim.build_graph(\n",
    "        width_basis, basis_type, list_shapes, start=0, rdm_basis_plugins=True\n",
    "    )\n",
    "    G = perturb([G], 0.00, id=role_id)[0]\n",
    "\n",
    "    if feature_generator is None:\n",
    "        feature_generator = featgen.ConstFeatureGen(1)\n",
    "    feature_generator.gen_node_features(G)\n",
    "\n",
    "    name = basis_type + \"_\" + str(width_basis) + \"_\" + str(nb_shapes)\n",
    "\n",
    "    return G, role_id, name\n",
    "\n",
    "def get_cycle(basis_type, nb_shapes=80, width_basis=8, feature_generator=None, m=3, draw=True):\n",
    "    \"\"\" Synthetic Graph:\n",
    "\n",
    "    Start with a tree and attach cycle-shaped (directed edges) subgraphs.\n",
    "    \"\"\"\n",
    "    list_shapes = [[\"dircycle\"]] * nb_shapes\n",
    "\n",
    "    if draw:\n",
    "        plt.figure(figsize=figsize)\n",
    "\n",
    "    G, role_id, _ = synthetic_structsim.build_graph(\n",
    "        width_basis, basis_type, list_shapes, start=0, rdm_basis_plugins=True\n",
    "    )\n",
    "    G = perturb([G], 0.00, id=role_id)[0]       # 0.05 original\n",
    "\n",
    "    if feature_generator is None:\n",
    "        feature_generator = featgen.ConstFeatureGen(1)\n",
    "    feature_generator.gen_node_features(G)\n",
    "\n",
    "    name = basis_type + \"_\" + str(width_basis) + \"_\" + str(nb_shapes)\n",
    "\n",
    "    return G, role_id, name\n",
    "\n",
    "def get_crane(basis_type, nb_shapes=80, width_basis=8, feature_generator=None, m=3, draw=True):\n",
    "    \"\"\" Synthetic Graph:\n",
    "\n",
    "    Start with a tree and attach crane-shaped subgraphs.\n",
    "    \"\"\"\n",
    "    list_shapes = [[\"varcycle\"]] * nb_shapes   # crane\n",
    "\n",
    "    if draw:\n",
    "        plt.figure(figsize=figsize)\n",
    "\n",
    "    G, role_id, _ = synthetic_structsim.build_graph(\n",
    "        width_basis, basis_type, list_shapes, start=0, rdm_basis_plugins=True\n",
    "    )\n",
    "    G = perturb([G], 0.00, id=role_id)[0]\n",
    "\n",
    "    if feature_generator is None:\n",
    "        feature_generator = featgen.ConstFeatureGen(1)\n",
    "    feature_generator.gen_node_features(G)\n",
    "\n",
    "    name = basis_type + \"_\" + str(width_basis) + \"_\" + str(nb_shapes)\n",
    "\n",
    "    return G, role_id, name"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Training Dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 1000/1000 [00:03<00:00, 273.89it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "#Graphs: 1000    #Nodes: 11.48    #Edges: 12.86 \n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 1000/1000 [00:04<00:00, 224.01it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "#Graphs: 2000    #Nodes: 23.20    #Edges: 32.68 \n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 1000/1000 [00:04<00:00, 236.25it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "#Graphs: 3000    #Nodes: 21.49    #Edges: 38.33 \n"
     ]
    },
    {
     "ename": "ValueError",
     "evalue": "setting an array element with a sequence. The requested array has an inhomogeneous shape after 2 dimensions. The detected shape was (5, 3000) + inhomogeneous part.",
     "output_type": "error",
     "traceback": [
      "\u001b[1;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[1;31mValueError\u001b[0m                                Traceback (most recent call last)",
      "Cell \u001b[1;32mIn[3], line 81\u001b[0m\n\u001b[0;32m     77\u001b[0m     ground_truth_list\u001b[38;5;241m.\u001b[39mappend(find_gd(edge_index, role_id))\n\u001b[0;32m     79\u001b[0m \u001b[38;5;28mprint\u001b[39m(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124m#Graphs: \u001b[39m\u001b[38;5;132;01m%d\u001b[39;00m\u001b[38;5;124m    #Nodes: \u001b[39m\u001b[38;5;132;01m%.2f\u001b[39;00m\u001b[38;5;124m    #Edges: \u001b[39m\u001b[38;5;132;01m%.2f\u001b[39;00m\u001b[38;5;124m \u001b[39m\u001b[38;5;124m\"\u001b[39m \u001b[38;5;241m%\u001b[39m (\u001b[38;5;28mlen\u001b[39m(ground_truth_list), np\u001b[38;5;241m.\u001b[39mmean(n_mean), np\u001b[38;5;241m.\u001b[39mmean(e_mean)))\n\u001b[1;32m---> 81\u001b[0m \u001b[43mnp\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43msave\u001b[49m\u001b[43m(\u001b[49m\u001b[43mosp\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mjoin\u001b[49m\u001b[43m(\u001b[49m\u001b[43mdata_dir\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;124;43m'\u001b[39;49m\u001b[38;5;124;43mtrain.npy\u001b[39;49m\u001b[38;5;124;43m'\u001b[39;49m\u001b[43m)\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43m(\u001b[49m\u001b[43medge_index_list\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mlabel_list\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mground_truth_list\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mrole_id_list\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mpos_list\u001b[49m\u001b[43m)\u001b[49m\u001b[43m)\u001b[49m\n",
      "File \u001b[1;32md:\\Anaconda\\envs\\CRCG\\lib\\site-packages\\numpy\\lib\\_npyio_impl.py:573\u001b[0m, in \u001b[0;36msave\u001b[1;34m(file, arr, allow_pickle, fix_imports)\u001b[0m\n\u001b[0;32m    570\u001b[0m     file_ctx \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mopen\u001b[39m(file, \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mwb\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n\u001b[0;32m    572\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m file_ctx \u001b[38;5;28;01mas\u001b[39;00m fid:\n\u001b[1;32m--> 573\u001b[0m     arr \u001b[38;5;241m=\u001b[39m \u001b[43mnp\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43masanyarray\u001b[49m\u001b[43m(\u001b[49m\u001b[43marr\u001b[49m\u001b[43m)\u001b[49m\n\u001b[0;32m    574\u001b[0m     \u001b[38;5;28mformat\u001b[39m\u001b[38;5;241m.\u001b[39mwrite_array(fid, arr, allow_pickle\u001b[38;5;241m=\u001b[39mallow_pickle,\n\u001b[0;32m    575\u001b[0m                        pickle_kwargs\u001b[38;5;241m=\u001b[39m\u001b[38;5;28mdict\u001b[39m(fix_imports\u001b[38;5;241m=\u001b[39mfix_imports))\n",
      "\u001b[1;31mValueError\u001b[0m: setting an array element with a sequence. The requested array has an inhomogeneous shape after 2 dimensions. The detected shape was (5, 3000) + inhomogeneous part."
     ]
    }
   ],
   "source": [
    "edge_index_list, label_list = [], []\n",
    "ground_truth_list, role_id_list, pos_list = [], [], []\n",
    "bias = float(global_b)\n",
    "\n",
    "def graph_stats(base_num):\n",
    "    if base_num == 1:\n",
    "        base = 'tree'\n",
    "        width_basis=np.random.choice(range(3))\n",
    "    if base_num == 2:\n",
    "        base = 'ladder'\n",
    "        width_basis=np.random.choice(range(8,12))\n",
    "    if base_num == 3:\n",
    "        base = 'wheel'\n",
    "        width_basis=np.random.choice(range(15,20))\n",
    "    return base, width_basis\n",
    "\n",
    "e_mean, n_mean = [], []\n",
    "for _ in tqdm(range(1000)):\n",
    "    base_num = np.random.choice([1,2,3], p=[bias,(1-bias)/2,(1-bias)/2])\n",
    "    base, width_basis = graph_stats(base_num)\n",
    "\n",
    "    G, role_id, name = get_cycle(basis_type=base, nb_shapes=1, \n",
    "                                    width_basis=width_basis, feature_generator=None, m=3, draw=False)\n",
    "    label_list.append(0)\n",
    "    e_mean.append(len(G.edges))\n",
    "    n_mean.append(len(G.nodes))\n",
    "    \n",
    "    role_id = np.array(role_id)\n",
    "    edge_index = np.array(G.edges, dtype=int).T\n",
    "\n",
    "    role_id_list.append(role_id)\n",
    "    edge_index_list.append(edge_index)\n",
    "    pos_list.append(np.array(list(nx.spring_layout(G).values())))\n",
    "    ground_truth_list.append(find_gd(edge_index, role_id))\n",
    "\n",
    "print(\"#Graphs: %d    #Nodes: %.2f    #Edges: %.2f \" % (len(ground_truth_list), np.mean(n_mean), np.mean(e_mean)))\n",
    "\n",
    "e_mean, n_mean = [], []\n",
    "for _ in tqdm(range(1000)):\n",
    "    base_num = np.random.choice([1,2,3], p=[(1-bias)/2,bias,(1-bias)/2])\n",
    "    base, width_basis = graph_stats(base_num)\n",
    "\n",
    "    G, role_id, name = get_house(basis_type=base, nb_shapes=1, \n",
    "                                    width_basis=width_basis, feature_generator=None, m=3, draw=False)\n",
    "    label_list.append(1)\n",
    "    e_mean.append(len(G.edges))\n",
    "    n_mean.append(len(G.nodes))\n",
    "\n",
    "    role_id = np.array(role_id)\n",
    "    edge_index = np.array(G.edges, dtype=int).T\n",
    "\n",
    "    role_id_list.append(role_id)\n",
    "    edge_index_list.append(edge_index)\n",
    "    pos_list.append(np.array(list(nx.spring_layout(G).values())))\n",
    "    ground_truth_list.append(find_gd(edge_index, role_id))\n",
    "\n",
    "print(\"#Graphs: %d    #Nodes: %.2f    #Edges: %.2f \" % (len(ground_truth_list), np.mean(n_mean), np.mean(e_mean)))\n",
    "\n",
    "\n",
    "e_mean, n_mean = [], []\n",
    "for _ in tqdm(range(1000)):\n",
    "    base_num = np.random.choice([1,2,3], p=[(1-bias)/2,(1-bias)/2,bias])\n",
    "    base, width_basis = graph_stats(base_num)\n",
    "    \n",
    "    G, role_id, name = get_crane(basis_type=base, nb_shapes=1, \n",
    "                                    width_basis=width_basis, feature_generator=None, m=3, draw=False)\n",
    "    label_list.append(2)\n",
    "    e_mean.append(len(G.edges))\n",
    "    n_mean.append(len(G.nodes))\n",
    "\n",
    "    role_id = np.array(role_id)\n",
    "    edge_index = np.array(G.edges, dtype=int).T\n",
    "\n",
    "    role_id_list.append(role_id)\n",
    "    edge_index_list.append(edge_index)\n",
    "    pos_list.append(np.array(list(nx.spring_layout(G).values())))\n",
    "    ground_truth_list.append(find_gd(edge_index, role_id))\n",
    "\n",
    "print(\"#Graphs: %d    #Nodes: %.2f    #Edges: %.2f \" % (len(ground_truth_list), np.mean(n_mean), np.mean(e_mean)))\n",
    "\n",
    "np.save(osp.join(data_dir, 'train.npy'), (edge_index_list, label_list, ground_truth_list, role_id_list, pos_list))"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Val Dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 1000/1000 [00:08<00:00, 119.38it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "#Graphs: 1000    #Nodes: 18.79    #Edges: 27.38 \n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 1000/1000 [00:08<00:00, 118.68it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "#Graphs: 2000    #Nodes: 18.77    #Edges: 28.27 \n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 1000/1000 [00:08<00:00, 120.57it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "# Graphs: 3000    # Nodes: 18.58    # Edges: 28.72 \n"
     ]
    }
   ],
   "source": [
    "edge_index_list, label_list = [], []\n",
    "ground_truth_list, role_id_list, pos_list = [], [], []\n",
    "bias = float(global_b)\n",
    "\n",
    "def graph_stats(base_num):\n",
    "    if base_num == 1:\n",
    "        base = 'tree'\n",
    "        width_basis=np.random.choice(range(3))\n",
    "    if base_num == 2:\n",
    "        base = 'ladder'\n",
    "        width_basis=np.random.choice(range(8,12))\n",
    "    if base_num == 3:\n",
    "        base = 'wheel'\n",
    "        width_basis=np.random.choice(range(15,20))\n",
    "    return base, width_basis\n",
    "\n",
    "e_mean, n_mean = [], []\n",
    "for _ in tqdm(range(1000)):\n",
    "    base_num = np.random.choice([1,2,3])\n",
    "    base, width_basis = graph_stats(base_num)\n",
    "\n",
    "    G, role_id, name = get_cycle(basis_type=base, nb_shapes=1, \n",
    "                                    width_basis=width_basis, feature_generator=None, m=3, draw=False)\n",
    "    label_list.append(0)\n",
    "    e_mean.append(len(G.edges))\n",
    "    n_mean.append(len(G.nodes))\n",
    "\n",
    "    role_id = np.array(role_id)\n",
    "    edge_index = np.array(G.edges, dtype=np.int).T\n",
    "\n",
    "    role_id_list.append(role_id)\n",
    "    edge_index_list.append(edge_index)\n",
    "    pos_list.append(np.array(list(nx.spring_layout(G).values())))\n",
    "    ground_truth_list.append(find_gd(edge_index, role_id))\n",
    "\n",
    "print(\"#Graphs: %d    #Nodes: %.2f    #Edges: %.2f \" % (len(ground_truth_list), np.mean(n_mean), np.mean(e_mean)))\n",
    "\n",
    "e_mean, n_mean = [], []\n",
    "for _ in tqdm(range(1000)):\n",
    "    base_num = np.random.choice([1,2,3])\n",
    "    base, width_basis = graph_stats(base_num)\n",
    "\n",
    "    G, role_id, name = get_house(basis_type=base, nb_shapes=1, \n",
    "                                    width_basis=width_basis, feature_generator=None, m=3, draw=False)\n",
    "    label_list.append(1)\n",
    "    e_mean.append(len(G.edges))\n",
    "    n_mean.append(len(G.nodes))\n",
    "\n",
    "    role_id = np.array(role_id)\n",
    "    edge_index = np.array(G.edges, dtype=np.int).T\n",
    "\n",
    "    role_id_list.append(role_id)\n",
    "    edge_index_list.append(edge_index)\n",
    "    pos_list.append(np.array(list(nx.spring_layout(G).values())))\n",
    "    ground_truth_list.append(find_gd(edge_index, role_id))\n",
    "\n",
    "print(\"#Graphs: %d    #Nodes: %.2f    #Edges: %.2f \" % (len(ground_truth_list), np.mean(n_mean), np.mean(e_mean)))\n",
    "\n",
    "\n",
    "e_mean, n_mean = [], []\n",
    "for _ in tqdm(range(1000)):\n",
    "    base_num = np.random.choice([1,2,3])\n",
    "    base, width_basis = graph_stats(base_num)\n",
    "    \n",
    "    G, role_id, name = get_crane(basis_type=base, nb_shapes=1, \n",
    "                                    width_basis=width_basis, feature_generator=None, m=3, draw=False)\n",
    "    label_list.append(2)\n",
    "    e_mean.append(len(G.edges))\n",
    "    n_mean.append(len(G.nodes))\n",
    "\n",
    "    role_id = np.array(role_id)\n",
    "    edge_index = np.array(G.edges, dtype=np.int).T\n",
    "\n",
    "    role_id_list.append(role_id)\n",
    "    edge_index_list.append(edge_index)\n",
    "    pos_list.append(np.array(list(nx.spring_layout(G).values())))\n",
    "    ground_truth_list.append(find_gd(edge_index, role_id))\n",
    "\n",
    "print(\"# Graphs: %d    # Nodes: %.2f    # Edges: %.2f \" % (len(ground_truth_list), np.mean(n_mean), np.mean(e_mean)))\n",
    "np.save(osp.join(data_dir, 'val.npy'), (edge_index_list, label_list, ground_truth_list, role_id_list, pos_list))"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Testing Dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "  0%|          | 0/2000 [00:00<?, ?it/s]"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 2000/2000 [01:51<00:00, 18.02it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "#Graphs: 2000    #Nodes: 90.84    #Edges: 126.40 \n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 2000/2000 [01:48<00:00, 18.36it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "#Graphs: 4000    #Nodes: 90.22    #Edges: 126.15 \n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 2000/2000 [01:47<00:00, 18.61it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "#Graphs: 6000    #Nodes: 89.57    #Edges: 127.93 \n"
     ]
    }
   ],
   "source": [
    "# no bias for test dataset\n",
    "edge_index_list, label_list = [], []\n",
    "ground_truth_list, role_id_list, pos_list = [], [], []\n",
    "\n",
    "def graph_stats_large(base_num):\n",
    "    if base_num == 1:\n",
    "        base = 'tree'\n",
    "        width_basis=np.random.choice(range(3,6))\n",
    "    if base_num == 2:\n",
    "        base = 'ladder'\n",
    "        width_basis=np.random.choice(range(30,50))\n",
    "    if base_num == 3:\n",
    "        base = 'wheel'\n",
    "        width_basis=np.random.choice(range(60,80))\n",
    "    return base, width_basis\n",
    "\n",
    "e_mean, n_mean = [], []\n",
    "for _ in tqdm(range(2000)):\n",
    "    base_num = np.random.choice([1,2,3]) # uniform\n",
    "    base, width_basis = graph_stats_large(base_num)\n",
    "\n",
    "    G, role_id, name = get_cycle(basis_type=base, nb_shapes=1, \n",
    "                                    width_basis=width_basis, feature_generator=None, m=3, draw=False)\n",
    "    label_list.append(0)\n",
    "    e_mean.append(len(G.edges))\n",
    "    n_mean.append(len(G.nodes))\n",
    "\n",
    "    role_id = np.array(role_id)\n",
    "    edge_index = np.array(G.edges, dtype=np.int).T\n",
    "\n",
    "    role_id_list.append(role_id)\n",
    "    edge_index_list.append(edge_index)\n",
    "    pos_list.append(np.array(list(nx.spring_layout(G).values())))\n",
    "    ground_truth_list.append(find_gd(edge_index, role_id))\n",
    "\n",
    "print(\"#Graphs: %d    #Nodes: %.2f    #Edges: %.2f \" % (len(ground_truth_list), np.mean(n_mean), np.mean(e_mean)))\n",
    "\n",
    "e_mean, n_mean = [], []\n",
    "for _ in tqdm(range(2000)):\n",
    "    base_num = np.random.choice([1,2,3])\n",
    "    base, width_basis = graph_stats_large(base_num)\n",
    "\n",
    "    G, role_id, name = get_house(basis_type=base, nb_shapes=1, \n",
    "                                    width_basis=width_basis, feature_generator=None, m=3, draw=False)\n",
    "    label_list.append(1)\n",
    "    e_mean.append(len(G.edges))\n",
    "    n_mean.append(len(G.nodes))\n",
    "\n",
    "    role_id = np.array(role_id)\n",
    "    edge_index = np.array(G.edges, dtype=np.int).T\n",
    "\n",
    "    role_id_list.append(role_id)\n",
    "    edge_index_list.append(edge_index)\n",
    "    pos_list.append(np.array(list(nx.spring_layout(G).values())))\n",
    "    ground_truth_list.append(find_gd(edge_index, role_id))\n",
    "\n",
    "print(\"#Graphs: %d    #Nodes: %.2f    #Edges: %.2f \" % (len(ground_truth_list), np.mean(n_mean), np.mean(e_mean)))\n",
    "\n",
    "e_mean, n_mean = [], []\n",
    "for _ in tqdm(range(2000)):\n",
    "    base_num = np.random.choice([1,2,3])\n",
    "    base, width_basis = graph_stats_large(base_num)\n",
    "\n",
    "    G, role_id, name = get_crane(basis_type=base, nb_shapes=1, \n",
    "                                    width_basis=width_basis, feature_generator=None, m=3, draw=False)\n",
    "    label_list.append(2)\n",
    "    e_mean.append(len(G.edges))\n",
    "    n_mean.append(len(G.nodes))\n",
    "\n",
    "    role_id = np.array(role_id)\n",
    "    edge_index = np.array(G.edges, dtype=np.int).T\n",
    "\n",
    "    role_id_list.append(role_id)\n",
    "    edge_index_list.append(edge_index)\n",
    "    pos_list.append(np.array(list(nx.spring_layout(G).values())))\n",
    "    ground_truth_list.append(find_gd(edge_index, role_id))\n",
    "\n",
    "print(\"#Graphs: %d    #Nodes: %.2f    #Edges: %.2f \" % (len(ground_truth_list), np.mean(n_mean), np.mean(e_mean)))\n",
    "np.save(osp.join(data_dir, 'test.npy'), (edge_index_list, label_list, ground_truth_list, role_id_list, pos_list))"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "CRCG",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.16"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
