{
 "cells": [
  {
   "attachments": {},
   "cell_type": "markdown",
   "id": "a32dd142",
   "metadata": {},
   "source": [
    "### Generate CRCG-NODE Dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "142774b6",
   "metadata": {},
   "outputs": [],
   "source": [
    "from BA3_loc import *\n",
    "from tqdm import tqdm\n",
    "import os.path as osp\n",
    "import warnings\n",
    "warnings.filterwarnings(\"ignore\")\n",
    "import random\n",
    "import math\n",
    "import torch\n",
    "import copy\n",
    "\n",
    "from scipy.stats import gamma\n",
    "from scipy.stats import gompertz\n",
    "from sklearn.metrics.pairwise import cosine_similarity\n",
    "from scipy.stats import weibull_min\n",
    "from scipy.special import gamma, gammaincinv\n",
    "from scipy.spatial.distance import cdist\n",
    "from sklearn.metrics.pairwise import euclidean_distances\n",
    "from collections import deque\n",
    "data_dir = f'./paper/'\n",
    "os.makedirs(data_dir, exist_ok=True)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "397bda31",
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "\n",
    "def generate_gamma(mu, sigma, size):\n",
    "\n",
    "    var = np.power(sigma, 2)\n",
    "    theta = np.divide(var, mu)\n",
    "    k = np.divide(mu, theta)\n",
    "\n",
    "    return gamma.rvs(a=k, scale=theta, size=size)\n",
    "def generate_nodes_normal_distributed(num_nodes,mean, std):\n",
    "      role_id = []\n",
    "      nodes = np.zeros((num_nodes, len(mean)))\n",
    "      for i in range(len(mean)):\n",
    "          nodes[:, i] = np.random.normal(loc=mean[i], scale=std[i], size=num_nodes)\n",
    "      role_id = list(range(num_nodes))\n",
    "      return nodes,role_id\n",
    "def generate_nodes_uniform_distributed(num_nodes,mean, std):\n",
    "      role_id = []\n",
    "      nodes = np.zeros((num_nodes, len(mean)))\n",
    "      for i in range(len(mean)):\n",
    "          nodes[:, i] = np.random.uniform(low=mean[i]- std[i], high=mean[i] + std[i], size=num_nodes)\n",
    "      role_id = list(range(num_nodes))\n",
    "      return nodes,role_id   \n",
    "def generate_nodes_exponential_distributed(num_nodes,mean, std):\n",
    "      role_id = []\n",
    "      nodes = np.zeros((num_nodes, len(mean)))\n",
    "      for i in range(len(mean)):\n",
    "          nodes[:, i] = np.random.exponential(scale=std[i]/mean[i], size=num_nodes) * mean[i]\n",
    "      role_id = list(range(num_nodes))\n",
    "      return nodes,role_id\n",
    "def generate_nodes_lognormal_distributed(num_nodes,mean, std):\n",
    "      role_id = []\n",
    "      nodes = np.zeros((num_nodes, len(mean)))\n",
    "      for i in range(len(mean)):\n",
    "            mu = np.log(mean[i]**2/np.sqrt(std[i]**2+mean[i]**2))\n",
    "            sigma = np.sqrt(np.log(std[i]**2/mean[i]**2 + 1))\n",
    "            nodes[:, i] = np.random.lognormal(mu, sigma, num_nodes)\n",
    "      role_id = list(range(num_nodes))\n",
    "      return nodes,role_id\n",
    "def generate_nodes_gamma_distributed(num_nodes,mean, std):\n",
    "      role_id = []\n",
    "      nodes = np.zeros((num_nodes, len(mean)))\n",
    "      for i in range(len(mean)):\n",
    "        nodes[:,i] = generate_gamma(mean[i], std[i], num_nodes)\n",
    "      role_id = list(range(num_nodes))\n",
    "      return nodes,role_id  \n",
    "def generate_nodes_beta_distributed(num_nodes,mean, std):\n",
    "      role_id = []\n",
    "      nodes = np.zeros((num_nodes, len(mean)))\n",
    "      for i in range(len(mean)):\n",
    "        a = ((1 - np.array(mean)) / np.array(std) ** 2 - 1 / np.array(mean)) * np.array(mean) ** 2\n",
    "        b = a * (1 / np.array(mean) - 1)\n",
    "        nodes[:,i] = np.random.beta(a=a[i], b=b[i], size=num_nodes)\n",
    "      role_id = list(range(num_nodes))\n",
    "      return nodes,role_id\n",
    "def generate_nodes_weibull_distributed(num_nodes,mean, std):\n",
    "      role_id = []\n",
    "\n",
    "      k = np.random.uniform(low=0.5, high=2.0, size=5)\n",
    "      lam = np.random.uniform(low=0.5, high=2.0, size=5)\n",
    "      nodes = np.zeros((num_nodes, len(mean)))\n",
    "      for i in range(len(mean)):\n",
    "        nodes[:, i] = weibull_min.rvs(k[i], scale=lam[i], size=num_nodes)\n",
    "        nodes[:, i] = (nodes[:, i] - nodes[:, i].mean()) / nodes[:, i].std()\n",
    "        nodes[:, i] = nodes[:, i] * std[i] + mean[i]\n",
    "      role_id = list(range(num_nodes))\n",
    "      return nodes,role_id\n",
    "def generate_nodes_laplace_distributed(num_nodes,mean, std):\n",
    "      role_id = []\n",
    "      nodes = np.zeros((num_nodes, len(mean)))\n",
    "      for i in range(len(mean)):\n",
    "        mu = mean[i]\n",
    "        b = std[i] / np.sqrt(2)\n",
    "        nodes[:, i] = np.random.laplace(mu, b, size=num_nodes)\n",
    "      role_id = list(range(num_nodes))\n",
    "      return nodes,role_id\n",
    "def generate_nodes_logistic_distributed(num_nodes,mean, std):    \n",
    "      role_id = []\n",
    "      nodes = np.zeros((num_nodes, len(mean)))\n",
    "      for i in range(len(mean)):\n",
    "        mu = mean[i]\n",
    "        s = std[i] / np.sqrt(3)\n",
    "        nodes[:, i] = np.random.logistic(mu, s, size=num_nodes)\n",
    "      role_id = list(range(num_nodes))\n",
    "      return nodes,role_id\n",
    "def generate_nodes_rayleigh_distributed(num_nodes,mean, std):\n",
    "      dim = len(mean)    \n",
    "      role_id = []\n",
    "      nodes = np.zeros((num_nodes, dim))\n",
    "      for i in range(dim):\n",
    "        sigma = std[i] / np.sqrt(2*np.pi)\n",
    "        nodes[:, i] = np.random.rayleigh(sigma, size=num_nodes)\n",
    "      role_id = list(range(num_nodes))\n",
    "      return nodes,role_id\n",
    "def generate_nodes_pareto_distributed(num_nodes,mean, std):\n",
    "      role_id = []\n",
    "      nodes = np.zeros((num_nodes, len(mean)))\n",
    "      for i in range(len(mean)):\n",
    "        alpha = mean[i] / std[i]\n",
    "        nodes[:, i] = np.random.pareto(alpha, size=num_nodes)\n",
    "      role_id = list(range(num_nodes))\n",
    "      return nodes,role_id\n",
    "def generate_nodes_cauchy_distributed(num_nodes,mean, std):\n",
    "      role_id = []\n",
    "      nodes = np.zeros((num_nodes, len(mean)))\n",
    "      for i in range(len(mean)):\n",
    "        x0 = mean[i]\n",
    "        gamma = std[i] * np.sqrt(np.pi / 2)\n",
    "        nodes[:, i] = np.random.standard_cauchy(size=num_nodes) * gamma + x0\n",
    "      role_id = list(range(num_nodes))\n",
    "      return nodes,role_id\n",
    "def generate_nodes_neg_binom_distributed(num_nodes,mean, std):\n",
    "      role_id = []\n",
    "      nodes = np.zeros((num_nodes, len(mean)))\n",
    "      for i in range(len(mean)):\n",
    "        p = mean[i] / (std[i]**2 + mean[i])\n",
    "        r = mean[i]**2 / (std[i]**2 - mean[i])\n",
    "        nodes[:, i]=np.random.negative_binomial(r, 1-p, size=num_nodes)\n",
    "      role_id = list(range(num_nodes)) \n",
    "      return nodes,role_id \n",
    "def generate_nodes_gumbel_distributed(num_nodes,mean, std):\n",
    "      role_id = []\n",
    "      nodes = np.zeros((num_nodes, len(mean)))\n",
    "      for i in range(len(mean)):\n",
    "         loc = mean[i] - 0.57721 * std[i]\n",
    "         scale = np.sqrt(6) * std[i] / np.pi\n",
    "         nodes[:, i]=np.random.gumbel(loc=loc, scale=scale, size=num_nodes)\n",
    "      role_id = list(range(num_nodes)) \n",
    "      return nodes,role_id   \n",
    "def generate_nodes_gompertz_distributed(num_nodes,mean, std):\n",
    "      role_id = []\n",
    "      nodes = np.zeros((num_nodes, len(mean)))\n",
    "      for i in range(len(mean)):\n",
    "         loc = mean[i] - np.log(np.log(2)/2)*std[i]\n",
    "         scale = np.exp(std[i]/np.log(2))\n",
    "         nodes[:, i]=gompertz.rvs(c=loc, scale=scale, size=num_nodes)\n",
    "      role_id = list(range(num_nodes))               \n",
    "      return nodes,role_id\n",
    "def rectangle_sequence(n):\n",
    "    seq = [0]\n",
    "    for i in range(1, n+1):\n",
    "        seq.append(seq[-1] + i*2)\n",
    "    return seq[1:]\n",
    "def binomial_coefficients(n, dim):\n",
    "    \n",
    "    seq = []\n",
    "    for i in range(1, dim+1):\n",
    "        seq.append(math.comb(n, i))\n",
    "    return seq\n",
    "def generate_nodes_arithmetic(num_nodes,dims,step):\n",
    "      role_id = []\n",
    "      nodes = []\n",
    "      for i in range(num_nodes):\n",
    "          start = random.uniform(0,10)\n",
    "          node = [start + j * step for j in range(dims)]\n",
    "          nodes.append(node)\n",
    "\n",
    "      role_id = [random.randint(0, 2) for i in range(num_nodes)]\n",
    "      return np.array(nodes).reshape((num_nodes, dims)),role_id\n",
    "def generate_nodes_geometric(num_nodes,dims,step):\n",
    "      role_id = []\n",
    "      nodes = []\n",
    "      for i in range(num_nodes):\n",
    "          start = random.uniform(0,10)\n",
    "          node = [start*step**j for j in range(dims)]\n",
    "          nodes.append(node)\n",
    "\n",
    "      role_id = [random.randint(0, 2) for i in range(num_nodes)]\n",
    "      return np.array(nodes).reshape((num_nodes, dims)),role_id\n",
    "def generate_nodes_fibonacci(num_nodes,dims,step):\n",
    "      role_id = []\n",
    "      nodes = []\n",
    "      for i in range(num_nodes):\n",
    "\n",
    "          start1 = round(random.uniform(0, 10), 1)\n",
    "          start2 = round(random.uniform(0, 10), 1)\n",
    "          fib_nums = [start1, start2]\n",
    "          for j in range(dims - 2):\n",
    "              fib_nums.append(fib_nums[-1] + fib_nums[-2])\n",
    "\n",
    "\n",
    "          nodes.append(fib_nums)\n",
    "      role_id = [random.randint(0, 2) for i in range(num_nodes)]\n",
    "      return np.array(nodes).reshape((num_nodes, dims)),role_id\n",
    "def generate_nodes_square(num_nodes,dims,step):\n",
    "        role_id = []\n",
    "        nodes = []\n",
    "        for i in range(num_nodes):\n",
    "          initial_val = random.uniform(0, 10)\n",
    "          node = [initial_val]\n",
    "          for j in range(dims-1):\n",
    "              node.append(initial_val ** 2)\n",
    "              initial_val=initial_val ** 2\n",
    "          nodes.append(node)\n",
    "        role_id = [random.randint(0, 2) for i in range(num_nodes)]\n",
    "        return np.array(nodes).reshape((num_nodes, dims)),role_id\n",
    "def generate_nodes_cube(num_nodes,dims,step):\n",
    "        role_id = []\n",
    "        nodes = []\n",
    "        for i in range(num_nodes):\n",
    "          initial_val = random.uniform(0, 10)\n",
    "          node = [initial_val]\n",
    "          for j in range(dims-1):\n",
    "              node.append(initial_val ** 3)\n",
    "              initial_val=initial_val ** 3\n",
    "          nodes.append(node)\n",
    "        role_id = [random.randint(0, 2) for i in range(num_nodes)]\n",
    "        return np.array(nodes).reshape((num_nodes, dims)),role_id\n",
    "def generate_nodes_prime(num_nodes,dims,step):\n",
    "       nodes = []\n",
    "       for i in range(num_nodes):\n",
    "\n",
    "          start = random.uniform(0, 10)\n",
    "\n",
    "          primes = []\n",
    "          n = 2\n",
    "          while len(primes) < dims:\n",
    "              if all(n % p != 0 for p in primes):\n",
    "                  primes.append(n)\n",
    "              n += 1\n",
    "          node = np.array(primes) * start\n",
    "          nodes.append(node)\n",
    "       role_id = [random.randint(0, 2) for i in range(num_nodes)]\n",
    "       return np.array(nodes).reshape((num_nodes, dims)),role_id\n",
    "def generate_nodes_triangular(num_nodes,dims,step):\n",
    "\n",
    "      nodes = []\n",
    "      for i in range(num_nodes):\n",
    "          initial_value = random.uniform(0, 10)\n",
    "          features = []\n",
    "          for j in range(dims):\n",
    "              feature = 0.5 * initial_value * (initial_value + 1)\n",
    "              features.append(feature)\n",
    "              initial_value += 1\n",
    "          nodes.append(features)\n",
    "      role_id = [random.randint(0, 2) for i in range(num_nodes)]\n",
    "      return np.array(nodes).reshape((num_nodes, dims)),role_id\n",
    "def generate_nodes_rectangular(num_nodes,dims,step):    \n",
    "\n",
    "      nodes = []\n",
    "      for i in range(num_nodes):\n",
    "          initial_value = random.uniform(0, 10)\n",
    "          seq = rectangle_sequence(dims)\n",
    "          feature = [initial_value + item for item in seq]\n",
    "          nodes.append(feature)\n",
    "      role_id = [random.randint(0, 2) for i in range(num_nodes)]\n",
    "      return np.array(nodes).reshape((num_nodes, dims)),role_id\n",
    "def generate_nodes_binomial(num_nodes,dims,step):\n",
    "\n",
    "      \n",
    "      nodes = []\n",
    "      for i in range(num_nodes):\n",
    "          rand_n = random.randint(1, 10)\n",
    "          seq = binomial_coefficients(rand_n, dims)\n",
    "          nodes.append(seq)\n",
    "      role_id = [random.randint(0, 2) for i in range(num_nodes)]\n",
    "      return np.array(nodes).reshape((num_nodes, dims)),role_id\n",
    "def generate_nodes_hamilton(num_nodes,dims,step):\n",
    "\n",
    "      nodes = []\n",
    "      for i in range(num_nodes):\n",
    "          initial_value = random.randint(1, 10)\n",
    "          hamilton_seq = [2**n - 1 for n in range(initial_value, initial_value+dims)]\n",
    "          nodes.append(hamilton_seq)\n",
    "      print(nodes)\n",
    "      role_id = [random.randint(0,2) for i in range(num_nodes)]\n",
    "      return np.array(nodes).reshape((num_nodes, dims)),role_id\n",
    "def merge_nodes(node_set1, node_set2):\n",
    "\n",
    "    return np.vstack((node_set1, node_set2))\n",
    "def build_sim_edges(nodes,sim_threshold):\n",
    "    num_nodes = nodes.shape[0]\n",
    "    edges = set()\n",
    "    for i in range(num_nodes):\n",
    "        for j in range(i+1, num_nodes):\n",
    "            similarity = cosine_similarity(nodes[i:i+1, :], nodes[j:j+1, :])[0, 0]\n",
    "            if similarity > sim_threshold:\n",
    "                edges.add((i, j))\n",
    "    return edges\n",
    "def create_partial_sim_edges(nodes, partial_sim_threshold,dims):\n",
    "    num_nodes = len(nodes)\n",
    "    edges = []\n",
    "    for i in range(num_nodes):\n",
    "        for j in range(i+1, num_nodes):\n",
    "            sim = cosine_similarity(nodes[i][dims].reshape(1, -1), nodes[j][dims].reshape(1, -1))\n",
    "            sim=sim[0][0]\n",
    "            if sim > partial_sim_threshold:\n",
    "                edges.append((i, j))\n",
    "    return edges\n",
    "\n",
    "node_generators = {\n",
    "    1: (generate_nodes_normal_distributed, \"Normal distribution generation based on mean and standard deviation\"),\n",
    "    2: (generate_nodes_uniform_distributed, \"Uniform distribution generation based on mean and standard deviation\"),\n",
    "    3: (generate_nodes_exponential_distributed, \"Exponential distribution generation based on mean and standard deviation\"),\n",
    "    4: (generate_nodes_lognormal_distributed, \"Log-normal distribution generation based on mean and standard deviation\"),\n",
    "    5: (generate_nodes_weibull_distributed, \"Weibull distribution generation based on mean and standard deviation\"),\n",
    "    6: (generate_nodes_laplace_distributed, \"Laplace distribution generation based on mean and standard deviation\"),\n",
    "    7: (generate_nodes_logistic_distributed, \"Logistic distribution generation based on mean and standard deviation\"),\n",
    "    8: (generate_nodes_rayleigh_distributed, \"Rayleigh distribution generation based on mean and standard deviation\"),\n",
    "    9: (generate_nodes_pareto_distributed, \"Pareto distribution generation based on mean and standard deviation\"),\n",
    "    10: (generate_nodes_cauchy_distributed, \"Cauchy distribution generation based on mean and standard deviation\"),\n",
    "    11: (generate_nodes_neg_binom_distributed, \"Negative binomial distribution generation based on mean and standard deviation\"),\n",
    "    12: (generate_nodes_gumbel_distributed, \"Gumbel distribution generation based on mean and standard deviation\"),\n",
    "    13: (generate_nodes_gompertz_distributed, \"Gompertz distribution generation based on mean and standard deviation\"),\n",
    "    14: (generate_nodes_normal_distributed, \"Gamma distribution generation based on mean and standard deviation\"),\n",
    "    15: (generate_nodes_normal_distributed, \"Beta distribution generation based on mean and standard deviation\"),\n",
    "    16: (generate_nodes_arithmetic, \"Arithmetic sequence generation\"),\n",
    "    17: (generate_nodes_geometric, \"Geometric sequence generation\"),\n",
    "    18: (generate_nodes_fibonacci, \"Fibonacci sequence generation\"),\n",
    "    19: (generate_nodes_square, \"Square number sequence generation\"),\n",
    "    20: (generate_nodes_cube, \"Cube number sequence generation\"),\n",
    "    21: (generate_nodes_prime, \"Prime number sequence generation\"),\n",
    "    22: (generate_nodes_triangular, \"Triangular number sequence generation\"),\n",
    "    23: (generate_nodes_rectangular, \"Rectangular number sequence generation\"),\n",
    "    24: (generate_nodes_binomial, \"Binomial coefficient sequence generation\"),\n",
    "    25: (generate_nodes_hamilton, \"Hamiltonian sequence generation\")\n",
    "}\n",
    "\n",
    "def generate_graph(type, num_nodes, mean, std, sim_threshold):\n",
    "    nodes,role_id = node_generators[type][0](num_nodes,mean,std)\n",
    "\n",
    "    edges = []\n",
    "    for i in range(num_nodes):\n",
    "        for j in range(i+1, num_nodes):\n",
    "\n",
    "            if sim_threshold is not None:\n",
    "\n",
    "               similarity = cosine_similarity(nodes[i:i+1, :], nodes[j:j+1, :])[0, 0]\n",
    "               if similarity > sim_threshold:\n",
    "                  edges.append((i, j))   \n",
    "    G = nx.Graph()\n",
    "    G.add_nodes_from(range(num_nodes))\n",
    "    G.add_edges_from(edges)\n",
    "    role_id = [random.randint(0, 2) for i in range(num_nodes)]\n",
    "\n",
    "    edge_index = torch.tensor(list(G.edges()), dtype=torch.long).t().contiguous()\n",
    "    if edge_index.size(1) == 0:\n",
    "\n",
    "       default_value = random.randint(0, num_nodes - 1)\n",
    "       edge_index = torch.tensor([[default_value], [default_value]], dtype=torch.long)\n",
    "\n",
    "    node_idx = torch.unique(edge_index.flatten())\n",
    "    num_nodes = node_idx.size(0)\n",
    "    max_node_idx = torch.max(node_idx).item()\n",
    "    if max_node_idx >= num_nodes or max_node_idx < num_nodes - 1:\n",
    "       num_nodes = max_node_idx + 1\n",
    "       node_idx = torch.arange(num_nodes, dtype=torch.long, device=edge_index.device)\n",
    "\n",
    "    node_idx_map = torch.zeros(num_nodes, dtype=torch.long, device=node_idx.device)\n",
    "    node_idx_map[node_idx] = torch.arange(num_nodes, dtype=torch.long, device=edge_index.device)[node_idx]\n",
    "    edge_index = node_idx_map[edge_index]\n",
    "    assert node_idx.max() == node_idx.size(0) - 1\n",
    "    return G,role_id\n",
    "\n",
    "def create_paper_citation_graph(num_papers, avg_citations_per_paper, num_classes, mean, std):\n",
    "    G = nx.DiGraph()\n",
    "    role_id = []\n",
    "\n",
    "\n",
    "    for paper_id in range(num_papers):\n",
    "        features = np.random.normal(mean,std,size=(5,))\n",
    "        label = np.random.randint(0, num_classes)\n",
    "        role_id.append(label)\n",
    "        G.add_node(paper_id, features=features, label=label)\n",
    "\n",
    "\n",
    "    for paper_id in range(num_papers):\n",
    "        remaining_papers = [p_id for p_id in range(num_papers) if not G.has_edge(p_id, paper_id)]\n",
    "        num_remaining_papers = len(remaining_papers)\n",
    "        adjusted_avg_citations_per_paper = min(avg_citations_per_paper, num_remaining_papers)\n",
    "        num_citations = np.random.poisson(adjusted_avg_citations_per_paper)\n",
    "\n",
    "        if num_remaining_papers == 0:\n",
    "            continue\n",
    "        cited_papers = np.random.choice(remaining_papers, size=min(num_citations, num_remaining_papers), replace=False)\n",
    "        for cited_paper_id in cited_papers:\n",
    "            G.add_edge(paper_id, cited_paper_id)\n",
    "\n",
    "    return G, role_id\n",
    "\n",
    "def create_paper_citation_graph2(num_papers, avg_citations_per_paper, num_classes, mean, std):\n",
    "    G = nx.DiGraph()\n",
    "    role_id = []\n",
    "\n",
    "\n",
    "    for paper_id in range(num_papers):\n",
    "        features = np.random.normal(mean,std, size=(5,))\n",
    "        label = np.random.randint(0, num_classes)\n",
    "        role_id.append(label)\n",
    "        G.add_node(paper_id, features=features, label=label)\n",
    "\n",
    "\n",
    "    citation_counts = {paper_id: np.random.poisson(avg_citations_per_paper) for paper_id in range(num_papers)}\n",
    "\n",
    "\n",
    "    for paper_id in range(num_papers):\n",
    "\n",
    "        sorted_papers = sorted(range(num_papers), key=lambda x: citation_counts[x], reverse=True)\n",
    "\n",
    "        cited_papers = [p_id for p_id in sorted_papers if p_id != paper_id]\n",
    "\n",
    "        num_citations = min(avg_citations_per_paper, len(cited_papers))\n",
    "        cited_papers = cited_papers[:num_citations]\n",
    "        for cited_paper_id in cited_papers:\n",
    "            G.add_edge(paper_id, cited_paper_id)\n",
    "\n",
    "    return G, role_id\n",
    "\n",
    "def create_paper_citation_graph3(num_papers, avg_citations_per_paper, num_classes, mean, std):\n",
    "    G = nx.DiGraph()\n",
    "    role_id = []\n",
    "\n",
    "\n",
    "    for paper_id in range(num_papers):\n",
    "        features = np.random.normal(mean,std, size=(5,))\n",
    "        label = np.random.randint(0, num_classes)\n",
    "        role_id.append(label)\n",
    "        G.add_node(paper_id, features=features, label=label)\n",
    "\n",
    "\n",
    "    author_paper_map = {}\n",
    "    for paper_id in range(num_papers):\n",
    "        num_authors = np.random.randint(1, 5)\n",
    "        authors = [f\"Author_{i}\" for i in range(num_authors)]\n",
    "        author_paper_map[paper_id] = authors\n",
    "\n",
    "\n",
    "    for paper_id in range(num_papers):\n",
    "        authors_of_current_paper = author_paper_map[paper_id]\n",
    "\n",
    "        other_papers = [p_id for p_id, authors in author_paper_map.items() if set(authors_of_current_paper) & set(authors)]\n",
    "\n",
    "        other_papers = [p_id for p_id in other_papers if p_id != paper_id]\n",
    "\n",
    "        num_citations = min(avg_citations_per_paper, len(other_papers))\n",
    "        cited_papers = np.random.choice(other_papers, size=num_citations, replace=False)\n",
    "        for cited_paper_id in cited_papers:\n",
    "            G.add_edge(paper_id, cited_paper_id)\n",
    "\n",
    "    return G, role_id\n",
    "\n",
    "def create_paper_citation_graph4(num_papers, avg_citations_per_paper, num_classes, mean, std):\n",
    "    G = nx.DiGraph()\n",
    "    role_id = []\n",
    "\n",
    "\n",
    "    for paper_id in range(num_papers):\n",
    "        features = np.random.normal(mean,std, size=(5,))\n",
    "        label = np.random.randint(0, num_classes)\n",
    "        role_id.append(label)\n",
    "        G.add_node(paper_id, features=features, label=label)\n",
    "\n",
    "\n",
    "    for paper_id in range(num_papers):\n",
    "        num_citations = np.random.poisson(avg_citations_per_paper)\n",
    "        available_papers = [p_id for p_id in range(num_papers) if p_id != paper_id]\n",
    "        if len(available_papers) < num_citations:\n",
    "            num_citations = len(available_papers)\n",
    "        cited_papers = np.random.choice(available_papers, size=num_citations, replace=False)\n",
    "        for cited_paper_id in cited_papers:\n",
    "            G.add_edge(paper_id, cited_paper_id)\n",
    "\n",
    "\n",
    "    for paper_id in range(num_papers):\n",
    "        visited = {paper_id}\n",
    "        propagation_queue = deque([(paper_id, 0)])\n",
    "        while propagation_queue:\n",
    "            current_paper_id, depth = propagation_queue.popleft()\n",
    "            if depth >= avg_citations_per_paper:\n",
    "                break\n",
    "            for neighbor_id in G.neighbors(current_paper_id):\n",
    "                if neighbor_id not in visited:\n",
    "                    visited.add(neighbor_id)\n",
    "                    G.add_edge(paper_id, neighbor_id)\n",
    "                    propagation_queue.append((neighbor_id, depth + 1))\n",
    "\n",
    "    return G, role_id\n",
    "\n",
    "def create_paper_citation_graph5(num_papers, avg_citations_per_paper, num_classes, mean, std):\n",
    "    G = nx.DiGraph()\n",
    "    role_id = []\n",
    "\n",
    "\n",
    "    for paper_id in range(num_papers):\n",
    "        features = np.random.normal(mean,std, size=(5,))\n",
    "        label = np.random.randint(0, num_classes)\n",
    "        role_id.append(label)\n",
    "        G.add_node(paper_id, features=features, label=label)\n",
    "\n",
    "\n",
    "    for paper_id in range(num_papers):\n",
    "        num_citations = min(np.random.poisson(avg_citations_per_paper), num_papers - 1)\n",
    "\n",
    "        candidate_papers = [p_id for p_id in range(num_papers) if p_id != paper_id]\n",
    "        cited_papers = np.random.choice(candidate_papers, size=num_citations, replace=False)\n",
    "        for cited_paper_id in cited_papers:\n",
    "            G.add_edge(paper_id, cited_paper_id)\n",
    "\n",
    "\n",
    "    topic_similarities = cosine_similarity(np.random.rand(num_papers, 10))\n",
    "\n",
    "\n",
    "    for paper_id in range(num_papers):\n",
    "        similar_papers = [(similarity, idx) for idx, similarity in enumerate(topic_similarities[paper_id]) if idx != paper_id]\n",
    "        similar_papers.sort(reverse=True)\n",
    "        num_citations = min(avg_citations_per_paper, len(similar_papers))\n",
    "        for _, cited_paper_id in similar_papers[:num_citations]:\n",
    "            if not G.has_edge(paper_id, cited_paper_id):\n",
    "                G.add_edge(paper_id, cited_paper_id)\n",
    "\n",
    "    return G, role_id\n",
    "\n",
    "\n",
    "def create_paper_citation_graph6(num_papers, avg_citations_per_paper, num_classes, mean, std,publication_years=None):\n",
    "    G = nx.DiGraph()\n",
    "    role_id = []\n",
    "\n",
    "\n",
    "    if publication_years is None:\n",
    "        publication_years = np.random.randint(2000, 2022, size=num_papers)\n",
    "\n",
    "\n",
    "    for paper_id in range(num_papers):\n",
    "        features = np.random.normal(mean,std, size=(5,))\n",
    "        label = np.random.randint(0, num_classes)\n",
    "        role_id.append(label)\n",
    "        G.add_node(paper_id, features=features, label=label)\n",
    "\n",
    "\n",
    "    for paper_id in range(num_papers):\n",
    "        num_citations = min(np.random.poisson(avg_citations_per_paper), num_papers - 1)\n",
    "\n",
    "        candidate_papers = [p_id for p_id in range(num_papers) if p_id != paper_id]\n",
    "        cited_papers = np.random.choice(candidate_papers, size=num_citations, replace=False)\n",
    "        for cited_paper_id in cited_papers:\n",
    "            G.add_edge(paper_id, cited_paper_id)\n",
    "\n",
    "\n",
    "    for paper_id in range(num_papers):\n",
    "        publication_year = publication_years[paper_id]\n",
    "\n",
    "        for cited_paper_id in range(num_papers):\n",
    "            if publication_years[cited_paper_id] < publication_year and not G.has_edge(paper_id, cited_paper_id):\n",
    "                G.add_edge(paper_id, cited_paper_id)\n",
    "                break\n",
    "\n",
    "    return G, role_id\n",
    "\n",
    "\n",
    "def create_author_influence_citation_graph(num_papers, avg_citations_per_paper, num_classes, mean, std):\n",
    "    G = nx.DiGraph()\n",
    "    role_id = []\n",
    "\n",
    "\n",
    "    for paper_id in range(num_papers):\n",
    "        features = np.random.normal(mean,std, size=(5,))\n",
    "        label = np.random.randint(0, num_classes)\n",
    "        role_id.append(label)\n",
    "        G.add_node(paper_id, features=features, label=label)\n",
    "\n",
    "\n",
    "    author_influence_data = {}\n",
    "    for paper_id in range(num_papers):\n",
    "        num_authors = np.random.randint(1, 5)\n",
    "        authors = np.random.randint(0, num_papers, size=num_authors)\n",
    "        influence_scores = np.random.uniform(0, 1, size=num_authors)\n",
    "        author_influence_data[paper_id] = dict(zip(authors, influence_scores))\n",
    "\n",
    "\n",
    "    for paper_id in range(num_papers):\n",
    "        num_citations = min(avg_citations_per_paper, num_papers - 1)\n",
    "        cited_papers = np.random.choice([p_id for p_id in range(num_papers) if p_id != paper_id], size=num_citations, replace=False)\n",
    "        for cited_paper_id in cited_papers:\n",
    "            G.add_edge(paper_id, cited_paper_id)\n",
    "\n",
    "\n",
    "    for paper_id in range(num_papers):\n",
    "        authors = list(author_influence_data[paper_id].keys())\n",
    "        for cited_paper_id in range(num_papers):\n",
    "            if cited_paper_id == paper_id:\n",
    "                continue\n",
    "            cited_authors = list(author_influence_data[cited_paper_id].keys())\n",
    "            if any(author in cited_authors for author in authors) and not G.has_edge(paper_id, cited_paper_id):\n",
    "                G.add_edge(paper_id, cited_paper_id)\n",
    "                break\n",
    "\n",
    "    return G, role_id\n",
    "\n",
    "def create_common_citation_count_citation_graph(num_papers, avg_citations_per_paper, num_classes, mean, std):\n",
    "    G = nx.DiGraph()\n",
    "    role_id = []\n",
    "\n",
    "\n",
    "    for paper_id in range(num_papers):\n",
    "        features = np.random.normal(mean,std, size=(5,))\n",
    "        label = np.random.randint(0, num_classes)\n",
    "        role_id.append(label)\n",
    "        G.add_node(paper_id, features=features, label=label)\n",
    "\n",
    "\n",
    "    for paper_id in range(num_papers):\n",
    "        num_citations = min(np.random.poisson(avg_citations_per_paper), num_papers - 1)\n",
    "        cited_papers = np.random.choice([p_id for p_id in range(num_papers) if p_id != paper_id], size=num_citations, replace=False)\n",
    "        for cited_paper_id in cited_papers:\n",
    "            G.add_edge(paper_id, cited_paper_id)\n",
    "\n",
    "\n",
    "    common_citation_count_data = {}\n",
    "    for paper_id in range(num_papers):\n",
    "        common_citation_counts = np.random.randint(0, 10, size=num_papers)\n",
    "        common_citation_count_data[paper_id] = common_citation_counts\n",
    "\n",
    "\n",
    "    for paper_id in range(num_papers):\n",
    "        common_citation_counts = common_citation_count_data[paper_id]\n",
    "        for cited_paper_id in range(num_papers):\n",
    "            if cited_paper_id == paper_id:\n",
    "                continue\n",
    "            cited_common_citation_counts = common_citation_count_data[cited_paper_id]\n",
    "            if any(count > 0 for count in common_citation_counts if count in cited_common_citation_counts) and not G.has_edge(paper_id, cited_paper_id):\n",
    "                G.add_edge(paper_id, cited_paper_id)\n",
    "                break\n",
    "\n",
    "    return G, role_id\n",
    "\n",
    "def create_citation_graph_based_on_citation_density(num_papers, avg_citations_per_paper, num_classes, mean, std):\n",
    "    G = nx.DiGraph()\n",
    "    role_id = []\n",
    "\n",
    "\n",
    "    for paper_id in range(num_papers):\n",
    "        features = np.random.normal(0, 1, size=(5,))\n",
    "        label = np.random.randint(0, num_classes)\n",
    "        role_id.append(label)\n",
    "        G.add_node(paper_id, features=features, label=label)\n",
    "\n",
    "\n",
    "    for paper_id in range(num_papers):\n",
    "        num_citations = min(np.random.poisson(avg_citations_per_paper), num_papers - 1)\n",
    "\n",
    "        candidate_papers = [p_id for p_id in range(num_papers) if p_id != paper_id]\n",
    "        cited_papers = np.random.choice(candidate_papers, size=num_citations, replace=False)\n",
    "        for cited_paper_id in cited_papers:\n",
    "            G.add_edge(paper_id, cited_paper_id)\n",
    "\n",
    "\n",
    "    citation_density = {paper_id: len(list(G.predecessors(paper_id))) / num_papers for paper_id in range(num_papers)}\n",
    "\n",
    "\n",
    "    for paper_id in range(num_papers):\n",
    "\n",
    "        sorted_papers = sorted(citation_density.keys(), key=lambda x: citation_density[x], reverse=True)\n",
    "\n",
    "        for cited_paper_id in sorted_papers[:avg_citations_per_paper]:\n",
    "            if cited_paper_id != paper_id and not G.has_edge(paper_id, cited_paper_id):\n",
    "                G.add_edge(paper_id, cited_paper_id)\n",
    "                break\n",
    "\n",
    "    return G, role_id\n",
    "\n",
    "def create_citation_graph_based_on_network_structure(num_papers, avg_citations_per_paper, num_classes, mean, std):\n",
    "    G = nx.DiGraph()\n",
    "    role_id = []\n",
    "\n",
    "\n",
    "    for paper_id in range(num_papers):\n",
    "        features = np.random.normal(mean,std,size=(5,))\n",
    "        label = np.random.randint(0, num_classes)\n",
    "        role_id.append(label)\n",
    "        G.add_node(paper_id, features=features, label=label)\n",
    "\n",
    "\n",
    "    for paper_id in range(num_papers):\n",
    "        num_citations = min(np.random.poisson(avg_citations_per_paper), num_papers - 1)\n",
    "\n",
    "        candidate_papers = [p_id for p_id in range(num_papers) if p_id != paper_id]\n",
    "        cited_papers = np.random.choice(candidate_papers, size=num_citations, replace=False)\n",
    "        for cited_paper_id in cited_papers:\n",
    "            G.add_edge(paper_id, cited_paper_id)\n",
    "\n",
    "\n",
    "    for paper_id in range(num_papers):\n",
    "        neighbors = list(G.neighbors(paper_id))\n",
    "        if neighbors:\n",
    "            cited_paper_id = np.random.choice(neighbors)\n",
    "            if not G.has_edge(paper_id, cited_paper_id):\n",
    "                G.add_edge(paper_id, cited_paper_id)\n",
    "\n",
    "    return G, role_id\n",
    "\n",
    "def generate_author_list(num_authors=3, num_fields=5):\n",
    "\n",
    "    authors = []\n",
    "    for _ in range(num_authors):\n",
    "        author_id = np.random.randint(1, 1000)\n",
    "        author_fields = np.random.choice(range(num_fields), size=np.random.randint(1, 4), replace=False)\n",
    "        authors.append((author_id, author_fields))\n",
    "    return authors\n",
    "def get_paper_field(authors):\n",
    "\n",
    "    paper_field = set()\n",
    "    for author in authors:\n",
    "        paper_field.update(author[1])\n",
    "    return paper_field\n",
    "def create_citation_graph_based_on_author_field(num_papers, avg_citations_per_paper, num_classes, mean, std):\n",
    "    G = nx.DiGraph()\n",
    "    role_id = []\n",
    "\n",
    "    for paper_id in range(num_papers):\n",
    "        features = np.random.normal(mean,std,size=(5,))\n",
    "        label = np.random.randint(0, num_classes)\n",
    "        role_id.append(label)\n",
    "        G.add_node(paper_id, features=features, label=label)\n",
    "\n",
    "\n",
    "    for paper_id in range(num_papers):\n",
    "        num_citations = np.random.poisson(avg_citations_per_paper)\n",
    "        cited_papers = np.random.choice([p_id for p_id in range(num_papers) if p_id != paper_id], size=num_citations, replace=False)\n",
    "        for cited_paper_id in cited_papers:\n",
    "            if cited_paper_id != paper_id:\n",
    "                G.add_edge(paper_id, cited_paper_id)\n",
    "\n",
    "\n",
    "\n",
    "    for paper_id in range(num_papers):\n",
    "        authors = generate_author_list()\n",
    "        paper_field = get_paper_field(authors)\n",
    "        for cited_paper_id in range(num_papers):\n",
    "            if cited_paper_id != paper_id:\n",
    "                cited_paper_authors = generate_author_list()\n",
    "                cited_paper_field = get_paper_field(cited_paper_authors)\n",
    "                if paper_field == cited_paper_field:\n",
    "                    G.add_edge(paper_id, cited_paper_id)\n",
    "                    break\n",
    "\n",
    "    return G, role_id\n",
    "\n",
    "\n",
    "def create_citation_graph_based_on_centrality(num_papers, avg_citations_per_paper, num_classes, centrality_measure='degree'):\n",
    "    G = nx.DiGraph()\n",
    "    role_id = []\n",
    "\n",
    "\n",
    "    for paper_id in range(num_papers):\n",
    "        features = np.random.normal(0, 1, size=(5,))\n",
    "        label = np.random.randint(0, num_classes)\n",
    "        role_id.append(label)\n",
    "        G.add_node(paper_id, features=features, label=label)\n",
    "\n",
    "\n",
    "    for paper_id in range(num_papers):\n",
    "        num_citations = min(np.random.poisson(avg_citations_per_paper), num_papers - 1)\n",
    "        cited_papers = np.random.choice([p_id for p_id in range(num_papers) if p_id != paper_id], size=num_citations, replace=False)\n",
    "        for cited_paper_id in cited_papers:\n",
    "            if cited_paper_id != paper_id:\n",
    "                G.add_edge(paper_id, cited_paper_id)\n",
    "\n",
    "\n",
    "    if centrality_measure == 'degree':\n",
    "        centrality = nx.degree_centrality(G)\n",
    "    elif centrality_measure == 'betweenness':\n",
    "        if nx.is_connected(G):\n",
    "            centrality = nx.betweenness_centrality(G)\n",
    "        else:\n",
    "            raise ValueError(\"Graph must be connected for betweenness centrality!\")\n",
    "    elif centrality_measure == 'closeness':\n",
    "        centrality = nx.closeness_centrality(G)\n",
    "    else:\n",
    "        raise ValueError(\"Unsupported centrality measure!\")\n",
    "\n",
    "\n",
    "    for paper_id in range(num_papers):\n",
    "        candidates = list(set(range(num_papers)) - {paper_id})\n",
    "        candidate_centralities = {candidate: centrality.get(candidate, 0) for candidate in candidates}\n",
    "        if not candidate_centralities:\n",
    "            G.add_edge(paper_id, np.random.randint(num_papers))\n",
    "        else:\n",
    "            max_centrality_paper = max(candidate_centralities, key=candidate_centralities.get)\n",
    "            G.add_edge(paper_id, max_centrality_paper)\n",
    "\n",
    "    return G, role_id\n",
    "\n",
    "\n",
    "def create_citation_graph_based_on_geographical_location(num_papers, avg_citations_per_paper, num_classes):\n",
    "    G = nx.DiGraph()\n",
    "    role_id = []\n",
    "\n",
    "\n",
    "    locations = [\"USA\", \"UK\", \"Canada\", \"Germany\", \"France\", \"China\", \"Japan\", \"Australia\", \"India\", \"Brazil\"]\n",
    "    author_locations = [random.choice(locations) for _ in range(num_papers)]\n",
    "\n",
    "\n",
    "    for paper_id in range(num_papers):\n",
    "        features = np.random.normal(0, 1, size=(5,))\n",
    "        label = np.random.randint(0, num_classes)\n",
    "        role_id.append(label)\n",
    "        G.add_node(paper_id, features=features, label=label)\n",
    "\n",
    "\n",
    "    for paper_id in range(num_papers):\n",
    "        num_citations = np.random.poisson(avg_citations_per_paper)\n",
    "        cited_papers = np.random.choice([p_id for p_id in range(num_papers) if p_id != paper_id], size=num_citations, replace=False)\n",
    "        for cited_paper_id in cited_papers:\n",
    "            if cited_paper_id != paper_id:\n",
    "                G.add_edge(paper_id, cited_paper_id)\n",
    "\n",
    "\n",
    "    for paper_id in range(num_papers):\n",
    "        author_location = author_locations[paper_id]\n",
    "        for other_paper_id, location in enumerate(author_locations):\n",
    "            if other_paper_id != paper_id and location == author_location:\n",
    "                G.add_edge(paper_id, other_paper_id)\n",
    "                break\n",
    "\n",
    "    return G, role_id\n",
    "\n",
    "def create_team_size_citation_graph(num_papers, avg_citations_per_paper, num_classes):\n",
    "    G = nx.DiGraph()\n",
    "    role_id = []\n",
    "\n",
    "\n",
    "    team_sizes = np.random.randint(50, 101, size=num_papers)\n",
    "\n",
    "\n",
    "    for paper_id in range(num_papers):\n",
    "        features = np.random.normal(0, 1, size=(5,))\n",
    "        label = np.random.randint(0, num_classes)\n",
    "        role_id.append(label)\n",
    "        G.add_node(paper_id, features=features, label=label)\n",
    "\n",
    "\n",
    "    for paper_id in range(num_papers):\n",
    "        num_citations = np.random.poisson(avg_citations_per_paper)\n",
    "        cited_papers = np.random.choice([p_id for p_id in range(num_papers) if p_id != paper_id], size=num_citations, replace=False)\n",
    "        for cited_paper_id in cited_papers:\n",
    "            G.add_edge(paper_id, cited_paper_id)\n",
    "\n",
    "\n",
    "    for paper_id in range(num_papers):\n",
    "        paper_team_size = team_sizes[paper_id]\n",
    "        for cited_paper_id in range(num_papers):\n",
    "            if paper_id == cited_paper_id:\n",
    "                continue\n",
    "            cited_paper_team_size = team_sizes[cited_paper_id]\n",
    "            if abs(paper_team_size - cited_paper_team_size) <= 5:\n",
    "                G.add_edge(paper_id, cited_paper_id)\n",
    "    return G, role_id\n",
    "\n",
    "def create_citation_graph_based_on_credibility(num_papers, avg_citations_per_paper, num_classes):\n",
    "    G = nx.DiGraph()\n",
    "    role_id = []\n",
    "\n",
    "\n",
    "    for paper_id in range(num_papers):\n",
    "        features = np.random.normal(0, 1, size=(5,))\n",
    "        label = np.random.randint(0, num_classes)\n",
    "        role_id.append(label)\n",
    "        G.add_node(paper_id, features=features, label=label)\n",
    "\n",
    "\n",
    "    for paper_id in range(num_papers):\n",
    "        num_citations = np.random.poisson(avg_citations_per_paper)\n",
    "        cited_papers = np.random.choice([p_id for p_id in range(num_papers) if p_id != paper_id], size=num_citations, replace=False)\n",
    "        for cited_paper_id in cited_papers:\n",
    "            G.add_edge(paper_id, cited_paper_id)\n",
    "\n",
    "\n",
    "    credibility_scores = np.random.rand(num_papers)\n",
    "\n",
    "\n",
    "    for paper_id in range(num_papers):\n",
    "        credibility_threshold = np.percentile(credibility_scores, 75)\n",
    "        candidates = [paper_id for paper_id, score in enumerate(credibility_scores) if score >= credibility_threshold and paper_id != paper_id]\n",
    "        if candidates:\n",
    "            cited_paper_id = np.random.choice(candidates)\n",
    "            G.add_edge(paper_id, cited_paper_id)\n",
    "\n",
    "    return G, role_id\n",
    "\n",
    "\n",
    "\n",
    "def generate_citation_graph(num_papers, num_authors_per_paper, num_authors):\n",
    "\n",
    "    G = nx.DiGraph()\n",
    "    author_list = []\n",
    "    paper_author_relations = {}\n",
    "\n",
    "\n",
    "    for paper_id in range(num_papers):\n",
    "        authors = np.random.choice(num_authors, size=num_authors_per_paper, replace=False)\n",
    "        paper_author_relations[paper_id] = authors\n",
    "        for author in authors:\n",
    "            author_list.append(author)\n",
    "\n",
    "\n",
    "    for paper_id, authors in paper_author_relations.items():\n",
    "        for author in authors:\n",
    "            G.add_node(author)\n",
    "            for cited_paper_id, cited_authors in paper_author_relations.items():\n",
    "                if cited_paper_id != paper_id:\n",
    "                    for cited_author in cited_authors:\n",
    "                        if author != cited_author and not G.has_edge(author, cited_author):\n",
    "                            G.add_edge(author, cited_author)\n",
    "    \n",
    "\n",
    "    citations = {}\n",
    "    for paper_id, authors in paper_author_relations.items():\n",
    "        for author in authors:\n",
    "            for successor_author in G.successors(author):\n",
    "                if successor_author in paper_author_relations and paper_id not in citations:\n",
    "                    citations[paper_id] = successor_author\n",
    "\n",
    "\n",
    "    role_id = np.random.randint(0, 2, size=num_papers)\n",
    "\n",
    "    return G, role_id, citations\n",
    "\n",
    "def convert_to_networkx_graph(citation_network):\n",
    "    G = nx.DiGraph()\n",
    "    num_papers = citation_network.shape[0]\n",
    "    for i in range(num_papers):\n",
    "        for j in range(num_papers):\n",
    "            if citation_network[i][j] == 1:\n",
    "                G.add_edge(i, j)\n",
    "    return G\n",
    "def generate_triangle_citation_network(num_papers, avg_citations_per_paper, num_classes):\n",
    "    citation_network = np.zeros((num_papers, num_papers), dtype=int)\n",
    "    role_id = np.zeros(num_papers, dtype=int)\n",
    "\n",
    "\n",
    "    num_citations = [random.randint(avg_citations_per_paper - 2, avg_citations_per_paper + 2) for _ in range(num_papers)]\n",
    "\n",
    "    for i in range(num_papers):\n",
    "\n",
    "        role_id[i] = random.randint(0, num_classes - 1)\n",
    "        for _ in range(num_citations[i]):\n",
    "\n",
    "            citation_paper = i\n",
    "            while citation_paper == i:\n",
    "                citation_paper = random.randint(0, num_papers - 1)\n",
    "            citation_network[i][citation_paper] = 1\n",
    "    \n",
    "\n",
    "    G = convert_to_networkx_graph(citation_network)\n",
    "    \n",
    "    return G, role_id\n",
    "\n",
    "\n",
    "\n",
    "def generate_citation_network_with_distance(num_papers, max_distance, self_citation_prob):\n",
    "    citation_network = np.zeros((num_papers, num_papers), dtype=int)\n",
    "    role_id = np.zeros(num_papers, dtype=int)\n",
    "    \n",
    "\n",
    "    citation_distances = np.zeros((num_papers, num_papers), dtype=int)\n",
    "\n",
    "\n",
    "    for i in range(num_papers):\n",
    "        for j in range(num_papers):\n",
    "            if i != j:\n",
    "\n",
    "                distance = random.randint(1, max_distance)\n",
    "                citation_distances[i][j] = distance\n",
    "                \n",
    "\n",
    "    for i in range(num_papers):\n",
    "\n",
    "        role_id[i] = random.randint(0, 1)\n",
    "        for j in range(num_papers):\n",
    "            if i != j:\n",
    "                distance = citation_distances[i][j]\n",
    "                if distance <= max_distance:\n",
    "                    probability = 1 / distance\n",
    "\n",
    "                    if random.random() < self_citation_prob:\n",
    "                        citation_network[i][j] = 1 if random.random() < probability else 0\n",
    "                        \n",
    "\n",
    "    G = convert_to_networkx_graph(citation_network)\n",
    "    \n",
    "    return G, role_id\n",
    "\n",
    "\n",
    "\n",
    "def generate_flow_direction_prob(num_domains):\n",
    "    flow_direction_prob = np.zeros((num_domains, num_domains))\n",
    "    for i in range(num_domains):\n",
    "        for j in range(num_domains):\n",
    "            if i != j:\n",
    "\n",
    "                flow_direction_prob[i][j] = np.random.rand()\n",
    "\n",
    "    flow_direction_prob /= np.sum(flow_direction_prob, axis=1, keepdims=True)\n",
    "    return flow_direction_prob\n",
    "def generate_citation_network_with_knowledge_flow(num_papers, num_domains):\n",
    "\n",
    "    flow_direction_prob = generate_flow_direction_prob(num_domains)\n",
    "    \n",
    "\n",
    "    citation_network = np.zeros((num_papers, num_papers), dtype=int)\n",
    "    role_id = np.zeros(num_papers, dtype=int)\n",
    "    \n",
    "    for i in range(num_papers):\n",
    "\n",
    "        role_id[i] = random.randint(0, 1)\n",
    "        for j in range(num_papers):\n",
    "            if i != j:\n",
    "\n",
    "                probability = flow_direction_prob[i % num_domains][j % num_domains]\n",
    "\n",
    "                if random.random() < probability:\n",
    "                    citation_network[i][j] = 1\n",
    "\n",
    "\n",
    "    G = convert_to_networkx_graph(citation_network)\n",
    "    \n",
    "    return G, role_id\n",
    "\n",
    "\n",
    "def generate_citation_network_with_chain_length(num_papers, max_chain_length, self_citation_prob):\n",
    "    citation_network = np.zeros((num_papers, num_papers), dtype=int)\n",
    "    role_id = np.zeros(num_papers, dtype=int)\n",
    "    \n",
    "\n",
    "    for i in range(num_papers):\n",
    "\n",
    "        role_id[i] = random.randint(0, 1)\n",
    "        for j in range(num_papers):\n",
    "            if i != j:\n",
    "\n",
    "                chain_length = random.randint(1, max_chain_length)\n",
    "\n",
    "                probability = 1 / chain_length\n",
    "\n",
    "                if random.random() < self_citation_prob:\n",
    "                    citation_network[i][j] = 1 if random.random() < probability else 0\n",
    "    \n",
    "\n",
    "    G = convert_to_networkx_graph(citation_network)\n",
    "    \n",
    "    return G, role_id\n",
    "\n",
    "def generate_citation_network_with_diversity(num_papers, num_domains, diversity_threshold, self_citation_prob):\n",
    "    citation_network = np.zeros((num_papers, num_papers), dtype=int)\n",
    "    role_id = np.zeros(num_papers, dtype=int)\n",
    "    \n",
    "\n",
    "    references = np.random.randint(0, num_domains, size=(num_papers, num_papers))\n",
    "    \n",
    "\n",
    "    diversity_scores = np.zeros(num_papers)\n",
    "    for i in range(num_papers):\n",
    "        diversity_scores[i] = len(set(references[i]))\n",
    "    \n",
    "\n",
    "    for i in range(num_papers):\n",
    "\n",
    "        role_id[i] = random.randint(0, 1)\n",
    "        for j in range(num_papers):\n",
    "            if i != j:\n",
    "\n",
    "                if diversity_scores[i] > diversity_threshold:\n",
    "                    if random.random() < self_citation_prob:\n",
    "                        citation_network[i][j] = 1\n",
    "                else:\n",
    "\n",
    "                    probability = 1 / (1 + diversity_threshold - diversity_scores[i])\n",
    "                    if random.random() < self_citation_prob:\n",
    "                        citation_network[i][j] = 1 if random.random() < probability else 0\n",
    "    \n",
    "\n",
    "    G = convert_to_networkx_graph(citation_network)\n",
    "    \n",
    "    return G, role_id\n",
    "\n",
    "\n",
    "def generate_citation_network_with_reference_count(num_papers, max_reference_count, self_citation_prob):\n",
    "    citation_network = np.zeros((num_papers, num_papers), dtype=int)\n",
    "    role_id = np.zeros(num_papers, dtype=int)\n",
    "    \n",
    "\n",
    "    reference_counts = np.random.randint(1, max_reference_count + 1, size=num_papers)\n",
    "    \n",
    "\n",
    "    for i in range(num_papers):\n",
    "\n",
    "        role_id[i] = random.randint(0, 1)\n",
    "        for j in range(num_papers):\n",
    "            if i != j:\n",
    "\n",
    "                if reference_counts[i] > reference_counts[j]:\n",
    "                    if random.random() < self_citation_prob:\n",
    "                        citation_network[i][j] = 1\n",
    "                else:\n",
    "\n",
    "                    probability = reference_counts[i] / reference_counts[j]\n",
    "                    if random.random() < self_citation_prob:\n",
    "                        citation_network[i][j] = 1 if random.random() < probability else 0\n",
    "    \n",
    "\n",
    "    G = convert_to_networkx_graph(citation_network)\n",
    "    \n",
    "    return G, role_id\n",
    "\n",
    "\n",
    "def generate_citation_network_with_research_object(num_papers, num_objects, self_citation_prob):\n",
    "    citation_network = np.zeros((num_papers, num_papers), dtype=int)\n",
    "    role_id = np.zeros(num_papers, dtype=int)\n",
    "    \n",
    "\n",
    "    object_popularity = np.random.rand(num_objects)\n",
    "    \n",
    "\n",
    "    paper_objects = np.random.randint(0, num_objects, size=num_papers)\n",
    "    \n",
    "\n",
    "    for i in range(num_papers):\n",
    "\n",
    "        role_id[i] = random.randint(0, 1)\n",
    "        for j in range(num_papers):\n",
    "            if i != j:\n",
    "\n",
    "                if object_popularity[paper_objects[i]] > object_popularity[paper_objects[j]]:\n",
    "                    if random.random() < self_citation_prob:\n",
    "                        citation_network[i][j] = 1\n",
    "                else:\n",
    "\n",
    "                    probability = object_popularity[paper_objects[i]] / object_popularity[paper_objects[j]]\n",
    "                    if random.random() < self_citation_prob:\n",
    "                        citation_network[i][j] = 1 if random.random() < probability else 0\n",
    "    \n",
    "\n",
    "    G = convert_to_networkx_graph(citation_network)\n",
    "    \n",
    "    return G, role_id\n",
    "\n",
    "\n",
    "def generate_citation_network_with_journal_reputation(num_papers, num_journals, self_citation_prob):\n",
    "    citation_network = np.zeros((num_papers, num_papers), dtype=int)\n",
    "    role_id = np.zeros(num_papers, dtype=int)\n",
    "    \n",
    "\n",
    "    journal_reputation = np.random.rand(num_journals)\n",
    "    \n",
    "\n",
    "    paper_journals = np.random.randint(0, num_journals, size=num_papers)\n",
    "    \n",
    "\n",
    "    for i in range(num_papers):\n",
    "\n",
    "        role_id[i] = random.randint(0, 1)\n",
    "        for j in range(num_papers):\n",
    "            if i != j:\n",
    "\n",
    "                if journal_reputation[paper_journals[i]] > journal_reputation[paper_journals[j]]:\n",
    "                    if random.random() < self_citation_prob:\n",
    "                        citation_network[i][j] = 1\n",
    "                else:\n",
    "\n",
    "                    probability = journal_reputation[paper_journals[i]] / journal_reputation[paper_journals[j]]\n",
    "                    if random.random() < self_citation_prob:\n",
    "                        citation_network[i][j] = 1 if random.random() < probability else 0\n",
    "    \n",
    "\n",
    "    G = convert_to_networkx_graph(citation_network)\n",
    "    \n",
    "    return G, role_id\n",
    "\n",
    "\n",
    "def generate_citation_network_with_open_access(num_papers, open_access_prob, self_citation_prob):\n",
    "    citation_network = np.zeros((num_papers, num_papers), dtype=int)\n",
    "    role_id = np.zeros(num_papers, dtype=int)\n",
    "    \n",
    "\n",
    "    open_access_status = np.random.rand(num_papers) < open_access_prob\n",
    "    \n",
    "\n",
    "    for i in range(num_papers):\n",
    "        role_id[i] = random.randint(0, 1)\n",
    "        for j in range(num_papers):\n",
    "            if i != j:\n",
    "\n",
    "                if open_access_status[i]:\n",
    "                    if random.random() < self_citation_prob:\n",
    "                        citation_network[i][j] = 1\n",
    "                else:\n",
    "\n",
    "                    if random.random() < self_citation_prob:\n",
    "                        citation_network[i][j] = 1 if random.random() < self_citation_prob else 0\n",
    "    \n",
    "\n",
    "    G = convert_to_networkx_graph(citation_network)\n",
    "    \n",
    "    return G, role_id\n",
    "\n",
    "\n",
    "\n",
    "paper_generators = {\n",
    "    1: (create_paper_citation_graph, \"Random citation relationship generation\"),\n",
    "    2: (create_paper_citation_graph2, \"Citation based on paper citation count\"),\n",
    "    3: (create_paper_citation_graph3, \"Citation based on author co-citation\"),\n",
    "    4: (create_paper_citation_graph4, \"Citation based on citation relationship propagation\"),\n",
    "    5: (create_paper_citation_graph5, \"Citation based on topic similarity\"),\n",
    "    6: (create_paper_citation_graph6, \"Citation based on citation time\"),\n",
    "    7: (create_author_influence_citation_graph, \"Citation based on author influence\"),\n",
    "    8: (create_common_citation_count_citation_graph, \"Citation based on common citation count\"),\n",
    "    9: (create_citation_graph_based_on_citation_density, \"Citation based on citation density\"),\n",
    "    10: (create_citation_graph_based_on_network_structure, \"Citation based on network structure\"),\n",
    "    11: (create_citation_graph_based_on_author_field, \"Citation based on author field\"),\n",
    "    12: (create_citation_graph_based_on_centrality, \"Citation based on citation network centrality\"),\n",
    "    13: (create_citation_graph_based_on_geographical_location, \"Citation based on author geographical location\"),\n",
    "    14: (create_team_size_citation_graph, \"Citation based on research team size\"),\n",
    "    15: (create_citation_graph_based_on_credibility, \"Citation based on citation credibility\"),\n",
    "    16: (generate_citation_graph, \"Citation based on academic lineage relationship\"),\n",
    "    17: (generate_triangle_citation_network, \"Citation based on citation structure\"),\n",
    "    18: (generate_citation_network_with_distance, \"Citation based on citation distance\"),\n",
    "    19: (generate_citation_network_with_knowledge_flow, \"Rules based on knowledge flow\"),\n",
    "    20: (generate_citation_network_with_chain_length, \"Rules based on citation chain\"),\n",
    "    21: (generate_citation_network_with_diversity, \"Citation based on diversity\"),\n",
    "    22: (generate_citation_network_with_reference_count, \"Citation based on reference count\"),\n",
    "    23: (generate_citation_network_with_research_object, \"Citation based on research object\"),\n",
    "    24: (generate_citation_network_with_journal_reputation, \"Citation based on journal/conference reputation\"),\n",
    "    25: (generate_citation_network_with_open_access, \"Rules based on open access\")\n",
    "}\n",
    "\n",
    "def generate_graph1(type1,type2,num_nodes, mean, std,sim_threshold):\n",
    "    nodes1,role_id1 = node_generators[type1][0](num_nodes,mean,std)\n",
    "    nodes2,role_id2 = node_generators[type2][0](num_nodes,mean,std)\n",
    "    nodes = np.concatenate((nodes1, nodes2), axis=0)\n",
    "\n",
    "    num_nodes = len(nodes)\n",
    "    edges = []\n",
    "    for i in range(num_nodes):\n",
    "        for j in range(i+1, num_nodes):\n",
    "\n",
    "            if sim_threshold is not None:\n",
    "\n",
    "               similarity = cosine_similarity(nodes[i:i+1, :], nodes[j:j+1, :])[0, 0]\n",
    "               if similarity > sim_threshold:\n",
    "                  edges.append((i, j))   \n",
    "    G = nx.Graph()\n",
    "    G.add_nodes_from(range(num_nodes))\n",
    "    G.add_edges_from(edges)\n",
    "    role_id = [random.randint(0, 2) for i in range(num_nodes)]\n",
    "\n",
    "    edge_index = torch.tensor(list(G.edges()), dtype=torch.long).t().contiguous()\n",
    "    if edge_index.size(1) == 0:\n",
    "\n",
    "       default_value = random.randint(0, num_nodes - 1)\n",
    "       edge_index = torch.tensor([[default_value], [default_value]], dtype=torch.long)\n",
    "\n",
    "    node_idx = torch.unique(edge_index.flatten())\n",
    "    num_nodes = node_idx.size(0)\n",
    "    max_node_idx = torch.max(node_idx).item()\n",
    "    if max_node_idx >= num_nodes or max_node_idx < num_nodes - 1:\n",
    "       num_nodes = max_node_idx + 1\n",
    "       node_idx = torch.arange(num_nodes, dtype=torch.long, device=edge_index.device)\n",
    "\n",
    "    node_idx_map = torch.zeros(num_nodes, dtype=torch.long, device=node_idx.device)\n",
    "    node_idx_map[node_idx] = torch.arange(num_nodes, dtype=torch.long, device=edge_index.device)[node_idx]\n",
    "    edge_index = node_idx_map[edge_index]\n",
    "    assert node_idx.max() == node_idx.size(0) - 1\n",
    "    return G,role_id\n",
    "\n",
    "def generate_graph2(type1,type2,num_nodes, mean, std,partial_sim_threshold, dims):\n",
    "    nodes1,role_id1 = node_generators[type1][0](num_nodes,mean,std)\n",
    "    nodes2,role_id2 = node_generators[type2][0](num_nodes,mean,std)\n",
    "    nodes = np.concatenate((nodes1, nodes2), axis=0)\n",
    "\n",
    "    num_nodes = len(nodes)\n",
    "    edges = []\n",
    "    for i in range(num_nodes):\n",
    "        for j in range(i+1, num_nodes):\n",
    "\n",
    "            if partial_sim_threshold is not None and dims is not None:\n",
    "\n",
    "               sim = cosine_similarity(nodes[i][dims].reshape(1, -1), nodes[j][dims].reshape(1, -1))\n",
    "               sim=sim[0][0]\n",
    "               if sim > partial_sim_threshold:\n",
    "                  edges.append((i, j))     \n",
    "    G = nx.Graph()\n",
    "    G.add_nodes_from(range(num_nodes))\n",
    "    G.add_edges_from(edges)\n",
    "    role_id = [random.randint(0, 2) for i in range(num_nodes)]\n",
    "    return G,role_id\n",
    "\n",
    "node_connectors = {\n",
    "    1: (generate_graph1, \"Similar edges with autonomous node similarity judgment\"),\n",
    "    2: (generate_graph2, \"Partial similar edges with autonomous multi-dimensional node similarity judgment\"),\n",
    "}\n",
    "num_papers=10\n",
    "avg_citations_per_paper=3\n",
    "num_classes=5\n",
    "mean=[1.5, 2.0, 1.2, 1.3, 1.8]\n",
    "std=[1.5, 2.0, 1.2, 1.3, 1.8]\n",
    "def generate_Y0():\n",
    "      G0,role_id = create_paper_citation_graph2(num_papers, avg_citations_per_paper, num_classes,mean, std)\n",
    "      label=0\n",
    "      return G0, role_id, label\n",
    "def generate_Y1():\n",
    "      G1,role_id = create_paper_citation_graph3(num_papers, avg_citations_per_paper, num_classes,mean, std)\n",
    "      label=1\n",
    "      return G1, role_id, label\n",
    "def generate_Y2():\n",
    "      G2,role_id = create_paper_citation_graph4(num_papers, avg_citations_per_paper, num_classes,mean, std)\n",
    "      label=2\n",
    "      return G2, role_id, label\n",
    "def generate_Y3():\n",
    "      G3,role_id = create_paper_citation_graph5(num_papers, avg_citations_per_paper, num_classes,mean, std)\n",
    "      label=3\n",
    "      return G3, role_id, label\n",
    "def generate_Y4():\n",
    "      G4,role_id = create_paper_citation_graph6(num_papers, avg_citations_per_paper, num_classes, mean=[1.5, 2.0, 1.2, 1.3, 1.8], std=[1.5, 2.0, 1.2, 1.3, 1.8],publication_years=None)\n",
    "      label=4\n",
    "      return G4, role_id, label\n",
    "def generate_real_dataset():\n",
    "\n",
    "\n",
    "\n",
    "        y = random.choice([0, 1, 2, 3, 4])\n",
    "        if y == 0:\n",
    "           G, role_id, label=generate_Y0()\n",
    "           motif1_present = True\n",
    "           motif2_present = True\n",
    "           motif3_present = False\n",
    "           motif4_present = False\n",
    "           motif5_present = False\n",
    "        elif y == 1:\n",
    "           G, role_id, label=generate_Y1()\n",
    "           motif1_present = True\n",
    "           motif2_present = False\n",
    "           motif3_present = True\n",
    "           motif4_present = False\n",
    "           motif5_present = False\n",
    "        elif y == 2:\n",
    "           G, role_id, label=generate_Y2()\n",
    "           motif1_present = False\n",
    "           motif2_present = True\n",
    "           motif3_present = False\n",
    "           motif4_present = False\n",
    "           motif5_present = True\n",
    "        elif y == 3:\n",
    "           G, role_id, label=generate_Y3()\n",
    "           motif1_present = False\n",
    "           motif2_present = False\n",
    "           motif3_present = False\n",
    "           motif4_present = True\n",
    "           motif5_present = True\n",
    "        elif y == 4:\n",
    "           G, role_id, label=generate_Y4()\n",
    "           motif1_present = False\n",
    "           motif2_present = False\n",
    "           motif3_present = True\n",
    "           motif4_present = True\n",
    "           motif5_present = False\n",
    "        return G, role_id, label,motif1_present,motif2_present,motif3_present,motif4_present,motif5_present\n",
    "\n",
    "def adjacent_connection(G1,G2):\n",
    "    nodes1 = set(G1.nodes())\n",
    "    nodes2 = set(G2.nodes())\n",
    "    if not nodes1 or not nodes2 or not G1.edges() or not G2.edges():\n",
    "        return nx.Graph(), [], torch.tensor([], dtype=torch.long)\n",
    "    common_nodes = nodes1.intersection(nodes2)\n",
    "\n",
    "    edge1 = random.choice(list(G1.edges()))\n",
    "    edge2 = random.choice(list(G2.edges()))\n",
    "\n",
    "    new_node1 = max(nodes1.union(nodes2)) + 1\n",
    "    new_node2 = max(nodes1.union(nodes2)) + 2\n",
    "\n",
    "    G1.remove_edge(*edge1)\n",
    "    G1.add_edge(edge1[0], new_node1)\n",
    "    G1.add_edge(new_node1, edge1[1])\n",
    "    G2.remove_edge(*edge2)\n",
    "    G2.add_edge(edge2[0], new_node2)\n",
    "    G2.add_edge(new_node2, edge2[1])\n",
    "\n",
    "    G1.add_node(new_node2)\n",
    "    G2.add_node(new_node1)\n",
    "\n",
    "    G = nx.compose(G1, G2)\n",
    "\n",
    "    G.add_edge(new_node1, new_node2)\n",
    "\n",
    "    for node in common_nodes:\n",
    "        G.add_node(node, role_id=np.random.randint(low=1, high=len(common_nodes) + 3))\n",
    "\n",
    "    if not nx.is_weakly_connected(G):\n",
    "        components = nx.connected_components(G)\n",
    "        largest_component = max(components, key=len)\n",
    "        isolated_nodes = [n for n in G.nodes() if n not in largest_component]\n",
    "        for u in isolated_nodes:\n",
    "            v = random.choice(list(largest_component))\n",
    "            G.add_edge(u, v)\n",
    "\n",
    "\n",
    "    role_id = [0] * G.number_of_nodes()\n",
    "    role_id[-2:] = torch.randint(0, 5, (2,)).tolist()\n",
    "\n",
    "\n",
    "    edge_index = torch.tensor(list(G.edges()), dtype=torch.long).t().contiguous()\n",
    "    if edge_index.size(1) == 0:\n",
    "\n",
    "       default_value = random.randint(0, num_nodes - 1)\n",
    "       edge_index = torch.tensor([[default_value], [default_value]], dtype=torch.long)\n",
    "\n",
    "    node_idx = torch.unique(edge_index.flatten())\n",
    "    num_nodes = node_idx.size(0)\n",
    "    max_node_idx = torch.max(node_idx).item()\n",
    "    if max_node_idx >= num_nodes or max_node_idx < num_nodes - 1:\n",
    "       num_nodes = max_node_idx + 1\n",
    "       node_idx = torch.arange(num_nodes, dtype=torch.long, device=edge_index.device)\n",
    "\n",
    "    node_idx_map = torch.zeros(num_nodes, dtype=torch.long, device=node_idx.device)\n",
    "    node_idx_map[node_idx] = torch.arange(num_nodes, dtype=torch.long, device=edge_index.device)[node_idx]\n",
    "    edge_index = node_idx_map[edge_index]\n",
    "    assert node_idx.max() == node_idx.size(0) - 1\n",
    "    return G,role_id\n",
    "\n",
    "def generate_false_cause_dataset1():\n",
    "\n",
    "\n",
    "\n",
    "        mean=[1.0, 2.0, 1.0, 1.5, 3.0]\n",
    "        std=[1.0, 2.0, 1.0, 1.5, 3.0]\n",
    "        G, role_id, label,motif1_present,motif2_present,motif3_present,motif4_present,motif5_present=generate_real_dataset()\n",
    "\n",
    "        if motif1_present == True:\n",
    "           G6,role_id= create_author_influence_citation_graph(num_papers, avg_citations_per_paper, num_classes,mean, std)\n",
    "           graph,role_id= adjacent_connection(G, G6)\n",
    "        elif motif2_present == True:\n",
    "           G7,role_id= create_common_citation_count_citation_graph(num_papers, avg_citations_per_paper, num_classes,mean, std)\n",
    "           graph,role_id= adjacent_connection(G, G7)\n",
    "        elif motif3_present == True:\n",
    "           G8,role_id= create_citation_graph_based_on_citation_density(num_papers, avg_citations_per_paper, num_classes,mean, std)\n",
    "           graph,role_id= adjacent_connection(G, G8)\n",
    "        elif motif4_present == True:\n",
    "           G9,role_id= create_citation_graph_based_on_network_structure(num_papers, avg_citations_per_paper, num_classes,mean, std)\n",
    "           graph,role_id= adjacent_connection(G, G9)\n",
    "        elif motif5_present == True:\n",
    "           G10,role_id= create_citation_graph_based_on_author_field(num_papers, avg_citations_per_paper, num_classes,mean, std)\n",
    "           graph,role_id= adjacent_connection(G, G10)\n",
    "        else:\n",
    "           graph,role_id=generate_false_dataset()\n",
    "        return graph, role_id, label\n",
    "\n",
    "def generate_false_cause_dataset2():\n",
    "        G, role_id, label,motif1_present,motif2_present,motif3_present,motif4_present,motif5_present=generate_real_dataset()\n",
    "\n",
    "        if motif1_present == True:\n",
    "           G6,role_id= generate_graph(6,random.randint(5,10),[1.0, 2.0, 1.0, 1.5, 3.0], [1.0, 2.0, 1.0, 1.5, 3.0],0.3)\n",
    "           graph,role_id= adjacent_connection(G, G6)\n",
    "        elif motif2_present == True:\n",
    "           G7,role_id= generate_graph(7,random.randint(5,10),[1.0, 2.0, 1.0, 1.5, 3.0], [1.0, 2.0, 1.0, 1.5, 3.0],0.3)\n",
    "           graph,role_id= adjacent_connection(G, G7)\n",
    "        elif motif3_present == True:\n",
    "           G8,role_id= generate_graph(8,random.randint(5,10),[1.0, 2.0, 1.0, 1.5, 3.0], [1.0, 2.0, 1.0, 1.5, 3.0],0.3)\n",
    "           graph,role_id= adjacent_connection(G, G8)\n",
    "        elif motif4_present == True:\n",
    "           G9,role_id= generate_graph(9,random.randint(5,10),[1.0, 2.0, 1.0, 1.5, 3.0], [1.0, 2.0, 1.0, 1.5, 3.0],0.5)\n",
    "           graph,role_id= adjacent_connection(G, G9)\n",
    "        elif motif5_present == True:\n",
    "           G10,role_id= generate_graph(10,random.randint(5,10),[1.0, 2.0, 1.0, 1.5, 3.0], [1.0, 2.0, 1.0, 1.5, 3.0],0.5)\n",
    "           graph,role_id= adjacent_connection(G, G10)\n",
    "        return graph, role_id, label\n",
    "\n",
    "def generate_false_cause_dataset3():\n",
    "        graph, role_id, label,motif1_present,motif2_present,motif3_present,motif4_present,motif5_present=generate_real_dataset()\n",
    "\n",
    "        if motif1_present == True:\n",
    "           G6,role_id= generate_graph(6,random.randint(5,10),[5.0, 7.0, 4.0, 6.0, 10.0], [5.0, 7.0, 4.0, 6.0, 10.0],0.3)\n",
    "           graph,role_id= adjacent_connection(graph, G6)\n",
    "        elif motif2_present == True:\n",
    "           G7,role_id= generate_graph(7,random.randint(5,10),[5.0, 7.0, 4.0, 6.0, 10.0], [5.0, 7.0, 4.0, 6.0, 10.0],0.3)\n",
    "           graph,role_id= adjacent_connection(graph, G7)\n",
    "        elif motif3_present == True:\n",
    "           G8,role_id= generate_graph(8,random.randint(5,10),[5.0, 7.0, 4.0, 6.0, 10.0], [5.0, 7.0, 4.0, 6.0, 10.0],0.3)\n",
    "           graph,role_id= adjacent_connection(graph, G8)\n",
    "        elif motif4_present == True:\n",
    "           G9,role_id= generate_graph(9,random.randint(5,10),[5.0, 7.0, 4.0, 6.0, 10.0], [5.0, 7.0, 4.0, 6.0, 10.0],0.5)\n",
    "           graph,role_id= adjacent_connection(graph, G9)\n",
    "        elif motif5_present == True:\n",
    "           G10,role_id= generate_graph(10,random.randint(5,10),[5.0, 7.0, 4.0, 6.0, 10.0], [5.0, 7.0, 4.0, 6.0, 10.0],0.5)\n",
    "           graph,role_id= adjacent_connection(graph, G10)\n",
    "        return graph, role_id, label\n",
    "\n",
    "def generate_false_cause_dataset4():\n",
    "        G, role_id, label,motif1_present,motif2_present,motif3_present,motif4_present,motif5_present=generate_real_dataset()\n",
    "\n",
    "        if motif1_present == True:\n",
    "           G6,role_id= generate_graph(6,random.randint(5,10),[10.0, 15.0, 8.0, 12.0, 20.0], [10.0, 15.0, 8.0, 12.0, 20.0],0.3)\n",
    "           graph,role_id= adjacent_connection(G, G6)\n",
    "        elif motif2_present == True:\n",
    "           G7,role_id= generate_graph(7,random.randint(5,10),[10.0, 15.0, 8.0, 12.0, 20.0], [10.0, 15.0, 8.0, 12.0, 20.0],0.3)\n",
    "           graph,role_id= adjacent_connection(G, G7)\n",
    "        elif motif3_present == True:\n",
    "           G8,role_id= generate_graph(8,random.randint(5,10),[10.0, 15.0, 8.0, 12.0, 20.0], [10.0, 15.0, 8.0, 12.0, 20.0],0.3)\n",
    "           graph,role_id= adjacent_connection(G, G8)\n",
    "        elif motif4_present == True:\n",
    "           G9,role_id= generate_graph(9,random.randint(5,10),[10.0, 15.0, 8.0, 12.0, 20.0], [10.0, 15.0, 8.0, 12.0, 20.0],0.5)\n",
    "           graph,role_id= adjacent_connection(G, G9)\n",
    "        elif motif5_present == True:\n",
    "           G10,role_id= generate_graph(10,random.randint(5,10),[10.0, 15.0, 8.0, 12.0, 20.0], [10.0, 15.0, 8.0, 12.0, 20.0],0.5)\n",
    "           graph,role_id= adjacent_connection(G, G10)\n",
    "        return graph, role_id, label\n",
    "\n",
    "\n",
    "def generate_false_cause_dataset5():\n",
    "        graph, role_id, label,motif1_present,motif2_present,motif3_present,motif4_present,motif5_present=generate_real_dataset()\n",
    "\n",
    "        if motif1_present == True:\n",
    "           G6,role_id= generate_graph(6,random.randint(5,10),[8.0, 10.0, 7.0, 9.0, 10.0], [8.0, 10.0, 7.0, 9.0, 10.0],0.3)\n",
    "           graph,role_id= adjacent_connection(graph, G6)\n",
    "        elif motif2_present == True:\n",
    "           G7,role_id= generate_graph(7,random.randint(5,10),[8.0, 10.0, 7.0, 9.0, 10.0], [8.0, 10.0, 7.0, 9.0, 10.0],0.3)\n",
    "           graph,role_id= adjacent_connection(graph, G7)\n",
    "        elif motif3_present == True:\n",
    "           G8,role_id= generate_graph(8,random.randint(5,10),[8.0, 10.0, 7.0, 9.0, 10.0], [8.0, 10.0, 7.0, 9.0, 10.0],0.3)\n",
    "           graph,role_id= adjacent_connection(graph, G8)\n",
    "        elif motif4_present == True:\n",
    "           G9,role_id= generate_graph(9,random.randint(5,10),[8.0, 10.0, 7.0, 9.0, 10.0], [8.0, 10.0, 7.0, 9.0, 10.0],0.5)\n",
    "           graph,role_id= adjacent_connection(graph, G9)\n",
    "        elif motif5_present == True:\n",
    "           G10,role_id= generate_graph(10,random.randint(5,10),[8.0, 10.0, 7.0, 9.0, 10.0], [8.0, 10.0, 7.0, 9.0, 10.0],0.5)\n",
    "           graph,role_id= adjacent_connection(graph, G10)\n",
    "        return graph, role_id, label\n",
    "\n",
    "def generate_false_cause_dataset6():\n",
    "        G, role_id, label,motif1_present,motif2_present,motif3_present,motif4_present,motif5_present=generate_real_dataset()\n",
    "\n",
    "        if motif1_present == True:\n",
    "           G6,role_id= generate_graph(6,random.randint(5,10),[20.0, 25.0, 18.0, 22.0, 20.0], [20.0, 25.0, 18.0, 22.0, 20.0],0.3)\n",
    "           graph,role_id= adjacent_connection(G, G6)\n",
    "        elif motif2_present == True:\n",
    "           G7,role_id= generate_graph(7,random.randint(5,10),[20.0, 25.0, 18.0, 22.0, 20.0], [20.0, 25.0, 18.0, 22.0, 20.0],0.3)\n",
    "           graph,role_id= adjacent_connection(G, G7)\n",
    "        elif motif3_present == True:\n",
    "           G8,role_id= generate_graph(8,random.randint(5,10),[20.0, 25.0, 18.0, 22.0, 20.0], [20.0, 25.0, 18.0, 22.0, 20.0],0.3)\n",
    "           graph,role_id= adjacent_connection(G, G8)\n",
    "        elif motif4_present == True:\n",
    "           G9,role_id= generate_graph(9,random.randint(5,10),[20.0, 25.0, 18.0, 22.0, 20.0], [20.0, 25.0, 18.0, 22.0, 20.0],0.5)\n",
    "           graph,role_id= adjacent_connection(G, G9)\n",
    "        elif motif5_present == True:\n",
    "           G10,role_id= generate_graph(10,random.randint(5,10),[20.0, 25.0, 18.0, 22.0, 20.0], [20.0, 25.0, 18.0, 22.0, 20.0],0.5)\n",
    "           graph,role_id= adjacent_connection(G, G10)\n",
    "        return graph, role_id, label\n",
    "def generate_false_dataset():\n",
    "    G,role_id,label=generate_false_cause_dataset1()\n",
    "    max_distance = 4\n",
    "    self_citation_prob = 0\n",
    "    num_domains = 3\n",
    "    max_chain_length = 5\n",
    "    diversity_threshold = 2\n",
    "    max_reference_count = 20\n",
    "    num_objects = 5\n",
    "    num_journals = 5\n",
    "    open_access_prob = 0.5\n",
    "\n",
    "    mean=[1.0, 2.0, 1.0, 1.5, 3.0]\n",
    "    std=[1.0, 2.0, 1.0, 1.5, 3.0]\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "    pgraph1,role_idr= create_paper_citation_graph(num_papers, avg_citations_per_paper, num_classes,mean, std)\n",
    "\n",
    "\n",
    "    graph,role_id=adjacent_connection(G,pgraph1)\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "    return graph,role_id,label\n",
    "\n",
    "def add_noise(G, delete_edge_prob, add_edge_prob, delete_node_prob, add_node_prob,label=None):\n",
    "\n",
    "    G_noisy = copy.deepcopy(G)\n",
    "\n",
    "\n",
    "\n",
    "    num_edges_to_delete = int(delete_edge_prob * G_noisy.number_of_edges())\n",
    "    edges_to_delete = random.sample(G_noisy.edges(), num_edges_to_delete)\n",
    "    G_noisy.remove_edges_from(edges_to_delete)\n",
    "    role_id_noisy = [random.randint(0, 2) for i in range(G_noisy.number_of_nodes())]\n",
    "    \n",
    "\n",
    "    num_edges_to_add = int(add_edge_prob * G_noisy.number_of_nodes() * (G_noisy.number_of_nodes()-1)/2)\n",
    "    for i in range(num_edges_to_add):\n",
    "        node1, node2 = random.sample(G_noisy.nodes(), 2)\n",
    "        if not G_noisy.has_edge(node1, node2):\n",
    "            G_noisy.add_edge(node1, node2)\n",
    "\n",
    "    num_nodes_to_delete = int(delete_node_prob * G_noisy.number_of_nodes())\n",
    "    nodes_to_delete = random.sample(G_noisy.nodes(), num_nodes_to_delete)\n",
    "    for node in nodes_to_delete:\n",
    "        G_noisy.remove_node(node)\n",
    "\n",
    "    num_nodes_to_add = int(add_node_prob * G_noisy.number_of_nodes())\n",
    "    for i in range(num_nodes_to_add):\n",
    "        node_id = G_noisy.number_of_nodes() + 1\n",
    "        G_noisy.add_node(node_id)\n",
    "\n",
    "        connected = False\n",
    "        while not connected:\n",
    "            nodes_to_connect = random.sample(G_noisy.nodes(), random.randint(1, G_noisy.number_of_nodes()-1))\n",
    "            for n in nodes_to_connect:\n",
    "                if not G_noisy.has_edge(node_id, n):\n",
    "                    G_noisy.add_edge(node_id, n)\n",
    "            connected = nx.is_weakly_connected(G_noisy)\n",
    "            if not connected:\n",
    "                for n in nodes_to_connect:\n",
    "                    G_noisy.remove_edge(node_id, n)\n",
    "\n",
    "    role_id_noisy = [random.randint(0, 2) for i in range(G_noisy.number_of_edges())]\n",
    "    label_noisy = label\n",
    "\n",
    "    return G_noisy, role_id_noisy, label_noisy\n"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "id": "fde7c855",
   "metadata": {},
   "source": [
    "## Training Dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "id": "5f5dc82f",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 2000/2000 [00:10<00:00, 188.33it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "#Graphs: 2000    #Nodes: 12.00    #Edges: 67.84 \n"
     ]
    }
   ],
   "source": [
    "edge_index_list, label_list = [], []\n",
    "ground_truth_list, role_id_list, pos_list, features_list = [], [], [], []\n",
    "e_mean, n_mean = [], []\n",
    "G0, _, _ = generate_false_cause_dataset1()\n",
    "feat_dict0 = nx.get_node_attributes(G0, 'features')\n",
    "feature_dim = next(iter(feat_dict0.values())).shape[0]\n",
    "for _ in tqdm(range(2000)):\n",
    "    G,role_id,label=generate_false_cause_dataset1()\n",
    "\n",
    "    label_list.append(label)\n",
    "    e_mean.append(len(G.edges))\n",
    "    n_mean.append(len(G.nodes))\n",
    "    role_id = np.array(role_id)\n",
    "    edge_index = np.array(G.edges, dtype=int).T\n",
    "    role_id_list.append(role_id)\n",
    "    edge_index_list.append(edge_index)\n",
    "    pos_list.append(np.array(list(nx.spring_layout(G).values())))\n",
    "    ground_truth_list.append(find_gd(edge_index, role_id))\n",
    "    feat_dict = nx.get_node_attributes(G, 'features')\n",
    "    cur_dim = (next(iter(feat_dict.values())).shape[0]\n",
    "               if feat_dict else feature_dim)\n",
    "\n",
    "    feat_mat = np.vstack([\n",
    "        feat_dict.get(n, np.zeros(cur_dim))\n",
    "        for n in sorted(G.nodes())\n",
    "    ])\n",
    "    features_list.append(feat_mat)\n",
    "print(\"#Graphs: %d    #Nodes: %.2f    #Edges: %.2f \" % (len(ground_truth_list), np.mean(n_mean), np.mean(e_mean)))\n",
    "np.save(osp.join(data_dir, 'form_train.npy'), {'features': features_list, 'edge_index': edge_index_list, 'label': label_list, 'ground_truth': ground_truth_list, 'role_id': role_id_list, 'pos': pos_list})\n"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "id": "c3587610",
   "metadata": {},
   "source": [
    "## Val Dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "id": "ebe75786",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 500/500 [00:02<00:00, 187.49it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "#Graphs: 500    #Nodes: 12.00    #Edges: 67.27 \n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    }
   ],
   "source": [
    "edge_index_list, label_list = [], []\n",
    "ground_truth_list, role_id_list, pos_list, features_list = [], [], [], []\n",
    "e_mean, n_mean = [], []\n",
    "G0, _, _ = generate_false_cause_dataset1()\n",
    "feat_dict0 = nx.get_node_attributes(G0, 'features')\n",
    "feature_dim = next(iter(feat_dict0.values())).shape[0]\n",
    "for _ in tqdm(range(500)):\n",
    "    G,role_id,label=generate_false_cause_dataset1()\n",
    "\n",
    "    label_list.append(label)\n",
    "    e_mean.append(len(G.edges))\n",
    "    n_mean.append(len(G.nodes))\n",
    "    role_id = np.array(role_id)\n",
    "    edge_index = np.array(G.edges, dtype=int).T\n",
    "    role_id_list.append(role_id)\n",
    "    edge_index_list.append(edge_index)\n",
    "    pos_list.append(np.array(list(nx.spring_layout(G).values())))\n",
    "    ground_truth_list.append(find_gd(edge_index, role_id))\n",
    "    feat_dict = nx.get_node_attributes(G, 'features')\n",
    "    cur_dim = (next(iter(feat_dict.values())).shape[0]\n",
    "               if feat_dict else feature_dim)\n",
    "\n",
    "    feat_mat = np.vstack([\n",
    "        feat_dict.get(n, np.zeros(cur_dim))\n",
    "        for n in sorted(G.nodes())\n",
    "    ])\n",
    "    features_list.append(feat_mat)\n",
    "print(\"#Graphs: %d    #Nodes: %.2f    #Edges: %.2f \" % (len(ground_truth_list), np.mean(n_mean), np.mean(e_mean)))\n",
    "np.save(osp.join(data_dir, 'form_val.npy'), {'features': features_list, 'edge_index': edge_index_list, 'label': label_list, 'ground_truth': ground_truth_list, 'role_id': role_id_list, 'pos': pos_list})\n"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "id": "9d858281",
   "metadata": {},
   "source": [
    "## Testing Dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "id": "1aff6470",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 500/500 [00:02<00:00, 186.32it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "#Graphs: 500    #Nodes: 12.00    #Edges: 67.20 \n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    }
   ],
   "source": [
    "edge_index_list, label_list = [], []\n",
    "ground_truth_list, role_id_list, pos_list, features_list = [], [], [], []\n",
    "e_mean, n_mean = [], []\n",
    "G0, _, _ = generate_false_cause_dataset1()\n",
    "feat_dict0 = nx.get_node_attributes(G0, 'features')\n",
    "feature_dim = next(iter(feat_dict0.values())).shape[0]\n",
    "for _ in tqdm(range(500)):\n",
    "    G,role_id,label=generate_false_cause_dataset1()\n",
    "\n",
    "    label_list.append(label)\n",
    "    e_mean.append(len(G.edges))\n",
    "    n_mean.append(len(G.nodes))\n",
    "    role_id = np.array(role_id)\n",
    "    edge_index = np.array(G.edges, dtype=int).T\n",
    "    role_id_list.append(role_id)\n",
    "    edge_index_list.append(edge_index)\n",
    "    pos_list.append(np.array(list(nx.spring_layout(G).values())))\n",
    "    ground_truth_list.append(find_gd(edge_index, role_id))\n",
    "    feat_dict = nx.get_node_attributes(G, 'features')\n",
    "    cur_dim = (next(iter(feat_dict.values())).shape[0]\n",
    "               if feat_dict else feature_dim)\n",
    "\n",
    "    feat_mat = np.vstack([\n",
    "        feat_dict.get(n, np.zeros(cur_dim))\n",
    "        for n in sorted(G.nodes())\n",
    "    ])\n",
    "    features_list.append(feat_mat)\n",
    "print(\"#Graphs: %d    #Nodes: %.2f    #Edges: %.2f \" % (len(ground_truth_list), np.mean(n_mean), np.mean(e_mean)))\n",
    "np.save(osp.join(data_dir, 'form_test.npy'), {'features': features_list, 'edge_index': edge_index_list, 'label': label_list, 'ground_truth': ground_truth_list, 'role_id': role_id_list, 'pos': pos_list})\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "50035316",
   "metadata": {},
   "source": [
    "## intervened"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 29,
   "id": "796dbe94",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 1500/1500 [00:08<00:00, 169.54it/s]\n",
      "100%|██████████| 250/250 [00:01<00:00, 172.64it/s]\n",
      "100%|██████████| 250/250 [00:01<00:00, 183.03it/s]\n"
     ]
    }
   ],
   "source": [
    "def create_fixed_intervention_confounders(num_nodes):\n",
    "\n",
    "\n",
    "    fixed_values = {\n",
    "        'social_status': 'medium',\n",
    "        'time_period': 'middle', \n",
    "        'domain_preference': 3,\n",
    "        'network_influence': 2,\n",
    "        'resource_access': 2,\n",
    "        'collaboration_tendency': 'cooperative'\n",
    "    }\n",
    "    \n",
    "\n",
    "    confounders = {}\n",
    "    for key, value in fixed_values.items():\n",
    "        confounders[key] = [value] * num_nodes\n",
    "    \n",
    "    return confounders\n",
    "\n",
    "def apply_fixed_confounding_bias(original_role_id, confounders, confounding_prob=0.9):\n",
    "\n",
    "    confounded_role_id = original_role_id.copy()\n",
    "    \n",
    "    for i in range(len(original_role_id)):\n",
    "        if random.random() < confounding_prob:\n",
    "\n",
    "            social_status = confounders['social_status'][i]\n",
    "            time_period = confounders['time_period'][i]\n",
    "            domain_pref = confounders['domain_preference'][i]\n",
    "            network_influence = confounders['network_influence'][i]\n",
    "            resource_access = confounders['resource_access'][i]\n",
    "            collaboration = confounders['collaboration_tendency'][i]\n",
    "            \n",
    "\n",
    "            confounded_role = 0\n",
    "            \n",
    "\n",
    "            if social_status == \"elite\":\n",
    "                confounded_role += 4\n",
    "            elif social_status == \"high\":\n",
    "                confounded_role += 3\n",
    "            elif social_status == \"medium\":\n",
    "                confounded_role += 2\n",
    "            else:\n",
    "                confounded_role += 1\n",
    "            \n",
    "\n",
    "            time_weights = {\"ancient\": 0, \"early\": 1, \"middle\": 2, \"late\": 3, \"modern\": 4}\n",
    "            confounded_role += time_weights[time_period]\n",
    "            \n",
    "\n",
    "            confounded_role += domain_pref\n",
    "            \n",
    "\n",
    "            confounded_role += network_influence * 2\n",
    "            \n",
    "\n",
    "            confounded_role += resource_access\n",
    "            \n",
    "\n",
    "            if collaboration == \"cooperative\":\n",
    "                confounded_role += 3\n",
    "            else:\n",
    "                confounded_role += 1\n",
    "            \n",
    "\n",
    "\n",
    "            confounded_role_id[i] = 5\n",
    "            \n",
    "\n",
    "    \n",
    "    return confounded_role_id\n",
    "\n",
    "def generate_enhanced_confounded_features(G, original_role_id, confounders, confounding_prob):\n",
    "\n",
    "    features_dict = {}\n",
    "    \n",
    "\n",
    "    base_feature_params = {\n",
    "        0: {'mean': [2.0, 3.0, 1.5, 2.5, 4.0], 'std': [0.3, 0.4, 0.2, 0.3, 0.5]},\n",
    "        1: {'mean': [1.0, 2.0, 3.0, 1.5, 2.0], 'std': [0.2, 0.3, 0.4, 0.2, 0.3]},\n",
    "        2: {'mean': [3.0, 1.0, 2.0, 3.5, 1.5], 'std': [0.4, 0.2, 0.3, 0.5, 0.2]},\n",
    "        3: {'mean': [2.5, 4.0, 1.0, 2.0, 3.0], 'std': [0.3, 0.5, 0.1, 0.2, 0.4]},\n",
    "        4: {'mean': [1.5, 2.5, 4.0, 1.0, 2.5], 'std': [0.2, 0.3, 0.6, 0.1, 0.3]}\n",
    "    }\n",
    "    \n",
    "    for node_id in G.nodes():\n",
    "        if node_id < len(original_role_id):\n",
    "            original_role = original_role_id[node_id]\n",
    "            \n",
    "\n",
    "            social_status = confounders['social_status'][node_id]\n",
    "            time_period = confounders['time_period'][node_id]\n",
    "            domain_pref = confounders['domain_preference'][node_id]\n",
    "            network_influence = confounders['network_influence'][node_id]\n",
    "            resource_access = confounders['resource_access'][node_id]\n",
    "            collaboration = confounders['collaboration_tendency'][node_id]\n",
    "            \n",
    "\n",
    "            base_mean = base_feature_params[original_role]['mean']\n",
    "            base_std = base_feature_params[original_role]['std']\n",
    "            features = np.random.normal(base_mean, base_std)\n",
    "            \n",
    "\n",
    "            if random.random() < confounding_prob:\n",
    "\n",
    "                confounding_bias = np.zeros(5)\n",
    "                \n",
    "\n",
    "                status_effects = {\n",
    "                    \"elite\": np.array([4.5, 6.0, 1.5, 3.0, 7.5]),\n",
    "                    \"high\": np.array([3.0, 4.0, 1.0, 2.0, 5.0]),\n",
    "                    \"medium\": np.array([1.5, 2.0, 2.5, 1.5, 2.5]),\n",
    "                    \"low\": np.array([-1.5, -2.0, 4.0, -1.0, 1.0])\n",
    "                }\n",
    "                confounding_bias += status_effects[social_status]\n",
    "                \n",
    "\n",
    "                time_effects = {\n",
    "                    \"ancient\": np.array([2.0, -1.0, 3.0, 4.0, -2.0]),\n",
    "                    \"early\": np.array([1.5, 2.5, -2.0, 3.0, 2.0]),\n",
    "                    \"middle\": np.array([-1.0, 3.5, -1.5, 1.0, 3.0]),\n",
    "                    \"late\": np.array([3.0, 1.0, -4.0, -2.0, 5.0]),\n",
    "                    \"modern\": np.array([-2.0, 4.0, 2.0, -3.0, 4.5])\n",
    "                }\n",
    "                confounding_bias += time_effects[time_period]\n",
    "                \n",
    "\n",
    "                domain_bias = np.array([\n",
    "                    domain_pref * 0.8,\n",
    "                    (domain_pref - 3) * 1.0,\n",
    "                    (6 - domain_pref) * 1.2,\n",
    "                    domain_pref * 0.6,\n",
    "                    (domain_pref + 2) * 0.9\n",
    "                ])\n",
    "                confounding_bias += domain_bias\n",
    "                \n",
    "\n",
    "                influence_bias = np.array([\n",
    "                    network_influence * 1.5,\n",
    "                    network_influence * 1.2,\n",
    "                    network_influence * -0.8,\n",
    "                    network_influence * 2.0,\n",
    "                    network_influence * 1.0\n",
    "                ])\n",
    "                confounding_bias += influence_bias\n",
    "                \n",
    "\n",
    "                resource_bias = np.array([\n",
    "                    resource_access * 0.7,\n",
    "                    resource_access * -0.5,\n",
    "                    resource_access * 1.3,\n",
    "                    resource_access * 0.9,\n",
    "                    resource_access * -0.8\n",
    "                ])\n",
    "                confounding_bias += resource_bias\n",
    "                \n",
    "\n",
    "                if collaboration == \"cooperative\":\n",
    "                    collab_bias = np.array([2.5, 3.0, -1.5, 2.0, 3.5])\n",
    "                else:\n",
    "                    collab_bias = np.array([-1.5, -2.0, 3.5, -1.0, -2.5])\n",
    "                confounding_bias += collab_bias\n",
    "                \n",
    "\n",
    "                features += confounding_bias * random.uniform(-20, 20) + random.uniform(-30, 30)\n",
    "            \n",
    "            features_dict[node_id] = features\n",
    "            G.nodes[node_id]['features'] = features\n",
    "    \n",
    "    return features_dict\n",
    "\n",
    "def generate_intervention_dataset(confounding_prob=0.9, dataset_type='train', num_samples=2000):\n",
    "\n",
    "    edge_index_list, label_list = [], []\n",
    "    ground_truth_list, role_id_list, pos_list, features_list = [], [], [], []\n",
    "    confounding_info_list = []\n",
    "    e_mean, n_mean = [], []\n",
    "    \n",
    "    for i in tqdm(range(num_samples)):\n",
    "        try:\n",
    "\n",
    "            G, original_role_id, label = generate_false_cause_dataset1()\n",
    "            \n",
    "\n",
    "            confounders = create_fixed_intervention_confounders(len(original_role_id))\n",
    "            \n",
    "\n",
    "            confounded_role_id = apply_fixed_confounding_bias(original_role_id, confounders, confounding_prob)\n",
    "            \n",
    "\n",
    "            features_dict = generate_enhanced_confounded_features(G, original_role_id, confounders, confounding_prob)\n",
    "            \n",
    "            label_list.append(label)\n",
    "            e_mean.append(len(G.edges))\n",
    "            n_mean.append(len(G.nodes))\n",
    "            \n",
    "\n",
    "            role_id_list.append(np.array(original_role_id))\n",
    "            \n",
    "            if G.number_of_edges() > 0:\n",
    "                edge_index = np.array(list(G.edges), dtype=int).T\n",
    "            else:\n",
    "                edge_index = np.array([[], []], dtype=int)\n",
    "            \n",
    "            edge_index_list.append(edge_index)\n",
    "            \n",
    "            try:\n",
    "                pos = nx.spring_layout(G) if G.number_of_nodes() <= 1000 else nx.random_layout(G)\n",
    "                pos_list.append(np.array(list(pos.values())))\n",
    "            except:\n",
    "                pos_list.append(np.array([]))\n",
    "            \n",
    "            try:\n",
    "                if edge_index.size > 0:\n",
    "                    row, col = edge_index\n",
    "                    original_role_id_array = np.array(original_role_id)\n",
    "                    gd = np.array(original_role_id_array[row] > 0, dtype=np.float64) * np.array(original_role_id_array[col] > 0, dtype=np.float64)\n",
    "                else:\n",
    "                    gd = np.array([])\n",
    "                ground_truth_list.append(gd)\n",
    "            except:\n",
    "                ground_truth_list.append(np.array([]))\n",
    "            \n",
    "\n",
    "            if G.number_of_nodes() > 0:\n",
    "                feat_mat = np.vstack([\n",
    "                    features_dict.get(n, np.zeros(5))\n",
    "                    for n in sorted(G.nodes())\n",
    "                ])\n",
    "                features_list.append(feat_mat)\n",
    "            else:\n",
    "                features_list.append(np.zeros((0, 5)))\n",
    "            \n",
    "\n",
    "            confounding_strength = np.mean([\n",
    "                1 if original_role_id[j] != confounded_role_id[j] else 0\n",
    "                for j in range(len(original_role_id))\n",
    "            ])\n",
    "            \n",
    "            confounding_info_list.append({\n",
    "                'confounding_prob': confounding_prob,\n",
    "                'dataset_type': dataset_type,\n",
    "                'original_role_id': original_role_id,\n",
    "                'confounded_role_id': confounded_role_id.tolist(),\n",
    "                'confounding_strength': confounding_strength,\n",
    "                'confounders': confounders,\n",
    "                'intervention_applied': True,\n",
    "                'intervention_type': 'fixed_confounders',\n",
    "                'label': label\n",
    "            })\n",
    "            \n",
    "        except Exception as e:\n",
    "            continue\n",
    "    \n",
    "    dataset_dict = {\n",
    "        'features': features_list,\n",
    "        'edge_index': edge_index_list,\n",
    "        'label': label_list,\n",
    "        'ground_truth': ground_truth_list,\n",
    "        'role_id': role_id_list,\n",
    "        'pos': pos_list,\n",
    "        'confounding_info': confounding_info_list,\n",
    "        'confounding_prob': confounding_prob,\n",
    "        'dataset_type': dataset_type,\n",
    "        'intervention_applied': True\n",
    "    }\n",
    "    \n",
    "    return dataset_dict\n",
    "\n",
    "def generate_intervention_experiment_datasets(base_dir='./data/paper/'):\n",
    "\n",
    "    import os\n",
    "    os.makedirs(base_dir, exist_ok=True)\n",
    "    \n",
    "    generated_files = []\n",
    "    confounding_prob = 0.9\n",
    "    \n",
    "\n",
    "    train_dataset_intervened = generate_intervention_dataset(\n",
    "        confounding_prob=confounding_prob,\n",
    "        dataset_type='train',\n",
    "        num_samples=1500\n",
    "    )\n",
    "    train_file = f'{base_dir}/train_intervened_0.9.npy'\n",
    "    np.save(train_file, train_dataset_intervened)\n",
    "    generated_files.append(train_file)\n",
    "    \n",
    "\n",
    "    val_dataset_intervened = generate_intervention_dataset(\n",
    "        confounding_prob=confounding_prob,\n",
    "        dataset_type='val',\n",
    "        num_samples=250\n",
    "    )\n",
    "    val_file = f'{base_dir}/val_intervened_0.9.npy'\n",
    "    np.save(val_file, val_dataset_intervened)\n",
    "    generated_files.append(val_file)\n",
    "    \n",
    "\n",
    "    test_dataset_intervened = generate_intervention_dataset(\n",
    "        confounding_prob=0,\n",
    "        dataset_type='test',\n",
    "        num_samples=250\n",
    "    )\n",
    "    test_file = f'{base_dir}/test_intervened_0.9.npy'\n",
    "    np.save(test_file, test_dataset_intervened)\n",
    "    generated_files.append(test_file)\n",
    "    \n",
    "    return generated_files\n",
    "\n",
    "def run_intervention_experiment():\n",
    "\n",
    "    generated_files = generate_intervention_experiment_datasets()\n",
    "    return generated_files\n",
    "\n",
    "if __name__ == \"__main__\":\n",
    "    generated_files = run_intervention_experiment()\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "62eaf13f",
   "metadata": {},
   "source": [
    "## conf"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e4b5a9c6",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "  0%|          | 0/3000 [00:00<?, ?it/s]"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 3000/3000 [00:17<00:00, 168.67it/s]\n",
      "100%|██████████| 375/375 [00:02<00:00, 174.73it/s]\n",
      "100%|██████████| 375/375 [00:02<00:00, 161.38it/s]\n"
     ]
    }
   ],
   "source": [
    "def create_confounding_variables(original_role_id, confounding_prob):\n",
    "\n",
    "    num_nodes = len(original_role_id)\n",
    "    confounders = {}\n",
    "    \n",
    "\n",
    "    social_status = []\n",
    "    for i in range(num_nodes):\n",
    "        if i < num_nodes // 4:\n",
    "            status = \"elite\"\n",
    "        elif i < num_nodes // 2:\n",
    "            status = \"high\"\n",
    "        elif i < 3 * num_nodes // 4:\n",
    "            status = \"medium\"\n",
    "        else:\n",
    "            status = \"low\"\n",
    "        social_status.append(status)\n",
    "    confounders['social_status'] = social_status\n",
    "    \n",
    "\n",
    "    time_period = []\n",
    "    for i in range(num_nodes):\n",
    "        period_idx = i % 5\n",
    "        periods = [\"ancient\", \"early\", \"middle\", \"late\", \"modern\"]\n",
    "        time_period.append(periods[period_idx])\n",
    "    confounders['time_period'] = time_period\n",
    "    \n",
    "\n",
    "    domain_preference = []\n",
    "    for i in range(num_nodes):\n",
    "        domain = i % 7\n",
    "        domain_preference.append(domain)\n",
    "    confounders['domain_preference'] = domain_preference\n",
    "    \n",
    "\n",
    "    network_influence = []\n",
    "    for i in range(num_nodes):\n",
    "\n",
    "        influence = int((np.sin(i * 0.5) + 1) * 2.5)\n",
    "        network_influence.append(influence)\n",
    "    confounders['network_influence'] = network_influence\n",
    "    \n",
    "\n",
    "    resource_access = []\n",
    "    for i in range(num_nodes):\n",
    "\n",
    "        access_level = (i * 3 + 7) % 6\n",
    "        resource_access.append(access_level)\n",
    "    confounders['resource_access'] = resource_access\n",
    "    \n",
    "\n",
    "    collaboration_tendency = []\n",
    "    for i in range(num_nodes):\n",
    "        if i % 2 == 0:\n",
    "            tendency = \"cooperative\"\n",
    "        else:\n",
    "            tendency = \"competitive\"\n",
    "        collaboration_tendency.append(tendency)\n",
    "    confounders['collaboration_tendency'] = collaboration_tendency\n",
    "    \n",
    "    return confounders\n",
    "\n",
    "def apply_confounding_bias(original_role_id, confounders, confounding_prob):\n",
    "\n",
    "    confounded_role_id = original_role_id.copy()\n",
    "    \n",
    "    for i in range(len(original_role_id)):\n",
    "        if random.random() < confounding_prob:\n",
    "\n",
    "            social_status = confounders['social_status'][i]\n",
    "            time_period = confounders['time_period'][i]\n",
    "            domain_pref = confounders['domain_preference'][i]\n",
    "            network_influence = confounders['network_influence'][i]\n",
    "            resource_access = confounders['resource_access'][i]\n",
    "            collaboration = confounders['collaboration_tendency'][i]\n",
    "            \n",
    "\n",
    "            confounded_role = 0\n",
    "            \n",
    "\n",
    "            if social_status == \"elite\":\n",
    "                confounded_role += 4\n",
    "            elif social_status == \"high\":\n",
    "                confounded_role += 3\n",
    "            elif social_status == \"medium\":\n",
    "                confounded_role += 2\n",
    "            else:\n",
    "                confounded_role += 1\n",
    "            \n",
    "\n",
    "            time_weights = {\"ancient\": 0, \"early\": 1, \"middle\": 2, \"late\": 3, \"modern\": 4}\n",
    "            confounded_role += time_weights[time_period]\n",
    "            \n",
    "\n",
    "            confounded_role += domain_pref\n",
    "            \n",
    "\n",
    "            confounded_role += network_influence * 2\n",
    "            \n",
    "\n",
    "            confounded_role += resource_access\n",
    "            \n",
    "\n",
    "            if collaboration == \"cooperative\":\n",
    "                confounded_role += 3\n",
    "            else:\n",
    "                confounded_role += 1\n",
    "            \n",
    "\n",
    "            for j in range(0, 10):\n",
    "                if random.random() < confounding_prob:\n",
    "                    confounded_role_id[j] = (confounded_role + random.randint(0, 5)) % 5\n",
    "                else:\n",
    "                    confounded_role_id[j] = 0\n",
    "    \n",
    "    return confounded_role_id\n",
    "\n",
    "def generate_confounded_features(G, original_role_id, confounders, confounding_prob):\n",
    "\n",
    "    features_dict = {}\n",
    "    \n",
    "\n",
    "    base_feature_params = {\n",
    "        0: {'mean': [2.0, 3.0, 1.5, 2.5, 4.0], 'std': [0.3, 0.4, 0.2, 0.3, 0.5]},\n",
    "        1: {'mean': [1.0, 2.0, 3.0, 1.5, 2.0], 'std': [0.2, 0.3, 0.4, 0.2, 0.3]},\n",
    "        2: {'mean': [3.0, 1.0, 2.0, 3.5, 1.5], 'std': [0.4, 0.2, 0.3, 0.5, 0.2]},\n",
    "        3: {'mean': [2.5, 4.0, 1.0, 2.0, 3.0], 'std': [0.3, 0.5, 0.1, 0.2, 0.4]},\n",
    "        4: {'mean': [1.5, 2.5, 4.0, 1.0, 2.5], 'std': [0.2, 0.3, 0.6, 0.1, 0.3]}\n",
    "    }\n",
    "    \n",
    "    for node_id in G.nodes():\n",
    "        if node_id < len(original_role_id):\n",
    "            original_role = original_role_id[node_id]\n",
    "            \n",
    "\n",
    "            social_status = confounders['social_status'][node_id]\n",
    "            time_period = confounders['time_period'][node_id]\n",
    "            domain_pref = confounders['domain_preference'][node_id]\n",
    "            network_influence = confounders['network_influence'][node_id]\n",
    "            resource_access = confounders['resource_access'][node_id]\n",
    "            collaboration = confounders['collaboration_tendency'][node_id]\n",
    "            \n",
    "\n",
    "            base_mean = base_feature_params[original_role]['mean']\n",
    "            base_std = base_feature_params[original_role]['std']\n",
    "            features = np.random.normal(base_mean, base_std)\n",
    "            \n",
    "\n",
    "            if random.random() < confounding_prob:\n",
    "\n",
    "                confounding_bias = np.zeros(5)\n",
    "                \n",
    "\n",
    "                status_effects = {\n",
    "                    \"elite\": np.array([4.5, 6.0, 1.5, 3.0, 7.5]),\n",
    "                    \"high\": np.array([3.0, 4.0, 1.0, 2.0, 5.0]),\n",
    "                    \"medium\": np.array([1.5, 2.0, 2.5, 1.5, 2.5]),\n",
    "                    \"low\": np.array([-1.5, -2.0, 4.0, -1.0, 1.0])\n",
    "                }\n",
    "                confounding_bias += status_effects[social_status]\n",
    "                \n",
    "\n",
    "                time_effects = {\n",
    "                    \"ancient\": np.array([2.0, -1.0, 3.0, 4.0, -2.0]),\n",
    "                    \"early\": np.array([1.5, 2.5, -2.0, 3.0, 2.0]),\n",
    "                    \"middle\": np.array([-1.0, 3.5, -1.5, 1.0, 3.0]),\n",
    "                    \"late\": np.array([3.0, 1.0, -4.0, -2.0, 5.0]),\n",
    "                    \"modern\": np.array([-2.0, 4.0, 2.0, -3.0, 4.5])\n",
    "                }\n",
    "                confounding_bias += time_effects[time_period]\n",
    "                \n",
    "\n",
    "                domain_bias = np.array([\n",
    "                    domain_pref * 0.8,\n",
    "                    (domain_pref - 3) * 1.0,\n",
    "                    (6 - domain_pref) * 1.2,\n",
    "                    domain_pref * 0.6,\n",
    "                    (domain_pref + 2) * 0.9\n",
    "                ])\n",
    "                confounding_bias += domain_bias\n",
    "                \n",
    "\n",
    "                influence_bias = np.array([\n",
    "                    network_influence * 1.5,\n",
    "                    network_influence * 1.2,\n",
    "                    network_influence * -0.8,\n",
    "                    network_influence * 2.0,\n",
    "                    network_influence * 1.0\n",
    "                ])\n",
    "                confounding_bias += influence_bias\n",
    "                \n",
    "\n",
    "                resource_bias = np.array([\n",
    "                    resource_access * 0.7,\n",
    "                    resource_access * -0.5,\n",
    "                    resource_access * 1.3,\n",
    "                    resource_access * 0.9,\n",
    "                    resource_access * -0.8\n",
    "                ])\n",
    "                confounding_bias += resource_bias\n",
    "                \n",
    "\n",
    "                if collaboration == \"cooperative\":\n",
    "                    collab_bias = np.array([2.5, 3.0, -1.5, 2.0, 3.5])\n",
    "                else:\n",
    "                    collab_bias = np.array([-1.5, -2.0, 3.5, -1.0, -2.5])\n",
    "                confounding_bias += collab_bias\n",
    "                \n",
    "\n",
    "                features += confounding_bias * 3.5\n",
    "            \n",
    "            features_dict[node_id] = features\n",
    "            G.nodes[node_id]['features'] = features\n",
    "    \n",
    "    return features_dict\n",
    "\n",
    "def generate_confounding_dataset(confounding_prob, dataset_type='train', num_samples=2000):\n",
    "\n",
    "    edge_index_list, label_list = [], []\n",
    "    ground_truth_list, role_id_list, pos_list, features_list = [], [], [], []\n",
    "    confounding_info_list = []\n",
    "    e_mean, n_mean = [], []\n",
    "    \n",
    "    for i in tqdm(range(num_samples)):\n",
    "        try:\n",
    "\n",
    "            G, original_role_id, label = generate_false_cause_dataset1()\n",
    "            \n",
    "\n",
    "            confounders = create_confounding_variables(original_role_id, confounding_prob)\n",
    "            \n",
    "\n",
    "            confounded_role_id = apply_confounding_bias(original_role_id, confounders, confounding_prob)\n",
    "            \n",
    "\n",
    "            features_dict = generate_confounded_features(G, confounded_role_id, confounders, confounding_prob)\n",
    "            \n",
    "            label_list.append(label)\n",
    "            e_mean.append(len(G.edges))\n",
    "            n_mean.append(len(G.nodes))\n",
    "            \n",
    "\n",
    "            role_id_list.append(np.array(confounded_role_id))\n",
    "            \n",
    "            if G.number_of_edges() > 0:\n",
    "                edge_index = np.array(list(G.edges), dtype=int).T\n",
    "            else:\n",
    "                edge_index = np.array([[], []], dtype=int)\n",
    "            \n",
    "            edge_index_list.append(edge_index)\n",
    "            \n",
    "            try:\n",
    "                pos = nx.spring_layout(G) if G.number_of_nodes() <= 1000 else nx.random_layout(G)\n",
    "                pos_list.append(np.array(list(pos.values())))\n",
    "            except:\n",
    "                pos_list.append(np.array([]))\n",
    "            \n",
    "            try:\n",
    "                if edge_index.size > 0:\n",
    "                    row, col = edge_index\n",
    "                    original_role_id_array = np.array(original_role_id)\n",
    "                    gd = np.array(original_role_id_array[row] > 0, dtype=np.float64) * np.array(original_role_id_array[col] > 0, dtype=np.float64)\n",
    "                else:\n",
    "                    gd = np.array([])\n",
    "                ground_truth_list.append(gd)\n",
    "            except:\n",
    "                ground_truth_list.append(np.array([]))\n",
    "            \n",
    "\n",
    "            if G.number_of_nodes() > 0:\n",
    "                feat_mat = np.vstack([\n",
    "                    features_dict.get(n, np.zeros(5))\n",
    "                    for n in sorted(G.nodes())\n",
    "                ])\n",
    "                features_list.append(feat_mat)\n",
    "            else:\n",
    "                features_list.append(np.zeros((0, 5)))\n",
    "            \n",
    "\n",
    "            confounding_strength = np.mean([\n",
    "                1 if original_role_id[j] != confounded_role_id[j] else 0\n",
    "                for j in range(len(original_role_id))\n",
    "            ])\n",
    "            \n",
    "            confounding_info_list.append({\n",
    "                'confounding_prob': confounding_prob,\n",
    "                'dataset_type': dataset_type,\n",
    "                'original_role_id': original_role_id,\n",
    "                'confounded_role_id': confounded_role_id.tolist(),\n",
    "                'confounding_strength': confounding_strength,\n",
    "                'confounders': confounders,\n",
    "                'label': label\n",
    "            })\n",
    "            \n",
    "        except Exception as e:\n",
    "            continue\n",
    "    \n",
    "    dataset_dict = {\n",
    "        'features': features_list,\n",
    "        'edge_index': edge_index_list,\n",
    "        'label': label_list,\n",
    "        'ground_truth': ground_truth_list,\n",
    "        'role_id': role_id_list,\n",
    "        'pos': pos_list,\n",
    "        'confounding_info': confounding_info_list,\n",
    "        'confounding_prob': confounding_prob,\n",
    "        'dataset_type': dataset_type\n",
    "    }\n",
    "    \n",
    "    return dataset_dict\n",
    "\n",
    "def generate_experiment_datasets(confounding_probs, base_dir='./data/paper/'):\n",
    "\n",
    "    import os\n",
    "    os.makedirs(base_dir, exist_ok=True)\n",
    "    \n",
    "    generated_files = []\n",
    "    \n",
    "    for prob in confounding_probs:\n",
    "        train_dataset = generate_confounding_dataset(\n",
    "            confounding_prob=prob, \n",
    "            dataset_type='train', \n",
    "            num_samples=1500\n",
    "        )\n",
    "        train_file = f'{base_dir}/train_conf_{prob:.1f}.npy'\n",
    "        np.save(train_file, train_dataset)\n",
    "        generated_files.append(train_file)\n",
    "        \n",
    "        val_dataset = generate_confounding_dataset(\n",
    "            confounding_prob=prob, \n",
    "            dataset_type='val', \n",
    "            num_samples=250\n",
    "        )\n",
    "        val_file = f'{base_dir}/val_conf_{prob:.1f}.npy'\n",
    "        np.save(val_file, val_dataset)\n",
    "        generated_files.append(val_file)\n",
    "        \n",
    "        test_dataset = generate_confounding_dataset(\n",
    "            confounding_prob=prob,\n",
    "            dataset_type='test', \n",
    "            num_samples=250\n",
    "        )\n",
    "        test_file = f'{base_dir}/test_conf_{prob:.1f}.npy'\n",
    "        np.save(test_file, test_dataset)\n",
    "        generated_files.append(test_file)\n",
    "    \n",
    "    return generated_files\n",
    "\n",
    "def run_confounding_experiment():\n",
    "\n",
    "    confounding_probs = [0.9, 0.8, 0.7, 0.6, 0.5, 0.4, 0.3, 0.2, 0.1]\n",
    "    generated_files = generate_experiment_datasets(confounding_probs)\n",
    "    return generated_files\n",
    "\n",
    "if __name__ == \"__main__\":\n",
    "    generated_files = run_confounding_experiment()\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "93a02c36",
   "metadata": {},
   "source": [
    "## conf and intervened"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6e1deedf",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "  0%|          | 0/3000 [00:00<?, ?it/s]"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 3000/3000 [00:18<00:00, 163.45it/s]\n",
      "100%|██████████| 375/375 [00:02<00:00, 168.92it/s]\n",
      "100%|██████████| 375/375 [00:02<00:00, 154.07it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Generated files: ['./data/casual//train_casual_2_3.npy', './data/casual//val_casual_2_3.npy', './data/casual//test_casual_2_3.npy']\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    }
   ],
   "source": [
    "def create_confounding_variables(original_role_id, confounding_prob):\n",
    "\n",
    "    num_nodes = len(original_role_id)\n",
    "    confounders = {}\n",
    "    \n",
    "\n",
    "    social_status = []\n",
    "    for i in range(num_nodes):\n",
    "        if i < num_nodes // 4:\n",
    "            status = \"elite\"\n",
    "        elif i < num_nodes // 2:\n",
    "            status = \"high\"\n",
    "        elif i < 3 * num_nodes // 4:\n",
    "            status = \"medium\"\n",
    "        else:\n",
    "            status = \"low\"\n",
    "        social_status.append(status)\n",
    "    confounders['social_status'] = social_status\n",
    "    \n",
    "\n",
    "    time_period = []\n",
    "    for i in range(num_nodes):\n",
    "        period_idx = i % 5\n",
    "        periods = [\"ancient\", \"early\", \"middle\", \"late\", \"modern\"]\n",
    "        time_period.append(periods[period_idx])\n",
    "    confounders['time_period'] = time_period\n",
    "    \n",
    "\n",
    "    domain_preference = []\n",
    "    for i in range(num_nodes):\n",
    "        domain = i % 7\n",
    "        domain_preference.append(domain)\n",
    "    confounders['domain_preference'] = domain_preference\n",
    "    \n",
    "\n",
    "    network_influence = []\n",
    "    for i in range(num_nodes):\n",
    "\n",
    "        influence = int((np.sin(i * 0.5) + 1) * 2.5)\n",
    "        network_influence.append(influence)\n",
    "    confounders['network_influence'] = network_influence\n",
    "    \n",
    "\n",
    "    resource_access = []\n",
    "    for i in range(num_nodes):\n",
    "\n",
    "        access_level = (i * 3 + 7) % 6\n",
    "        resource_access.append(access_level)\n",
    "    confounders['resource_access'] = resource_access\n",
    "    \n",
    "\n",
    "    collaboration_tendency = []\n",
    "    for i in range(num_nodes):\n",
    "        if i % 2 == 0:\n",
    "            tendency = \"cooperative\"\n",
    "        else:\n",
    "            tendency = \"competitive\"\n",
    "        collaboration_tendency.append(tendency)\n",
    "    confounders['collaboration_tendency'] = collaboration_tendency\n",
    "    \n",
    "    return confounders\n",
    "\n",
    "def create_fixed_intervention_confounders(num_nodes):\n",
    "\n",
    "\n",
    "    fixed_values = {\n",
    "        'social_status': 'medium',\n",
    "        'time_period': 'middle', \n",
    "        'domain_preference': 3,\n",
    "        'network_influence': 2,\n",
    "        'resource_access': 2,\n",
    "        'collaboration_tendency': 'cooperative'\n",
    "    }\n",
    "    \n",
    "\n",
    "    confounders = {}\n",
    "    for key, value in fixed_values.items():\n",
    "        confounders[key] = [value] * num_nodes\n",
    "    \n",
    "    return confounders\n",
    "\n",
    "def apply_confounding_bias_gen_conf_style(original_role_id, confounders, confounding_prob):\n",
    "\n",
    "    confounded_role_id = original_role_id.copy()\n",
    "    \n",
    "    for i in range(len(original_role_id)):\n",
    "        if random.random() < confounding_prob:\n",
    "\n",
    "            social_status = confounders['social_status'][i]\n",
    "            time_period = confounders['time_period'][i]\n",
    "            domain_pref = confounders['domain_preference'][i]\n",
    "            network_influence = confounders['network_influence'][i]\n",
    "            resource_access = confounders['resource_access'][i]\n",
    "            collaboration = confounders['collaboration_tendency'][i]\n",
    "            \n",
    "\n",
    "            confounded_role = 0\n",
    "            \n",
    "\n",
    "            if social_status == \"elite\":\n",
    "                confounded_role += 4\n",
    "            elif social_status == \"high\":\n",
    "                confounded_role += 3\n",
    "            elif social_status == \"medium\":\n",
    "                confounded_role += 2\n",
    "            else:\n",
    "                confounded_role += 1\n",
    "            \n",
    "\n",
    "            time_weights = {\"ancient\": 0, \"early\": 1, \"middle\": 2, \"late\": 3, \"modern\": 4}\n",
    "            confounded_role += time_weights[time_period]\n",
    "            \n",
    "\n",
    "            confounded_role += domain_pref\n",
    "            \n",
    "\n",
    "            confounded_role += network_influence * 2\n",
    "            \n",
    "\n",
    "            confounded_role += resource_access\n",
    "            \n",
    "\n",
    "            if collaboration == \"cooperative\":\n",
    "                confounded_role += 3\n",
    "            else:\n",
    "                confounded_role += 1\n",
    "            \n",
    "\n",
    "            for j in range(0, 10):\n",
    "                if random.random() < confounding_prob:\n",
    "                    confounded_role_id[j] = (confounded_role + random.randint(0, 5)) % 5\n",
    "                else:\n",
    "                    confounded_role_id[j] = 0\n",
    "    \n",
    "    return confounded_role_id\n",
    "\n",
    "def apply_fixed_confounding_bias_gen_intervened_style(original_role_id, confounders, confounding_prob=0.9):\n",
    "\n",
    "    confounded_role_id = original_role_id.copy()\n",
    "    \n",
    "    for i in range(len(original_role_id)):\n",
    "        if random.random() < confounding_prob:\n",
    "\n",
    "            social_status = confounders['social_status'][i]\n",
    "            time_period = confounders['time_period'][i]\n",
    "            domain_pref = confounders['domain_preference'][i]\n",
    "            network_influence = confounders['network_influence'][i]\n",
    "            resource_access = confounders['resource_access'][i]\n",
    "            collaboration = confounders['collaboration_tendency'][i]\n",
    "            \n",
    "\n",
    "            confounded_role = 0\n",
    "            \n",
    "\n",
    "            if social_status == \"elite\":\n",
    "                confounded_role += 4\n",
    "            elif social_status == \"high\":\n",
    "                confounded_role += 3\n",
    "            elif social_status == \"medium\":\n",
    "                confounded_role += 2\n",
    "            else:\n",
    "                confounded_role += 1\n",
    "            \n",
    "\n",
    "            time_weights = {\"ancient\": 0, \"early\": 1, \"middle\": 2, \"late\": 3, \"modern\": 4}\n",
    "            confounded_role += time_weights[time_period]\n",
    "            \n",
    "\n",
    "            confounded_role += domain_pref\n",
    "            \n",
    "\n",
    "            confounded_role += network_influence * 2\n",
    "            \n",
    "\n",
    "            confounded_role += resource_access\n",
    "            \n",
    "\n",
    "            if collaboration == \"cooperative\":\n",
    "                confounded_role += 3\n",
    "            else:\n",
    "                confounded_role += 1\n",
    "            \n",
    "\n",
    "\n",
    "            confounded_role_id[i] = 1\n",
    "    \n",
    "    return confounded_role_id\n",
    "\n",
    "def generate_confounded_features_gen_conf_style(G, original_role_id, confounders, confounding_prob):\n",
    "\n",
    "    features_dict = {}\n",
    "    \n",
    "\n",
    "    base_feature_params = {\n",
    "        0: {'mean': [2.0, 3.0, 1.5, 2.5, 4.0], 'std': [0.3, 0.4, 0.2, 0.3, 0.5]},\n",
    "        1: {'mean': [1.0, 2.0, 3.0, 1.5, 2.0], 'std': [0.2, 0.3, 0.4, 0.2, 0.3]},\n",
    "        2: {'mean': [3.0, 1.0, 2.0, 3.5, 1.5], 'std': [0.4, 0.2, 0.3, 0.5, 0.2]},\n",
    "        3: {'mean': [2.5, 4.0, 1.0, 2.0, 3.0], 'std': [0.3, 0.5, 0.1, 0.2, 0.4]},\n",
    "        4: {'mean': [1.5, 2.5, 4.0, 1.0, 2.5], 'std': [0.2, 0.3, 0.6, 0.1, 0.3]}\n",
    "    }\n",
    "    \n",
    "    for node_id in G.nodes():\n",
    "        if node_id < len(original_role_id):\n",
    "            original_role = original_role_id[node_id]\n",
    "            \n",
    "\n",
    "            social_status = confounders['social_status'][node_id]\n",
    "            time_period = confounders['time_period'][node_id]\n",
    "            domain_pref = confounders['domain_preference'][node_id]\n",
    "            network_influence = confounders['network_influence'][node_id]\n",
    "            resource_access = confounders['resource_access'][node_id]\n",
    "            collaboration = confounders['collaboration_tendency'][node_id]\n",
    "            \n",
    "\n",
    "            base_mean = base_feature_params[original_role]['mean']\n",
    "            base_std = base_feature_params[original_role]['std']\n",
    "            features = np.random.normal(base_mean, base_std)\n",
    "            \n",
    "\n",
    "            if random.random() < confounding_prob:\n",
    "\n",
    "                confounding_bias = np.zeros(5)\n",
    "                \n",
    "\n",
    "                status_effects = {\n",
    "                    \"elite\": np.array([4.5, 6.0, 1.5, 3.0, 7.5]),\n",
    "                    \"high\": np.array([3.0, 4.0, 1.0, 2.0, 5.0]),\n",
    "                    \"medium\": np.array([1.5, 2.0, 2.5, 1.5, 2.5]),\n",
    "                    \"low\": np.array([-1.5, -2.0, 4.0, -1.0, 1.0])\n",
    "                }\n",
    "                confounding_bias += status_effects[social_status]\n",
    "                \n",
    "\n",
    "                time_effects = {\n",
    "                    \"ancient\": np.array([2.0, -1.0, 3.0, 4.0, -2.0]),\n",
    "                    \"early\": np.array([1.5, 2.5, -2.0, 3.0, 2.0]),\n",
    "                    \"middle\": np.array([-1.0, 3.5, -1.5, 1.0, 3.0]),\n",
    "                    \"late\": np.array([3.0, 1.0, -4.0, -2.0, 5.0]),\n",
    "                    \"modern\": np.array([-2.0, 4.0, 2.0, -3.0, 4.5])\n",
    "                }\n",
    "                confounding_bias += time_effects[time_period]\n",
    "                \n",
    "\n",
    "                domain_bias = np.array([\n",
    "                    domain_pref * 0.8,\n",
    "                    (domain_pref - 3) * 1.0,\n",
    "                    (6 - domain_pref) * 1.2,\n",
    "                    domain_pref * 0.6,\n",
    "                    (domain_pref + 2) * 0.9\n",
    "                ])\n",
    "                confounding_bias += domain_bias\n",
    "                \n",
    "\n",
    "                influence_bias = np.array([\n",
    "                    network_influence * 1.5,\n",
    "                    network_influence * 1.2,\n",
    "                    network_influence * -0.8,\n",
    "                    network_influence * 2.0,\n",
    "                    network_influence * 1.0\n",
    "                ])\n",
    "                confounding_bias += influence_bias\n",
    "                \n",
    "\n",
    "                resource_bias = np.array([\n",
    "                    resource_access * 0.7,\n",
    "                    resource_access * -0.5,\n",
    "                    resource_access * 1.3,\n",
    "                    resource_access * 0.9,\n",
    "                    resource_access * -0.8\n",
    "                ])\n",
    "                confounding_bias += resource_bias\n",
    "                \n",
    "\n",
    "                if collaboration == \"cooperative\":\n",
    "                    collab_bias = np.array([2.5, 3.0, -1.5, 2.0, 3.5])\n",
    "                else:\n",
    "                    collab_bias = np.array([-1.5, -2.0, 3.5, -1.0, -2.5])\n",
    "                confounding_bias += collab_bias\n",
    "                \n",
    "\n",
    "                features += confounding_bias * 3.5\n",
    "            \n",
    "            features_dict[node_id] = features\n",
    "            G.nodes[node_id]['features'] = features\n",
    "    \n",
    "    return features_dict\n",
    "\n",
    "def generate_enhanced_confounded_features_gen_intervened_style(G, original_role_id, confounders, confounding_prob):\n",
    "\n",
    "    features_dict = {}\n",
    "    \n",
    "\n",
    "    base_feature_params = {\n",
    "        0: {'mean': [2.0, 3.0, 1.5, 2.5, 4.0], 'std': [0.3, 0.4, 0.2, 0.3, 0.5]},\n",
    "        1: {'mean': [1.0, 2.0, 3.0, 1.5, 2.0], 'std': [0.2, 0.3, 0.4, 0.2, 0.3]},\n",
    "        2: {'mean': [3.0, 1.0, 2.0, 3.5, 1.5], 'std': [0.4, 0.2, 0.3, 0.5, 0.2]},\n",
    "        3: {'mean': [2.5, 4.0, 1.0, 2.0, 3.0], 'std': [0.3, 0.5, 0.1, 0.2, 0.4]},\n",
    "        4: {'mean': [1.5, 2.5, 4.0, 1.0, 2.5], 'std': [0.2, 0.3, 0.6, 0.1, 0.3]}\n",
    "    }\n",
    "    \n",
    "    for node_id in G.nodes():\n",
    "        if node_id < len(original_role_id):\n",
    "            original_role = original_role_id[node_id]\n",
    "            \n",
    "\n",
    "            social_status = confounders['social_status'][node_id]\n",
    "            time_period = confounders['time_period'][node_id]\n",
    "            domain_pref = confounders['domain_preference'][node_id]\n",
    "            network_influence = confounders['network_influence'][node_id]\n",
    "            resource_access = confounders['resource_access'][node_id]\n",
    "            collaboration = confounders['collaboration_tendency'][node_id]\n",
    "            \n",
    "\n",
    "            base_mean = base_feature_params[original_role]['mean']\n",
    "            base_std = base_feature_params[original_role]['std']\n",
    "            features = np.random.normal(base_mean, base_std)\n",
    "            \n",
    "\n",
    "            if random.random() < confounding_prob:\n",
    "\n",
    "                confounding_bias = np.zeros(5)\n",
    "                \n",
    "\n",
    "                status_effects = {\n",
    "                    \"elite\": np.array([4.5, 6.0, 1.5, 3.0, 7.5]),\n",
    "                    \"high\": np.array([3.0, 4.0, 1.0, 2.0, 5.0]),\n",
    "                    \"medium\": np.array([1.5, 2.0, 2.5, 1.5, 2.5]),\n",
    "                    \"low\": np.array([-1.5, -2.0, 4.0, -1.0, 1.0])\n",
    "                }\n",
    "                confounding_bias += status_effects[social_status]\n",
    "                \n",
    "\n",
    "                time_effects = {\n",
    "                    \"ancient\": np.array([2.0, -1.0, 3.0, 4.0, -2.0]),\n",
    "                    \"early\": np.array([1.5, 2.5, -2.0, 3.0, 2.0]),\n",
    "                    \"middle\": np.array([-1.0, 3.5, -1.5, 1.0, 3.0]),\n",
    "                    \"late\": np.array([3.0, 1.0, -4.0, -2.0, 5.0]),\n",
    "                    \"modern\": np.array([-2.0, 4.0, 2.0, -3.0, 4.5])\n",
    "                }\n",
    "                confounding_bias += time_effects[time_period]\n",
    "                \n",
    "\n",
    "                domain_bias = np.array([\n",
    "                    domain_pref * 0.8,\n",
    "                    (domain_pref - 3) * 1.0,\n",
    "                    (6 - domain_pref) * 1.2,\n",
    "                    domain_pref * 0.6,\n",
    "                    (domain_pref + 2) * 0.9\n",
    "                ])\n",
    "                confounding_bias += domain_bias\n",
    "                \n",
    "\n",
    "                influence_bias = np.array([\n",
    "                    network_influence * 1.5,\n",
    "                    network_influence * 1.2,\n",
    "                    network_influence * -0.8,\n",
    "                    network_influence * 2.0,\n",
    "                    network_influence * 1.0\n",
    "                ])\n",
    "                confounding_bias += influence_bias\n",
    "                \n",
    "\n",
    "                resource_bias = np.array([\n",
    "                    resource_access * 0.7,\n",
    "                    resource_access * -0.5,\n",
    "                    resource_access * 1.3,\n",
    "                    resource_access * 0.9,\n",
    "                    resource_access * -0.8\n",
    "                ])\n",
    "                confounding_bias += resource_bias\n",
    "                \n",
    "\n",
    "                if collaboration == \"cooperative\":\n",
    "                    collab_bias = np.array([2.5, 3.0, -1.5, 2.0, 3.5])\n",
    "                else:\n",
    "                    collab_bias = np.array([-1.5, -2.0, 3.5, -1.0, -2.5])\n",
    "                confounding_bias += collab_bias\n",
    "                \n",
    "\n",
    "                features += confounding_bias * random.uniform(-20, 20) + random.uniform(-30, 30)\n",
    "            \n",
    "            features_dict[node_id] = features\n",
    "            G.nodes[node_id]['features'] = features\n",
    "    \n",
    "    return features_dict\n",
    "\n",
    "def generate_mixed_dataset(confounding_prob=0.7, intervention_prob=0.1, dataset_type='train', num_samples=2000):\n",
    "\n",
    "    edge_index_list, label_list = [], []\n",
    "    ground_truth_list, role_id_list, pos_list, features_list = [], [], [], []\n",
    "    confounding_info_list = []\n",
    "    e_mean, n_mean = [], []\n",
    "    \n",
    "    for i in tqdm(range(num_samples)):\n",
    "        try:\n",
    "\n",
    "            G, original_role_id, label = generate_false_cause_dataset1()\n",
    "            \n",
    "\n",
    "            if intervention_prob == 0.0:\n",
    "\n",
    "                confounders = create_confounding_variables(original_role_id, confounding_prob)\n",
    "                confounded_role_id = apply_confounding_bias_gen_conf_style(original_role_id, confounders, confounding_prob)\n",
    "                features_dict = generate_confounded_features_gen_conf_style(G, confounded_role_id, confounders, confounding_prob)\n",
    "                \n",
    "            elif intervention_prob == 1.0:\n",
    "\n",
    "                confounders = create_fixed_intervention_confounders(len(original_role_id))\n",
    "                confounded_role_id = apply_fixed_confounding_bias_gen_intervened_style(original_role_id, confounders, 0.9)\n",
    "                features_dict = generate_enhanced_confounded_features_gen_intervened_style(G, original_role_id, confounders, 0.9)\n",
    "                \n",
    "            else:\n",
    "\n",
    "\n",
    "                if random.random() < intervention_prob:\n",
    "\n",
    "                    confounders = create_fixed_intervention_confounders(len(original_role_id))\n",
    "                    confounded_role_id = apply_fixed_confounding_bias_gen_intervened_style(original_role_id, confounders, 0.9)\n",
    "                    features_dict = generate_enhanced_confounded_features_gen_intervened_style(G, original_role_id, confounders, 0.9)\n",
    "                else:\n",
    "\n",
    "                    confounders = create_confounding_variables(original_role_id, confounding_prob)\n",
    "                    confounded_role_id = apply_confounding_bias_gen_conf_style(original_role_id, confounders, confounding_prob)\n",
    "                    features_dict = generate_confounded_features_gen_conf_style(G, confounded_role_id, confounders, confounding_prob)\n",
    "            \n",
    "            label_list.append(label)\n",
    "            e_mean.append(len(G.edges))\n",
    "            n_mean.append(len(G.nodes))\n",
    "            \n",
    "\n",
    "            role_id_list.append(np.array(confounded_role_id))\n",
    "            \n",
    "            if G.number_of_edges() > 0:\n",
    "                edge_index = np.array(list(G.edges), dtype=int).T\n",
    "            else:\n",
    "                edge_index = np.array([[], []], dtype=int)\n",
    "            \n",
    "            edge_index_list.append(edge_index)\n",
    "            \n",
    "            try:\n",
    "                pos = nx.spring_layout(G) if G.number_of_nodes() <= 1000 else nx.random_layout(G)\n",
    "                pos_list.append(np.array(list(pos.values())))\n",
    "            except:\n",
    "                pos_list.append(np.array([]))\n",
    "            \n",
    "            try:\n",
    "                if edge_index.size > 0:\n",
    "                    row, col = edge_index\n",
    "                    original_role_id_array = np.array(original_role_id)\n",
    "                    gd = np.array(original_role_id_array[row] > 0, dtype=np.float64) * np.array(original_role_id_array[col] > 0, dtype=np.float64)\n",
    "                else:\n",
    "                    gd = np.array([])\n",
    "                ground_truth_list.append(gd)\n",
    "            except:\n",
    "                ground_truth_list.append(np.array([]))\n",
    "            \n",
    "\n",
    "            if G.number_of_nodes() > 0:\n",
    "                feat_mat = np.vstack([\n",
    "                    features_dict.get(n, np.zeros(5))\n",
    "                    for n in sorted(G.nodes())\n",
    "                ])\n",
    "                features_list.append(feat_mat)\n",
    "            else:\n",
    "                features_list.append(np.zeros((0, 5)))\n",
    "            \n",
    "\n",
    "            confounding_strength = np.mean([\n",
    "                1 if original_role_id[j] != confounded_role_id[j] else 0\n",
    "                for j in range(len(original_role_id))\n",
    "            ])\n",
    "            \n",
    "            confounding_info_list.append({\n",
    "                'confounding_prob': confounding_prob,\n",
    "                'intervention_prob': intervention_prob,\n",
    "                'dataset_type': dataset_type,\n",
    "                'original_role_id': original_role_id,\n",
    "                'confounded_role_id': confounded_role_id.tolist(),\n",
    "                'confounding_strength': confounding_strength,\n",
    "                'confounders': confounders,\n",
    "                'label': label\n",
    "            })\n",
    "            \n",
    "        except Exception as e:\n",
    "            continue\n",
    "    \n",
    "    dataset_dict = {\n",
    "        'features': features_list,\n",
    "        'edge_index': edge_index_list,\n",
    "        'label': label_list,\n",
    "        'ground_truth': ground_truth_list,\n",
    "        'role_id': role_id_list,\n",
    "        'pos': pos_list,\n",
    "        'confounding_info': confounding_info_list,\n",
    "        'confounding_prob': confounding_prob,\n",
    "        'intervention_prob': intervention_prob,\n",
    "        'dataset_type': dataset_type\n",
    "    }\n",
    "    \n",
    "    return dataset_dict\n",
    "\n",
    "def generate_mixed_experiment_datasets(intervention_prob=0.1, base_dir='./data/int/'):\n",
    "\n",
    "    import os\n",
    "    os.makedirs(base_dir, exist_ok=True)\n",
    "    \n",
    "    generated_files = []\n",
    "    confounding_prob = 0.7\n",
    "    \n",
    "\n",
    "    train_dataset = generate_mixed_dataset(\n",
    "        confounding_prob=confounding_prob,\n",
    "        intervention_prob=intervention_prob,\n",
    "        dataset_type='train',\n",
    "        num_samples=1500\n",
    "    )\n",
    "    train_file = f'{base_dir}/train_int_{intervention_prob:.1f}.npy'\n",
    "    np.save(train_file, train_dataset)\n",
    "    generated_files.append(train_file)\n",
    "    \n",
    "\n",
    "    val_dataset = generate_mixed_dataset(\n",
    "        confounding_prob=confounding_prob,\n",
    "        intervention_prob=intervention_prob,\n",
    "        dataset_type='val',\n",
    "        num_samples=200\n",
    "    )\n",
    "    val_file = f'{base_dir}/val_int_{intervention_prob:.1f}.npy'\n",
    "    np.save(val_file, val_dataset)\n",
    "    generated_files.append(val_file)\n",
    "    \n",
    "\n",
    "    test_dataset = generate_mixed_dataset(\n",
    "        confounding_prob=confounding_prob,\n",
    "        intervention_prob=intervention_prob,\n",
    "        dataset_type='test',\n",
    "        num_samples=200\n",
    "    )\n",
    "    test_file = f'{base_dir}/test_int_{intervention_prob:.1f}.npy'\n",
    "    np.save(test_file, test_dataset)\n",
    "    generated_files.append(test_file)\n",
    "    \n",
    "    return generated_files\n",
    "\n",
    "def run_mixed_experiment(intervention_prob=0.1):\n",
    "\n",
    "    generated_files = generate_mixed_experiment_datasets(intervention_prob)\n",
    "    return generated_files\n",
    "\n",
    "if __name__ == \"__main__\":\n",
    "\n",
    "    generated_files = run_mixed_experiment(intervention_prob=0.2)\n",
    "    print(f\"Generated files: {generated_files}\")\n",
    "    \n",
    "\n",
    "\n",
    "\n",
    "\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "3dc24389",
   "metadata": {},
   "source": [
    "## struc conf"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "a081ff26",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 3000/3000 [00:29<00:00, 100.22it/s]\n",
      "100%|██████████| 375/375 [00:03<00:00, 102.59it/s]\n",
      "100%|██████████| 375/375 [00:03<00:00, 98.48it/s] "
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Generated files: ['./data/casual//train_casual_1_4.npy', './data/casual//val_casual_1_4.npy', './data/casual//test_casual_1_4.npy']\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    }
   ],
   "source": [
    "def create_complex_confounding_variables(original_role_id, G, confounding_prob):\n",
    "\n",
    "    num_nodes = len(original_role_id)\n",
    "    confounders = {}\n",
    "    \n",
    "\n",
    "    social_status = []\n",
    "    degree_centrality = nx.degree_centrality(G) if G.number_of_nodes() > 0 else {}\n",
    "    for i in range(num_nodes):\n",
    "        centrality = degree_centrality.get(i, 0)\n",
    "        if centrality > 0.8:\n",
    "            status = \"elite\"\n",
    "        elif centrality > 0.6:\n",
    "            status = \"high\"\n",
    "        elif centrality > 0.3:\n",
    "            status = \"medium\"\n",
    "        else:\n",
    "            status = \"low\"\n",
    "        social_status.append(status)\n",
    "    confounders['social_status'] = social_status\n",
    "    \n",
    "\n",
    "    citation_strength = []\n",
    "    for i in range(num_nodes):\n",
    "\n",
    "        in_degree = G.in_degree(i) if G.is_directed() else G.degree(i)\n",
    "        out_degree = G.out_degree(i) if G.is_directed() else G.degree(i)\n",
    "\n",
    "        strength = (in_degree * 2 + out_degree) / max(1, num_nodes * 0.1)\n",
    "        citation_strength.append(min(strength, 10))\n",
    "    confounders['citation_strength'] = citation_strength\n",
    "    \n",
    "\n",
    "    clustering_info = []\n",
    "    clustering_coeff = nx.clustering(G.to_undirected()) if G.number_of_nodes() > 0 else {}\n",
    "    for i in range(num_nodes):\n",
    "        coeff = clustering_coeff.get(i, 0)\n",
    "        if coeff > 0.7:\n",
    "            cluster_level = \"high_cluster\"\n",
    "        elif coeff > 0.4:\n",
    "            cluster_level = \"medium_cluster\"\n",
    "        else:\n",
    "            cluster_level = \"low_cluster\"\n",
    "        clustering_info.append(cluster_level)\n",
    "    confounders['clustering_info'] = clustering_info\n",
    "    \n",
    "\n",
    "    neighbor_influence = []\n",
    "    for i in range(num_nodes):\n",
    "        neighbors = list(G.neighbors(i))\n",
    "        if neighbors:\n",
    "\n",
    "            neighbor_degrees = [G.degree(n) for n in neighbors]\n",
    "            avg_neighbor_influence = np.mean(neighbor_degrees)\n",
    "            influence_level = int(avg_neighbor_influence / max(1, np.mean([G.degree(n) for n in G.nodes()])) * 5)\n",
    "        else:\n",
    "            influence_level = 0\n",
    "        neighbor_influence.append(min(influence_level, 9))\n",
    "    confounders['neighbor_influence'] = neighbor_influence\n",
    "    \n",
    "\n",
    "    path_centrality = []\n",
    "    try:\n",
    "        if G.number_of_nodes() > 1 and nx.is_connected(G.to_undirected()):\n",
    "            betweenness = nx.betweenness_centrality(G.to_undirected())\n",
    "        else:\n",
    "            betweenness = {i: 0 for i in range(num_nodes)}\n",
    "    except:\n",
    "        betweenness = {i: 0 for i in range(num_nodes)}\n",
    "    \n",
    "    for i in range(num_nodes):\n",
    "        centrality = betweenness.get(i, 0)\n",
    "        if centrality > 0.1:\n",
    "            path_level = \"high_path\"\n",
    "        elif centrality > 0.05:\n",
    "            path_level = \"medium_path\"\n",
    "        else:\n",
    "            path_level = \"low_path\"\n",
    "        path_centrality.append(path_level)\n",
    "    confounders['path_centrality'] = path_centrality\n",
    "    \n",
    "\n",
    "    temporal_correlation = []\n",
    "    for i in range(num_nodes):\n",
    "\n",
    "        time_factor = (i % 10) / 10.0\n",
    "        neighbor_time_influence = 0\n",
    "        neighbors = list(G.neighbors(i))\n",
    "        if neighbors:\n",
    "            neighbor_time_influence = np.mean([(n % 10) / 10.0 for n in neighbors])\n",
    "        \n",
    "        combined_temporal = (time_factor + neighbor_time_influence) / 2\n",
    "        if combined_temporal > 0.7:\n",
    "            temporal_level = \"future_oriented\"\n",
    "        elif combined_temporal > 0.3:\n",
    "            temporal_level = \"present_focused\"\n",
    "        else:\n",
    "            temporal_level = \"past_oriented\"\n",
    "        temporal_correlation.append(temporal_level)\n",
    "    confounders['temporal_correlation'] = temporal_correlation\n",
    "    \n",
    "\n",
    "    multilayer_interaction = []\n",
    "    for i in range(num_nodes):\n",
    "\n",
    "        degree_norm = G.degree(i) / max(1, max([G.degree(n) for n in G.nodes()]))\n",
    "        clustering_norm = clustering_coeff.get(i, 0)\n",
    "        centrality_norm = degree_centrality.get(i, 0)\n",
    "        \n",
    "        interaction_score = (degree_norm + clustering_norm + centrality_norm) / 3\n",
    "        if interaction_score > 0.6:\n",
    "            interaction_type = \"strong_interaction\"\n",
    "        elif interaction_score > 0.3:\n",
    "            interaction_type = \"moderate_interaction\"\n",
    "        else:\n",
    "            interaction_type = \"weak_interaction\"\n",
    "        multilayer_interaction.append(interaction_type)\n",
    "    confounders['multilayer_interaction'] = multilayer_interaction\n",
    "    \n",
    "\n",
    "    structural_equivalence = []\n",
    "    for i in range(num_nodes):\n",
    "\n",
    "        node_neighbors = set(G.neighbors(i))\n",
    "        max_similarity = 0\n",
    "        for j in range(num_nodes):\n",
    "            if i != j:\n",
    "                other_neighbors = set(G.neighbors(j))\n",
    "                if len(node_neighbors) > 0 or len(other_neighbors) > 0:\n",
    "                    jaccard_sim = len(node_neighbors.intersection(other_neighbors)) / max(1, len(node_neighbors.union(other_neighbors)))\n",
    "                    max_similarity = max(max_similarity, jaccard_sim)\n",
    "        \n",
    "        if max_similarity > 0.5:\n",
    "            equiv_level = \"high_equivalence\"\n",
    "        elif max_similarity > 0.2:\n",
    "            equiv_level = \"medium_equivalence\"\n",
    "        else:\n",
    "            equiv_level = \"low_equivalence\"\n",
    "        structural_equivalence.append(equiv_level)\n",
    "    confounders['structural_equivalence'] = structural_equivalence\n",
    "    \n",
    "    return confounders\n",
    "\n",
    "def apply_complex_confounding_bias(original_role_id, G, confounders, confounding_prob):\n",
    "\n",
    "    confounded_role_id = original_role_id.copy()\n",
    "    \n",
    "    for i in range(len(original_role_id)):\n",
    "        if random.random() < confounding_prob:\n",
    "\n",
    "            social_status = confounders['social_status'][i]\n",
    "            citation_strength = confounders['citation_strength'][i]\n",
    "            clustering_info = confounders['clustering_info'][i]\n",
    "            neighbor_influence = confounders['neighbor_influence'][i]\n",
    "            path_centrality = confounders['path_centrality'][i]\n",
    "            temporal_correlation = confounders['temporal_correlation'][i]\n",
    "            multilayer_interaction = confounders['multilayer_interaction'][i]\n",
    "            structural_equivalence = confounders['structural_equivalence'][i]\n",
    "            \n",
    "\n",
    "            confounded_role = 0\n",
    "            \n",
    "\n",
    "            if social_status == \"elite\":\n",
    "                confounded_role += 5\n",
    "            elif social_status == \"high\":\n",
    "                confounded_role += 4\n",
    "            elif social_status == \"medium\":\n",
    "                confounded_role += 2\n",
    "            else:\n",
    "                confounded_role += 1\n",
    "            \n",
    "\n",
    "            confounded_role += int(citation_strength * 0.8)\n",
    "            \n",
    "\n",
    "            clustering_weights = {\"high_cluster\": 4, \"medium_cluster\": 2, \"low_cluster\": 1}\n",
    "            confounded_role += clustering_weights[clustering_info]\n",
    "            \n",
    "\n",
    "            confounded_role += neighbor_influence\n",
    "            \n",
    "\n",
    "            path_weights = {\"high_path\": 5, \"medium_path\": 3, \"low_path\": 1}\n",
    "            confounded_role += path_weights[path_centrality]\n",
    "            \n",
    "\n",
    "            temporal_weights = {\"future_oriented\": 4, \"present_focused\": 2, \"past_oriented\": 1}\n",
    "            confounded_role += temporal_weights[temporal_correlation]\n",
    "            \n",
    "\n",
    "            interaction_weights = {\"strong_interaction\": 6, \"moderate_interaction\": 3, \"weak_interaction\": 1}\n",
    "            confounded_role += interaction_weights[multilayer_interaction]\n",
    "            \n",
    "\n",
    "            equiv_weights = {\"high_equivalence\": 3, \"medium_equivalence\": 2, \"low_equivalence\": 1}\n",
    "            confounded_role += equiv_weights[structural_equivalence]\n",
    "            \n",
    "\n",
    "            node_degree = G.degree(i) if i in G.nodes() else 0\n",
    "            degree_influence = min(node_degree, 5)\n",
    "            confounded_role += degree_influence\n",
    "            \n",
    "\n",
    "            final_role = confounded_role % 5\n",
    "            \n",
    "            confounded_role_id[i] = final_role\n",
    "    \n",
    "    return confounded_role_id\n",
    "\n",
    "def generate_complex_confounded_features(G, original_role_id, confounders, confounding_prob):\n",
    "\n",
    "    features_dict = {}\n",
    "    \n",
    "\n",
    "    base_feature_params = {\n",
    "        0: {'mean': [2.0, 3.0, 1.5, 2.5, 4.0], 'std': [0.3, 0.4, 0.2, 0.3, 0.5]},\n",
    "        1: {'mean': [1.0, 2.0, 3.0, 1.5, 2.0], 'std': [0.2, 0.3, 0.4, 0.2, 0.3]},\n",
    "        2: {'mean': [3.0, 1.0, 2.0, 3.5, 1.5], 'std': [0.4, 0.2, 0.3, 0.5, 0.2]},\n",
    "        3: {'mean': [2.5, 4.0, 1.0, 2.0, 3.0], 'std': [0.3, 0.5, 0.1, 0.2, 0.4]},\n",
    "        4: {'mean': [1.5, 2.5, 4.0, 1.0, 2.5], 'std': [0.2, 0.3, 0.6, 0.1, 0.3]}\n",
    "    }\n",
    "    \n",
    "    for node_id in G.nodes():\n",
    "        if node_id < len(original_role_id):\n",
    "            original_role = original_role_id[node_id]\n",
    "            \n",
    "\n",
    "            social_status = confounders['social_status'][node_id]\n",
    "            citation_strength = confounders['citation_strength'][node_id]\n",
    "            clustering_info = confounders['clustering_info'][node_id]\n",
    "            neighbor_influence = confounders['neighbor_influence'][node_id]\n",
    "            path_centrality = confounders['path_centrality'][node_id]\n",
    "            temporal_correlation = confounders['temporal_correlation'][node_id]\n",
    "            multilayer_interaction = confounders['multilayer_interaction'][node_id]\n",
    "            structural_equivalence = confounders['structural_equivalence'][node_id]\n",
    "            \n",
    "\n",
    "            base_mean = base_feature_params[original_role]['mean']\n",
    "            base_std = base_feature_params[original_role]['std']\n",
    "            features = np.random.normal(base_mean, base_std)\n",
    "            \n",
    "\n",
    "            if random.random() < confounding_prob:\n",
    "\n",
    "                confounding_bias = np.zeros(5)\n",
    "                \n",
    "\n",
    "                status_effects = {\n",
    "                    \"elite\": np.array([6.0, 8.0, 2.0, 4.0, 9.0]),\n",
    "                    \"high\": np.array([4.0, 5.0, 1.5, 2.5, 6.0]),\n",
    "                    \"medium\": np.array([2.0, 2.5, 3.0, 2.0, 3.0]),\n",
    "                    \"low\": np.array([-1.0, -1.5, 5.0, -0.5, 1.5])\n",
    "                }\n",
    "                confounding_bias += status_effects[social_status]\n",
    "                \n",
    "\n",
    "                citation_bias = np.array([\n",
    "                    citation_strength * 0.8,\n",
    "                    citation_strength * 0.6,\n",
    "                    citation_strength * -0.4,\n",
    "                    citation_strength * 1.0,\n",
    "                    citation_strength * 0.7\n",
    "                ])\n",
    "                confounding_bias += citation_bias\n",
    "                \n",
    "\n",
    "                clustering_effects = {\n",
    "                    \"high_cluster\": np.array([3.5, 4.0, -2.0, 3.0, 4.5]),\n",
    "                    \"medium_cluster\": np.array([1.5, 2.0, 0.5, 1.5, 2.0]),\n",
    "                    \"low_cluster\": np.array([-1.0, -1.0, 3.0, -0.5, -1.0])\n",
    "                }\n",
    "                confounding_bias += clustering_effects[clustering_info]\n",
    "                \n",
    "\n",
    "                neighbor_bias = np.array([\n",
    "                    neighbor_influence * 0.5,\n",
    "                    neighbor_influence * 0.7,\n",
    "                    neighbor_influence * -0.3,\n",
    "                    neighbor_influence * 0.8,\n",
    "                    neighbor_influence * 0.6\n",
    "                ])\n",
    "                confounding_bias += neighbor_bias\n",
    "                \n",
    "\n",
    "                path_effects = {\n",
    "                    \"high_path\": np.array([4.0, 5.0, -1.5, 4.5, 5.5]),\n",
    "                    \"medium_path\": np.array([2.0, 2.5, 0.5, 2.0, 2.5]),\n",
    "                    \"low_path\": np.array([0.5, 0.5, 1.5, 0.5, 0.5])\n",
    "                }\n",
    "                confounding_bias += path_effects[path_centrality]\n",
    "                \n",
    "\n",
    "                temporal_effects = {\n",
    "                    \"future_oriented\": np.array([3.0, 4.0, -1.0, 3.5, 4.0]),\n",
    "                    \"present_focused\": np.array([1.0, 1.5, 1.0, 1.5, 1.5]),\n",
    "                    \"past_oriented\": np.array([-1.0, -1.5, 3.0, -1.0, -1.5])\n",
    "                }\n",
    "                confounding_bias += temporal_effects[temporal_correlation]\n",
    "                \n",
    "\n",
    "                interaction_effects = {\n",
    "                    \"strong_interaction\": np.array([5.0, 6.0, -2.0, 5.5, 6.5]),\n",
    "                    \"moderate_interaction\": np.array([2.5, 3.0, 0.5, 2.5, 3.0]),\n",
    "                    \"weak_interaction\": np.array([0.5, 0.5, 2.0, 0.5, 0.5])\n",
    "                }\n",
    "                confounding_bias += interaction_effects[multilayer_interaction]\n",
    "                \n",
    "\n",
    "                equiv_effects = {\n",
    "                    \"high_equivalence\": np.array([2.5, 3.0, -1.0, 2.5, 3.0]),\n",
    "                    \"medium_equivalence\": np.array([1.0, 1.5, 0.5, 1.0, 1.5]),\n",
    "                    \"low_equivalence\": np.array([0.0, 0.0, 1.0, 0.0, 0.0])\n",
    "                }\n",
    "                confounding_bias += equiv_effects[structural_equivalence]\n",
    "                \n",
    "\n",
    "                node_degree = G.degree(node_id) if node_id in G.nodes() else 0\n",
    "                degree_bias = np.array([\n",
    "                    node_degree * 0.3,\n",
    "                    node_degree * 0.4,\n",
    "                    node_degree * -0.2,\n",
    "                    node_degree * 0.5,\n",
    "                    node_degree * 0.3\n",
    "                ])\n",
    "                confounding_bias += degree_bias\n",
    "                \n",
    "\n",
    "                features += confounding_bias * 4.0\n",
    "            \n",
    "            features_dict[node_id] = features\n",
    "            G.nodes[node_id]['features'] = features\n",
    "    \n",
    "    return features_dict\n",
    "\n",
    "def generate_complex_confounding_dataset(confounding_prob, dataset_type='train', num_samples=2000):\n",
    "\n",
    "    edge_index_list, label_list = [], []\n",
    "    ground_truth_list, role_id_list, pos_list, features_list = [], [], [], []\n",
    "    confounding_info_list = []\n",
    "    e_mean, n_mean = [], []\n",
    "    \n",
    "    for i in tqdm(range(num_samples)):\n",
    "        try:\n",
    "\n",
    "            G, original_role_id, label = generate_false_cause_dataset1()\n",
    "            \n",
    "\n",
    "            confounders = create_complex_confounding_variables(original_role_id, G, confounding_prob)\n",
    "            \n",
    "\n",
    "            confounded_role_id = apply_complex_confounding_bias(original_role_id, G, confounders, confounding_prob)\n",
    "            \n",
    "\n",
    "            features_dict = generate_complex_confounded_features(G, confounded_role_id, confounders, confounding_prob)\n",
    "            \n",
    "            label_list.append(label)\n",
    "            e_mean.append(len(G.edges))\n",
    "            n_mean.append(len(G.nodes))\n",
    "            \n",
    "\n",
    "            role_id_list.append(np.array(confounded_role_id))\n",
    "            \n",
    "            if G.number_of_edges() > 0:\n",
    "                edge_index = np.array(list(G.edges), dtype=int).T\n",
    "            else:\n",
    "                edge_index = np.array([[], []], dtype=int)\n",
    "            \n",
    "            edge_index_list.append(edge_index)\n",
    "            \n",
    "            try:\n",
    "                pos = nx.spring_layout(G) if G.number_of_nodes() <= 1000 else nx.random_layout(G)\n",
    "                pos_list.append(np.array(list(pos.values())))\n",
    "            except:\n",
    "                pos_list.append(np.array([]))\n",
    "            \n",
    "            try:\n",
    "                if edge_index.size > 0:\n",
    "                    row, col = edge_index\n",
    "                    original_role_id_array = np.array(original_role_id)\n",
    "                    gd = np.array(original_role_id_array[row] > 0, dtype=np.float64) * np.array(original_role_id_array[col] > 0, dtype=np.float64)\n",
    "                else:\n",
    "                    gd = np.array([])\n",
    "                ground_truth_list.append(gd)\n",
    "            except:\n",
    "                ground_truth_list.append(np.array([]))\n",
    "            \n",
    "\n",
    "            if G.number_of_nodes() > 0:\n",
    "                feat_mat = np.vstack([\n",
    "                    features_dict.get(n, np.zeros(5))\n",
    "                    for n in sorted(G.nodes())\n",
    "                ])\n",
    "                features_list.append(feat_mat)\n",
    "            else:\n",
    "                features_list.append(np.zeros((0, 5)))\n",
    "            \n",
    "\n",
    "            confounding_strength = np.mean([\n",
    "                1 if original_role_id[j] != confounded_role_id[j] else 0\n",
    "                for j in range(len(original_role_id))\n",
    "            ])\n",
    "            \n",
    "\n",
    "            network_complexity = {\n",
    "                'avg_degree': np.mean([G.degree(n) for n in G.nodes()]) if G.number_of_nodes() > 0 else 0,\n",
    "                'clustering_coefficient': nx.average_clustering(G.to_undirected()) if G.number_of_nodes() > 0 else 0,\n",
    "                'density': nx.density(G) if G.number_of_nodes() > 0 else 0\n",
    "            }\n",
    "            \n",
    "            confounding_info_list.append({\n",
    "                'confounding_prob': confounding_prob,\n",
    "                'dataset_type': dataset_type,\n",
    "                'original_role_id': original_role_id,\n",
    "                'confounded_role_id': confounded_role_id.tolist(),\n",
    "                'confounding_strength': confounding_strength,\n",
    "                'confounders': confounders,\n",
    "                'network_complexity': network_complexity,\n",
    "                'label': label\n",
    "            })\n",
    "            \n",
    "        except Exception as e:\n",
    "            continue\n",
    "    \n",
    "    dataset_dict = {\n",
    "        'features': features_list,\n",
    "        'edge_index': edge_index_list,\n",
    "        'label': label_list,\n",
    "        'ground_truth': ground_truth_list,\n",
    "        'role_id': role_id_list,\n",
    "        'pos': pos_list,\n",
    "        'confounding_info': confounding_info_list,\n",
    "        'confounding_prob': confounding_prob,\n",
    "        'dataset_type': dataset_type\n",
    "    }\n",
    "    \n",
    "    return dataset_dict\n",
    "\n",
    "def generate_complex_experiment_datasets(confounding_prob=0.7, base_dir='./data/casual/'):\n",
    "\n",
    "    import os\n",
    "    os.makedirs(base_dir, exist_ok=True)\n",
    "    \n",
    "    generated_files = []\n",
    "    \n",
    "\n",
    "    train_dataset = generate_complex_confounding_dataset(\n",
    "        confounding_prob=0, \n",
    "        dataset_type='train', \n",
    "        num_samples=3000\n",
    "    )\n",
    "    train_file = f'{base_dir}/train_casual_1_4.npy'\n",
    "    np.save(train_file, train_dataset)\n",
    "    generated_files.append(train_file)\n",
    "    \n",
    "\n",
    "    val_dataset = generate_complex_confounding_dataset(\n",
    "        confounding_prob=0, \n",
    "        dataset_type='val', \n",
    "        num_samples=375\n",
    "    )\n",
    "    val_file = f'{base_dir}/val_casual_1_4.npy'\n",
    "    np.save(val_file, val_dataset)\n",
    "    generated_files.append(val_file)\n",
    "    \n",
    "\n",
    "    test_dataset = generate_complex_confounding_dataset(\n",
    "        confounding_prob=confounding_prob,\n",
    "        dataset_type='test', \n",
    "        num_samples=375\n",
    "    )\n",
    "    test_file = f'{base_dir}/test_casual_1_4.npy'\n",
    "    np.save(test_file, test_dataset)\n",
    "    generated_files.append(test_file)\n",
    "    \n",
    "    return generated_files\n",
    "\n",
    "def run_complex_confounding_experiment():\n",
    "\n",
    "    confounding_prob = 0.7\n",
    "    generated_files = generate_complex_experiment_datasets(confounding_prob)\n",
    "    return generated_files\n",
    "\n",
    "if __name__ == \"__main__\":\n",
    "    generated_files = run_complex_confounding_experiment()\n",
    "    print(f\"Generated files: {generated_files}\")\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "10a4f302",
   "metadata": {},
   "source": [
    "## struc conf and intervened"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "3064d346",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "  0%|          | 0/3000 [00:00<?, ?it/s]"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 3000/3000 [00:20<00:00, 147.08it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Generated files: ['./data/casual//train_casual_2_4.npy']\n"
     ]
    }
   ],
   "source": [
    "from BA3_loc import *\n",
    "from tqdm import tqdm\n",
    "import os.path as osp\n",
    "import warnings\n",
    "warnings.filterwarnings(\"ignore\")\n",
    "import random\n",
    "import math\n",
    "import torch\n",
    "import copy\n",
    "\n",
    "from scipy.stats import gamma\n",
    "from scipy.stats import gompertz\n",
    "from sklearn.metrics.pairwise import cosine_similarity\n",
    "from scipy.stats import weibull_min\n",
    "from scipy.special import gamma, gammaincinv\n",
    "from scipy.spatial.distance import cdist\n",
    "from sklearn.metrics.pairwise import euclidean_distances\n",
    "from collections import deque\n",
    "from paper import *\n",
    "\n",
    "def create_complex_confounding_variables(original_role_id, G, confounding_prob):\n",
    "\n",
    "    num_nodes = len(original_role_id)\n",
    "    confounders = {}\n",
    "    \n",
    "\n",
    "    social_status = []\n",
    "    degree_centrality = nx.degree_centrality(G) if G.number_of_nodes() > 0 else {}\n",
    "    for i in range(num_nodes):\n",
    "        centrality = degree_centrality.get(i, 0)\n",
    "        if centrality > 0.8:\n",
    "            status = \"elite\"\n",
    "        elif centrality > 0.6:\n",
    "            status = \"high\"\n",
    "        elif centrality > 0.3:\n",
    "            status = \"medium\"\n",
    "        else:\n",
    "            status = \"low\"\n",
    "        social_status.append(status)\n",
    "    confounders['social_status'] = social_status\n",
    "    \n",
    "\n",
    "    citation_strength = []\n",
    "    for i in range(num_nodes):\n",
    "\n",
    "        in_degree = G.in_degree(i) if G.is_directed() else G.degree(i)\n",
    "        out_degree = G.out_degree(i) if G.is_directed() else G.degree(i)\n",
    "\n",
    "        strength = (in_degree * 2 + out_degree) / max(1, num_nodes * 0.1)\n",
    "        citation_strength.append(min(strength, 10))\n",
    "    confounders['citation_strength'] = citation_strength\n",
    "    \n",
    "\n",
    "    clustering_info = []\n",
    "    clustering_coeff = nx.clustering(G.to_undirected()) if G.number_of_nodes() > 0 else {}\n",
    "    for i in range(num_nodes):\n",
    "        coeff = clustering_coeff.get(i, 0)\n",
    "        if coeff > 0.7:\n",
    "            cluster_level = \"high_cluster\"\n",
    "        elif coeff > 0.4:\n",
    "            cluster_level = \"medium_cluster\"\n",
    "        else:\n",
    "            cluster_level = \"low_cluster\"\n",
    "        clustering_info.append(cluster_level)\n",
    "    confounders['clustering_info'] = clustering_info\n",
    "    \n",
    "\n",
    "    neighbor_influence = []\n",
    "    for i in range(num_nodes):\n",
    "        neighbors = list(G.neighbors(i))\n",
    "        if neighbors:\n",
    "\n",
    "            neighbor_degrees = [G.degree(n) for n in neighbors]\n",
    "            avg_neighbor_influence = np.mean(neighbor_degrees)\n",
    "            influence_level = int(avg_neighbor_influence / max(1, np.mean([G.degree(n) for n in G.nodes()])) * 5)\n",
    "        else:\n",
    "            influence_level = 0\n",
    "        neighbor_influence.append(min(influence_level, 9))\n",
    "    confounders['neighbor_influence'] = neighbor_influence\n",
    "    \n",
    "\n",
    "    path_centrality = []\n",
    "    try:\n",
    "        if G.number_of_nodes() > 1 and nx.is_connected(G.to_undirected()):\n",
    "            betweenness = nx.betweenness_centrality(G.to_undirected())\n",
    "        else:\n",
    "            betweenness = {i: 0 for i in range(num_nodes)}\n",
    "    except:\n",
    "        betweenness = {i: 0 for i in range(num_nodes)}\n",
    "    \n",
    "    for i in range(num_nodes):\n",
    "        centrality = betweenness.get(i, 0)\n",
    "        if centrality > 0.1:\n",
    "            path_level = \"high_path\"\n",
    "        elif centrality > 0.05:\n",
    "            path_level = \"medium_path\"\n",
    "        else:\n",
    "            path_level = \"low_path\"\n",
    "        path_centrality.append(path_level)\n",
    "    confounders['path_centrality'] = path_centrality\n",
    "    \n",
    "\n",
    "    temporal_correlation = []\n",
    "    for i in range(num_nodes):\n",
    "\n",
    "        time_factor = (i % 10) / 10.0\n",
    "        neighbor_time_influence = 0\n",
    "        neighbors = list(G.neighbors(i))\n",
    "        if neighbors:\n",
    "            neighbor_time_influence = np.mean([(n % 10) / 10.0 for n in neighbors])\n",
    "        \n",
    "        combined_temporal = (time_factor + neighbor_time_influence) / 2\n",
    "        if combined_temporal > 0.7:\n",
    "            temporal_level = \"future_oriented\"\n",
    "        elif combined_temporal > 0.3:\n",
    "            temporal_level = \"present_focused\"\n",
    "        else:\n",
    "            temporal_level = \"past_oriented\"\n",
    "        temporal_correlation.append(temporal_level)\n",
    "    confounders['temporal_correlation'] = temporal_correlation\n",
    "    \n",
    "\n",
    "    multilayer_interaction = []\n",
    "    for i in range(num_nodes):\n",
    "\n",
    "        degree_norm = G.degree(i) / max(1, max([G.degree(n) for n in G.nodes()]))\n",
    "        clustering_norm = clustering_coeff.get(i, 0)\n",
    "        centrality_norm = degree_centrality.get(i, 0)\n",
    "        \n",
    "        interaction_score = (degree_norm + clustering_norm + centrality_norm) / 3\n",
    "        if interaction_score > 0.6:\n",
    "            interaction_type = \"strong_interaction\"\n",
    "        elif interaction_score > 0.3:\n",
    "            interaction_type = \"moderate_interaction\"\n",
    "        else:\n",
    "            interaction_type = \"weak_interaction\"\n",
    "        multilayer_interaction.append(interaction_type)\n",
    "    confounders['multilayer_interaction'] = multilayer_interaction\n",
    "    \n",
    "\n",
    "    structural_equivalence = []\n",
    "    for i in range(num_nodes):\n",
    "\n",
    "        node_neighbors = set(G.neighbors(i))\n",
    "        max_similarity = 0\n",
    "        for j in range(num_nodes):\n",
    "            if i != j:\n",
    "                other_neighbors = set(G.neighbors(j))\n",
    "                if len(node_neighbors) > 0 or len(other_neighbors) > 0:\n",
    "                    jaccard_sim = len(node_neighbors.intersection(other_neighbors)) / max(1, len(node_neighbors.union(other_neighbors)))\n",
    "                    max_similarity = max(max_similarity, jaccard_sim)\n",
    "        \n",
    "        if max_similarity > 0.5:\n",
    "            equiv_level = \"high_equivalence\"\n",
    "        elif max_similarity > 0.2:\n",
    "            equiv_level = \"medium_equivalence\"\n",
    "        else:\n",
    "            equiv_level = \"low_equivalence\"\n",
    "        structural_equivalence.append(equiv_level)\n",
    "    confounders['structural_equivalence'] = structural_equivalence\n",
    "    \n",
    "    return confounders\n",
    "\n",
    "def create_fixed_intervention_complex_confounders(num_nodes):\n",
    "\n",
    "\n",
    "    fixed_values = {\n",
    "        'social_status': 'medium',\n",
    "        'citation_strength': 5.0,\n",
    "        'clustering_info': 'medium_cluster',\n",
    "        'neighbor_influence': 4,\n",
    "        'path_centrality': 'medium_path',\n",
    "        'temporal_correlation': 'present_focused',\n",
    "        'multilayer_interaction': 'moderate_interaction',\n",
    "        'structural_equivalence': 'medium_equivalence'\n",
    "    }\n",
    "    \n",
    "\n",
    "    confounders = {}\n",
    "    for key, value in fixed_values.items():\n",
    "        confounders[key] = [value] * num_nodes\n",
    "    \n",
    "    return confounders\n",
    "\n",
    "def apply_complex_confounding_bias(original_role_id, G, confounders, confounding_prob):\n",
    "\n",
    "    confounded_role_id = original_role_id.copy()\n",
    "    \n",
    "    for i in range(len(original_role_id)):\n",
    "        if random.random() < confounding_prob:\n",
    "\n",
    "            social_status = confounders['social_status'][i]\n",
    "            citation_strength = confounders['citation_strength'][i]\n",
    "            clustering_info = confounders['clustering_info'][i]\n",
    "            neighbor_influence = confounders['neighbor_influence'][i]\n",
    "            path_centrality = confounders['path_centrality'][i]\n",
    "            temporal_correlation = confounders['temporal_correlation'][i]\n",
    "            multilayer_interaction = confounders['multilayer_interaction'][i]\n",
    "            structural_equivalence = confounders['structural_equivalence'][i]\n",
    "            \n",
    "\n",
    "            confounded_role = 0\n",
    "            \n",
    "\n",
    "            if social_status == \"elite\":\n",
    "                confounded_role += 5\n",
    "            elif social_status == \"high\":\n",
    "                confounded_role += 4\n",
    "            elif social_status == \"medium\":\n",
    "                confounded_role += 2\n",
    "            else:\n",
    "                confounded_role += 1\n",
    "            \n",
    "\n",
    "            confounded_role += int(citation_strength * 0.8)\n",
    "            \n",
    "\n",
    "            clustering_weights = {\"high_cluster\": 4, \"medium_cluster\": 2, \"low_cluster\": 1}\n",
    "            confounded_role += clustering_weights[clustering_info]\n",
    "            \n",
    "\n",
    "            confounded_role += neighbor_influence\n",
    "            \n",
    "\n",
    "            path_weights = {\"high_path\": 5, \"medium_path\": 3, \"low_path\": 1}\n",
    "            confounded_role += path_weights[path_centrality]\n",
    "            \n",
    "\n",
    "            temporal_weights = {\"future_oriented\": 4, \"present_focused\": 2, \"past_oriented\": 1}\n",
    "            confounded_role += temporal_weights[temporal_correlation]\n",
    "            \n",
    "\n",
    "            interaction_weights = {\"strong_interaction\": 6, \"moderate_interaction\": 3, \"weak_interaction\": 1}\n",
    "            confounded_role += interaction_weights[multilayer_interaction]\n",
    "            \n",
    "\n",
    "            equiv_weights = {\"high_equivalence\": 3, \"medium_equivalence\": 2, \"low_equivalence\": 1}\n",
    "            confounded_role += equiv_weights[structural_equivalence]\n",
    "            \n",
    "\n",
    "            node_degree = G.degree(i) if i in G.nodes() else 0\n",
    "            degree_influence = min(node_degree, 5)\n",
    "            confounded_role += degree_influence\n",
    "            \n",
    "\n",
    "            final_role = confounded_role % 5\n",
    "            \n",
    "            confounded_role_id[i] = final_role\n",
    "    \n",
    "    return confounded_role_id\n",
    "\n",
    "def apply_fixed_complex_confounding_bias(original_role_id, G, confounders, confounding_prob=0.9):\n",
    "\n",
    "    confounded_role_id = original_role_id.copy()\n",
    "    \n",
    "    for i in range(len(original_role_id)):\n",
    "        if random.random() < confounding_prob:\n",
    "\n",
    "            social_status = confounders['social_status'][i]\n",
    "            citation_strength = confounders['citation_strength'][i]\n",
    "            clustering_info = confounders['clustering_info'][i]\n",
    "            neighbor_influence = confounders['neighbor_influence'][i]\n",
    "            path_centrality = confounders['path_centrality'][i]\n",
    "            temporal_correlation = confounders['temporal_correlation'][i]\n",
    "            multilayer_interaction = confounders['multilayer_interaction'][i]\n",
    "            structural_equivalence = confounders['structural_equivalence'][i]\n",
    "            \n",
    "\n",
    "            confounded_role = 0\n",
    "            \n",
    "\n",
    "            if social_status == \"elite\":\n",
    "                confounded_role += 5\n",
    "            elif social_status == \"high\":\n",
    "                confounded_role += 4\n",
    "            elif social_status == \"medium\":\n",
    "                confounded_role += 2\n",
    "            else:\n",
    "                confounded_role += 1\n",
    "            \n",
    "\n",
    "            confounded_role += int(citation_strength * 0.8)\n",
    "            \n",
    "\n",
    "            clustering_weights = {\"high_cluster\": 4, \"medium_cluster\": 2, \"low_cluster\": 1}\n",
    "            confounded_role += clustering_weights[clustering_info]\n",
    "            \n",
    "\n",
    "            confounded_role += neighbor_influence\n",
    "            \n",
    "\n",
    "            path_weights = {\"high_path\": 5, \"medium_path\": 3, \"low_path\": 1}\n",
    "            confounded_role += path_weights[path_centrality]\n",
    "            \n",
    "\n",
    "            temporal_weights = {\"future_oriented\": 4, \"present_focused\": 2, \"past_oriented\": 1}\n",
    "            confounded_role += temporal_weights[temporal_correlation]\n",
    "            \n",
    "\n",
    "            interaction_weights = {\"strong_interaction\": 6, \"moderate_interaction\": 3, \"weak_interaction\": 1}\n",
    "            confounded_role += interaction_weights[multilayer_interaction]\n",
    "            \n",
    "\n",
    "            equiv_weights = {\"high_equivalence\": 3, \"medium_equivalence\": 2, \"low_equivalence\": 1}\n",
    "            confounded_role += equiv_weights[structural_equivalence]\n",
    "            \n",
    "\n",
    "\n",
    "            confounded_role_id[i] = 2\n",
    "    \n",
    "    return confounded_role_id\n",
    "\n",
    "def generate_complex_confounded_features(G, original_role_id, confounders, confounding_prob):\n",
    "\n",
    "    features_dict = {}\n",
    "    \n",
    "\n",
    "    base_feature_params = {\n",
    "        0: {'mean': [2.0, 3.0, 1.5, 2.5, 4.0], 'std': [0.3, 0.4, 0.2, 0.3, 0.5]},\n",
    "        1: {'mean': [1.0, 2.0, 3.0, 1.5, 2.0], 'std': [0.2, 0.3, 0.4, 0.2, 0.3]},\n",
    "        2: {'mean': [3.0, 1.0, 2.0, 3.5, 1.5], 'std': [0.4, 0.2, 0.3, 0.5, 0.2]},\n",
    "        3: {'mean': [2.5, 4.0, 1.0, 2.0, 3.0], 'std': [0.3, 0.5, 0.1, 0.2, 0.4]},\n",
    "        4: {'mean': [1.5, 2.5, 4.0, 1.0, 2.5], 'std': [0.2, 0.3, 0.6, 0.1, 0.3]}\n",
    "    }\n",
    "    \n",
    "    for node_id in G.nodes():\n",
    "        if node_id < len(original_role_id):\n",
    "            original_role = original_role_id[node_id]\n",
    "            \n",
    "\n",
    "            social_status = confounders['social_status'][node_id]\n",
    "            citation_strength = confounders['citation_strength'][node_id]\n",
    "            clustering_info = confounders['clustering_info'][node_id]\n",
    "            neighbor_influence = confounders['neighbor_influence'][node_id]\n",
    "            path_centrality = confounders['path_centrality'][node_id]\n",
    "            temporal_correlation = confounders['temporal_correlation'][node_id]\n",
    "            multilayer_interaction = confounders['multilayer_interaction'][node_id]\n",
    "            structural_equivalence = confounders['structural_equivalence'][node_id]\n",
    "            \n",
    "\n",
    "            base_mean = base_feature_params[original_role]['mean']\n",
    "            base_std = base_feature_params[original_role]['std']\n",
    "            features = np.random.normal(base_mean, base_std)\n",
    "            \n",
    "\n",
    "            if random.random() < confounding_prob:\n",
    "\n",
    "                confounding_bias = np.zeros(5)\n",
    "                \n",
    "\n",
    "                status_effects = {\n",
    "                    \"elite\": np.array([6.0, 8.0, 2.0, 4.0, 9.0]),\n",
    "                    \"high\": np.array([4.0, 5.0, 1.5, 2.5, 6.0]),\n",
    "                    \"medium\": np.array([2.0, 2.5, 3.0, 2.0, 3.0]),\n",
    "                    \"low\": np.array([-1.0, -1.5, 5.0, -0.5, 1.5])\n",
    "                }\n",
    "                confounding_bias += status_effects[social_status]\n",
    "                \n",
    "\n",
    "                citation_bias = np.array([\n",
    "                    citation_strength * 0.8,\n",
    "                    citation_strength * 0.6,\n",
    "                    citation_strength * -0.4,\n",
    "                    citation_strength * 1.0,\n",
    "                    citation_strength * 0.7\n",
    "                ])\n",
    "                confounding_bias += citation_bias\n",
    "                \n",
    "\n",
    "                clustering_effects = {\n",
    "                    \"high_cluster\": np.array([3.5, 4.0, -2.0, 3.0, 4.5]),\n",
    "                    \"medium_cluster\": np.array([1.5, 2.0, 0.5, 1.5, 2.0]),\n",
    "                    \"low_cluster\": np.array([-1.0, -1.0, 3.0, -0.5, -1.0])\n",
    "                }\n",
    "                confounding_bias += clustering_effects[clustering_info]\n",
    "                \n",
    "\n",
    "                neighbor_bias = np.array([\n",
    "                    neighbor_influence * 0.5,\n",
    "                    neighbor_influence * 0.7,\n",
    "                    neighbor_influence * -0.3,\n",
    "                    neighbor_influence * 0.8,\n",
    "                    neighbor_influence * 0.6\n",
    "                ])\n",
    "                confounding_bias += neighbor_bias\n",
    "                \n",
    "\n",
    "                path_effects = {\n",
    "                    \"high_path\": np.array([4.0, 5.0, -1.5, 4.5, 5.5]),\n",
    "                    \"medium_path\": np.array([2.0, 2.5, 0.5, 2.0, 2.5]),\n",
    "                    \"low_path\": np.array([0.5, 0.5, 1.5, 0.5, 0.5])\n",
    "                }\n",
    "                confounding_bias += path_effects[path_centrality]\n",
    "                \n",
    "\n",
    "                temporal_effects = {\n",
    "                    \"future_oriented\": np.array([3.0, 4.0, -1.0, 3.5, 4.0]),\n",
    "                    \"present_focused\": np.array([1.0, 1.5, 1.0, 1.5, 1.5]),\n",
    "                    \"past_oriented\": np.array([-1.0, -1.5, 3.0, -1.0, -1.5])\n",
    "                }\n",
    "                confounding_bias += temporal_effects[temporal_correlation]\n",
    "                \n",
    "\n",
    "                interaction_effects = {\n",
    "                    \"strong_interaction\": np.array([5.0, 6.0, -2.0, 5.5, 6.5]),\n",
    "                    \"moderate_interaction\": np.array([2.5, 3.0, 0.5, 2.5, 3.0]),\n",
    "                    \"weak_interaction\": np.array([0.5, 0.5, 2.0, 0.5, 0.5])\n",
    "                }\n",
    "                confounding_bias += interaction_effects[multilayer_interaction]\n",
    "                \n",
    "\n",
    "                equiv_effects = {\n",
    "                    \"high_equivalence\": np.array([2.5, 3.0, -1.0, 2.5, 3.0]),\n",
    "                    \"medium_equivalence\": np.array([1.0, 1.5, 0.5, 1.0, 1.5]),\n",
    "                    \"low_equivalence\": np.array([0.0, 0.0, 1.0, 0.0, 0.0])\n",
    "                }\n",
    "                confounding_bias += equiv_effects[structural_equivalence]\n",
    "                \n",
    "\n",
    "                node_degree = G.degree(node_id) if node_id in G.nodes() else 0\n",
    "                degree_bias = np.array([\n",
    "                    node_degree * 0.3,\n",
    "                    node_degree * 0.4,\n",
    "                    node_degree * -0.2,\n",
    "                    node_degree * 0.5,\n",
    "                    node_degree * 0.3\n",
    "                ])\n",
    "                confounding_bias += degree_bias\n",
    "                \n",
    "\n",
    "                features += confounding_bias * 4.0\n",
    "            \n",
    "            features_dict[node_id] = features\n",
    "            G.nodes[node_id]['features'] = features\n",
    "    \n",
    "    return features_dict\n",
    "\n",
    "def generate_fixed_complex_confounded_features(G, original_role_id, confounders, confounding_prob):\n",
    "\n",
    "    features_dict = {}\n",
    "    \n",
    "\n",
    "    base_feature_params = {\n",
    "        0: {'mean': [2.0, 3.0, 1.5, 2.5, 4.0], 'std': [0.3, 0.4, 0.2, 0.3, 0.5]},\n",
    "        1: {'mean': [1.0, 2.0, 3.0, 1.5, 2.0], 'std': [0.2, 0.3, 0.4, 0.2, 0.3]},\n",
    "        2: {'mean': [3.0, 1.0, 2.0, 3.5, 1.5], 'std': [0.4, 0.2, 0.3, 0.5, 0.2]},\n",
    "        3: {'mean': [2.5, 4.0, 1.0, 2.0, 3.0], 'std': [0.3, 0.5, 0.1, 0.2, 0.4]},\n",
    "        4: {'mean': [1.5, 2.5, 4.0, 1.0, 2.5], 'std': [0.2, 0.3, 0.6, 0.1, 0.3]}\n",
    "    }\n",
    "    \n",
    "    for node_id in G.nodes():\n",
    "        if node_id < len(original_role_id):\n",
    "            original_role = original_role_id[node_id]\n",
    "            \n",
    "\n",
    "            social_status = confounders['social_status'][node_id]\n",
    "            citation_strength = confounders['citation_strength'][node_id]\n",
    "            clustering_info = confounders['clustering_info'][node_id]\n",
    "            neighbor_influence = confounders['neighbor_influence'][node_id]\n",
    "            path_centrality = confounders['path_centrality'][node_id]\n",
    "            temporal_correlation = confounders['temporal_correlation'][node_id]\n",
    "            multilayer_interaction = confounders['multilayer_interaction'][node_id]\n",
    "            structural_equivalence = confounders['structural_equivalence'][node_id]\n",
    "            \n",
    "\n",
    "            base_mean = base_feature_params[original_role]['mean']\n",
    "            base_std = base_feature_params[original_role]['std']\n",
    "            features = np.random.normal(base_mean, base_std)\n",
    "            \n",
    "\n",
    "            if random.random() < confounding_prob:\n",
    "\n",
    "                confounding_bias = np.zeros(5)\n",
    "                \n",
    "\n",
    "                status_effects = {\n",
    "                    \"elite\": np.array([6.0, 8.0, 2.0, 4.0, 9.0]),\n",
    "                    \"high\": np.array([4.0, 5.0, 1.5, 2.5, 6.0]),\n",
    "                    \"medium\": np.array([2.0, 2.5, 3.0, 2.0, 3.0]),\n",
    "                    \"low\": np.array([-1.0, -1.5, 5.0, -0.5, 1.5])\n",
    "                }\n",
    "                confounding_bias += status_effects[social_status]\n",
    "                \n",
    "\n",
    "                citation_bias = np.array([\n",
    "                    citation_strength * 0.8,\n",
    "                    citation_strength * 0.6,\n",
    "                    citation_strength * -0.4,\n",
    "                    citation_strength * 1.0,\n",
    "                    citation_strength * 0.7\n",
    "                ])\n",
    "                confounding_bias += citation_bias\n",
    "                \n",
    "\n",
    "                clustering_effects = {\n",
    "                    \"high_cluster\": np.array([3.5, 4.0, -2.0, 3.0, 4.5]),\n",
    "                    \"medium_cluster\": np.array([1.5, 2.0, 0.5, 1.5, 2.0]),\n",
    "                    \"low_cluster\": np.array([-1.0, -1.0, 3.0, -0.5, -1.0])\n",
    "                }\n",
    "                confounding_bias += clustering_effects[clustering_info]\n",
    "                \n",
    "\n",
    "                neighbor_bias = np.array([\n",
    "                    neighbor_influence * 0.5,\n",
    "                    neighbor_influence * 0.7,\n",
    "                    neighbor_influence * -0.3,\n",
    "                    neighbor_influence * 0.8,\n",
    "                    neighbor_influence * 0.6\n",
    "                ])\n",
    "                confounding_bias += neighbor_bias\n",
    "                \n",
    "\n",
    "                path_effects = {\n",
    "                    \"high_path\": np.array([4.0, 5.0, -1.5, 4.5, 5.5]),\n",
    "                    \"medium_path\": np.array([2.0, 2.5, 0.5, 2.0, 2.5]),\n",
    "                    \"low_path\": np.array([0.5, 0.5, 1.5, 0.5, 0.5])\n",
    "                }\n",
    "                confounding_bias += path_effects[path_centrality]\n",
    "                \n",
    "\n",
    "                temporal_effects = {\n",
    "                    \"future_oriented\": np.array([3.0, 4.0, -1.0, 3.5, 4.0]),\n",
    "                    \"present_focused\": np.array([1.0, 1.5, 1.0, 1.5, 1.5]),\n",
    "                    \"past_oriented\": np.array([-1.0, -1.5, 3.0, -1.0, -1.5])\n",
    "                }\n",
    "                confounding_bias += temporal_effects[temporal_correlation]\n",
    "                \n",
    "\n",
    "                interaction_effects = {\n",
    "                    \"strong_interaction\": np.array([5.0, 6.0, -2.0, 5.5, 6.5]),\n",
    "                    \"moderate_interaction\": np.array([2.5, 3.0, 0.5, 2.5, 3.0]),\n",
    "                    \"weak_interaction\": np.array([0.5, 0.5, 2.0, 0.5, 0.5])\n",
    "                }\n",
    "                confounding_bias += interaction_effects[multilayer_interaction]\n",
    "                \n",
    "\n",
    "                equiv_effects = {\n",
    "                    \"high_equivalence\": np.array([2.5, 3.0, -1.0, 2.5, 3.0]),\n",
    "                    \"medium_equivalence\": np.array([1.0, 1.5, 0.5, 1.0, 1.5]),\n",
    "                    \"low_equivalence\": np.array([0.0, 0.0, 1.0, 0.0, 0.0])\n",
    "                }\n",
    "                confounding_bias += equiv_effects[structural_equivalence]\n",
    "                \n",
    "\n",
    "                features += confounding_bias * random.uniform(-20, 20) + random.uniform(-30, 30)\n",
    "            \n",
    "            features_dict[node_id] = features\n",
    "            G.nodes[node_id]['features'] = features\n",
    "    \n",
    "    return features_dict\n",
    "\n",
    "def generate_complex_intervened_dataset(confounding_prob=0.7, intervention_prob=0.1, dataset_type='train', num_samples=2000):\n",
    "\n",
    "    edge_index_list, label_list = [], []\n",
    "    ground_truth_list, role_id_list, pos_list, features_list = [], [], [], []\n",
    "    confounding_info_list = []\n",
    "    e_mean, n_mean = [], []\n",
    "    \n",
    "    for i in tqdm(range(num_samples)):\n",
    "        try:\n",
    "\n",
    "            G, original_role_id, label = generate_false_cause_dataset1()\n",
    "            \n",
    "\n",
    "            if intervention_prob == 0.0:\n",
    "\n",
    "                confounders = create_complex_confounding_variables(original_role_id, G, confounding_prob)\n",
    "                confounded_role_id = apply_complex_confounding_bias(original_role_id, G, confounders, confounding_prob)\n",
    "                features_dict = generate_complex_confounded_features(G, confounded_role_id, confounders, confounding_prob)\n",
    "                \n",
    "            elif intervention_prob == 1.0:\n",
    "\n",
    "                confounders = create_fixed_intervention_complex_confounders(len(original_role_id))\n",
    "                confounded_role_id = apply_fixed_complex_confounding_bias(original_role_id, G, confounders, 0.9)\n",
    "                features_dict = generate_fixed_complex_confounded_features(G, original_role_id, confounders, 0.9)\n",
    "                \n",
    "            else:\n",
    "\n",
    "\n",
    "                if random.random() < intervention_prob:\n",
    "\n",
    "                    confounders = create_fixed_intervention_complex_confounders(len(original_role_id))\n",
    "                    confounded_role_id = apply_fixed_complex_confounding_bias(original_role_id, G, confounders, 0.9)\n",
    "                    features_dict = generate_fixed_complex_confounded_features(G, original_role_id, confounders, 0.9)\n",
    "                else:\n",
    "\n",
    "                    confounders = create_complex_confounding_variables(original_role_id, G, confounding_prob)\n",
    "                    confounded_role_id = apply_complex_confounding_bias(original_role_id, G, confounders, confounding_prob)\n",
    "                    features_dict = generate_complex_confounded_features(G, confounded_role_id, confounders, confounding_prob)\n",
    "            \n",
    "            label_list.append(label)\n",
    "            e_mean.append(len(G.edges))\n",
    "            n_mean.append(len(G.nodes))\n",
    "            \n",
    "\n",
    "            role_id_list.append(np.array(confounded_role_id))\n",
    "            \n",
    "            if G.number_of_edges() > 0:\n",
    "                edge_index = np.array(list(G.edges), dtype=int).T\n",
    "            else:\n",
    "                edge_index = np.array([[], []], dtype=int)\n",
    "            \n",
    "            edge_index_list.append(edge_index)\n",
    "            \n",
    "            try:\n",
    "                pos = nx.spring_layout(G) if G.number_of_nodes() <= 1000 else nx.random_layout(G)\n",
    "                pos_list.append(np.array(list(pos.values())))\n",
    "            except:\n",
    "                pos_list.append(np.array([]))\n",
    "            \n",
    "            try:\n",
    "                if edge_index.size > 0:\n",
    "                    row, col = edge_index\n",
    "                    original_role_id_array = np.array(original_role_id)\n",
    "                    gd = np.array(original_role_id_array[row] > 0, dtype=np.float64) * np.array(original_role_id_array[col] > 0, dtype=np.float64)\n",
    "                else:\n",
    "                    gd = np.array([])\n",
    "                ground_truth_list.append(gd)\n",
    "            except:\n",
    "                ground_truth_list.append(np.array([]))\n",
    "            \n",
    "\n",
    "            if G.number_of_nodes() > 0:\n",
    "                feat_mat = np.vstack([\n",
    "                    features_dict.get(n, np.zeros(5))\n",
    "                    for n in sorted(G.nodes())\n",
    "                ])\n",
    "                features_list.append(feat_mat)\n",
    "            else:\n",
    "                features_list.append(np.zeros((0, 5)))\n",
    "            \n",
    "\n",
    "            confounding_strength = np.mean([\n",
    "                1 if original_role_id[j] != confounded_role_id[j] else 0\n",
    "                for j in range(len(original_role_id))\n",
    "            ])\n",
    "            \n",
    "\n",
    "            network_complexity = {\n",
    "                'avg_degree': np.mean([G.degree(n) for n in G.nodes()]) if G.number_of_nodes() > 0 else 0,\n",
    "                'clustering_coefficient': nx.average_clustering(G.to_undirected()) if G.number_of_nodes() > 0 else 0,\n",
    "                'density': nx.density(G) if G.number_of_nodes() > 0 else 0\n",
    "            }\n",
    "            \n",
    "            confounding_info_list.append({\n",
    "                'confounding_prob': confounding_prob,\n",
    "                'intervention_prob': intervention_prob,\n",
    "                'dataset_type': dataset_type,\n",
    "                'original_role_id': original_role_id,\n",
    "                'confounded_role_id': confounded_role_id.tolist(),\n",
    "                'confounding_strength': confounding_strength,\n",
    "                'confounders': confounders,\n",
    "                'network_complexity': network_complexity,\n",
    "                'label': label\n",
    "            })\n",
    "            \n",
    "        except Exception as e:\n",
    "            continue\n",
    "    \n",
    "    dataset_dict = {\n",
    "        'features': features_list,\n",
    "        'edge_index': edge_index_list,\n",
    "        'label': label_list,\n",
    "        'ground_truth': ground_truth_list,\n",
    "        'role_id': role_id_list,\n",
    "        'pos': pos_list,\n",
    "        'confounding_info': confounding_info_list,\n",
    "        'confounding_prob': confounding_prob,\n",
    "        'intervention_prob': intervention_prob,\n",
    "        'dataset_type': dataset_type\n",
    "    }\n",
    "    \n",
    "    return dataset_dict\n",
    "\n",
    "def generate_complex_intervened_experiment_datasets(intervention_prob=0.1, base_dir='./data/casual/'):\n",
    "\n",
    "    import os\n",
    "    os.makedirs(base_dir, exist_ok=True)\n",
    "    \n",
    "    generated_files = []\n",
    "    confounding_prob = 0.7\n",
    "    \n",
    "\n",
    "    train_dataset = generate_complex_intervened_dataset(\n",
    "        confounding_prob=confounding_prob,\n",
    "        intervention_prob=1.0,\n",
    "        dataset_type='train',\n",
    "        num_samples=3000\n",
    "    )\n",
    "    train_file = f'{base_dir}/train_casual_2_4.npy'\n",
    "    np.save(train_file, train_dataset)\n",
    "    generated_files.append(train_file)\n",
    "    \n",
    "\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "    \n",
    "\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "    \n",
    "    return generated_files\n",
    "\n",
    "def run_complex_intervened_experiment(intervention_prob=0.1):\n",
    "\n",
    "    generated_files = generate_complex_intervened_experiment_datasets(intervention_prob)\n",
    "    return generated_files\n",
    "\n",
    "if __name__ == \"__main__\":\n",
    "\n",
    "    generated_files = run_complex_intervened_experiment(intervention_prob=0.2)\n",
    "    print(f\"Generated files: {generated_files}\")\n",
    "    \n",
    "\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "\n"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "CRCG",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.16"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}