{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "import json\n",
    "import csv\n",
    "\n",
    "with open('/home/pj20/gode/data_process/valid_smiles_ids.json', 'r') as f:\n",
    "    valid_smiles_ids = json.load(f)\n",
    "\n",
    "with open('/home/pj20/gode/data_process/valid_smiles.csv') as f:\n",
    "    reader = csv.reader(f)\n",
    "    next(reader)  # skip header\n",
    "\n",
    "    valid_smiles = []\n",
    "    for line in reader:\n",
    "        smiles = line[0]\n",
    "        valid_smiles.append(line[0])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(41212, 41212)"
      ]
     },
     "execution_count": 10,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "len(valid_smiles_ids), len(valid_smiles)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [],
   "source": [
    "smile2entid = {}\n",
    "\n",
    "for i in range(len(valid_smiles)):\n",
    "    smile2entid[valid_smiles[i]] = valid_smiles_ids[i]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [],
   "source": [
    "import pickle\n",
    "\n",
    "with open('/data/pj20/molkg_kge/transe/entity_embedding_1200.pkl', 'rb') as f:\n",
    "    ent_emb = pickle.load(f)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "184819"
      ]
     },
     "execution_count": 13,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "ent_emb = ent_emb.cpu().detach().numpy()\n",
    "len(ent_emb)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([-0.05708027, -0.01096982, -0.02211162, ...,  0.03796372,\n",
       "       -0.00046855, -0.01706253], dtype=float32)"
      ]
     },
     "execution_count": 14,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "ent_emb[-1]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "from tqdm import tqdm\n",
    "\n",
    "def emb_map(task):\n",
    "    with open(f\"./exampledata/finetune/{task}.csv\") as f:\n",
    "        reader = csv.reader(f)\n",
    "        next(reader)  # skip header\n",
    "\n",
    "        smiles = []\n",
    "        for line in reader:\n",
    "            smiles.append(line[0])\n",
    "\n",
    "\n",
    "    pre_feature = features = np.load(f\"./exampledata/finetune/{task}.npz\")['features']\n",
    "    \n",
    "    kge_emb_ = []\n",
    "    for i in tqdm(range(len(smiles))):\n",
    "        smile = smiles[i]\n",
    "        if smile in smile2entid.keys():\n",
    "            id_ = smile2entid[smile]\n",
    "            emb = ent_emb[id_]\n",
    "            kge_emb_.append(emb)\n",
    "        else:\n",
    "            emb = np.zeros(len(ent_emb[0]))\n",
    "            kge_emb_.append(emb)\n",
    "    \n",
    "    kge_emb_ = np.array(kge_emb_)\n",
    "\n",
    "    post_feature = np.concatenate((pre_feature, kge_emb_), axis=1)\n",
    "\n",
    "    # np.save(f\"./exampledata/finetune/{task}_kge.npy\", kge_emb_)\n",
    "    np.save(f\"./exampledata/finetune/{task}_fg_kge.npy\", post_feature)\n",
    "    \n",
    "    return kge_emb_"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 1513/1513 [00:00<00:00, 210271.11it/s]\n",
      "100%|██████████| 2039/2039 [00:00<00:00, 680148.39it/s]\n",
      "100%|██████████| 1478/1478 [00:00<00:00, 761570.19it/s]\n",
      "100%|██████████| 1128/1128 [00:00<00:00, 1008134.44it/s]\n",
      "100%|██████████| 642/642 [00:00<00:00, 1009652.48it/s]\n",
      "100%|██████████| 4200/4200 [00:00<00:00, 1050389.17it/s]\n",
      "100%|██████████| 6830/6830 [00:00<00:00, 823973.78it/s]\n",
      "100%|██████████| 21786/21786 [00:00<00:00, 513923.31it/s]\n",
      "100%|██████████| 1427/1427 [00:00<00:00, 465634.96it/s]\n",
      "100%|██████████| 7831/7831 [00:00<00:00, 743466.23it/s]\n",
      "100%|██████████| 8576/8576 [00:00<00:00, 623208.55it/s]\n"
     ]
    }
   ],
   "source": [
    "tasks = ['bace', 'bbbp', 'clintox', 'esol', 'freesolv', 'lipo', 'qm7', 'qm8', 'sider', 'tox21', 'toxcast']\n",
    "\n",
    "for task in tasks:\n",
    "    emb_map(task=task)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 53,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "1400"
      ]
     },
     "execution_count": 53,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "len(np.concatenate((np.load('./bace.npz')['features'], np.load('./bace_kge.npy')), axis=1)[0])\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3.8.13 ('kgc')",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.13"
  },
  "orig_nbformat": 4,
  "vscode": {
   "interpreter": {
    "hash": "3d0509d9aa81f2882b18eeb72d4d23c32cae9029e9b99f63cde94ba86c35ac78"
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
