{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [],
   "source": [
    "import json\n",
    "import csv\n",
    "\n",
    "with open('/home/pj20/gode/data_process/valid_smiles_ids.json', 'r') as f:\n",
    "    valid_smiles_ids = json.load(f)\n",
    "\n",
    "with open('/home/pj20/gode/data_process/valid_smiles.csv') as f:\n",
    "    reader = csv.reader(f)\n",
    "    next(reader)  # skip header\n",
    "\n",
    "    valid_smiles = []\n",
    "    for line in reader:\n",
    "        smiles = line[0]\n",
    "        valid_smiles.append(line[0])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(41212, 41212)"
      ]
     },
     "execution_count": 16,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "len(valid_smiles_ids), len(valid_smiles)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [],
   "source": [
    "smile2entid = {}\n",
    "\n",
    "for i in range(len(valid_smiles)):\n",
    "    smile2entid[valid_smiles[i]] = valid_smiles_ids[i]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/pj20/miniconda3/envs/kgc/lib/python3.8/site-packages/tqdm/auto.py:22: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
      "  from .autonotebook import tqdm as notebook_tqdm\n"
     ]
    }
   ],
   "source": [
    "import torch\n",
    "import numpy as np\n",
    "import json\n",
    "from grover.model.models import GroverFpGeneration, GroverFinetuneTask, GroverFinetuneKGE, GroverKGNNFinetuneTask, KGNN, MGNN\n",
    "\n",
    "def get_everything(data_path):\n",
    "    # Training Labels\n",
    "    ## Load entity type labels\n",
    "    print('Loading entity type labels...')\n",
    "    ent_type = torch.tensor(np.load(f'{data_path}/ent_type_onehot.npy')) # (num_ent, num_ent_type)\n",
    "\n",
    "    ## Load center molecule motifs\n",
    "    print('Loading center molecule motifs...')\n",
    "    motifs = []\n",
    "    with open(f'{data_path}/id2motifs.json', 'r') as f:\n",
    "        id2motifs = json.load(f)\n",
    "    motif_len = len(id2motifs['0'])\n",
    "    for i in range(len(ent_type)):\n",
    "        if str(i) in id2motifs.keys():\n",
    "            motifs.append(np.array(id2motifs[str(i)]))\n",
    "        else:\n",
    "            motifs.append(np.array([0] * motif_len))\n",
    "\n",
    "    motifs = torch.tensor(np.array(motifs), dtype=torch.long) # (num_ent, motif_len)\n",
    "\n",
    "\n",
    "    # Entire Knowledge Graph (MolKG)\n",
    "    print('Loading entire knowledge graph...')\n",
    "    with open(f'{data_path}/graph.pt', 'rb') as f:\n",
    "        G_tg = torch.load(f)\n",
    "\n",
    "    return ent_type, motifs, G_tg"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 41,
   "metadata": {},
   "outputs": [],
   "source": [
    "KHOP = 3\n",
    "KGE = True\n",
    "HIDDEN_EMB =1200\n",
    "def build_model_kgnn():\n",
    "    print(\"Preparing KGNN data...\")\n",
    "    data_path = '/data/pj20/molkg/pretrain_data'\n",
    "    ent_type, motifs, _ = get_everything(data_path)\n",
    "\n",
    "    kgnn = KGNN(\n",
    "        node_emb=None,\n",
    "        rel_emb=None,\n",
    "        num_nodes=ent_type.shape[0],\n",
    "        num_rels=39,\n",
    "        embedding_dim=512,\n",
    "        hidden_dim=200,\n",
    "        num_motifs=motifs.shape[1],\n",
    "    )\n",
    "\n",
    "    print(\"Loading Pre-trained KGNN ...\")\n",
    "    # kgnn.load_state_dict(torch.load(f'/data/pj20/molkg/kgnn_last_{KHOP}_hops_kge_{KGE}_{HIDDEN_EMB}.pkl', map_location='cuda:0'), strict=False)\n",
    "    kgnn.load_state_dict(torch.load(f'/data/pj20/molkg/kgnn_last_{KHOP}_hops_kge_{KGE}.pkl', map_location='cuda:0'), strict=False)\n",
    "\n",
    "    kgnn = kgnn.cuda()\n",
    "\n",
    "    return kgnn"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 42,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Preparing KGNN data...\n",
      "Loading entity type labels...\n",
      "Loading center molecule motifs...\n",
      "Loading entire knowledge graph...\n",
      "Loading Pre-trained KGNN ...\n"
     ]
    }
   ],
   "source": [
    "kgnn = build_model_kgnn()\n",
    "kgnn.add_embedding()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 43,
   "metadata": {},
   "outputs": [],
   "source": [
    "kgnn_emb = kgnn.node_emb.weight.data.cpu().numpy()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 44,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "184820"
      ]
     },
     "execution_count": 44,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "len(kgnn_emb)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 45,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "512"
      ]
     },
     "execution_count": 45,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "len(kgnn_emb[-2])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 46,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "from tqdm import tqdm\n",
    "\n",
    "def emb_map(task):\n",
    "    with open(f\"./exampledata/finetune/{task}.csv\") as f:\n",
    "        reader = csv.reader(f)\n",
    "        next(reader)  # skip header\n",
    "\n",
    "        smiles = []\n",
    "        for line in reader:\n",
    "            smiles.append(line[0])\n",
    "\n",
    "\n",
    "    pre_feature = features = np.load(f\"./exampledata/finetune/{task}.npz\")['features']\n",
    "    \n",
    "    kgnn_emb_ = []\n",
    "    for i in tqdm(range(len(smiles))):\n",
    "        smile = smiles[i]\n",
    "        if smile in smile2entid.keys():\n",
    "            id_ = smile2entid[smile]\n",
    "            emb = kgnn_emb[id_]\n",
    "            kgnn_emb_.append(emb)\n",
    "        else:\n",
    "            emb = np.zeros(len(kgnn_emb[0]))\n",
    "            kgnn_emb_.append(emb)\n",
    "    \n",
    "    kgnn_emb_ = np.array(kgnn_emb_)\n",
    "\n",
    "    post_feature = np.concatenate((pre_feature, kgnn_emb_), axis=1)\n",
    "\n",
    "    np.save(f\"./exampledata/finetune/{task}_kgnn_3hop.npy\", kgnn_emb_)\n",
    "    # np.save(f\"./exampledata/finetune/{task}_fg_kgnn.npy\", post_feature)\n",
    "\n",
    "    return kgnn_emb_"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 47,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 1513/1513 [00:00<00:00, 515514.37it/s]\n",
      "100%|██████████| 2039/2039 [00:00<00:00, 931306.31it/s]\n",
      "100%|██████████| 1478/1478 [00:00<00:00, 517504.07it/s]\n",
      "100%|██████████| 1128/1128 [00:00<00:00, 963775.70it/s]\n",
      "100%|██████████| 642/642 [00:00<00:00, 1011169.05it/s]\n",
      "100%|██████████| 4200/4200 [00:00<00:00, 1087479.28it/s]\n",
      "100%|██████████| 6830/6830 [00:00<00:00, 960184.22it/s]\n",
      "100%|██████████| 21786/21786 [00:00<00:00, 558946.34it/s]\n",
      "100%|██████████| 1427/1427 [00:00<00:00, 385053.51it/s]\n",
      "100%|██████████| 7831/7831 [00:00<00:00, 958911.47it/s]\n",
      "100%|██████████| 8576/8576 [00:00<00:00, 826961.65it/s]\n"
     ]
    }
   ],
   "source": [
    "tasks = ['bace', 'bbbp', 'clintox', 'esol', 'freesolv', 'lipo', 'qm7', 'qm8', 'sider', 'tox21', 'toxcast']\n",
    "\n",
    "for task in tasks:\n",
    "    emb_map(task=task)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3.8.13 ('kgc')",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.13"
  },
  "orig_nbformat": 4,
  "vscode": {
   "interpreter": {
    "hash": "3d0509d9aa81f2882b18eeb72d4d23c32cae9029e9b99f63cde94ba86c35ac78"
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
