{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.nn.functional as F\n",
    "import matplotlib.pyplot as plt\n",
    "import seaborn as sns\n",
    "import numpy as np\n",
    "import dgl\n",
    "from dgl.nn import ChebConv\n",
    "import networkx as nx\n",
    "import networkx.algorithms.community as nx_comm\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "from tqdm.notebook import tqdm\n",
    "from functools import reduce\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "def node_connects_cluster(node):\n",
    "    return set(map(lambda x: inverse_cluster_dict[x], list(g[node]))).union(set([inverse_cluster_dict[node]]))\n",
    "\n",
    "\n",
    "# path = 'Dataset/socfb-Caltech36.mtx'\n",
    "path = 'Dataset/socfb-Stanford3.mtx'\n",
    "\n",
    "df = pd.read_table(path, skiprows=1, names = [\"source\", \"target\"], sep=\" \")\n",
    "g = nx.from_pandas_edgelist(df)\n",
    "\n",
    "# calculate basic elements\n",
    "num_nodes = g.number_of_nodes()\n",
    "num_edges = g.number_of_edges()\n",
    "degs = [g.degree[i] for i in g.nodes]\n",
    "avg_deg = sum(degs)/len(degs)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "dgl_G = dgl.from_networkx(g)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "A = np.array(nx.adjacency_matrix(g).todense(), dtype = np.float64)\n",
    "deg_array = np.array(list(dict(g.degree).values()))\n",
    "D_inv_A = np.zeros_like(A)\n",
    "for i in range(num_nodes):\n",
    "    D_inv_A[i] = A[i] / deg_array[i]\n",
    "    \n",
    "    \n",
    "    \n",
    "multi_hop_A = np.linalg.matrix_power(D_inv_A, 2)\n",
    "# # torch.save(multi_hop_A, \"A_2hop.pkl\")\n",
    "# multi_hop_A = torch.load(\"A_2hop.pkl\")\n",
    "\n",
    "# # set diagonal of 2-hop adjacency to 0\n",
    "for i in range(num_nodes):\n",
    "    multi_hop_A[i, i] = 0\n",
    "\n",
    "# torch.save(multi_hop_A, \"A_2hop.pkl\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [],
   "source": [
    "node_list = list(g.nodes.keys())\n",
    "node_list = np.array(node_list)\n",
    "\n",
    "# from node to index in node_list\n",
    "node_list_index_dict = {}\n",
    "for node in g.nodes:\n",
    "    node_list_index_dict[node] = np.where(node_list == node)[0][0]\n",
    "    \n",
    "\n",
    "# from graph index to feature vector index\n",
    "dict_from_graph_to_feat = {}\n",
    "for i, node in enumerate(g.nodes):\n",
    "    dict_from_graph_to_feat[node] = i     \n",
    "    \n",
    "\n",
    "\n",
    "\n",
    "\n",
    "def po_2hop_linear_model(graph, z_vec, alpha=0, beta=1, sigma=0.1, gamma=1, r1=1, r2=0, cov_weight=1):         \n",
    "    g_vec = alpha + beta * z_vec + cov_weight * (normalized_deg_array.reshape(-1,1) + number_of_connected_clusters_array.reshape(-1,1) ) * z_vec + gamma * (\n",
    "    r1 * np.matmul(D_inv_A, z_vec) + r2 * np.matmul(multi_hop_A, z_vec)) # is_inner_node_array\n",
    "    # print(g_vec.shape)\n",
    "    y_vec = g_vec + sigma * np.random.normal(size=(num_nodes, 1))\n",
    "    for i in range(num_nodes):\n",
    "        graph.nodes[node_list[i]][\"g\"] = g_vec[i][0]\n",
    "        graph.nodes[node_list[i]][\"y\"] = y_vec[i][0]\n",
    "        graph.nodes[node_list[i]][\"z\"] = z_vec[i][0]\n",
    "\n",
    "\n",
    "def po_2hop_linear_model_gt(graph, z_vec, alpha=0, beta=1, sigma=0.1, gamma=1, r1=1, r2=0, cov_weight=1):\n",
    "    z_vec = z_vec - z_vec + 1   \n",
    "    g_vec = alpha + beta * z_vec + cov_weight * (normalized_deg_array + number_of_connected_clusters_array) * z_vec + gamma * (\n",
    "    r1 * np.matmul(D_inv_A, z_vec) + r2 * np.matmul(multi_hop_A, z_vec)) # is_inner_node_array\n",
    "    return g_vec\n",
    "        \n",
    "def po_2hop_linear_model_without_covariate(graph, z_vec, alpha=0, beta=1, sigma=0.1, gamma=1, r1=1, r2=0):         \n",
    "    y_vec = alpha + beta * z_vec + sigma * np.random.normal(size=(num_nodes, 1)) + gamma * (\n",
    "    r1 * np.matmul(D_inv_A, z_vec) + r2 * np.matmul(multi_hop_A, z_vec)\n",
    ") \n",
    "    for i in range(num_nodes):\n",
    "        graph.nodes[node_list[i]][\"y\"] = y_vec[i][0]\n",
    "        graph.nodes[node_list[i]][\"z\"] = z_vec[i][0]        "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "class GCN(nn.Module):\n",
    "    def __init__(self, num_inputs=2):\n",
    "        super().__init__()\n",
    "        self.conv1 = dgl.nn.ChebConv(num_inputs, 16, 2)\n",
    "        self.conv2 = dgl.nn.ChebConv(16, 16, 1)\n",
    "        self.conv3 = dgl.nn.ChebConv(16, 1, 1)\n",
    "        \n",
    "    def forward(self, g, features):\n",
    "        x = self.conv1(g, features)\n",
    "        x = self.conv2(g, x)\n",
    "        x = self.conv3(g, x)\n",
    "        return x\n",
    "        "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "TRUE GATE: 4.000\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "8c7e64294a4745b5beb165308c9b710b",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/1000 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Treatment Proportion: 0.1\n",
      "Bias\t Hajek: -1.856\t CAE: -1.886\t MII: -1.892\t GNN: -1.891\t PPI: -1.030\n",
      "Std\t Hajek: 0.408\t CAE: 0.455\t MII: 0.302\t GNN: 1.406\t PPI: 0.743\n",
      "MSE\t Hajek: 3.611\t CAE: 3.765\t MII: 3.672\t GNN: 5.553\t PPI: 1.611\n",
      "\n",
      "\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "9284425d1092479b87eb88bc19c54964",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/1000 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Treatment Proportion: 0.3\n",
      "Bias\t Hajek: -1.754\t CAE: -1.875\t MII: -1.887\t GNN: -0.718\t PPI: -0.361\n",
      "Std\t Hajek: 0.301\t CAE: 0.237\t MII: 0.130\t GNN: 0.225\t PPI: 0.198\n",
      "MSE\t Hajek: 3.169\t CAE: 3.574\t MII: 3.579\t GNN: 0.567\t PPI: 0.169\n",
      "\n",
      "\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "d493f77474994b198c7b41d4b3125d3c",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/1000 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Treatment Proportion: 0.5\n",
      "Bias\t Hajek: -1.660\t CAE: -1.844\t MII: -1.882\t GNN: -0.486\t PPI: -0.256\n",
      "Std\t Hajek: 0.230\t CAE: 0.160\t MII: 0.092\t GNN: 0.178\t PPI: 0.130\n",
      "MSE\t Hajek: 2.809\t CAE: 3.426\t MII: 3.552\t GNN: 0.268\t PPI: 0.082\n",
      "\n",
      "\n",
      "TRUE GATE: 4.978\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "a547cd8eed724a23acbb18971d712695",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/1000 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Treatment Proportion: 0.1\n",
      "Bias\t Hajek: -2.327\t CAE: -2.411\t MII: -2.368\t GNN: -2.363\t PPI: -1.444\n",
      "Std\t Hajek: 0.436\t CAE: 0.460\t MII: 0.333\t GNN: 1.429\t PPI: 0.760\n",
      "MSE\t Hajek: 5.605\t CAE: 6.024\t MII: 5.719\t GNN: 7.629\t PPI: 2.663\n",
      "\n",
      "\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "729956ebf00a4298b3094f37592daed3",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/1000 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Treatment Proportion: 0.3\n",
      "Bias\t Hajek: -2.086\t CAE: -2.288\t MII: -2.237\t GNN: -1.184\t PPI: -0.805\n",
      "Std\t Hajek: 0.329\t CAE: 0.244\t MII: 0.152\t GNN: 0.312\t PPI: 0.218\n",
      "MSE\t Hajek: 4.460\t CAE: 5.294\t MII: 5.029\t GNN: 1.499\t PPI: 0.696\n",
      "\n",
      "\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "2036bf3936f848d8bbb2e919d6365f36",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/1000 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Treatment Proportion: 0.5\n",
      "Bias\t Hajek: -1.893\t CAE: -2.148\t MII: -2.143\t GNN: -0.807\t PPI: -0.582\n",
      "Std\t Hajek: 0.259\t CAE: 0.171\t MII: 0.104\t GNN: 0.201\t PPI: 0.149\n",
      "MSE\t Hajek: 3.650\t CAE: 4.641\t MII: 4.602\t GNN: 0.691\t PPI: 0.360\n",
      "\n",
      "\n",
      "TRUE GATE: 4.000\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "6697a0b9f2a846cbb9b63fd4135add98",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/1000 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Treatment Proportion: 0.1\n",
      "Bias\t Hajek: -1.912\t CAE: -1.916\t MII: -1.922\t GNN: -1.432\t PPI: -0.777\n",
      "Std\t Hajek: 0.558\t CAE: 0.342\t MII: 0.243\t GNN: 1.132\t PPI: 0.629\n",
      "MSE\t Hajek: 3.967\t CAE: 3.789\t MII: 3.752\t GNN: 3.331\t PPI: 0.999\n",
      "\n",
      "\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "ef297ea3309f4f63a0f85943ad6057d9",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/1000 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Treatment Proportion: 0.3\n",
      "Bias\t Hajek: -1.816\t CAE: -1.913\t MII: -1.925\t GNN: -0.754\t PPI: -0.320\n",
      "Std\t Hajek: 0.430\t CAE: 0.193\t MII: 0.131\t GNN: 0.244\t PPI: 0.205\n",
      "MSE\t Hajek: 3.482\t CAE: 3.696\t MII: 3.722\t GNN: 0.628\t PPI: 0.145\n",
      "\n",
      "\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "76fcdb076ba842aeacbb7214422b765e",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/1000 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Treatment Proportion: 0.5\n",
      "Bias\t Hajek: -1.735\t CAE: -1.904\t MII: -1.925\t GNN: -0.492\t PPI: -0.195\n",
      "Std\t Hajek: 0.381\t CAE: 0.136\t MII: 0.092\t GNN: 0.147\t PPI: 0.115\n",
      "MSE\t Hajek: 3.154\t CAE: 3.642\t MII: 3.713\t GNN: 0.264\t PPI: 0.052\n",
      "\n",
      "\n",
      "TRUE GATE: 4.978\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "5c80869e20384f2a94ab9f2506476846",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/1000 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Treatment Proportion: 0.1\n",
      "Bias\t Hajek: -2.426\t CAE: -2.525\t MII: -2.438\t GNN: -1.877\t PPI: -1.203\n",
      "Std\t Hajek: 0.580\t CAE: 0.346\t MII: 0.277\t GNN: 0.871\t PPI: 0.526\n",
      "MSE\t Hajek: 6.223\t CAE: 6.494\t MII: 6.019\t GNN: 4.282\t PPI: 1.725\n",
      "\n",
      "\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "3b32baeaa9e9496383314878b1a7d3d4",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/1000 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Treatment Proportion: 0.3\n",
      "Bias\t Hajek: -2.181\t CAE: -2.392\t MII: -2.305\t GNN: -1.221\t PPI: -0.785\n",
      "Std\t Hajek: 0.453\t CAE: 0.195\t MII: 0.157\t GNN: 0.260\t PPI: 0.205\n",
      "MSE\t Hajek: 4.963\t CAE: 5.760\t MII: 5.337\t GNN: 1.559\t PPI: 0.659\n",
      "\n",
      "\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "19995c23cf2b46d38e95e8527e74f8f7",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/1000 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Treatment Proportion: 0.5\n",
      "Bias\t Hajek: -1.989\t CAE: -2.249\t MII: -2.202\t GNN: -0.815\t PPI: -0.538\n",
      "Std\t Hajek: 0.392\t CAE: 0.139\t MII: 0.102\t GNN: 0.136\t PPI: 0.117\n",
      "MSE\t Hajek: 4.112\t CAE: 5.077\t MII: 4.859\t GNN: 0.682\t PPI: 0.303\n",
      "\n",
      "\n",
      "TRUE GATE: 4.000\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "f9c5a9d17d204024869bd4ee11a78ca8",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/1000 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Treatment Proportion: 0.1\n",
      "Bias\t Hajek: -1.909\t CAE: -1.957\t MII: -1.948\t GNN: -1.510\t PPI: -0.793\n",
      "Std\t Hajek: 0.593\t CAE: 0.349\t MII: 0.242\t GNN: 1.120\t PPI: 0.649\n",
      "MSE\t Hajek: 3.995\t CAE: 3.950\t MII: 3.854\t GNN: 3.535\t PPI: 1.050\n",
      "\n",
      "\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "b6407c30206c429e9497cfe7ac6cc560",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/1000 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Treatment Proportion: 0.3\n",
      "Bias\t Hajek: -1.849\t CAE: -1.951\t MII: -1.945\t GNN: -0.782\t PPI: -0.289\n",
      "Std\t Hajek: 0.472\t CAE: 0.192\t MII: 0.129\t GNN: 0.221\t PPI: 0.172\n",
      "MSE\t Hajek: 3.642\t CAE: 3.843\t MII: 3.801\t GNN: 0.660\t PPI: 0.113\n",
      "\n",
      "\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "1acf2291838349d8abae381770b0381c",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/1000 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Treatment Proportion: 0.5\n",
      "Bias\t Hajek: -1.766\t CAE: -1.930\t MII: -1.942\t GNN: -0.514\t PPI: -0.170\n",
      "Std\t Hajek: 0.411\t CAE: 0.140\t MII: 0.098\t GNN: 0.160\t PPI: 0.117\n",
      "MSE\t Hajek: 3.287\t CAE: 3.746\t MII: 3.781\t GNN: 0.290\t PPI: 0.043\n",
      "\n",
      "\n",
      "TRUE GATE: 4.978\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "d85526a551cd48709d8668f3415d3ffe",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/1000 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Treatment Proportion: 0.1\n",
      "Bias\t Hajek: -2.448\t CAE: -2.603\t MII: -2.487\t GNN: -1.910\t PPI: -1.221\n",
      "Std\t Hajek: 0.608\t CAE: 0.353\t MII: 0.266\t GNN: 0.804\t PPI: 0.529\n",
      "MSE\t Hajek: 6.364\t CAE: 6.899\t MII: 6.255\t GNN: 4.293\t PPI: 1.770\n",
      "\n",
      "\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "289efa99154541e9858e2f4aa461ef5e",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/1000 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Treatment Proportion: 0.3\n",
      "Bias\t Hajek: -2.239\t CAE: -2.456\t MII: -2.350\t GNN: -1.274\t PPI: -0.799\n",
      "Std\t Hajek: 0.485\t CAE: 0.194\t MII: 0.142\t GNN: 0.267\t PPI: 0.204\n",
      "MSE\t Hajek: 5.246\t CAE: 6.071\t MII: 5.544\t GNN: 1.693\t PPI: 0.680\n",
      "\n",
      "\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "108627ff5ec2422bbfaf85fa7f666413",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/1000 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Treatment Proportion: 0.5\n",
      "Bias\t Hajek: -2.040\t CAE: -2.296\t MII: -2.242\t GNN: -0.856\t PPI: -0.545\n",
      "Std\t Hajek: 0.423\t CAE: 0.141\t MII: 0.104\t GNN: 0.133\t PPI: 0.117\n",
      "MSE\t Hajek: 4.340\t CAE: 5.293\t MII: 5.036\t GNN: 0.751\t PPI: 0.311\n",
      "\n",
      "\n"
     ]
    }
   ],
   "source": [
    "r2 = 0\n",
    "sigma = 2\n",
    "repeat_num = 1000\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "for res in [2, 5, 10]:\n",
    "    \n",
    "    # clustering\n",
    "    # generally, we fix the outcome of clustering\n",
    "    clusters = nx_comm.louvain_communities(g, seed = 10, resolution=res)\n",
    "    clusters = sorted(clusters, key = len, reverse=True)\n",
    "    cluster_sizes = list(map(len, clusters))\n",
    "    num_cluster = len(clusters)\n",
    "    \n",
    "    # dict: from node to its cluster\n",
    "    inverse_cluster_dict = {\n",
    "        node: cl for cl in range(num_cluster) for node in clusters[cl]\n",
    "    }\n",
    "    \n",
    "    # dict: from node to its connected cluster\n",
    "    node_to_connected_clusters = {\n",
    "        node: node_connects_cluster(node) for node in range(1, num_nodes + 1)\n",
    "    }\n",
    "    \n",
    "    for i in g.nodes:\n",
    "        g.nodes[i][\"n_cl\"] = len(node_to_connected_clusters[i])\n",
    "        g.nodes[i][\"deg\"] = g.degree[i]\n",
    "    \n",
    "    rev_cluster_sizes = cluster_sizes[::-1]\n",
    "    rev_clusters = clusters[::-1]\n",
    "    rev_cluster_sizes = np.array(rev_cluster_sizes)\n",
    "        \n",
    "    # compute interior nodes    \n",
    "    interior_nodes = []\n",
    "    for i in g.nodes:\n",
    "        interior_flag = 1\n",
    "        cluster_index = inverse_cluster_dict[i]\n",
    "        for nb in g[i]:\n",
    "            if inverse_cluster_dict[nb] != cluster_index:\n",
    "                interior_flag = 0\n",
    "                break       \n",
    "        if interior_flag == 1:\n",
    "            interior_nodes.append(i)\n",
    "        \n",
    "    boundary_nodes = np.array(list(set(range(1,num_nodes)).difference(set(interior_nodes))))    \n",
    "    boundary_nodes_feat_index = [dict_from_graph_to_feat[node] for node in boundary_nodes]\n",
    "    interior_nodes_feat_index = [dict_from_graph_to_feat[node] for node in interior_nodes]    \n",
    "        \n",
    "\n",
    "    \n",
    "    \n",
    "    normalized_deg_array = np.array(degs)/avg_deg\n",
    "    is_inner_node_array = np.zeros(num_nodes)\n",
    "    number_of_connected_clusters_array = np.zeros(num_nodes)\n",
    "    for node in g.nodes:\n",
    "        if node in interior_nodes:\n",
    "            is_inner_node_array[node_list_index_dict[node]] = 1\n",
    "        number_of_connected_clusters_array[node_list_index_dict[node]] = g.nodes[node][\"n_cl\"]\n",
    "    \n",
    "    number_of_connected_clusters_array /= number_of_connected_clusters_array.mean()\n",
    "        \n",
    "    \n",
    "\n",
    "    \n",
    "\n",
    "\n",
    "    cov_weight = 1\n",
    "    \n",
    "    for r2 in [0, 1]:\n",
    "        # true_GATE_list = []\n",
    "        # calculate true GATE\n",
    "        # for _ in range(100):\n",
    "        #     outcome_vec = po_2hop_linear_model_gt(g, np.ones((num_nodes, 1)), sigma=sigma, r2=r2) # add dimension\n",
    "        #     true_GATE_list.append(outcome_vec.mean())\n",
    "        # true_GATE = sum(true_GATE_list)/repeat_num\n",
    "        if r2 == 1:\n",
    "            true_GATE = 2.978 + cov_weight * 2\n",
    "        elif r2 == 0:\n",
    "            true_GATE = 2 + cov_weight * 2\n",
    "            \n",
    "        print(\"TRUE GATE: {:.3F}\".format(true_GATE))\n",
    "        \n",
    "        \n",
    "        for p2 in [0.1, 0.3, 0.5]: \n",
    "            \n",
    "            bias_CAE_list = []\n",
    "            bias_HT_list = []\n",
    "            bias_MII_list = []\n",
    "            bias_PPI_list = []\n",
    "            bias_GNN_list = []\n",
    "               \n",
    "            global_pred_list = np.zeros((repeat_num,))\n",
    "            boundary_pred_list = np.zeros((repeat_num,))\n",
    "            \n",
    "            p_list = [p2]\n",
    "            # p_list = [0.1, 0.5]\n",
    "                \n",
    "            for seed in tqdm(range(repeat_num)):\n",
    "                np.random.seed(seed)     \n",
    "                rollout_index = np.random.uniform(0, 1, size=(num_cluster,))           \n",
    "                torch.manual_seed(1)\n",
    "                    \n",
    "                # print(f\"Seed: {seed}\") \n",
    "                \n",
    "                g_feat_list = []\n",
    "                y_list = []\n",
    "                \n",
    "                # initialize the HT weights\n",
    "                for i in g.nodes:\n",
    "                    g.nodes[i][\"w_HT\"] = 0\n",
    "                \n",
    "                \n",
    "                \n",
    "                \n",
    "                for p in p_list:\n",
    "                    z_vector = np.zeros((num_nodes)) \n",
    "                    nx.set_node_attributes(g, 0, \"z\")            \n",
    "                    # tr_clusters = np.arange(num_cluster)[rollout_index<p]\n",
    "                    tr_clusters = np.arange(num_cluster)[rollout_index<np.quantile(rollout_index, p)] # Complete randomization\n",
    "                    if len(tr_clusters) > 0:\n",
    "                        tr_units = reduce(lambda x, y: x.union(y), [clusters[i] for i in tr_clusters])              \n",
    "                        nx.set_node_attributes(g, {unit:1 for unit in tr_units}, \"z\")\n",
    "                        for node in tr_units:\n",
    "                            z_vector[node_list_index_dict[node]] = 1\n",
    "                                                \n",
    "                    po_2hop_linear_model(g, z_vector.reshape(-1,1), sigma=sigma, r2=r2, cov_weight=cov_weight) # add dimension            \n",
    "                    # po_2hop_linear_model_without_covariate(g, z_vector.reshape(-1,1), sigma=sigma, r2=r2) # add dimension            \n",
    "                    # po_linear_model(g, gamma = 1)            \n",
    "                    \n",
    "                    g_feat_list.append(torch.tensor([[g.nodes[n]['z'], g.nodes[n]['deg']] for n in g.nodes], dtype=torch.float))\n",
    "                    y_list.append(torch.tensor([[g.nodes[n]['y']] for n in g.nodes], dtype=torch.float).reshape(-1))\n",
    "        \n",
    "                # Instantiate the model and optimizer\n",
    "                model = GCN()\n",
    "                optimizer = torch.optim.Adam(model.parameters(), lr=0.005)\n",
    "                # optimizer = torch.optim.SGD(model.parameters(), lr=0.01)\n",
    "                \n",
    "                # Train the model\n",
    "                for epoch in range(300):\n",
    "                    for i in range(len(p_list)):\n",
    "                        # model.train()\n",
    "                        optimizer.zero_grad()\n",
    "                        out = model(dgl_G, g_feat_list[i]).squeeze()\n",
    "                        loss = F.mse_loss(out, y_list[i])  \n",
    "                        # loss = F.mse_loss(out[boundary_nodes_feat_index], y_list[i][boundary_nodes_feat_index])  \n",
    "                        loss.backward()\n",
    "                        optimizer.step()\n",
    "                    \n",
    "                        # if epoch % 50 == 0:\n",
    "                        #     print(f'Epoch {epoch}, Treat Prop: {p_list[i]:.2f}, Loss: {loss:.4f}')\n",
    "        \n",
    "                \n",
    "                g_feat = g_feat_list[0].clone()\n",
    "                g_feat[:,0] = 1\n",
    "                global_treat_pred = model(dgl_G, g_feat).detach().numpy()\n",
    "                \n",
    "                # MII\n",
    "                true_Y_interior = []\n",
    "                treated_interior_feat_indices = []\n",
    "                for node in interior_nodes:\n",
    "                    if g.nodes[node][\"z\"]==1:\n",
    "                        true_Y_interior.append(g.nodes[node][\"y\"])\n",
    "                        feat_index = dict_from_graph_to_feat[node]\n",
    "                        treated_interior_feat_indices.append(feat_index)\n",
    "                \n",
    "                MII = np.mean(true_Y_interior)\n",
    "                MII_PPI = MII + global_treat_pred.mean() - global_treat_pred[treated_interior_feat_indices].mean()\n",
    "                \n",
    "                \n",
    "                # CAE & HT\n",
    "                cl_level_Y_interior = []\n",
    "                HT_exposure_weight = []\n",
    "                for cl in range(num_cluster):\n",
    "                    cl_Y = []\n",
    "                    for node in clusters[cl]:\n",
    "                        if g.nodes[node][\"z\"]==1:                    \n",
    "                            node_deg = g.nodes[node]['deg']\n",
    "                            num_treated_ngbr = sum([g.nodes[ngbr]['z'] for ngbr in g[node]])\n",
    "                            # lies in interior or clean boundary node \n",
    "                            if num_treated_ngbr == node_deg:  \n",
    "                                cl_Y.append(g.nodes[node][\"y\"])\n",
    "                                # if \n",
    "                                g.nodes[node]['w_HT'] = (1/p) ** (g.nodes[node]['n_cl'] )                                                       \n",
    "                                                    \n",
    "                    if len(cl_Y) > 0:\n",
    "                        cl_Y = np.array(cl_Y)\n",
    "                        cl_level_Y_interior.append(cl_Y.mean())\n",
    "        \n",
    "                ht_weights = np.array([g.nodes[i][\"w_HT\"] for i in g.nodes])\n",
    "                ht_y_array = np.array([g.nodes[i][\"y\"] for i in g.nodes])\n",
    "                \n",
    "                \n",
    "                cl_level_Y_interior = np.array(cl_level_Y_interior)\n",
    "                \n",
    "                HT = (ht_y_array * ht_weights).sum()/ht_weights.sum()        \n",
    "                CAE = cl_level_Y_interior.mean()                \n",
    "                EST_GNN = global_treat_pred.mean() \n",
    "                \n",
    "                \n",
    "                \n",
    "        \n",
    "                bias_HT_list.append(HT - true_GATE)\n",
    "                bias_CAE_list.append(CAE - true_GATE)                        \n",
    "                bias_MII_list.append(MII - true_GATE)    \n",
    "                bias_GNN_list.append(EST_GNN - true_GATE)\n",
    "                bias_PPI_list.append(MII_PPI - true_GATE)\n",
    "        \n",
    "            print(\"Treatment Proportion: {}\".format(p2))\n",
    "        \n",
    "            HT_array = np.array(bias_HT_list)\n",
    "            CAE_array = np.array(bias_CAE_list)\n",
    "            MII_array = np.array(bias_MII_list)\n",
    "            GNN_array = np.array(bias_GNN_list)\n",
    "            PPI_array = np.array(bias_PPI_list)\n",
    "        \n",
    "            estm_list = [HT_array, CAE_array, MII_array, GNN_array, PPI_array]\n",
    "        \n",
    "            bias_list = list(map(lambda x: x.mean(), estm_list))\n",
    "            std_list = list(map(lambda x: x.std(), estm_list))\n",
    "            mse_list = list(map(lambda x: x.mean()**2 + x.var(), estm_list))\n",
    "        \n",
    "            print(\"Bias\\t Hajek: {:.3f}\\t CAE: {:.3f}\\t MII: {:.3f}\\t GNN: {:.3f}\\t PPI: {:.3f}\".format(*bias_list))    \n",
    "            print(\"Std\\t Hajek: {:.3f}\\t CAE: {:.3f}\\t MII: {:.3f}\\t GNN: {:.3f}\\t PPI: {:.3f}\".format(*std_list))\n",
    "            print(\"MSE\\t Hajek: {:.3f}\\t CAE: {:.3f}\\t MII: {:.3f}\\t GNN: {:.3f}\\t PPI: {:.3f}\".format(*mse_list))\n",
    "            print(\"\\n\")\n",
    "            \n",
    "            torch.save(estm_list, \"result/estm_list_res{}_p{}_cov{}_r2{}.pkl\".format(res, p2, cov_weight, r2))\n",
    "            "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "myconda",
   "language": "python",
   "name": "myconda"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
