{
 "cells": [
  {
   "attachments": {},
   "cell_type": "markdown",
   "id": "a32dd142",
   "metadata": {},
   "source": [
    "### Generate CRCG-NODE Dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "142774b6",
   "metadata": {},
   "outputs": [],
   "source": [
    "from BA3_loc import *\n",
    "from tqdm import tqdm\n",
    "import os.path as osp\n",
    "import warnings\n",
    "warnings.filterwarnings(\"ignore\")\n",
    "import random\n",
    "import math\n",
    "import torch\n",
    "import copy\n",
    "from scipy.stats import gamma\n",
    "from scipy.stats import gompertz\n",
    "from sklearn.metrics.pairwise import cosine_similarity\n",
    "from scipy.stats import weibull_min\n",
    "from scipy.special import gamma, gammaincinv\n",
    "from scipy.spatial.distance import cdist\n",
    "from sklearn.metrics.pairwise import euclidean_distances\n",
    "data_dir = f'../data/CRCG-NODE/raw/'\n",
    "os.makedirs(data_dir, exist_ok=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "397bda31",
   "metadata": {},
   "outputs": [],
   "source": [
    "#对于节点信息进行符合一定规律的生成\n",
    "# 生成Gamma分布\n",
    "def generate_gamma(mu, sigma, size):\n",
    "    # 根据平均值和标准差计算形状参数k和尺度参数theta\n",
    "    var = np.power(sigma, 2)\n",
    "    theta = np.divide(var, mu)\n",
    "    k = np.divide(mu, theta)\n",
    "    # 生成Gamma分布\n",
    "    return gamma.rvs(a=k, scale=theta, size=size)\n",
    "def generate_nodes_normal_distributed(num_nodes,mean, std):\n",
    "      role_id = []\n",
    "      nodes = np.zeros((num_nodes, len(mean)))\n",
    "      for i in range(len(mean)):\n",
    "          nodes[:, i] = np.random.normal(loc=mean[i], scale=std[i], size=num_nodes)\n",
    "      role_id = list(range(num_nodes))\n",
    "      return nodes,role_id\n",
    "def generate_nodes_uniform_distributed(num_nodes,mean, std):\n",
    "      role_id = []\n",
    "      nodes = np.zeros((num_nodes, len(mean)))\n",
    "      for i in range(len(mean)):\n",
    "          nodes[:, i] = np.random.uniform(low=mean[i]- std[i], high=mean[i] + std[i], size=num_nodes)\n",
    "      role_id = list(range(num_nodes))\n",
    "      return nodes,role_id   \n",
    "def generate_nodes_exponential_distributed(num_nodes,mean, std):\n",
    "      role_id = []\n",
    "      nodes = np.zeros((num_nodes, len(mean)))\n",
    "      for i in range(len(mean)):\n",
    "          nodes[:, i] = np.random.exponential(scale=std[i]/mean[i], size=num_nodes) * mean[i]\n",
    "      role_id = list(range(num_nodes))\n",
    "      return nodes,role_id\n",
    "def generate_nodes_lognormal_distributed(num_nodes,mean, std):\n",
    "      role_id = []\n",
    "      nodes = np.zeros((num_nodes, len(mean)))\n",
    "      for i in range(len(mean)):\n",
    "            mu = np.log(mean[i]**2/np.sqrt(std[i]**2+mean[i]**2))\n",
    "            sigma = np.sqrt(np.log(std[i]**2/mean[i]**2 + 1))\n",
    "            nodes[:, i] = np.random.lognormal(mu, sigma, num_nodes)\n",
    "      role_id = list(range(num_nodes))\n",
    "      return nodes,role_id\n",
    "def generate_nodes_gamma_distributed(num_nodes,mean, std):\n",
    "      role_id = []\n",
    "      nodes = np.zeros((num_nodes, len(mean)))\n",
    "      for i in range(len(mean)):\n",
    "        nodes[:,i] = generate_gamma(mean[i], std[i], num_nodes)\n",
    "      role_id = list(range(num_nodes))\n",
    "      return nodes,role_id  \n",
    "def generate_nodes_beta_distributed(num_nodes,mean, std):\n",
    "      role_id = []\n",
    "      nodes = np.zeros((num_nodes, len(mean)))\n",
    "      for i in range(len(mean)):\n",
    "        a = ((1 - np.array(mean)) / np.array(std) ** 2 - 1 / np.array(mean)) * np.array(mean) ** 2\n",
    "        b = a * (1 / np.array(mean) - 1)\n",
    "        nodes[:,i] = np.random.beta(a=a[i], b=b[i], size=num_nodes)\n",
    "      role_id = list(range(num_nodes))\n",
    "      return nodes,role_id\n",
    "def generate_nodes_weibull_distributed(num_nodes,mean, std):\n",
    "      role_id = []\n",
    "      # 随机设定形状参数k和尺度参数lambda\n",
    "      k = np.random.uniform(low=0.5, high=2.0, size=5)\n",
    "      lam = np.random.uniform(low=0.5, high=2.0, size=5)\n",
    "      nodes = np.zeros((num_nodes, len(mean)))\n",
    "      for i in range(len(mean)):\n",
    "        nodes[:, i] = weibull_min.rvs(k[i], scale=lam[i], size=num_nodes)\n",
    "        nodes[:, i] = (nodes[:, i] - nodes[:, i].mean()) / nodes[:, i].std()\n",
    "        nodes[:, i] = nodes[:, i] * std[i] + mean[i]\n",
    "      role_id = list(range(num_nodes))\n",
    "      return nodes,role_id\n",
    "def generate_nodes_laplace_distributed(num_nodes,mean, std):\n",
    "      role_id = []\n",
    "      nodes = np.zeros((num_nodes, len(mean)))\n",
    "      for i in range(len(mean)):\n",
    "        mu = mean[i]\n",
    "        b = std[i] / np.sqrt(2)\n",
    "        nodes[:, i] = np.random.laplace(mu, b, size=num_nodes)\n",
    "      role_id = list(range(num_nodes))\n",
    "      return nodes,role_id\n",
    "def generate_nodes_logistic_distributed(num_nodes,mean, std):    \n",
    "      role_id = []\n",
    "      nodes = np.zeros((num_nodes, len(mean)))\n",
    "      for i in range(len(mean)):\n",
    "        mu = mean[i]\n",
    "        s = std[i] / np.sqrt(3)\n",
    "        nodes[:, i] = np.random.logistic(mu, s, size=num_nodes)\n",
    "      role_id = list(range(num_nodes))\n",
    "      return nodes,role_id\n",
    "def generate_nodes_rayleigh_distributed(num_nodes,mean, std):\n",
    "      dim = len(mean)    \n",
    "      role_id = []\n",
    "      nodes = np.zeros((num_nodes, dim))\n",
    "      for i in range(dim):\n",
    "        sigma = std[i] / np.sqrt(2*np.pi)\n",
    "        nodes[:, i] = np.random.rayleigh(sigma, size=num_nodes)\n",
    "      role_id = list(range(num_nodes))\n",
    "      return nodes,role_id\n",
    "def generate_nodes_pareto_distributed(num_nodes,mean, std):\n",
    "      role_id = []\n",
    "      nodes = np.zeros((num_nodes, len(mean)))\n",
    "      for i in range(len(mean)):\n",
    "        alpha = mean[i] / std[i]\n",
    "        nodes[:, i] = np.random.pareto(alpha, size=num_nodes)\n",
    "      role_id = list(range(num_nodes))\n",
    "      return nodes,role_id\n",
    "def generate_nodes_cauchy_distributed(num_nodes,mean, std):\n",
    "      role_id = []\n",
    "      nodes = np.zeros((num_nodes, len(mean)))\n",
    "      for i in range(len(mean)):\n",
    "        x0 = mean[i]\n",
    "        gamma = std[i] * np.sqrt(np.pi / 2)\n",
    "        nodes[:, i] = np.random.standard_cauchy(size=num_nodes) * gamma + x0\n",
    "      role_id = list(range(num_nodes))\n",
    "      return nodes,role_id\n",
    "def generate_nodes_neg_binom_distributed(num_nodes,mean, std):\n",
    "      role_id = []\n",
    "      nodes = np.zeros((num_nodes, len(mean)))\n",
    "      for i in range(len(mean)):\n",
    "        p = mean[i] / (std[i]**2 + mean[i])\n",
    "        r = mean[i]**2 / (std[i]**2 - mean[i])\n",
    "        nodes[:, i]=np.random.negative_binomial(r, 1-p, size=num_nodes)\n",
    "      role_id = list(range(num_nodes)) \n",
    "      return nodes,role_id \n",
    "def generate_nodes_gumbel_distributed(num_nodes,mean, std):\n",
    "      role_id = []\n",
    "      nodes = np.zeros((num_nodes, len(mean)))\n",
    "      for i in range(len(mean)):\n",
    "         loc = mean[i] - 0.57721 * std[i]\n",
    "         scale = np.sqrt(6) * std[i] / np.pi\n",
    "         nodes[:, i]=np.random.gumbel(loc=loc, scale=scale, size=num_nodes)\n",
    "      role_id = list(range(num_nodes)) \n",
    "      return nodes,role_id   \n",
    "def generate_nodes_gompertz_distributed(num_nodes,mean, std):\n",
    "      role_id = []\n",
    "      nodes = np.zeros((num_nodes, len(mean)))\n",
    "      for i in range(len(mean)):\n",
    "         loc = mean[i] - np.log(np.log(2)/2)*std[i]\n",
    "         scale = np.exp(std[i]/np.log(2))\n",
    "         nodes[:, i]=gompertz.rvs(c=loc, scale=scale, size=num_nodes)\n",
    "      role_id = list(range(num_nodes))               \n",
    "      return nodes,role_id\n",
    "def rectangle_sequence(n):\n",
    "    seq = [0]\n",
    "    for i in range(1, n+1):\n",
    "        seq.append(seq[-1] + i*2)\n",
    "    return seq[1:]\n",
    "def binomial_coefficients(n, dim):\n",
    "    \"\"\"Generate a binomial coefficient sequence of length `dim` for a given `n`\"\"\"\n",
    "    seq = []\n",
    "    for i in range(1, dim+1):\n",
    "        seq.append(math.comb(n, i))\n",
    "    return seq\n",
    "def generate_nodes_arithmetic(num_nodes,dims,step):\n",
    "      role_id = []\n",
    "      nodes = []\n",
    "      for i in range(num_nodes):\n",
    "          start = random.uniform(0,10)  # 随机生成起始值\n",
    "          node = [start + j * step for j in range(dims)] # 生成等差数列\n",
    "          nodes.append(node)\n",
    "          #print(nodes)\n",
    "      role_id = [random.randint(0, 2) for i in range(num_nodes)]\n",
    "      return np.array(nodes).reshape((num_nodes, dims)),role_id\n",
    "def generate_nodes_geometric(num_nodes,dims,step):\n",
    "      role_id = []\n",
    "      nodes = []\n",
    "      for i in range(num_nodes):\n",
    "          start = random.uniform(0,10)  # 随机生成起始值\n",
    "          node = [start*step**j for j in range(dims)]  # 生成等比数列\n",
    "          nodes.append(node)\n",
    "          #print(nodes)\n",
    "      role_id = [random.randint(0, 2) for i in range(num_nodes)]\n",
    "      return np.array(nodes).reshape((num_nodes, dims)),role_id\n",
    "def generate_nodes_fibonacci(num_nodes,dims,step):\n",
    "      role_id = []\n",
    "      nodes = []\n",
    "      for i in range(num_nodes):\n",
    "          # 根据节点索引 i 计算斐波那契数列\n",
    "          start1 = round(random.uniform(0, 10), 1)\n",
    "          start2 = round(random.uniform(0, 10), 1)\n",
    "          fib_nums = [start1, start2]\n",
    "          for j in range(dims - 2):\n",
    "              fib_nums.append(fib_nums[-1] + fib_nums[-2])\n",
    "          #fib_nums = np.array(fib_nums) * (i + 1)\n",
    "          # 将节点特征添加到节点数组\n",
    "          nodes.append(fib_nums)\n",
    "      role_id = [random.randint(0, 2) for i in range(num_nodes)]\n",
    "      return np.array(nodes).reshape((num_nodes, dims)),role_id\n",
    "def generate_nodes_square(num_nodes,dims,step):\n",
    "        role_id = []\n",
    "        nodes = []\n",
    "        for i in range(num_nodes):\n",
    "          initial_val = random.uniform(0, 10)\n",
    "          node = [initial_val]\n",
    "          for j in range(dims-1):\n",
    "              node.append(initial_val ** 2)\n",
    "              initial_val=initial_val ** 2\n",
    "          nodes.append(node)\n",
    "        role_id = [random.randint(0, 2) for i in range(num_nodes)]\n",
    "        return np.array(nodes).reshape((num_nodes, dims)),role_id\n",
    "def generate_nodes_cube(num_nodes,dims,step):\n",
    "        role_id = []\n",
    "        nodes = []\n",
    "        for i in range(num_nodes):\n",
    "          initial_val = random.uniform(0, 10)\n",
    "          node = [initial_val]\n",
    "          for j in range(dims-1):\n",
    "              node.append(initial_val ** 3)\n",
    "              initial_val=initial_val ** 3\n",
    "          nodes.append(node)\n",
    "        role_id = [random.randint(0, 2) for i in range(num_nodes)]\n",
    "        return np.array(nodes).reshape((num_nodes, dims)),role_id\n",
    "def generate_nodes_prime(num_nodes,dims,step):\n",
    "       nodes = []\n",
    "       for i in range(num_nodes):\n",
    "          # 随机生成起始值\n",
    "          start = random.uniform(0, 10)\n",
    "          # 生成质数数列\n",
    "          primes = []\n",
    "          n = 2\n",
    "          while len(primes) < dims:\n",
    "              if all(n % p != 0 for p in primes):\n",
    "                  primes.append(n)\n",
    "              n += 1\n",
    "          node = np.array(primes) * start\n",
    "          nodes.append(node)\n",
    "       role_id = [random.randint(0, 2) for i in range(num_nodes)]\n",
    "       return np.array(nodes).reshape((num_nodes, dims)),role_id\n",
    "def generate_nodes_triangular(num_nodes,dims,step):\n",
    "      #三角数列是指数列中的每一项都是由1到n连续相加所得到的和，它们的值恰好是一个等边三角形的面积\n",
    "      nodes = []\n",
    "      for i in range(num_nodes):\n",
    "          initial_value = random.uniform(0, 10)\n",
    "          features = []\n",
    "          for j in range(dims):\n",
    "              feature = 0.5 * initial_value * (initial_value + 1)\n",
    "              features.append(feature)\n",
    "              initial_value += 1\n",
    "          nodes.append(features)\n",
    "      role_id = [random.randint(0, 2) for i in range(num_nodes)]\n",
    "      return np.array(nodes).reshape((num_nodes, dims)),role_id\n",
    "def generate_nodes_rectangular(num_nodes,dims,step):    \n",
    "      #矩形数列是指数列中的每一项都是由1到n连续相加所得到的和，它们的值恰好是一个长宽比为2:1的矩形的面积\n",
    "      nodes = []\n",
    "      for i in range(num_nodes):\n",
    "          initial_value = random.uniform(0, 10)\n",
    "          seq = rectangle_sequence(dims)\n",
    "          feature = [initial_value + item for item in seq]\n",
    "          nodes.append(feature)\n",
    "      role_id = [random.randint(0, 2) for i in range(num_nodes)]\n",
    "      return np.array(nodes).reshape((num_nodes, dims)),role_id\n",
    "def generate_nodes_binomial(num_nodes,dims,step):\n",
    "      #二项式系数数列是指二项式系数的数列，它们的值可以表示为组合数，即n个物体中选取k个的方案数\n",
    "      \"\"\"Generate `n_nodes` random binomial coefficient sequences of length `dim`\"\"\"\n",
    "      nodes = []\n",
    "      for i in range(num_nodes):\n",
    "          rand_n = random.randint(1, 10)\n",
    "          seq = binomial_coefficients(rand_n, dims)\n",
    "          nodes.append(seq)\n",
    "      role_id = [random.randint(0, 2) for i in range(num_nodes)]\n",
    "      return np.array(nodes).reshape((num_nodes, dims)),role_id\n",
    "def generate_nodes_hamilton(num_nodes,dims,step):\n",
    "      #哈密顿数列是一个由正整数构成的数列，它的通项公式为H_n = 2^n-1，其中n为自然数\n",
    "      nodes = []\n",
    "      for i in range(num_nodes):\n",
    "          initial_value = random.randint(1, 10)\n",
    "          hamilton_seq = [2**n - 1 for n in range(initial_value, initial_value+dims)]\n",
    "          nodes.append(hamilton_seq)\n",
    "      print(nodes)\n",
    "      role_id = [random.randint(0,2) for i in range(num_nodes)]\n",
    "      return np.array(nodes).reshape((num_nodes, dims)),role_id\n",
    "def merge_nodes(node_set1, node_set2):\n",
    "    \"\"\"\n",
    "    将两个节点集合按垂直方向合并\n",
    "    :param node_set1: 第一个节点集合\n",
    "    :param node_set2: 第二个节点集合\n",
    "    :return: 合并后的节点集合\n",
    "    \"\"\"\n",
    "    return np.vstack((node_set1, node_set2))\n",
    "def build_sim_edges(nodes,sim_threshold):\n",
    "    num_nodes = nodes.shape[0]\n",
    "    edges = set()\n",
    "    for i in range(num_nodes):\n",
    "        for j in range(i+1, num_nodes):\n",
    "            similarity = cosine_similarity(nodes[i:i+1, :], nodes[j:j+1, :])[0, 0]\n",
    "            if similarity > sim_threshold:\n",
    "                edges.add((i, j))\n",
    "    return edges\n",
    "def create_partial_sim_edges(nodes, partial_sim_threshold,dims):\n",
    "    num_nodes = len(nodes)\n",
    "    edges = []\n",
    "    for i in range(num_nodes):\n",
    "        for j in range(i+1, num_nodes):\n",
    "            \n",
    "            sim = cosine_similarity(nodes[i][dims].reshape(1, -1), nodes[j][dims].reshape(1, -1))\n",
    "            sim=sim[0][0]\n",
    "            if sim > partial_sim_threshold:\n",
    "                edges.append((i, j))\n",
    "    return edges\n",
    "'''\n",
    "维护一个列表,简单描述节点生成方法(节点分类人工合成数据集)\n",
    "'''\n",
    "node_generators = {\n",
    "    1: (generate_nodes_normal_distributed, \"基于平均值与标准差正态分布生成\"),\n",
    "    2: (generate_nodes_uniform_distributed, \"基于平均值与标准差均匀分布生成\"),\n",
    "    3: (generate_nodes_exponential_distributed, \"基于平均值与标准差指数分布生成\"),\n",
    "    4: (generate_nodes_lognormal_distributed, \"基于平均值与标准差对数正态分布生成\"),\n",
    "    5: (generate_nodes_weibull_distributed, \"基于平均值与标准差Weibull分布生成\"),\n",
    "    6: (generate_nodes_laplace_distributed, \"基于平均值与标准差Laplace分布生成\"),\n",
    "    7: (generate_nodes_logistic_distributed, \"基于平均值与标准差Logistic分布生成\"),\n",
    "    8: (generate_nodes_rayleigh_distributed, \"基于平均值与标准差Rayleigh分布生成\"),\n",
    "    9: (generate_nodes_pareto_distributed, \"基于平均值与标准差Pareto分布生成\"),\n",
    "    10: (generate_nodes_cauchy_distributed, \"基于平均值与标准差Cauchy分布生成\"),\n",
    "    11: (generate_nodes_neg_binom_distributed, \"基于平均值与标准差负二项分布生成\"),\n",
    "    12: (generate_nodes_gumbel_distributed, \"基于平均值与标准差Gumbel分布生成\"),\n",
    "    13: (generate_nodes_gompertz_distributed, \"基于平均值与标准差Gompertz分布生成\"),\n",
    "    14: (generate_nodes_normal_distributed, \"基于平均值与标准差Gamma分布生成\"),\n",
    "    15: (generate_nodes_normal_distributed, \"基于平均值与标准差Beta分布生成\"),\n",
    "    16: (generate_nodes_arithmetic, \"基于等差数列生成\"),\n",
    "    17: (generate_nodes_geometric, \"基于等比数列生成\"),\n",
    "    18: (generate_nodes_fibonacci, \"基于斐波那契数列生成\"),\n",
    "    19: (generate_nodes_square, \"基于平方数列生成\"),\n",
    "    20: (generate_nodes_cube, \"基于立方数列生成\"),\n",
    "    21: (generate_nodes_prime, \"基于质数数列生成\"),\n",
    "    22: (generate_nodes_triangular, \"基于三角数列生成\"),\n",
    "    23: (generate_nodes_rectangular, \"基于矩形数列生成\"),\n",
    "    24: (generate_nodes_binomial, \"基于二项式系数数列生成\"),\n",
    "    25: (generate_nodes_hamilton, \"基于哈密顿数列生成\")\n",
    "}\n",
    "#生成单张图\n",
    "def generate_graph(type,num_nodes, mean, std,sim_threshold):\n",
    "    nodes,role_id = node_generators[type][0](num_nodes,mean,std)\n",
    "    #nodes,role_id = generate_nodes_type('hamilton',num_nodes,5,step)\n",
    "    edges = []\n",
    "    for i in range(num_nodes):\n",
    "        for j in range(i+1, num_nodes):\n",
    "            #相似边,自主判断节点相似度,大于一个固定值,则连一条边\n",
    "            if sim_threshold is not None:\n",
    "               #edges |= build_edges(nodes,sim_threshold)\n",
    "               similarity = cosine_similarity(nodes[i:i+1, :], nodes[j:j+1, :])[0, 0]\n",
    "               if similarity > sim_threshold:\n",
    "                  edges.append((i, j))   \n",
    "    G = nx.Graph()\n",
    "    G.add_nodes_from(range(num_nodes))\n",
    "    G.add_edges_from(edges)\n",
    "    role_id = [random.randint(0, 2) for i in range(num_nodes)]\n",
    "        # Convert graph to edge index representation\n",
    "    edge_index = torch.tensor(list(G.edges()), dtype=torch.long).t().contiguous()\n",
    "    if edge_index.size(1) == 0:\n",
    "    # Handle the case where edge_index is empty\n",
    "       default_value = random.randint(0, num_nodes - 1)  # Generate a random default value\n",
    "       edge_index = torch.tensor([[default_value], [default_value]], dtype=torch.long)\n",
    "# Ensure node indices are consecutive and satisfy the assertion\n",
    "    node_idx = torch.unique(edge_index.flatten())\n",
    "    num_nodes = node_idx.size(0)\n",
    "    max_node_idx = torch.max(node_idx).item()\n",
    "    if max_node_idx >= num_nodes or max_node_idx < num_nodes - 1:\n",
    "       num_nodes = max_node_idx + 1\n",
    "       node_idx = torch.arange(num_nodes, dtype=torch.long, device=edge_index.device)\n",
    "    # Adjust the size of node_idx_map to match the updated num_nodes\n",
    "    node_idx_map = torch.zeros(num_nodes, dtype=torch.long, device=node_idx.device)\n",
    "    node_idx_map[node_idx] = torch.arange(num_nodes, dtype=torch.long, device=edge_index.device)[node_idx]\n",
    "    edge_index = node_idx_map[edge_index]\n",
    "    assert node_idx.max() == node_idx.size(0) - 1\n",
    "    return G,role_id\n",
    "#cosine相似度/相似边\n",
    "def generate_graph1(type1,type2,num_nodes, mean, std,sim_threshold):\n",
    "    nodes1,role_id1 = node_generators[type1][0](num_nodes,mean,std)\n",
    "    nodes2,role_id2 = node_generators[type2][0](num_nodes,mean,std)\n",
    "    nodes = np.concatenate((nodes1, nodes2), axis=0)\n",
    "    #nodes,role_id = generate_nodes_type('hamilton',num_nodes,5,step)\n",
    "    num_nodes = len(nodes)\n",
    "    edges = []\n",
    "    for i in range(num_nodes):\n",
    "        for j in range(i+1, num_nodes):\n",
    "            #相似边,自主判断节点相似度,大于一个固定值,则连一条边\n",
    "            if sim_threshold is not None:\n",
    "               #edges |= build_edges(nodes,sim_threshold)\n",
    "               similarity = cosine_similarity(nodes[i:i+1, :], nodes[j:j+1, :])[0, 0]\n",
    "               if similarity > sim_threshold:\n",
    "                  edges.append((i, j))   \n",
    "    G = nx.Graph()\n",
    "    G.add_nodes_from(range(num_nodes))\n",
    "    G.add_edges_from(edges)\n",
    "    role_id = [random.randint(0, 2) for i in range(num_nodes)]\n",
    "        # Convert graph to edge index representation\n",
    "    edge_index = torch.tensor(list(G.edges()), dtype=torch.long).t().contiguous()\n",
    "    if edge_index.size(1) == 0:\n",
    "    # Handle the case where edge_index is empty\n",
    "       default_value = random.randint(0, num_nodes - 1)  # Generate a random default value\n",
    "       edge_index = torch.tensor([[default_value], [default_value]], dtype=torch.long)\n",
    "# Ensure node indices are consecutive and satisfy the assertion\n",
    "    node_idx = torch.unique(edge_index.flatten())\n",
    "    num_nodes = node_idx.size(0)\n",
    "    max_node_idx = torch.max(node_idx).item()\n",
    "    if max_node_idx >= num_nodes or max_node_idx < num_nodes - 1:\n",
    "       num_nodes = max_node_idx + 1\n",
    "       node_idx = torch.arange(num_nodes, dtype=torch.long, device=edge_index.device)\n",
    "    # Adjust the size of node_idx_map to match the updated num_nodes\n",
    "    node_idx_map = torch.zeros(num_nodes, dtype=torch.long, device=node_idx.device)\n",
    "    node_idx_map[node_idx] = torch.arange(num_nodes, dtype=torch.long, device=edge_index.device)[node_idx]\n",
    "    edge_index = node_idx_map[edge_index]\n",
    "    assert node_idx.max() == node_idx.size(0) - 1\n",
    "    return G,role_id\n",
    "#cosine相似度/部分相似边\n",
    "def generate_graph2(type1,type2,num_nodes, mean, std,partial_sim_threshold, dims):\n",
    "    nodes1,role_id1 = node_generators[type1][0](num_nodes,mean,std)\n",
    "    nodes2,role_id2 = node_generators[type2][0](num_nodes,mean,std)\n",
    "    nodes = np.concatenate((nodes1, nodes2), axis=0)\n",
    "    #nodes,role_id = generate_nodes_type('hamilton',num_nodes,5,step)\n",
    "    num_nodes = len(nodes)\n",
    "    edges = []\n",
    "    for i in range(num_nodes):\n",
    "        for j in range(i+1, num_nodes):\n",
    "            #部分相似边，自主判断节点某几个维度相似度,大于一个固定值，则连一条边\n",
    "            if partial_sim_threshold is not None and dims is not None:\n",
    "               #edges |= set(create_edges(nodes, partial_sim_threshold, dims))\n",
    "               sim = cosine_similarity(nodes[i][dims].reshape(1, -1), nodes[j][dims].reshape(1, -1))\n",
    "               sim=sim[0][0]\n",
    "               if sim > partial_sim_threshold:\n",
    "                  edges.append((i, j))     \n",
    "    G = nx.Graph()\n",
    "    G.add_nodes_from(range(num_nodes))\n",
    "    G.add_edges_from(edges)\n",
    "    role_id = [random.randint(0, 2) for i in range(num_nodes)]\n",
    "    return G,role_id\n",
    "'''\n",
    "维护一个列表,简单描述节点连接方法(节点分类人工合成数据集)\n",
    "'''\n",
    "node_connectors = {\n",
    "    1: (generate_graph1, \"相似边，自主判断节点相似度\"),\n",
    "    2: (generate_graph2, \"部分相似边，自主判断节点某几个维度相似度\"),\n",
    "}\n",
    "def generate_Y0():\n",
    "      G0,role_id = generate_graph1(1,2,5,[1.5, 2.0, 1.2, 1.3, 1.8], [1.5, 2.0, 1.2, 1.3, 1.8],0.3)\n",
    "      label=0\n",
    "      return G0, role_id, label\n",
    "def generate_Y1():\n",
    "      G1,role_id = generate_graph2(1,3,5,[1.5, 2.0, 1.2, 1.3, 1.8], [1.5, 2.0, 1.2, 1.3, 1.8],0.5,3)\n",
    "      label=1\n",
    "      return G1, role_id, label\n",
    "def generate_Y2():\n",
    "      G2,role_id = generate_graph2(2,5,5,[1.5, 2.0, 1.2, 1.3, 1.8], [1.5, 2.0, 1.2, 1.3, 1.8],0.5,2)\n",
    "      label=2\n",
    "      return G2, role_id, label\n",
    "def generate_Y3():\n",
    "      G3,role_id = generate_graph1(4,5,5,[1.5, 2.0, 1.2, 1.3, 1.8], [1.5, 2.0, 1.2, 1.3, 1.8],0.5)\n",
    "      label=3\n",
    "      return G3, role_id, label\n",
    "def generate_Y4():\n",
    "      G4,role_id = generate_graph1(3,4,5,[1.5, 2.0, 1.2, 1.3, 1.8], [1.5, 2.0, 1.2, 1.3, 1.8],0.3)\n",
    "      label=4\n",
    "      return G4, role_id, label\n",
    "def generate_real_dataset():\n",
    "      #生成图数据集\n",
    "      #与标签具有真实因果关系的图数据\n",
    "     #for _ in range(num_samples):\n",
    "        y = random.choice([0, 1, 2 ,3 ,4])  # 随机选择Y的取值\n",
    "        if y == 0:\n",
    "           G, role_id, label=generate_Y0()\n",
    "           motif1_present = True\n",
    "           motif2_present = True\n",
    "           motif3_present = False\n",
    "           motif4_present = False\n",
    "           motif5_present = False\n",
    "        elif y == 1:\n",
    "           G, role_id, label=generate_Y1()\n",
    "           motif1_present = True\n",
    "           motif2_present = False\n",
    "           motif3_present = True\n",
    "           motif4_present = False\n",
    "           motif5_present = False\n",
    "        elif y == 2:\n",
    "           G, role_id, label=generate_Y2()\n",
    "           motif1_present = False\n",
    "           motif2_present = True\n",
    "           motif3_present = False\n",
    "           motif4_present = False\n",
    "           motif5_present = True\n",
    "        elif y == 3:\n",
    "           G, role_id, label=generate_Y3()\n",
    "           motif1_present = False\n",
    "           motif2_present = False\n",
    "           motif3_present = False\n",
    "           motif4_present = True\n",
    "           motif5_present = True\n",
    "        elif y == 4:\n",
    "           G, role_id, label=generate_Y4()\n",
    "           motif1_present = False\n",
    "           motif2_present = False\n",
    "           motif3_present = True\n",
    "           motif4_present = True\n",
    "           motif5_present = False\n",
    "        return G, role_id, label,motif1_present,motif2_present,motif3_present,motif4_present,motif5_present\n",
    "#相邻，即两个motif通过一条边连接在一起\n",
    "def adjacent_connection(G1,G2):\n",
    "    nodes1 = set(G1.nodes())\n",
    "    nodes2 = set(G2.nodes())\n",
    "    if not nodes1 or not nodes2 or not G1.edges() or not G2.edges():\n",
    "        return nx.Graph(), [], torch.tensor([], dtype=torch.long)  # 返回空的图形、空的 role_id 列表和空的 edge_index\n",
    "    common_nodes = nodes1.intersection(nodes2)\n",
    "    #随机选择两个图中的一条边作为相邻拼接的公共边\n",
    "    edge1 = random.choice(list(G1.edges()))\n",
    "    edge2 = random.choice(list(G2.edges()))\n",
    "    # 为相邻拼接的公共边创建新节点\n",
    "    new_node1 = max(nodes1.union(nodes2)) + 1\n",
    "    new_node2 = max(nodes1.union(nodes2)) + 2\n",
    "    # 在两个图中替换相邻拼接的公共边\n",
    "    G1.remove_edge(*edge1)\n",
    "    G1.add_edge(edge1[0], new_node1)\n",
    "    G1.add_edge(new_node1, edge1[1])\n",
    "    G2.remove_edge(*edge2)\n",
    "    G2.add_edge(edge2[0], new_node2)\n",
    "    G2.add_edge(new_node2, edge2[1])\n",
    "    # 在两个图中加入新节点\n",
    "    G1.add_node(new_node2)\n",
    "    G2.add_node(new_node1)\n",
    "    # 合并两个图\n",
    "    G = nx.compose(G1, G2)\n",
    "    # 添加相邻拼接的边\n",
    "    G.add_edge(new_node1, new_node2)\n",
    "    # 确保在新图中添加两个原图中所有节点，包括共享的节点\n",
    "    for node in common_nodes:\n",
    "        G.add_node(node, role_id=np.random.randint(low=1, high=len(common_nodes) + 3))\n",
    "    # 判断图是否连通，如果不连通就添加边\n",
    "    if not nx.is_connected(G):\n",
    "        components = nx.connected_components(G)\n",
    "        largest_component = max(components, key=len)\n",
    "        isolated_nodes = [n for n in G.nodes() if n not in largest_component]\n",
    "        for u in isolated_nodes:\n",
    "            v = random.choice(list(largest_component))\n",
    "            G.add_edge(u, v)\n",
    "    # 生成与节点数相同的 role_id 列表\n",
    "    #role_id = [i for i in range(G.number_of_nodes())]\n",
    "    role_id = [0] * G.number_of_nodes()  # 创建一个初始全零的 role_id 列表\n",
    "    role_id[-6:] = torch.randint(1, 4, (6,)).tolist()  # 将后六位设置为 1、2 或 3 中的随机数\n",
    "\n",
    "    # Convert graph to edge index representation\n",
    "    edge_index = torch.tensor(list(G.edges()), dtype=torch.long).t().contiguous()\n",
    "    if edge_index.size(1) == 0:\n",
    "    # Handle the case where edge_index is empty\n",
    "       default_value = random.randint(0, num_nodes - 1)  # Generate a random default value\n",
    "       edge_index = torch.tensor([[default_value], [default_value]], dtype=torch.long)\n",
    "# Ensure node indices are consecutive and satisfy the assertion\n",
    "    node_idx = torch.unique(edge_index.flatten())\n",
    "    num_nodes = node_idx.size(0)\n",
    "    max_node_idx = torch.max(node_idx).item()\n",
    "    if max_node_idx >= num_nodes or max_node_idx < num_nodes - 1:\n",
    "       num_nodes = max_node_idx + 1\n",
    "       node_idx = torch.arange(num_nodes, dtype=torch.long, device=edge_index.device)\n",
    "    # Adjust the size of node_idx_map to match the updated num_nodes\n",
    "    node_idx_map = torch.zeros(num_nodes, dtype=torch.long, device=node_idx.device)\n",
    "    node_idx_map[node_idx] = torch.arange(num_nodes, dtype=torch.long, device=edge_index.device)[node_idx]\n",
    "    edge_index = node_idx_map[edge_index]\n",
    "    assert node_idx.max() == node_idx.size(0) - 1\n",
    "    return G,role_id\n",
    "#显著度高-标准差最小\n",
    "def generate_false_cause_dataset1():\n",
    "        G, role_id, label,motif1_present,motif2_present,motif3_present,motif4_present,motif5_present=generate_real_dataset()\n",
    "        # 根据规则判断是否与类型6-10节点生成的图相关联\n",
    "        if motif1_present == True:\n",
    "           G6,role_id= generate_graph(6,random.randint(5,10),[1.0, 1.2, 1.0, 1.5, 1.0], [1.0, 1.2, 1.0, 1.5, 1.0],0.3)\n",
    "           graph,role_id= adjacent_connection(G, G6)\n",
    "        elif motif2_present == True:\n",
    "           G7,role_id= generate_graph(7,random.randint(5,10),[1.0, 1.2, 1.0, 1.5, 1.0], [1.0, 1.2, 1.0, 1.5, 1.0],0.3)\n",
    "           graph,role_id= adjacent_connection(G, G7)\n",
    "        elif motif3_present == True:\n",
    "           G8,role_id= generate_graph(8,random.randint(5,10),[1.0, 1.2, 1.0, 1.5, 1.0],[1.0, 1.2, 1.0, 1.5, 1.0],0.3)\n",
    "           graph,role_id= adjacent_connection(G, G8)\n",
    "        elif motif4_present == True:\n",
    "           G9,role_id= generate_graph(9,random.randint(5,10),[1.0, 1.2, 1.0, 1.5, 1.0], [1.0, 1.2, 1.0, 1.5, 1.0],0.5)\n",
    "           graph,role_id= adjacent_connection(G, G9)\n",
    "        elif motif5_present == True:\n",
    "           G10,role_id= generate_graph(10,random.randint(5,10),[1.0, 1.2, 1.0, 1.5, 1.0], [1.0, 1.2, 1.0, 1.5, 1.0],0.5)\n",
    "           graph,role_id= adjacent_connection(G, G10)\n",
    "        return graph, role_id, label\n",
    "#显著度高-标准差小\n",
    "def generate_false_cause_dataset2():\n",
    "        G, role_id, label,motif1_present,motif2_present,motif3_present,motif4_present,motif5_present=generate_real_dataset()\n",
    "        # 根据规则判断是否与类型6-10节点生成的图相关联\n",
    "        if motif1_present == True:\n",
    "           G6,role_id= generate_graph(6,random.randint(5,10),[1.0, 2.0, 1.0, 1.5, 3.0], [1.0, 2.0, 1.0, 1.5, 3.0],0.3)\n",
    "           graph,role_id= adjacent_connection(G, G6)\n",
    "        elif motif2_present == True:\n",
    "           G7,role_id= generate_graph(7,random.randint(5,10),[1.0, 2.0, 1.0, 1.5, 3.0], [1.0, 2.0, 1.0, 1.5, 3.0],0.3)\n",
    "           graph,role_id= adjacent_connection(G, G7)\n",
    "        elif motif3_present == True:\n",
    "           G8,role_id= generate_graph(8,random.randint(5,10),[1.0, 2.0, 1.0, 1.5, 3.0], [1.0, 2.0, 1.0, 1.5, 3.0],0.3)\n",
    "           graph,role_id= adjacent_connection(G, G8)\n",
    "        elif motif4_present == True:\n",
    "           G9,role_id= generate_graph(9,random.randint(5,10),[1.0, 2.0, 1.0, 1.5, 3.0], [1.0, 2.0, 1.0, 1.5, 3.0],0.5)\n",
    "           graph,role_id= adjacent_connection(G, G9)\n",
    "        elif motif5_present == True:\n",
    "           G10,role_id= generate_graph(10,random.randint(5,10),[1.0, 2.0, 1.0, 1.5, 3.0], [1.0, 2.0, 1.0, 1.5, 3.0],0.5)\n",
    "           graph,role_id= adjacent_connection(G, G10)\n",
    "        return graph, role_id, label\n",
    "#显著度中-标准差中\n",
    "def generate_false_cause_dataset3():\n",
    "        graph, role_id, label,motif1_present,motif2_present,motif3_present,motif4_present,motif5_present=generate_real_dataset()\n",
    "        # 根据规则判断是否与类型6-10节点生成的图相关联\n",
    "        if motif1_present == True:\n",
    "           G6,role_id= generate_graph(6,random.randint(5,10),[5.0, 7.0, 4.0, 6.0, 10.0], [5.0, 7.0, 4.0, 6.0, 10.0],0.3)\n",
    "           graph,role_id= adjacent_connection(graph, G6)\n",
    "        elif motif2_present == True:\n",
    "           G7,role_id= generate_graph(7,random.randint(5,10),[5.0, 7.0, 4.0, 6.0, 10.0], [5.0, 7.0, 4.0, 6.0, 10.0],0.3)\n",
    "           graph,role_id= adjacent_connection(graph, G7)\n",
    "        elif motif3_present == True:\n",
    "           G8,role_id= generate_graph(8,random.randint(5,10),[5.0, 7.0, 4.0, 6.0, 10.0], [5.0, 7.0, 4.0, 6.0, 10.0],0.3)\n",
    "           graph,role_id= adjacent_connection(graph, G8)\n",
    "        elif motif4_present == True:\n",
    "           G9,role_id= generate_graph(9,random.randint(5,10),[5.0, 7.0, 4.0, 6.0, 10.0], [5.0, 7.0, 4.0, 6.0, 10.0],0.5)\n",
    "           graph,role_id= adjacent_connection(graph, G9)\n",
    "        elif motif5_present == True:\n",
    "           G10,role_id= generate_graph(10,random.randint(5,10),[5.0, 7.0, 4.0, 6.0, 10.0], [5.0, 7.0, 4.0, 6.0, 10.0],0.5)\n",
    "           graph,role_id= adjacent_connection(graph, G10)\n",
    "        return graph, role_id, label\n",
    "#显著度低-标准差大\n",
    "def generate_false_cause_dataset4():\n",
    "        G, role_id, label,motif1_present,motif2_present,motif3_present,motif4_present,motif5_present=generate_real_dataset()\n",
    "        # 根据规则判断是否与类型6-10节点生成的图相关联\n",
    "        if motif1_present == True:\n",
    "           G6,role_id= generate_graph(6,random.randint(5,10),[10.0, 15.0, 8.0, 12.0, 20.0], [10.0, 15.0, 8.0, 12.0, 20.0],0.3)\n",
    "           graph,role_id= adjacent_connection(G, G6)\n",
    "        elif motif2_present == True:\n",
    "           G7,role_id= generate_graph(7,random.randint(5,10),[10.0, 15.0, 8.0, 12.0, 20.0], [10.0, 15.0, 8.0, 12.0, 20.0],0.3)\n",
    "           graph,role_id= adjacent_connection(G, G7)\n",
    "        elif motif3_present == True:\n",
    "           G8,role_id= generate_graph(8,random.randint(5,10),[10.0, 15.0, 8.0, 12.0, 20.0], [10.0, 15.0, 8.0, 12.0, 20.0],0.3)\n",
    "           graph,role_id= adjacent_connection(G, G8)\n",
    "        elif motif4_present == True:\n",
    "           G9,role_id= generate_graph(9,random.randint(5,10),[10.0, 15.0, 8.0, 12.0, 20.0], [10.0, 15.0, 8.0, 12.0, 20.0],0.5)\n",
    "           graph,role_id= adjacent_connection(G, G9)\n",
    "        elif motif5_present == True:\n",
    "           G10,role_id= generate_graph(10,random.randint(5,10),[10.0, 15.0, 8.0, 12.0, 20.0], [10.0, 15.0, 8.0, 12.0, 20.0],0.5)\n",
    "           graph,role_id= adjacent_connection(G, G10)\n",
    "        return graph, role_id, label\n",
    "\n",
    "#显著度中-标准差中高\n",
    "def generate_false_cause_dataset5():\n",
    "        graph, role_id, label,motif1_present,motif2_present,motif3_present,motif4_present,motif5_present=generate_real_dataset()\n",
    "        # 根据规则判断是否与类型6-10节点生成的图相关联\n",
    "        if motif1_present == True:\n",
    "           G6,role_id= generate_graph(6,random.randint(5,10),[8.0, 10.0, 7.0, 9.0, 10.0], [8.0, 10.0, 7.0, 9.0, 10.0],0.3)\n",
    "           graph,role_id= adjacent_connection(graph, G6)\n",
    "        elif motif2_present == True:\n",
    "           G7,role_id= generate_graph(7,random.randint(5,10),[8.0, 10.0, 7.0, 9.0, 10.0], [8.0, 10.0, 7.0, 9.0, 10.0],0.3)\n",
    "           graph,role_id= adjacent_connection(graph, G7)\n",
    "        elif motif3_present == True:\n",
    "           G8,role_id= generate_graph(8,random.randint(5,10),[8.0, 10.0, 7.0, 9.0, 10.0], [8.0, 10.0, 7.0, 9.0, 10.0],0.3)\n",
    "           graph,role_id= adjacent_connection(graph, G8)\n",
    "        elif motif4_present == True:\n",
    "           G9,role_id= generate_graph(9,random.randint(5,10),[8.0, 10.0, 7.0, 9.0, 10.0], [8.0, 10.0, 7.0, 9.0, 10.0],0.5)\n",
    "           graph,role_id= adjacent_connection(graph, G9)\n",
    "        elif motif5_present == True:\n",
    "           G10,role_id= generate_graph(10,random.randint(5,10),[8.0, 10.0, 7.0, 9.0, 10.0], [8.0, 10.0, 7.0, 9.0, 10.0],0.5)\n",
    "           graph,role_id= adjacent_connection(graph, G10)\n",
    "        return graph, role_id, label\n",
    "#显著度低-标准差最大\n",
    "def generate_false_cause_dataset6():\n",
    "        G, role_id, label,motif1_present,motif2_present,motif3_present,motif4_present,motif5_present=generate_real_dataset()\n",
    "        # 根据规则判断是否与类型6-10节点生成的图相关联\n",
    "        if motif1_present == True:\n",
    "           G6,role_id= generate_graph(6,random.randint(5,10),[20.0, 25.0, 18.0, 22.0, 20.0], [20.0, 25.0, 18.0, 22.0, 20.0],0.3)\n",
    "           graph,role_id= adjacent_connection(G, G6)\n",
    "        elif motif2_present == True:\n",
    "           G7,role_id= generate_graph(7,random.randint(5,10),[20.0, 25.0, 18.0, 22.0, 20.0], [20.0, 25.0, 18.0, 22.0, 20.0],0.3)\n",
    "           graph,role_id= adjacent_connection(G, G7)\n",
    "        elif motif3_present == True:\n",
    "           G8,role_id= generate_graph(8,random.randint(5,10),[20.0, 25.0, 18.0, 22.0, 20.0], [20.0, 25.0, 18.0, 22.0, 20.0],0.3)\n",
    "           graph,role_id= adjacent_connection(G, G8)\n",
    "        elif motif4_present == True:\n",
    "           G9,role_id= generate_graph(9,random.randint(5,10),[20.0, 25.0, 18.0, 22.0, 20.0], [20.0, 25.0, 18.0, 22.0, 20.0],0.5)\n",
    "           graph,role_id= adjacent_connection(G, G9)\n",
    "        elif motif5_present == True:\n",
    "           G10,role_id= generate_graph(10,random.randint(5,10),[20.0, 25.0, 18.0, 22.0, 20.0], [20.0, 25.0, 18.0, 22.0, 20.0],0.5)\n",
    "           graph,role_id= adjacent_connection(G, G10)\n",
    "        return graph, role_id, label\n",
    "def generate_false_dataset():\n",
    "    #G,role_id,label=generate_false_cause_dataset2()\n",
    "    G, role_id, label,motif1_present,motif2_present,motif3_present,motif4_present,motif5_present=generate_real_dataset()\n",
    "    #与标签无关系的图数据\n",
    "    num1 = random.randint(6, 15)\n",
    "    num2 = random.randint(6, 15)\n",
    "    graph,role_idr= generate_graph1(num1,num2,10,[1.5, 2.0, 1.2, 1.3, 1.8], [1.5, 2.0, 1.2, 1.3, 1.8],0.5)\n",
    "    graph,role_id=adjacent_connection(G,graph)\n",
    "    G_noisy, role_id_noisy, label_noisy=add_noise(graph,0, 0.1, 0, 0.1,label)\n",
    "    return G_noisy, role_id_noisy,label_noisy\n",
    "'''\n",
    "数据集中还要额外包含添加随机噪声数据的功能：\n",
    "1.随机删除或者创建固定数量边(百分比)\n",
    "2.随机删除一定数量的节点（百分比）\n",
    "3.随机创建一定数量的节点，并且将这些节点与已经存在的图随机相连（百分比）\n",
    "'''\n",
    "def add_noise(G, delete_edge_prob, add_edge_prob, delete_node_prob, add_node_prob,label=None):\n",
    "    # create a copy of original graph to add noise\n",
    "    G_noisy = copy.deepcopy(G)\n",
    "    # check if graph is connected\n",
    "    '''\n",
    "    if not nx.is_connected(G_noisy):\n",
    "        # add edges to make it connected\n",
    "        for c in nx.connected_components(G_noisy):\n",
    "            node1, node2 = random.sample(c, 2)\n",
    "            G_noisy.add_edge(node1, node2)\n",
    "    '''\n",
    "    # randomly delete edges\n",
    "    num_edges_to_delete = int(delete_edge_prob * G_noisy.number_of_edges())\n",
    "    edges_to_delete = random.sample(G_noisy.edges(), num_edges_to_delete)\n",
    "    G_noisy.remove_edges_from(edges_to_delete)\n",
    "    role_id_noisy = [random.randint(0, 2) for i in range(G_noisy.number_of_nodes())]\n",
    "    \n",
    "    # randomly add edges\n",
    "    num_edges_to_add = int(add_edge_prob * G_noisy.number_of_nodes() * (G_noisy.number_of_nodes()-1)/2)\n",
    "    for i in range(num_edges_to_add):\n",
    "        node1, node2 = random.sample(G_noisy.nodes(), 2)\n",
    "        if not G_noisy.has_edge(node1, node2):\n",
    "            G_noisy.add_edge(node1, node2)\n",
    "    # randomly delete nodes\n",
    "    num_nodes_to_delete = int(delete_node_prob * G_noisy.number_of_nodes())\n",
    "    nodes_to_delete = random.sample(G_noisy.nodes(), num_nodes_to_delete)\n",
    "    for node in nodes_to_delete:\n",
    "        G_noisy.remove_node(node)\n",
    "    # randomly add nodes\n",
    "    num_nodes_to_add = int(add_node_prob * G_noisy.number_of_nodes())\n",
    "    for i in range(num_nodes_to_add):\n",
    "        node_id = G_noisy.number_of_nodes() + 1\n",
    "        G_noisy.add_node(node_id)\n",
    "        # randomly connect new node to existing nodes\n",
    "        connected = False\n",
    "        while not connected:\n",
    "            nodes_to_connect = random.sample(G_noisy.nodes(), random.randint(1, G_noisy.number_of_nodes()-1))\n",
    "            for n in nodes_to_connect:\n",
    "                if not G_noisy.has_edge(node_id, n):\n",
    "                    G_noisy.add_edge(node_id, n)\n",
    "            connected = nx.is_connected(G_noisy)\n",
    "            if not connected:\n",
    "                for n in nodes_to_connect:\n",
    "                    G_noisy.remove_edge(node_id, n)\n",
    "    # set new role_id and label for noisy graph\n",
    "    role_id_noisy = [random.randint(0, 2) for i in range(G_noisy.number_of_edges())]\n",
    "    label_noisy = label\n",
    "    # return noisy graph and corresponding role_id and label\n",
    "    return G_noisy, role_id_noisy, label_noisy"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "id": "fde7c855",
   "metadata": {},
   "source": [
    "## Training Dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "5f5dc82f",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 1000/1000 [00:35<00:00, 27.89it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "#Graphs: 1000    #Nodes: 12.00    #Edges: 45.12 \n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    }
   ],
   "source": [
    "edge_index_list, label_list = [], []\n",
    "ground_truth_list, role_id_list, pos_list = [], [], []\n",
    "e_mean, n_mean = [], []\n",
    "for _ in tqdm(range(1000)):\n",
    "    G,role_id,label=generate_false_cause_dataset1()\n",
    "    #G,role_id,label=generate_false_dataset()\n",
    "    label_list.append(label)#在列表末尾添加元素\n",
    "    e_mean.append(len(G.edges))\n",
    "    n_mean.append(len(G.nodes))\n",
    "    role_id = np.array(role_id)\n",
    "    edge_index = np.array(G.edges, dtype=int).T\n",
    "    role_id_list.append(role_id)\n",
    "    edge_index_list.append(edge_index)\n",
    "    pos_list.append(np.array(list(nx.spring_layout(G).values())))\n",
    "    ground_truth_list.append(find_gd(edge_index, role_id))\n",
    "print(\"#Graphs: %d    #Nodes: %.2f    #Edges: %.2f \" % (len(ground_truth_list), np.mean(n_mean), np.mean(e_mean)))\n",
    "np.save(osp.join(data_dir, 'train.npy'), {'edge_index': edge_index_list, 'label': label_list, 'ground_truth': ground_truth_list, 'role_id': role_id_list, 'pos': pos_list})"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "id": "c3587610",
   "metadata": {},
   "source": [
    "## Val Dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "ebe75786",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 1000/1000 [00:35<00:00, 27.81it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "#Graphs: 1000    #Nodes: 12.00    #Edges: 44.80 \n"
     ]
    }
   ],
   "source": [
    "edge_index_list, label_list = [], []\n",
    "ground_truth_list, role_id_list, pos_list = [], [], []\n",
    "e_mean, n_mean = [], []\n",
    "for _ in tqdm(range(1000)):\n",
    "    G,role_id,label=generate_false_cause_dataset1()\n",
    "    label_list.append(label)#在列表末尾添加元素\n",
    "    e_mean.append(len(G.edges))\n",
    "    n_mean.append(len(G.nodes))\n",
    "    role_id = np.array(role_id)\n",
    "    edge_index = np.array(G.edges, dtype=int).T\n",
    "    role_id_list.append(role_id)\n",
    "    edge_index_list.append(edge_index)\n",
    "    pos_list.append(np.array(list(nx.spring_layout(G).values())))\n",
    "    ground_truth_list.append(find_gd(edge_index, role_id))\n",
    "print(\"#Graphs: %d    #Nodes: %.2f    #Edges: %.2f \" % (len(ground_truth_list), np.mean(n_mean), np.mean(e_mean)))\n",
    "np.save(osp.join(data_dir, 'val.npy'), {'edge_index': edge_index_list, 'label': label_list, 'ground_truth': ground_truth_list, 'role_id': role_id_list, 'pos': pos_list})"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "id": "9d858281",
   "metadata": {},
   "source": [
    "## Testing Dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "1aff6470",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 2000/2000 [03:27<00:00,  9.64it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "#Graphs: 2000    #Nodes: 24.00    #Edges: 164.70 \n"
     ]
    }
   ],
   "source": [
    "edge_index_list, label_list = [], []\n",
    "ground_truth_list, role_id_list, pos_list = [], [], []\n",
    "e_mean, n_mean = [], []\n",
    "for _ in tqdm(range(2000)):\n",
    "    G,role_id,label=generate_false_dataset()\n",
    "    label_list.append(label)#在列表末尾添加元素\n",
    "    e_mean.append(len(G.edges))\n",
    "    n_mean.append(len(G.nodes))\n",
    "    role_id = np.array(role_id)\n",
    "    edge_index = np.array(G.edges, dtype=int).T\n",
    "    role_id_list.append(role_id)\n",
    "    edge_index_list.append(edge_index)\n",
    "    pos_list.append(np.array(list(nx.spring_layout(G).values())))\n",
    "    ground_truth_list.append(find_gd(edge_index, role_id))\n",
    "print(\"#Graphs: %d    #Nodes: %.2f    #Edges: %.2f \" % (len(ground_truth_list), np.mean(n_mean), np.mean(e_mean)))\n",
    "np.save(osp.join(data_dir, 'test.npy'), {'edge_index': edge_index_list, 'label': label_list, 'ground_truth': ground_truth_list, 'role_id': role_id_list, 'pos': pos_list})"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3.9.7 ('base': conda)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.5"
  },
  "vscode": {
   "interpreter": {
    "hash": "c71b0b87ea436ae79e2503ec051639fc2420e91bd742cb356b7debceb9d5ed19"
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
