{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [],
   "source": [
    "import sys, os\n",
    "module_paths = [os.path.abspath('.'), os.path.abspath('..')]\n",
    "for module_path in module_paths:\n",
    "    if module_path not in sys.path:\n",
    "        sys.path.append(module_path)\n",
    "\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "import scanpy as sc\n",
    "import networkx as nx\n",
    "\n",
    "import torch\n",
    "import torch_geometric as pyg"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {},
   "outputs": [],
   "source": [
    "dataset = \"L008\" #marson, sciplex\n",
    "\n",
    "# 1. initial\n",
    "graph = \"grn.pkl\"\n",
    "dimension = 8"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "'''\n",
    "# 2. updated\n",
    "graph = \"%s_grn.pkl\" % dataset\n",
    "dimension = 128\n",
    "'''"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>mito</th>\n",
       "      <th>n_cells</th>\n",
       "      <th>highly_variable</th>\n",
       "      <th>means</th>\n",
       "      <th>dispersions</th>\n",
       "      <th>dispersions_norm</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>PERM1</th>\n",
       "      <td>False</td>\n",
       "      <td>48</td>\n",
       "      <td>True</td>\n",
       "      <td>0.000282</td>\n",
       "      <td>1.064582</td>\n",
       "      <td>2.723460</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>AL645608.7</th>\n",
       "      <td>False</td>\n",
       "      <td>8</td>\n",
       "      <td>True</td>\n",
       "      <td>0.000048</td>\n",
       "      <td>1.515906</td>\n",
       "      <td>4.124676</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>TTLL10</th>\n",
       "      <td>False</td>\n",
       "      <td>136</td>\n",
       "      <td>True</td>\n",
       "      <td>0.000833</td>\n",
       "      <td>0.976748</td>\n",
       "      <td>2.320261</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>TNFRSF18</th>\n",
       "      <td>False</td>\n",
       "      <td>86547</td>\n",
       "      <td>True</td>\n",
       "      <td>0.938895</td>\n",
       "      <td>0.915269</td>\n",
       "      <td>5.125483</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>TNFRSF4</th>\n",
       "      <td>False</td>\n",
       "      <td>88190</td>\n",
       "      <td>True</td>\n",
       "      <td>1.004453</td>\n",
       "      <td>0.994387</td>\n",
       "      <td>5.840080</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>...</th>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>RPS4Y1</th>\n",
       "      <td>False</td>\n",
       "      <td>37645</td>\n",
       "      <td>True</td>\n",
       "      <td>0.490005</td>\n",
       "      <td>1.509542</td>\n",
       "      <td>20.917027</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>PRKY</th>\n",
       "      <td>False</td>\n",
       "      <td>16216</td>\n",
       "      <td>True</td>\n",
       "      <td>0.097039</td>\n",
       "      <td>0.925138</td>\n",
       "      <td>4.583006</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>USP9Y</th>\n",
       "      <td>False</td>\n",
       "      <td>21406</td>\n",
       "      <td>True</td>\n",
       "      <td>0.131560</td>\n",
       "      <td>0.896381</td>\n",
       "      <td>5.459471</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>DDX3Y</th>\n",
       "      <td>False</td>\n",
       "      <td>27168</td>\n",
       "      <td>True</td>\n",
       "      <td>0.205420</td>\n",
       "      <td>1.001347</td>\n",
       "      <td>9.897668</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>EIF1AY</th>\n",
       "      <td>False</td>\n",
       "      <td>32750</td>\n",
       "      <td>True</td>\n",
       "      <td>0.301572</td>\n",
       "      <td>1.123131</td>\n",
       "      <td>13.103132</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "<p>2048 rows × 6 columns</p>\n",
       "</div>"
      ],
      "text/plain": [
       "             mito  n_cells  highly_variable     means  dispersions  \\\n",
       "PERM1       False       48             True  0.000282     1.064582   \n",
       "AL645608.7  False        8             True  0.000048     1.515906   \n",
       "TTLL10      False      136             True  0.000833     0.976748   \n",
       "TNFRSF18    False    86547             True  0.938895     0.915269   \n",
       "TNFRSF4     False    88190             True  1.004453     0.994387   \n",
       "...           ...      ...              ...       ...          ...   \n",
       "RPS4Y1      False    37645             True  0.490005     1.509542   \n",
       "PRKY        False    16216             True  0.097039     0.925138   \n",
       "USP9Y       False    21406             True  0.131560     0.896381   \n",
       "DDX3Y       False    27168             True  0.205420     1.001347   \n",
       "EIF1AY      False    32750             True  0.301572     1.123131   \n",
       "\n",
       "            dispersions_norm  \n",
       "PERM1               2.723460  \n",
       "AL645608.7          4.124676  \n",
       "TTLL10              2.320261  \n",
       "TNFRSF18            5.125483  \n",
       "TNFRSF4             5.840080  \n",
       "...                      ...  \n",
       "RPS4Y1             20.917027  \n",
       "PRKY                4.583006  \n",
       "USP9Y               5.459471  \n",
       "DDX3Y               9.897668  \n",
       "EIF1AY             13.103132  \n",
       "\n",
       "[2048 rows x 6 columns]"
      ]
     },
     "execution_count": 20,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "adata = sc.read('../datasets/%s_prepped.h5ad' % dataset)\n",
    "adata.var"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array(['PERM1', 'AL645608.7', 'TTLL10', ..., 'USP9Y', 'DDX3Y', 'EIF1AY'],\n",
       "      dtype=object)"
      ]
     },
     "execution_count": 21,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "nodes = np.array(adata.var.index)\n",
    "nodes"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "metadata": {},
   "outputs": [],
   "source": [
    "source_path = \"./node/source/%s\" % dataset\n",
    "export_path = \"./node/export/%s\" % dataset\n",
    "os.makedirs(source_path, exist_ok=True)\n",
    "os.makedirs(export_path, exist_ok=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "metadata": {},
   "outputs": [],
   "source": [
    "np.savetxt(os.path.join(source_path, \"genes.txt\"),\n",
    "    np.expand_dims(nodes, 1), fmt=\"%s\"\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "metadata": {},
   "outputs": [],
   "source": [
    "from gvci.utils.general_utils import Suppressor\n",
    "from node.gene2vec import gene2vec\n",
    "\n",
    "with Suppressor():\n",
    "    gene2vec(os.path.abspath(source_path), os.path.abspath(export_path), 'txt', dimension=dimension)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>1</th>\n",
       "      <th>2</th>\n",
       "      <th>3</th>\n",
       "      <th>4</th>\n",
       "      <th>5</th>\n",
       "      <th>6</th>\n",
       "      <th>7</th>\n",
       "      <th>8</th>\n",
       "      <th>9</th>\n",
       "      <th>10</th>\n",
       "      <th>...</th>\n",
       "      <th>119</th>\n",
       "      <th>120</th>\n",
       "      <th>121</th>\n",
       "      <th>122</th>\n",
       "      <th>123</th>\n",
       "      <th>124</th>\n",
       "      <th>125</th>\n",
       "      <th>126</th>\n",
       "      <th>127</th>\n",
       "      <th>128</th>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>LINC00858</th>\n",
       "      <td>0.003542</td>\n",
       "      <td>-0.001215</td>\n",
       "      <td>0.001945</td>\n",
       "      <td>0.001278</td>\n",
       "      <td>-0.001829</td>\n",
       "      <td>0.001718</td>\n",
       "      <td>0.000143</td>\n",
       "      <td>-0.001030</td>\n",
       "      <td>-0.002358</td>\n",
       "      <td>0.002222</td>\n",
       "      <td>...</td>\n",
       "      <td>0.000305</td>\n",
       "      <td>0.001580</td>\n",
       "      <td>-0.002127</td>\n",
       "      <td>-0.002126</td>\n",
       "      <td>0.000488</td>\n",
       "      <td>0.002571</td>\n",
       "      <td>0.003811</td>\n",
       "      <td>-0.000152</td>\n",
       "      <td>-0.002099</td>\n",
       "      <td>0.000958</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>TRAV2</th>\n",
       "      <td>0.002181</td>\n",
       "      <td>-0.002596</td>\n",
       "      <td>0.001249</td>\n",
       "      <td>0.002189</td>\n",
       "      <td>0.001884</td>\n",
       "      <td>-0.001652</td>\n",
       "      <td>0.000192</td>\n",
       "      <td>0.002426</td>\n",
       "      <td>-0.000076</td>\n",
       "      <td>0.003001</td>\n",
       "      <td>...</td>\n",
       "      <td>0.000096</td>\n",
       "      <td>-0.000737</td>\n",
       "      <td>-0.000560</td>\n",
       "      <td>0.002024</td>\n",
       "      <td>-0.001201</td>\n",
       "      <td>-0.002491</td>\n",
       "      <td>-0.001386</td>\n",
       "      <td>-0.001832</td>\n",
       "      <td>-0.002688</td>\n",
       "      <td>0.002967</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>AC078883.2</th>\n",
       "      <td>-0.003033</td>\n",
       "      <td>-0.001809</td>\n",
       "      <td>0.001240</td>\n",
       "      <td>0.003574</td>\n",
       "      <td>-0.001829</td>\n",
       "      <td>0.001150</td>\n",
       "      <td>-0.003513</td>\n",
       "      <td>0.003745</td>\n",
       "      <td>-0.000522</td>\n",
       "      <td>-0.001208</td>\n",
       "      <td>...</td>\n",
       "      <td>0.002082</td>\n",
       "      <td>-0.002157</td>\n",
       "      <td>-0.001686</td>\n",
       "      <td>0.001767</td>\n",
       "      <td>0.002506</td>\n",
       "      <td>-0.003903</td>\n",
       "      <td>0.003278</td>\n",
       "      <td>0.000143</td>\n",
       "      <td>0.000004</td>\n",
       "      <td>0.001928</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>TRBV20-1</th>\n",
       "      <td>0.001685</td>\n",
       "      <td>0.003851</td>\n",
       "      <td>-0.003494</td>\n",
       "      <td>0.000167</td>\n",
       "      <td>0.003281</td>\n",
       "      <td>-0.002159</td>\n",
       "      <td>0.000452</td>\n",
       "      <td>0.002975</td>\n",
       "      <td>0.003568</td>\n",
       "      <td>-0.002158</td>\n",
       "      <td>...</td>\n",
       "      <td>0.000272</td>\n",
       "      <td>0.003697</td>\n",
       "      <td>0.003641</td>\n",
       "      <td>-0.000085</td>\n",
       "      <td>0.001998</td>\n",
       "      <td>-0.000632</td>\n",
       "      <td>0.001540</td>\n",
       "      <td>-0.000766</td>\n",
       "      <td>0.000373</td>\n",
       "      <td>-0.003807</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>AL356433.1</th>\n",
       "      <td>0.000107</td>\n",
       "      <td>0.002853</td>\n",
       "      <td>-0.001400</td>\n",
       "      <td>0.002992</td>\n",
       "      <td>0.001033</td>\n",
       "      <td>-0.001295</td>\n",
       "      <td>-0.001350</td>\n",
       "      <td>-0.000953</td>\n",
       "      <td>-0.001533</td>\n",
       "      <td>-0.003903</td>\n",
       "      <td>...</td>\n",
       "      <td>0.002857</td>\n",
       "      <td>-0.001940</td>\n",
       "      <td>0.000215</td>\n",
       "      <td>0.000149</td>\n",
       "      <td>-0.001111</td>\n",
       "      <td>-0.003449</td>\n",
       "      <td>0.000961</td>\n",
       "      <td>-0.002117</td>\n",
       "      <td>0.001884</td>\n",
       "      <td>-0.003056</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>...</th>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>ACTBL2</th>\n",
       "      <td>0.002277</td>\n",
       "      <td>-0.003387</td>\n",
       "      <td>0.002414</td>\n",
       "      <td>-0.000591</td>\n",
       "      <td>-0.000002</td>\n",
       "      <td>0.002597</td>\n",
       "      <td>-0.002413</td>\n",
       "      <td>-0.003434</td>\n",
       "      <td>0.000297</td>\n",
       "      <td>-0.003735</td>\n",
       "      <td>...</td>\n",
       "      <td>-0.002573</td>\n",
       "      <td>-0.003697</td>\n",
       "      <td>-0.000489</td>\n",
       "      <td>0.002347</td>\n",
       "      <td>0.002833</td>\n",
       "      <td>-0.000664</td>\n",
       "      <td>0.001193</td>\n",
       "      <td>0.003360</td>\n",
       "      <td>-0.002272</td>\n",
       "      <td>-0.003755</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>AL365184.1</th>\n",
       "      <td>0.000186</td>\n",
       "      <td>-0.003339</td>\n",
       "      <td>0.002609</td>\n",
       "      <td>0.001022</td>\n",
       "      <td>0.001954</td>\n",
       "      <td>-0.003444</td>\n",
       "      <td>0.000040</td>\n",
       "      <td>-0.003451</td>\n",
       "      <td>0.000744</td>\n",
       "      <td>-0.003660</td>\n",
       "      <td>...</td>\n",
       "      <td>-0.000292</td>\n",
       "      <td>-0.002637</td>\n",
       "      <td>-0.002500</td>\n",
       "      <td>-0.001029</td>\n",
       "      <td>0.003164</td>\n",
       "      <td>-0.001360</td>\n",
       "      <td>-0.001486</td>\n",
       "      <td>0.003028</td>\n",
       "      <td>0.000205</td>\n",
       "      <td>0.000859</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>F7</th>\n",
       "      <td>-0.003686</td>\n",
       "      <td>-0.002746</td>\n",
       "      <td>-0.001082</td>\n",
       "      <td>0.000925</td>\n",
       "      <td>-0.001716</td>\n",
       "      <td>-0.003214</td>\n",
       "      <td>-0.000732</td>\n",
       "      <td>-0.000915</td>\n",
       "      <td>-0.002174</td>\n",
       "      <td>0.000795</td>\n",
       "      <td>...</td>\n",
       "      <td>-0.000083</td>\n",
       "      <td>0.003070</td>\n",
       "      <td>0.001057</td>\n",
       "      <td>-0.003901</td>\n",
       "      <td>-0.003571</td>\n",
       "      <td>0.001401</td>\n",
       "      <td>-0.002979</td>\n",
       "      <td>0.003123</td>\n",
       "      <td>-0.003853</td>\n",
       "      <td>-0.001550</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>LAYN</th>\n",
       "      <td>-0.003777</td>\n",
       "      <td>-0.003208</td>\n",
       "      <td>-0.000018</td>\n",
       "      <td>-0.003389</td>\n",
       "      <td>0.001076</td>\n",
       "      <td>0.000279</td>\n",
       "      <td>-0.000475</td>\n",
       "      <td>0.000826</td>\n",
       "      <td>0.001353</td>\n",
       "      <td>-0.000708</td>\n",
       "      <td>...</td>\n",
       "      <td>-0.000322</td>\n",
       "      <td>-0.002708</td>\n",
       "      <td>-0.003624</td>\n",
       "      <td>0.003009</td>\n",
       "      <td>-0.001911</td>\n",
       "      <td>0.000925</td>\n",
       "      <td>0.001040</td>\n",
       "      <td>-0.003085</td>\n",
       "      <td>0.000399</td>\n",
       "      <td>-0.000879</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>CYMP-AS1</th>\n",
       "      <td>0.003671</td>\n",
       "      <td>-0.000911</td>\n",
       "      <td>-0.000822</td>\n",
       "      <td>-0.000258</td>\n",
       "      <td>0.001072</td>\n",
       "      <td>0.000860</td>\n",
       "      <td>-0.002000</td>\n",
       "      <td>0.001827</td>\n",
       "      <td>-0.001006</td>\n",
       "      <td>-0.000821</td>\n",
       "      <td>...</td>\n",
       "      <td>-0.002661</td>\n",
       "      <td>0.003220</td>\n",
       "      <td>0.003562</td>\n",
       "      <td>-0.000228</td>\n",
       "      <td>-0.001937</td>\n",
       "      <td>0.003155</td>\n",
       "      <td>-0.000226</td>\n",
       "      <td>-0.002597</td>\n",
       "      <td>0.003290</td>\n",
       "      <td>0.001762</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "<p>2048 rows × 128 columns</p>\n",
       "</div>"
      ],
      "text/plain": [
       "                 1         2         3         4         5         6    \\\n",
       "0                                                                        \n",
       "LINC00858   0.003542 -0.001215  0.001945  0.001278 -0.001829  0.001718   \n",
       "TRAV2       0.002181 -0.002596  0.001249  0.002189  0.001884 -0.001652   \n",
       "AC078883.2 -0.003033 -0.001809  0.001240  0.003574 -0.001829  0.001150   \n",
       "TRBV20-1    0.001685  0.003851 -0.003494  0.000167  0.003281 -0.002159   \n",
       "AL356433.1  0.000107  0.002853 -0.001400  0.002992  0.001033 -0.001295   \n",
       "...              ...       ...       ...       ...       ...       ...   \n",
       "ACTBL2      0.002277 -0.003387  0.002414 -0.000591 -0.000002  0.002597   \n",
       "AL365184.1  0.000186 -0.003339  0.002609  0.001022  0.001954 -0.003444   \n",
       "F7         -0.003686 -0.002746 -0.001082  0.000925 -0.001716 -0.003214   \n",
       "LAYN       -0.003777 -0.003208 -0.000018 -0.003389  0.001076  0.000279   \n",
       "CYMP-AS1    0.003671 -0.000911 -0.000822 -0.000258  0.001072  0.000860   \n",
       "\n",
       "                 7         8         9         10   ...       119       120  \\\n",
       "0                                                   ...                       \n",
       "LINC00858   0.000143 -0.001030 -0.002358  0.002222  ...  0.000305  0.001580   \n",
       "TRAV2       0.000192  0.002426 -0.000076  0.003001  ...  0.000096 -0.000737   \n",
       "AC078883.2 -0.003513  0.003745 -0.000522 -0.001208  ...  0.002082 -0.002157   \n",
       "TRBV20-1    0.000452  0.002975  0.003568 -0.002158  ...  0.000272  0.003697   \n",
       "AL356433.1 -0.001350 -0.000953 -0.001533 -0.003903  ...  0.002857 -0.001940   \n",
       "...              ...       ...       ...       ...  ...       ...       ...   \n",
       "ACTBL2     -0.002413 -0.003434  0.000297 -0.003735  ... -0.002573 -0.003697   \n",
       "AL365184.1  0.000040 -0.003451  0.000744 -0.003660  ... -0.000292 -0.002637   \n",
       "F7         -0.000732 -0.000915 -0.002174  0.000795  ... -0.000083  0.003070   \n",
       "LAYN       -0.000475  0.000826  0.001353 -0.000708  ... -0.000322 -0.002708   \n",
       "CYMP-AS1   -0.002000  0.001827 -0.001006 -0.000821  ... -0.002661  0.003220   \n",
       "\n",
       "                 121       122       123       124       125       126  \\\n",
       "0                                                                        \n",
       "LINC00858  -0.002127 -0.002126  0.000488  0.002571  0.003811 -0.000152   \n",
       "TRAV2      -0.000560  0.002024 -0.001201 -0.002491 -0.001386 -0.001832   \n",
       "AC078883.2 -0.001686  0.001767  0.002506 -0.003903  0.003278  0.000143   \n",
       "TRBV20-1    0.003641 -0.000085  0.001998 -0.000632  0.001540 -0.000766   \n",
       "AL356433.1  0.000215  0.000149 -0.001111 -0.003449  0.000961 -0.002117   \n",
       "...              ...       ...       ...       ...       ...       ...   \n",
       "ACTBL2     -0.000489  0.002347  0.002833 -0.000664  0.001193  0.003360   \n",
       "AL365184.1 -0.002500 -0.001029  0.003164 -0.001360 -0.001486  0.003028   \n",
       "F7          0.001057 -0.003901 -0.003571  0.001401 -0.002979  0.003123   \n",
       "LAYN       -0.003624  0.003009 -0.001911  0.000925  0.001040 -0.003085   \n",
       "CYMP-AS1    0.003562 -0.000228 -0.001937  0.003155 -0.000226 -0.002597   \n",
       "\n",
       "                 127       128  \n",
       "0                               \n",
       "LINC00858  -0.002099  0.000958  \n",
       "TRAV2      -0.002688  0.002967  \n",
       "AC078883.2  0.000004  0.001928  \n",
       "TRBV20-1    0.000373 -0.003807  \n",
       "AL356433.1  0.001884 -0.003056  \n",
       "...              ...       ...  \n",
       "ACTBL2     -0.002272 -0.003755  \n",
       "AL365184.1  0.000205  0.000859  \n",
       "F7         -0.003853 -0.001550  \n",
       "LAYN        0.000399 -0.000879  \n",
       "CYMP-AS1    0.003290  0.001762  \n",
       "\n",
       "[2048 rows x 128 columns]"
      ]
     },
     "execution_count": 25,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "nodes_emb = pd.read_csv(os.path.join(export_path, \"gene2vec_dim_%s_iter_10.txt\") % dimension, sep=\" \", index_col=0, header=None)\n",
    "nodes_emb"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "metadata": {},
   "outputs": [],
   "source": [
    "nodes_emb_dict = nodes_emb.T.to_dict('list')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 27,
   "metadata": {},
   "outputs": [],
   "source": [
    "import pickle\n",
    "'''\n",
    "from src.utils.graph_utils import parse_grn\n",
    "\n",
    "graph_df = pd.read_parquet('hg19_TFinfo_dataframe_gimmemotifsv5_fpr2_threshold_10_20210630.parquet')\n",
    "nx_graph = parse_grn(graph_df, 'gene_short_name')\n",
    "\n",
    "pickle.dump(nx_graph, open('./grn.pkl', 'wb'))\n",
    "'''\n",
    "with open(graph, 'rb') as f:\n",
    "    nx_graph = pickle.load(f)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 28,
   "metadata": {},
   "outputs": [],
   "source": [
    "for n in nodes:\n",
    "    if not nx_graph.has_node(n):\n",
    "        nx_graph.add_node(n)\n",
    "def filter_node(n):\n",
    "    return n in nodes\n",
    "nx_graph = nx.subgraph_view(nx_graph, filter_node)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 29,
   "metadata": {},
   "outputs": [],
   "source": [
    "nodes_dict = dict(zip(nodes, range(len(nodes))))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 30,
   "metadata": {},
   "outputs": [],
   "source": [
    "x = []\n",
    "edge_index = []\n",
    "for i, node in enumerate(nodes):\n",
    "    x.append(nodes_emb_dict[node])\n",
    "\n",
    "    edges = list(nx_graph.in_edges(node))\n",
    "    edges = [(nodes_dict[n[0]], i) for n in edges]\n",
    "    edge_index.extend(edges)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 31,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "torch.Size([2048, 128])"
      ]
     },
     "execution_count": 31,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "x = torch.Tensor(x)\n",
    "x.size()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 32,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "torch.Size([25288, 2])"
      ]
     },
     "execution_count": 32,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "edge_index = torch.LongTensor(edge_index)\n",
    "edge_index.size()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 33,
   "metadata": {},
   "outputs": [],
   "source": [
    "grn = pyg.data.Data(x, edge_index.t())\n",
    "torch.save(grn, '%s_grn_%s.pth' % (dataset, dimension))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 34,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Data(x=[2048, 128], edge_index=[2, 25288])"
      ]
     },
     "execution_count": 34,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "torch.load('%s_grn_%s.pth' % (dataset, dimension))"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3.9.0 ('gvci-env')",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.0"
  },
  "orig_nbformat": 4,
  "vscode": {
   "interpreter": {
    "hash": "a18dafcc48613bb6c3783e5aa2cfbbab56400df7d9096bc06931f344217b23cf"
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
