{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/yoonhyeok/anaconda3/envs/grit/lib/python3.9/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
      "  from .autonotebook import tqdm as notebook_tqdm\n"
     ]
    }
   ],
   "source": [
    "# Import libraries\n",
    "import networkx as nx\n",
    "import pandas as pd\n",
    "import torch\n",
    "from torch_geometric.data import Data\n",
    "from torch_geometric.utils import from_networkx"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "10000"
      ]
     },
     "execution_count": 2,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# Ckt-Bench101\n",
    "ckt_101_igraph_path = \"CktBench101/ckt_bench_101.pkl\"\n",
    "ckt_101_igraph = pd.read_pickle(ckt_101_igraph_path)\n",
    "ckt_101_igraph = ckt_101_igraph[0] + ckt_101_igraph[1]\n",
    "\n",
    "len(ckt_101_igraph)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "50000"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# Ckt-Bench301\n",
    "ckt_301_igraph_path = \"CktBench301/ckt_bench_301.pkl\"\n",
    "ckt_301_igraph = pd.read_pickle(ckt_301_igraph_path)\n",
    "\n",
    "len(ckt_301_igraph)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Ckt-Bench labels\n",
    "ckt_101_label = pd.read_csv(\"CktBench101/perform101.csv\", index_col=0)\n",
    "ckt_301_label = pd.read_csv(\"CktBench301/perform301.csv\", index_col=0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "ckt_101_valid_index = ckt_101_label[\"valid\"]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "ckt_101_label = ckt_101_label.drop(columns=\"valid\").apply(lambda x: (x - x.min()) / (x.max() - x.min()))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "ckt_101_label[\"valid\"] = ckt_101_valid_index"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>gain</th>\n",
       "      <th>pm</th>\n",
       "      <th>bw</th>\n",
       "      <th>fom</th>\n",
       "      <th>valid</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>0.741644</td>\n",
       "      <td>0.288365</td>\n",
       "      <td>0.391568</td>\n",
       "      <td>0.388257</td>\n",
       "      <td>1</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>0.544696</td>\n",
       "      <td>0.499622</td>\n",
       "      <td>0.003782</td>\n",
       "      <td>0.017046</td>\n",
       "      <td>1</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>0.912797</td>\n",
       "      <td>0.729811</td>\n",
       "      <td>0.547656</td>\n",
       "      <td>0.560851</td>\n",
       "      <td>1</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>0.874374</td>\n",
       "      <td>0.687333</td>\n",
       "      <td>0.014866</td>\n",
       "      <td>0.039415</td>\n",
       "      <td>1</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>0.751872</td>\n",
       "      <td>0.558718</td>\n",
       "      <td>0.473774</td>\n",
       "      <td>0.479860</td>\n",
       "      <td>1</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>...</th>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>9995</th>\n",
       "      <td>0.195525</td>\n",
       "      <td>0.285147</td>\n",
       "      <td>0.006768</td>\n",
       "      <td>0.007099</td>\n",
       "      <td>1</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>9996</th>\n",
       "      <td>0.535768</td>\n",
       "      <td>0.274264</td>\n",
       "      <td>0.670378</td>\n",
       "      <td>0.657082</td>\n",
       "      <td>1</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>9997</th>\n",
       "      <td>0.379575</td>\n",
       "      <td>0.602285</td>\n",
       "      <td>0.006815</td>\n",
       "      <td>0.022485</td>\n",
       "      <td>1</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>9998</th>\n",
       "      <td>0.483916</td>\n",
       "      <td>0.473328</td>\n",
       "      <td>0.194293</td>\n",
       "      <td>0.200923</td>\n",
       "      <td>1</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>9999</th>\n",
       "      <td>0.349384</td>\n",
       "      <td>0.295981</td>\n",
       "      <td>0.004618</td>\n",
       "      <td>0.007118</td>\n",
       "      <td>1</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "<p>10000 rows × 5 columns</p>\n",
       "</div>"
      ],
      "text/plain": [
       "          gain        pm        bw       fom  valid\n",
       "0     0.741644  0.288365  0.391568  0.388257      1\n",
       "1     0.544696  0.499622  0.003782  0.017046      1\n",
       "2     0.912797  0.729811  0.547656  0.560851      1\n",
       "3     0.874374  0.687333  0.014866  0.039415      1\n",
       "4     0.751872  0.558718  0.473774  0.479860      1\n",
       "...        ...       ...       ...       ...    ...\n",
       "9995  0.195525  0.285147  0.006768  0.007099      1\n",
       "9996  0.535768  0.274264  0.670378  0.657082      1\n",
       "9997  0.379575  0.602285  0.006815  0.022485      1\n",
       "9998  0.483916  0.473328  0.194293  0.200923      1\n",
       "9999  0.349384  0.295981  0.004618  0.007118      1\n",
       "\n",
       "[10000 rows x 5 columns]"
      ]
     },
     "execution_count": 9,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "ckt_101_label"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "ckt_301_valid_index = ckt_301_label[\"valid\"]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [],
   "source": [
    "ckt_301_label = ckt_301_label.drop(columns=\"valid\").apply(lambda x: (x - x.min()) / (x.max() - x.min()))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [],
   "source": [
    "ckt_301_label[\"valid\"] = ckt_301_valid_index"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>gain</th>\n",
       "      <th>pm</th>\n",
       "      <th>bw</th>\n",
       "      <th>fom</th>\n",
       "      <th>valid</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>36229</th>\n",
       "      <td>0.278091</td>\n",
       "      <td>0.499986</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.012204</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>44535</th>\n",
       "      <td>0.278091</td>\n",
       "      <td>0.499986</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.012204</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>90358</th>\n",
       "      <td>0.278091</td>\n",
       "      <td>0.499986</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.012204</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>46650</th>\n",
       "      <td>0.278091</td>\n",
       "      <td>0.499986</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.012204</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>30550</th>\n",
       "      <td>0.278091</td>\n",
       "      <td>0.499986</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.012204</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>...</th>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>85197</th>\n",
       "      <td>0.278091</td>\n",
       "      <td>0.499986</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.012204</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>64332</th>\n",
       "      <td>0.278091</td>\n",
       "      <td>0.499986</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.012204</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>45225</th>\n",
       "      <td>0.278091</td>\n",
       "      <td>0.499986</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.012204</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>43605</th>\n",
       "      <td>0.278091</td>\n",
       "      <td>0.499986</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.012204</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>20522</th>\n",
       "      <td>0.278091</td>\n",
       "      <td>0.499986</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.012204</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "<p>2752 rows × 5 columns</p>\n",
       "</div>"
      ],
      "text/plain": [
       "           gain        pm   bw       fom  valid\n",
       "36229  0.278091  0.499986  0.0  0.012204      0\n",
       "44535  0.278091  0.499986  0.0  0.012204      0\n",
       "90358  0.278091  0.499986  0.0  0.012204      0\n",
       "46650  0.278091  0.499986  0.0  0.012204      0\n",
       "30550  0.278091  0.499986  0.0  0.012204      0\n",
       "...         ...       ...  ...       ...    ...\n",
       "85197  0.278091  0.499986  0.0  0.012204      0\n",
       "64332  0.278091  0.499986  0.0  0.012204      0\n",
       "45225  0.278091  0.499986  0.0  0.012204      0\n",
       "43605  0.278091  0.499986  0.0  0.012204      0\n",
       "20522  0.278091  0.499986  0.0  0.012204      0\n",
       "\n",
       "[2752 rows x 5 columns]"
      ]
     },
     "execution_count": 13,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "ckt_301_label[ckt_301_label.valid == 0] # 2752"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [],
   "source": [
    "invalid_idx = ckt_301_label.reset_index()[ckt_301_label.reset_index().valid == 0].index.tolist()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [],
   "source": [
    "ckt_301_igraph_valid = [ckt for idx, ckt in enumerate(ckt_301_igraph) if idx not in invalid_idx]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Converting to format for GraphGPS framework"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [],
   "source": [
    "def convert_to_networks(data):\n",
    "    G = nx.DiGraph()\n",
    "\n",
    "    for i in range(data.num_nodes):\n",
    "        G.add_node(i)\n",
    "\n",
    "    for i, (source, target) in enumerate(data.edge_index.t().tolist()):\n",
    "        G.add_edge(source, target)\n",
    "\n",
    "    return G"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/yoonhyeok/anaconda3/envs/grit/lib/python3.9/site-packages/torch_geometric/data/storage.py:280: UserWarning: Unable to accurately infer 'num_nodes' from the attribute set '{'edge_index'}'. Please explicitly set 'num_nodes' as an attribute of 'data' to suppress this warning\n",
      "  warnings.warn(\n"
     ]
    }
   ],
   "source": [
    "ckt_101_pt = []\n",
    "\n",
    "for idx, ckt in enumerate(ckt_101_igraph):\n",
    "    _, node_ckt = ckt\n",
    "    node_types = torch.tensor(node_ckt.vs[\"type\"], dtype=torch.float)\n",
    "    node_feat = torch.tensor(node_ckt.vs[\"feat\"], dtype=torch.float)\n",
    "    node_features = torch.stack([node_types, node_feat], dim=1)\n",
    "    edge_index = torch.tensor([e.tuple for e in node_ckt.es], dtype=torch.long).t().contiguous()\n",
    "    labels = ckt_101_label.iloc[idx]\n",
    "    fom = torch.tensor([labels.fom], dtype=torch.float)\n",
    "    gain = torch.tensor([labels.gain], dtype=torch.float)\n",
    "    pm = torch.tensor([labels.pm], dtype=torch.float)\n",
    "    bw = torch.tensor([labels.bw], dtype=torch.float)\n",
    "    tmp_data = Data(edge_index=edge_index)\n",
    "    G = convert_to_networks(tmp_data)\n",
    "    edge_betweenness = nx.edge_betweenness_centrality(G, normalized=True)\n",
    "    edge_load_centrality = nx.edge_load_centrality(G)\n",
    "    trophic_differences = nx.trophic_differences(G)\n",
    "    for u, v, d in G.edges(data=True):\n",
    "        d['e'] = [edge_betweenness[(u, v)], edge_load_centrality[(u, v)], trophic_differences[(u, v)]]\n",
    "    tmp_e_data = from_networkx(G)\n",
    "    data = {\"x\":node_features, \"edge_index\":edge_index, \"edge_attr\":tmp_e_data.e, \"y\":fom, \"gain\":gain, \"pm\":pm, \"bw\":bw}\n",
    "    ckt_101_pt.append(data)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "10000"
      ]
     },
     "execution_count": 20,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "len(ckt_101_pt)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "metadata": {},
   "outputs": [],
   "source": [
    "torch.save(ckt_101_pt, \"CktBench101/ckt_bench_101.pt\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/yoonhyeok/anaconda3/envs/grit/lib/python3.9/site-packages/torch_geometric/data/storage.py:280: UserWarning: Unable to accurately infer 'num_nodes' from the attribute set '{'edge_index'}'. Please explicitly set 'num_nodes' as an attribute of 'data' to suppress this warning\n",
      "  warnings.warn(\n"
     ]
    }
   ],
   "source": [
    "ckt_301_pt = []\n",
    "\n",
    "for idx, ckt in enumerate(ckt_301_igraph):\n",
    "    _, node_ckt = ckt\n",
    "    node_types = torch.tensor(node_ckt.vs[\"type\"], dtype=torch.float)\n",
    "    node_feat = torch.tensor(node_ckt.vs[\"feat\"], dtype=torch.float)\n",
    "    node_features = torch.stack([node_types, node_feat], dim=1)\n",
    "    edge_index = torch.tensor([e.tuple for e in node_ckt.es], dtype=torch.long).t().contiguous()\n",
    "    labels = ckt_301_label.iloc[idx]\n",
    "    fom = torch.tensor([labels.fom], dtype=torch.float)\n",
    "    gain = torch.tensor([labels.gain], dtype=torch.float)\n",
    "    pm = torch.tensor([labels.pm], dtype=torch.float)\n",
    "    bw = torch.tensor([labels.bw], dtype=torch.float)\n",
    "    tmp_data = Data(edge_index=edge_index)\n",
    "    G = convert_to_networks(tmp_data)\n",
    "    edge_betweenness = nx.edge_betweenness_centrality(G, normalized=True)\n",
    "    edge_load_centrality = nx.edge_load_centrality(G)\n",
    "    trophic_differences = nx.trophic_differences(G)\n",
    "    for u, v, d in G.edges(data=True):\n",
    "        d['e'] = [edge_betweenness[(u, v)], edge_load_centrality[(u, v)], trophic_differences[(u, v)]]\n",
    "    tmp_e_data = from_networkx(G)\n",
    "    data = {\"x\":node_features, \"edge_index\":edge_index, \"edge_attr\":tmp_e_data.e, \"y\":fom, \"gain\":gain, \"pm\":pm, \"bw\":bw}\n",
    "    ckt_301_pt.append(data)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "50000"
      ]
     },
     "execution_count": 24,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "len(ckt_301_pt)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "metadata": {},
   "outputs": [],
   "source": [
    "ckt_301_pt_save = [ckt for idx, ckt in enumerate(ckt_301_pt) if idx not in invalid_idx]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "47248"
      ]
     },
     "execution_count": 26,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "len(ckt_301_pt_save)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 27,
   "metadata": {},
   "outputs": [],
   "source": [
    "torch.save(ckt_301_pt_save, \"CktBench301/ckt_bench_301.pt\")"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "grit",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.19"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
