{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.nn.functional as F\n",
    "import matplotlib.pyplot as plt\n",
    "import seaborn as sns\n",
    "import numpy as np\n",
    "import dgl\n",
    "from dgl.nn import ChebConv\n",
    "import networkx as nx\n",
    "import networkx.algorithms.community as nx_comm\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "from tqdm.notebook import tqdm\n",
    "from functools import reduce\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "def node_connects_cluster(node):\n",
    "    return set(map(lambda x: inverse_cluster_dict[x], list(g[node]))).union(set([inverse_cluster_dict[node]]))\n",
    "\n",
    "\n",
    "# path = 'Dataset/socfb-Caltech36.mtx'\n",
    "path = 'Dataset/socfb-Stanford3.mtx'\n",
    "\n",
    "df = pd.read_table(path, skiprows=1, names = [\"source\", \"target\"], sep=\" \")\n",
    "g = nx.from_pandas_edgelist(df)\n",
    "\n",
    "# calculate basic elements\n",
    "num_nodes = g.number_of_nodes()\n",
    "num_edges = g.number_of_edges()\n",
    "degs = [g.degree[i] for i in g.nodes]\n",
    "avg_deg = sum(degs)/len(degs)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "dgl_G = dgl.from_networkx(g)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "A = np.array(nx.adjacency_matrix(g).todense(), dtype = np.float64)\n",
    "deg_array = np.array(list(dict(g.degree).values()))\n",
    "D_inv_A = np.zeros_like(A)\n",
    "for i in range(num_nodes):\n",
    "    D_inv_A[i] = A[i] / deg_array[i]\n",
    "    \n",
    "    \n",
    "    \n",
    "multi_hop_A = np.linalg.matrix_power(D_inv_A, 2)\n",
    "# # torch.save(multi_hop_A, \"A_2hop.pkl\")\n",
    "# multi_hop_A = torch.load(\"A_2hop.pkl\")\n",
    "\n",
    "# # set diagonal of 2-hop adjacency to 0\n",
    "for i in range(num_nodes):\n",
    "    multi_hop_A[i, i] = 0\n",
    "\n",
    "# torch.save(multi_hop_A, \"A_2hop.pkl\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "node_list = list(g.nodes.keys())\n",
    "node_list = np.array(node_list)\n",
    "\n",
    "# from node to index in node_list\n",
    "node_list_index_dict = {}\n",
    "for node in g.nodes:\n",
    "    node_list_index_dict[node] = np.where(node_list == node)[0][0]\n",
    "    \n",
    "\n",
    "# from graph index to feature vector index\n",
    "dict_from_graph_to_feat = {}\n",
    "for i, node in enumerate(g.nodes):\n",
    "    dict_from_graph_to_feat[node] = i     \n",
    "    \n",
    "\n",
    "\n",
    "\n",
    "\n",
    "def po_2hop_linear_model(graph, z_vec, alpha=0, beta=1, sigma=0.1, gamma=1, r1=1, r2=0, cov_weight=1):         \n",
    "    g_vec = alpha + beta * z_vec + cov_weight * (normalized_deg_array.reshape(-1,1) + number_of_connected_clusters_array.reshape(-1,1) ) * z_vec + gamma * (\n",
    "    r1 * np.matmul(D_inv_A, z_vec) + r2 * np.matmul(multi_hop_A, z_vec)) # is_inner_node_array\n",
    "    # print(g_vec.shape)\n",
    "    y_vec = g_vec + sigma * np.random.normal(size=(num_nodes, 1))\n",
    "    for i in range(num_nodes):\n",
    "        graph.nodes[node_list[i]][\"g\"] = g_vec[i][0]\n",
    "        graph.nodes[node_list[i]][\"y\"] = y_vec[i][0]\n",
    "        graph.nodes[node_list[i]][\"z\"] = z_vec[i][0]\n",
    "\n",
    "\n",
    "# def po_2hop_linear_model_gt(graph, z_vec, alpha=0, beta=1, sigma=0.1, gamma=1, r1=1, r2=0, cov_weight=1):\n",
    "#     z_vec = z_vec - z_vec + 1   \n",
    "#     g_vec = alpha + beta * z_vec + cov_weight * (normalized_deg_array + number_of_connected_clusters_array) * z_vec + gamma * (\n",
    "#     r1 * np.matmul(D_inv_A, z_vec) + r2 * np.matmul(multi_hop_A, z_vec)) # is_inner_node_array\n",
    "#     return g_vec\n",
    "        \n",
    "def po_2hop_linear_model_without_covariate(graph, z_vec, alpha=0, beta=1, sigma=0.1, gamma=1, r1=1, r2=0):         \n",
    "    y_vec = alpha + beta * z_vec + sigma * np.random.normal(size=(num_nodes, 1)) + gamma * (\n",
    "    r1 * np.matmul(D_inv_A, z_vec) + r2 * np.matmul(multi_hop_A, z_vec)\n",
    ") \n",
    "    for i in range(num_nodes):\n",
    "        graph.nodes[node_list[i]][\"y\"] = y_vec[i][0]\n",
    "        graph.nodes[node_list[i]][\"z\"] = z_vec[i][0]        "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "class GCN(nn.Module):\n",
    "    def __init__(self, num_inputs=2):\n",
    "        super().__init__()\n",
    "        self.conv1 = dgl.nn.ChebConv(num_inputs, 16, 2)\n",
    "        self.conv2 = dgl.nn.ChebConv(16, 16, 1)\n",
    "        self.conv3 = dgl.nn.ChebConv(16, 1, 1)\n",
    "        \n",
    "    def forward(self, g, features):\n",
    "        x = self.conv1(g, features)\n",
    "        x = self.conv2(g, x)\n",
    "        x = self.conv3(g, x)\n",
    "        return x\n",
    "        "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "TRUE GATE: 3.000\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "fe063a536d004175b8ca60812dce2ad4",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/1000 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/root/miniconda3/envs/myconda/lib/python3.10/site-packages/dgl/nn/pytorch/conv/chebconv.py:108: DGLWarning: lambda_max is not provided, using default value of 2.  Please use dgl.laplacian_lambda_max to compute the eigenvalues.\n",
      "  dgl_warning(\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Treatment Proportion: 0.1\n",
      "Bias\t Hajek: -0.965\t CAE: -0.947\t MII: -0.955\t GNN: -1.491\t PPI: -0.593\n",
      "Std\t Hajek: 0.558\t CAE: 0.342\t MII: 0.242\t GNN: 1.001\t PPI: 0.406\n",
      "MSE\t Hajek: 1.243\t CAE: 1.014\t MII: 0.971\t GNN: 3.225\t PPI: 0.516\n",
      "\n",
      "\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "3ec4386cd83748a0ad1302ec810504b4",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/1000 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Treatment Proportion: 0.3\n",
      "Bias\t Hajek: -0.908\t CAE: -0.948\t MII: -0.961\t GNN: -0.653\t PPI: -0.250\n",
      "Std\t Hajek: 0.420\t CAE: 0.193\t MII: 0.130\t GNN: 0.188\t PPI: 0.159\n",
      "MSE\t Hajek: 1.001\t CAE: 0.937\t MII: 0.940\t GNN: 0.462\t PPI: 0.088\n",
      "\n",
      "\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "84f6f8c7fbf6491eadf1e6b3a609f203",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/1000 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Treatment Proportion: 0.5\n",
      "Bias\t Hajek: -0.874\t CAE: -0.948\t MII: -0.961\t GNN: -0.437\t PPI: -0.168\n",
      "Std\t Hajek: 0.370\t CAE: 0.136\t MII: 0.092\t GNN: 0.106\t PPI: 0.097\n",
      "MSE\t Hajek: 0.901\t CAE: 0.917\t MII: 0.933\t GNN: 0.202\t PPI: 0.038\n",
      "\n",
      "\n",
      "TRUE GATE: 3.978\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "5206a16595914fe8860195e8012be649",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/1000 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Treatment Proportion: 0.1\n",
      "Bias\t Hajek: -1.480\t CAE: -1.555\t MII: -1.471\t GNN: -1.902\t PPI: -1.079\n",
      "Std\t Hajek: 0.575\t CAE: 0.345\t MII: 0.274\t GNN: 0.899\t PPI: 0.389\n",
      "MSE\t Hajek: 2.520\t CAE: 2.538\t MII: 2.239\t GNN: 4.427\t PPI: 1.316\n",
      "\n",
      "\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "38f687d063ae4fc0870ce3a7f950bac6",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/1000 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Treatment Proportion: 0.3\n",
      "Bias\t Hajek: -1.274\t CAE: -1.428\t MII: -1.341\t GNN: -1.150\t PPI: -0.733\n",
      "Std\t Hajek: 0.438\t CAE: 0.195\t MII: 0.154\t GNN: 0.178\t PPI: 0.165\n",
      "MSE\t Hajek: 1.814\t CAE: 2.076\t MII: 1.822\t GNN: 1.355\t PPI: 0.564\n",
      "\n",
      "\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "4e6ddafb2fd048b9907f709e6e77b1d2",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/1000 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Treatment Proportion: 0.5\n",
      "Bias\t Hajek: -1.129\t CAE: -1.293\t MII: -1.239\t GNN: -0.790\t PPI: -0.480\n",
      "Std\t Hajek: 0.377\t CAE: 0.138\t MII: 0.100\t GNN: 0.117\t PPI: 0.101\n",
      "MSE\t Hajek: 1.416\t CAE: 1.692\t MII: 1.544\t GNN: 0.638\t PPI: 0.241\n",
      "\n",
      "\n"
     ]
    }
   ],
   "source": [
    "# r2 = 0\n",
    "sigma = 2\n",
    "repeat_num = 1000\n",
    "\n",
    "\n",
    "for res in [5]:\n",
    "    \n",
    "    # clustering\n",
    "    # generally, we fix the outcome of clustering\n",
    "    clusters = nx_comm.louvain_communities(g, seed = 10, resolution=res)\n",
    "    clusters = sorted(clusters, key = len, reverse=True)\n",
    "    cluster_sizes = list(map(len, clusters))\n",
    "    num_cluster = len(clusters)\n",
    "    \n",
    "    # dict: from node to its cluster\n",
    "    inverse_cluster_dict = {\n",
    "        node: cl for cl in range(num_cluster) for node in clusters[cl]\n",
    "    }\n",
    "    \n",
    "    # dict: from node to its connected cluster\n",
    "    node_to_connected_clusters = {\n",
    "        node: node_connects_cluster(node) for node in range(1, num_nodes + 1)\n",
    "    }\n",
    "    \n",
    "    for i in g.nodes:\n",
    "        g.nodes[i][\"n_cl\"] = len(node_to_connected_clusters[i])\n",
    "        g.nodes[i][\"deg\"] = g.degree[i]\n",
    "    \n",
    "    rev_cluster_sizes = cluster_sizes[::-1]\n",
    "    rev_clusters = clusters[::-1]\n",
    "    rev_cluster_sizes = np.array(rev_cluster_sizes)\n",
    "        \n",
    "    # compute interior nodes    \n",
    "    interior_nodes = []\n",
    "    for i in g.nodes:\n",
    "        interior_flag = 1\n",
    "        cluster_index = inverse_cluster_dict[i]\n",
    "        for nb in g[i]:\n",
    "            if inverse_cluster_dict[nb] != cluster_index:\n",
    "                interior_flag = 0\n",
    "                break       \n",
    "        if interior_flag == 1:\n",
    "            interior_nodes.append(i)\n",
    "        \n",
    "    boundary_nodes = np.array(list(set(range(1,num_nodes)).difference(set(interior_nodes))))    \n",
    "    boundary_nodes_feat_index = [dict_from_graph_to_feat[node] for node in boundary_nodes]\n",
    "    interior_nodes_feat_index = [dict_from_graph_to_feat[node] for node in interior_nodes]    \n",
    "        \n",
    "\n",
    "    \n",
    "    \n",
    "    normalized_deg_array = np.array(degs)/avg_deg\n",
    "    is_inner_node_array = np.zeros(num_nodes)\n",
    "    number_of_connected_clusters_array = np.zeros(num_nodes)\n",
    "    for node in g.nodes:\n",
    "        if node in interior_nodes:\n",
    "            is_inner_node_array[node_list_index_dict[node]] = 1\n",
    "        number_of_connected_clusters_array[node_list_index_dict[node]] = g.nodes[node][\"n_cl\"]\n",
    "    \n",
    "    number_of_connected_clusters_array /= number_of_connected_clusters_array.mean()\n",
    "        \n",
    "    \n",
    "\n",
    "    cov_weight = 0.5\n",
    "    \n",
    "    for r2 in [0, 1]:\n",
    "        # true_GATE_list = []\n",
    "        # calculate true GATE\n",
    "        # for _ in range(100):\n",
    "        #     outcome_vec = po_2hop_linear_model_gt(g, np.ones((num_nodes, 1)), sigma=sigma, r2=r2) # add dimension\n",
    "        #     true_GATE_list.append(outcome_vec.mean())\n",
    "        # true_GATE = sum(true_GATE_list)/repeat_num\n",
    "        if r2 == 1:\n",
    "            true_GATE = 2.978 + cov_weight * 2\n",
    "        elif r2 == 0:\n",
    "            true_GATE = 2 + cov_weight * 2\n",
    "            \n",
    "        print(\"TRUE GATE: {:.3F}\".format(true_GATE))\n",
    "        \n",
    "        \n",
    "        for p2 in [0.1, 0.3, 0.5]: \n",
    "            \n",
    "            bias_CAE_list = []\n",
    "            bias_HT_list = []\n",
    "            bias_MII_list = []\n",
    "            bias_PPI_list = []\n",
    "            bias_GNN_list = []\n",
    "               \n",
    "            global_pred_list = np.zeros((repeat_num,))\n",
    "            boundary_pred_list = np.zeros((repeat_num,))\n",
    "            \n",
    "            p_list = [p2]\n",
    "            # p_list = [0.1, 0.5]\n",
    "                \n",
    "            for seed in tqdm(range(repeat_num)):\n",
    "                np.random.seed(seed)     \n",
    "                rollout_index = np.random.uniform(0, 1, size=(num_cluster,))           \n",
    "                torch.manual_seed(1)\n",
    "                    \n",
    "                # print(f\"Seed: {seed}\") \n",
    "                \n",
    "                g_feat_list = []\n",
    "                y_list = []\n",
    "                \n",
    "                # initialize the HT weights\n",
    "                for i in g.nodes:\n",
    "                    g.nodes[i][\"w_HT\"] = 0\n",
    "                \n",
    "                \n",
    "                \n",
    "                \n",
    "                for p in p_list:\n",
    "                    z_vector = np.zeros((num_nodes)) \n",
    "                    nx.set_node_attributes(g, 0, \"z\")            \n",
    "                    # tr_clusters = np.arange(num_cluster)[rollout_index<p]\n",
    "                    tr_clusters = np.arange(num_cluster)[rollout_index<np.quantile(rollout_index, p)] # Complete randomization\n",
    "                    if len(tr_clusters) > 0:\n",
    "                        tr_units = reduce(lambda x, y: x.union(y), [clusters[i] for i in tr_clusters])              \n",
    "                        nx.set_node_attributes(g, {unit:1 for unit in tr_units}, \"z\")\n",
    "                        for node in tr_units:\n",
    "                            z_vector[node_list_index_dict[node]] = 1\n",
    "                                                \n",
    "                    po_2hop_linear_model(g, z_vector.reshape(-1,1), sigma=sigma, r2=r2, cov_weight=cov_weight) # add dimension            \n",
    "                    \n",
    "                    \n",
    "                    g_feat_list.append(torch.tensor([[g.nodes[n]['z'], g.nodes[n]['deg']] for n in g.nodes], dtype=torch.float))\n",
    "                    y_list.append(torch.tensor([[g.nodes[n]['y']] for n in g.nodes], dtype=torch.float).reshape(-1))\n",
    "        \n",
    "                # Instantiate the model and optimizer\n",
    "                model = GCN()\n",
    "                optimizer = torch.optim.Adam(model.parameters(), lr=0.005)\n",
    "                # optimizer = torch.optim.SGD(model.parameters(), lr=0.01)\n",
    "                \n",
    "                # Train the model\n",
    "                for epoch in range(300):\n",
    "                    for i in range(len(p_list)):\n",
    "                        # model.train()\n",
    "                        optimizer.zero_grad()\n",
    "                        out = model(dgl_G, g_feat_list[i]).squeeze()\n",
    "                        loss = F.mse_loss(out, y_list[i])  \n",
    "                        # loss = F.mse_loss(out[boundary_nodes_feat_index], y_list[i][boundary_nodes_feat_index])  \n",
    "                        loss.backward()\n",
    "                        optimizer.step()\n",
    "                    \n",
    "                        # if epoch % 50 == 0:\n",
    "                        #     print(f'Epoch {epoch}, Treat Prop: {p_list[i]:.2f}, Loss: {loss:.4f}')\n",
    "        \n",
    "                \n",
    "                g_feat = g_feat_list[0].clone()\n",
    "                g_feat[:,0] = 1\n",
    "                global_treat_pred = model(dgl_G, g_feat).detach().numpy()\n",
    "                \n",
    "                # MII\n",
    "                true_Y_interior = []\n",
    "                treated_interior_feat_indices = []\n",
    "                for node in interior_nodes:\n",
    "                    if g.nodes[node][\"z\"]==1:\n",
    "                        true_Y_interior.append(g.nodes[node][\"y\"])\n",
    "                        feat_index = dict_from_graph_to_feat[node]\n",
    "                        treated_interior_feat_indices.append(feat_index)\n",
    "                \n",
    "                MII = np.mean(true_Y_interior)\n",
    "                MII_PPI = MII + global_treat_pred.mean() - global_treat_pred[treated_interior_feat_indices].mean()\n",
    "                \n",
    "                \n",
    "                # CAE & HT\n",
    "                cl_level_Y_interior = []\n",
    "                HT_exposure_weight = []\n",
    "                for cl in range(num_cluster):\n",
    "                    cl_Y = []\n",
    "                    for node in clusters[cl]:\n",
    "                        if g.nodes[node][\"z\"]==1:                    \n",
    "                            node_deg = g.nodes[node]['deg']\n",
    "                            num_treated_ngbr = sum([g.nodes[ngbr]['z'] for ngbr in g[node]])\n",
    "                            # lies in interior or clean boundary node \n",
    "                            if num_treated_ngbr == node_deg:  \n",
    "                                cl_Y.append(g.nodes[node][\"y\"])\n",
    "                                # if \n",
    "                                g.nodes[node]['w_HT'] = (1/p) ** (g.nodes[node]['n_cl'] )                                                       \n",
    "                                                    \n",
    "                    if len(cl_Y) > 0:\n",
    "                        cl_Y = np.array(cl_Y)\n",
    "                        cl_level_Y_interior.append(cl_Y.mean())\n",
    "        \n",
    "                ht_weights = np.array([g.nodes[i][\"w_HT\"] for i in g.nodes])\n",
    "                ht_y_array = np.array([g.nodes[i][\"y\"] for i in g.nodes])\n",
    "                \n",
    "                \n",
    "                cl_level_Y_interior = np.array(cl_level_Y_interior)\n",
    "                \n",
    "                HT = (ht_y_array * ht_weights).sum()/ht_weights.sum()        \n",
    "                CAE = cl_level_Y_interior.mean()                \n",
    "                EST_GNN = global_treat_pred.mean() \n",
    "                \n",
    "                \n",
    "                \n",
    "        \n",
    "                bias_HT_list.append(HT - true_GATE)\n",
    "                bias_CAE_list.append(CAE - true_GATE)                        \n",
    "                bias_MII_list.append(MII - true_GATE)    \n",
    "                bias_GNN_list.append(EST_GNN - true_GATE)\n",
    "                bias_PPI_list.append(MII_PPI - true_GATE)\n",
    "        \n",
    "            print(\"Treatment Proportion: {}\".format(p2))\n",
    "        \n",
    "            HT_array = np.array(bias_HT_list)\n",
    "            CAE_array = np.array(bias_CAE_list)\n",
    "            MII_array = np.array(bias_MII_list)\n",
    "            GNN_array = np.array(bias_GNN_list)\n",
    "            PPI_array = np.array(bias_PPI_list)\n",
    "        \n",
    "            estm_list = [HT_array, CAE_array, MII_array, GNN_array, PPI_array]\n",
    "        \n",
    "            bias_list = list(map(lambda x: x.mean(), estm_list))\n",
    "            std_list = list(map(lambda x: x.std(), estm_list))\n",
    "            mse_list = list(map(lambda x: x.mean()**2 + x.var(), estm_list))\n",
    "        \n",
    "            print(\"Bias\\t Hajek: {:.3f}\\t CAE: {:.3f}\\t MII: {:.3f}\\t GNN: {:.3f}\\t PPI: {:.3f}\".format(*bias_list))    \n",
    "            print(\"Std\\t Hajek: {:.3f}\\t CAE: {:.3f}\\t MII: {:.3f}\\t GNN: {:.3f}\\t PPI: {:.3f}\".format(*std_list))\n",
    "            print(\"MSE\\t Hajek: {:.3f}\\t CAE: {:.3f}\\t MII: {:.3f}\\t GNN: {:.3f}\\t PPI: {:.3f}\".format(*mse_list))\n",
    "            print(\"\\n\")\n",
    "            \n",
    "            torch.save(estm_list, \"result/estm_list_res{}_p{}_cov{}_r2{}.pkl\".format(res, p2, cov_weight, r2))\n",
    "            "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "myconda",
   "language": "python",
   "name": "myconda"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
