{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [],
   "source": [
    "from grover.data import MoleculeDatapoint, MoleculeDataset, StandardScaler\n",
    "import csv\n",
    "from tqdm import tqdm\n",
    "from torch.utils.data import DataLoader\n",
    "from grover.data import MolCollator \n",
    "from argparse import ArgumentParser, Namespace\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [],
   "source": [
    "args = Namespace()\n",
    "args.no_cache = True\n",
    "args.bond_drop_rate = 0\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {},
   "outputs": [],
   "source": [
    "path = \"/home/pj20/gode/data_process/valid_smiles.csv\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {},
   "outputs": [],
   "source": [
    "with open(path) as f:\n",
    "    reader = csv.reader(f)\n",
    "    next(reader)  # skip header\n",
    "\n",
    "    lines = []\n",
    "    for line in reader:\n",
    "        smiles = line[0]\n",
    "        lines.append(line)\n",
    "\n",
    "    data = MoleculeDataset([\n",
    "        MoleculeDatapoint(\n",
    "            line=line,\n",
    "        ) for i, line in tqdm(enumerate(lines), total=len(lines), disable=True)\n",
    "    ])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "metadata": {},
   "outputs": [],
   "source": [
    "mol_collator = MolCollator(shared_dict={}, args=args)\n",
    "train_data = DataLoader(data,\n",
    "                        batch_size=64,\n",
    "                        num_workers=10,\n",
    "                        collate_fn=mol_collator)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "(['CC1COC2=C3N1C=C(C(=O)C3=CC(=C2N4CCN(CC4)C)F)C(=O)O', '[H]Oc1nn([H])c(C([H])([H])[H])c1C([H])([H])[H]', 'C=C(C)C(=O)OCCC(F)(F)C(F)(F)C(F)(F)C(F)(F)C(F)(F)C(F)(F)F', 'C1CCCNCCC1', 'CNC(=O)C[C@@H](N)C(=O)N[C@@H](C(=O)N[C@@H]1C(=O)N2[C@@H](C(=O)O)C(C)(C)S[C@H]12)c1ccc(O)cc1', '[H]C1=C([H])C([H])(C([H])([H])C([H])([H])[H])C([H])([H])C1([H])C([H])([H])[H]', 'N=C1NC(=N)c2ccccc21', '[H]OC([H])([H])c1c([H])oc([H])c1C([H])([H])[H]', 'C1=CC=CC=C1CN2C(CC(C2)CN)=O', 'OC1=CC(Cl)=C(Cl)C(Cl)=C1', '[H]OC1([H])C([H])([H])C([H])=C([H])C([H])([H])C1([H])C([H])([H])[H]', '[H]C1=NC([H])([H])C([H])([H])C(=O)O1', 'C[C@H]1[C@H]2[C@H](C[C@H]3[C@@H]4CC[C@H]5C[C@@H](O)CC[C@]5(C)[C@H]4CC(=O)[C@]23C)O[C@]11CC[C@@H](C)CO1', 'O(C(C)(C)C)C(=O)N[C@@H](CC(C)C)C(=O)N[C@@H](Cc1ccccc1)C(=O)N[C@H]([C@@H](O)C[C@H](C(=O)NCC(C)C)C)CC(C)C', 'CC[C@@H](N)CO', 'COCC(=O)N(C(C)C(=O)OC)C1=C(C)C=CC=C1C', '[H]C(=O)C1(C#N)C([H])([H])C1([H])[H]', 'CCCCCCCCCCCCCCCC', 'CC(=O)O[C@H]1CC[C@@]2(C)C(=CC[C@H]3[C@@H]4CC=C(C(C)=O)[C@@]4(C)CC[C@@H]32)C1', 'O=C(O)CO', 'Nc1nc(N)nc(-c2cc(Cl)ccc2Cl)n1', '[Cl].CCCCOc1ccc(cc1)C(=O)CCN2CCCCC2', 'CC/C=C\\\\CC/C=C/CO', '[H]OC([H])([H])C1(O[H])C([H])([H])OC([H])([H])C1([H])[H]', '[H]C(=O)c1c([H])c(N([H])[H])n([H])c1[H]', '[H]OC([H])([H])C([H])([H])N1C([H])([H])C([H])([H])C1([H])[H]', 'CC1=CC=C(Cl)C=C1', 'OCCC1=CC=CC=C1', 'CC(=O)O[C@H]1CC[C@@]2([C@H]3CC[C@]4([C@H]([C@@H]3CC=C2C1)CC=C4c5cccnc5)C)C', '[H]c1c([H])n([H])c([H])c([H])c1=O', '[H]C([H])([H])C1(C([H])([H])[H])C([H])([H])C([H])([H])C1([H])[H]', '[H]N([H])C(=O)N(C([H])([H])[H])C([H])([H])C([H])([H])[H]', '[H]C#CC([H])([H])C([H])([H])N([H])C(=O)N([H])[H]', 'CCC1C(C(C(C(=O)C(CC(C(C(C(C(C(=O)O1)C)OC2CC(C(C(O2)C)O)(C)OC)C)OC3C(C(CC(O3)C)N(C)C)O)(C)O)C)C)O)(C)O', '[H]C([H])=C([H])C#CC([H])([H])C([H])([H])C([H])([H])[H]', '[H]C(=O)C(C(=O)N([H])[H])(C([H])([H])[H])C([H])([H])[H]', 'OC(=O)COc1ccc(cc1c2ccccc2)c3nccs3', 'CCC[C@@H](C(=O)C(=O)NC1CC1)NC(=O)[C@@H]2[C@H]3CCC[C@H]3CN2C(=O)[C@H](C(C)(C)C)NC(=O)[C@H](C4CCCCC4)NC(=O)c5cnccn5', '[H]N1C([H])([H])C([H])(C1([H])[H])C1([H])C([H])([H])C1([H])[H]', 'CN(C(=O)Cc1ccc(Cl)c(Cl)c1)[C@H]1CC[C@@]2(CCCO2)C[C@@H]1N1CCCC1', 'CN(C)C=NC1=CC=C(Cl)C=C1C', '[H]c1nc([H])n(C([H])([H])C#N)n1', 'CC(C)n1c2ccccc2c3c(C)c(NC(=O)N4CCOCC4)ccc13', '[H]c1nn(C([H])([H])[H])nc1C([H])([H])[H]', 'O=C1CC2=CC=CC=C2C1', 'CC1CCC(CO)CC1', 'CO[C@H](C(=O)[C@@H](O)[C@@H](C)O)C1Cc2cc3cc(O[C@H]4C[C@@H](O[C@H]5C[C@@H](O)[C@H](O)[C@@H](C)O5)[C@@H](O)[C@@H](C)O4)c(C)c(O)c3c(O)c2C(=O)[C@H]1O[C@H]1C[C@@H](O[C@H]2C[C@@H](O[C@H]3C[C@](C)(O)[C@H](O)[C@@H](C)O3)[C@H](O)[C@@H](C)O2)[C@H](O)[C@@H](C)O1', '[H]C1=C([H])S(=O)(=O)N([H])C1([H])[H]', 'Cl.CN1CCCC(CC1)N1N=C(CC2=CC=C(Cl)C=C2)C2=CC=CC=C2C1=O', '[H]Oc1nc(N([H])[H])c([H])nc1[H]', '[H]C(=O)c1onc(C([H])([H])[H])c1[H]', '[H]C1=C([H])C([H])(C(=O)N1[H])C([H])([H])[H]', '[H]C(=O)C(=NC([H])([H])C([H])([H])[H])N([H])C([H])([H])[H]', 'CCOC(=O)NCCOc2ccc(Oc1ccccc1)cc2', 'ClC(Cl)=C(C1=CC=C(Cl)C=C1)C1=CC=C(Cl)C=C1', 'CP(=O)([O-])OCCC[Si](O)(O)O', '[H]OC([H])([H])C([H])(C([H])([H])[H])C([H])([H])C([H])(C([H])([H])[H])C([H])([H])[H]', '[H]OC([H])(C(=O)C([H])(O[H])C([H])([H])[H])C([H])([H])[H]', 'COc1cc(OC)c(cc1NC(=O)CCC(=O)O)S(=O)(=O)N(c2ccc(C)cc2)c3ccc(C)cc3', '[H]C1([H])C(=O)OC2([H])C([H])([H])C([H])([H])C12[H]', 'C[C@]12CC[C@H]3[C@@H](CCC4=CC(=O)C=C[C@]34C)[C@@H]1CCC(=O)O2', 'CC(C)(C)OC(=O)NCCCSC[C@H]1O[C@H]([C@H](O)[C@@H]1O)n2cnc3c(N)ncnc23', 'CCN(CC)C(=O)c1ccc(cc1)C(N2CCNCC2)c3cccc4cccnc34', '[H]C([H])([H])C(=O)C([H])([H])C([H])([H])C([H])([H])[H]'], (tensor([[0., 0., 0.,  ..., 0., 0., 0.],\n",
      "        [0., 0., 0.,  ..., 0., 0., 0.],\n",
      "        [0., 0., 0.,  ..., 1., 0., 0.],\n",
      "        ...,\n",
      "        [0., 0., 0.,  ..., 0., 0., 0.],\n",
      "        [0., 0., 0.,  ..., 0., 0., 0.],\n",
      "        [0., 0., 0.,  ..., 0., 0., 0.]]), tensor([[0., 0., 0.,  ..., 0., 0., 0.],\n",
      "        [0., 0., 0.,  ..., 0., 0., 0.],\n",
      "        [0., 0., 0.,  ..., 0., 0., 0.],\n",
      "        ...,\n",
      "        [0., 0., 0.,  ..., 0., 0., 0.],\n",
      "        [0., 0., 0.,  ..., 0., 0., 0.],\n",
      "        [0., 0., 0.,  ..., 0., 0., 0.]]), tensor([[   0,    0,    0,    0],\n",
      "        [   2,    0,    0,    0],\n",
      "        [   1,    4,    6,    0],\n",
      "        ...,\n",
      "        [2197, 2200,    0,    0],\n",
      "        [2199, 2202,    0,    0],\n",
      "        [2201,    0,    0,    0]]), tensor([   0,    1,    2,  ..., 1056, 1056, 1057]), tensor([   0,    2,    1,  ..., 2199, 2202, 2201]), tensor([[   1,   26],\n",
      "        [  27,    8],\n",
      "        [  35,   27],\n",
      "        [  62,    8],\n",
      "        [  70,   34],\n",
      "        [ 104,    8],\n",
      "        [ 112,   11],\n",
      "        [ 123,    8],\n",
      "        [ 131,   15],\n",
      "        [ 146,   10],\n",
      "        [ 156,    8],\n",
      "        [ 164,    7],\n",
      "        [ 171,   31],\n",
      "        [ 202,   44],\n",
      "        [ 246,    6],\n",
      "        [ 252,   20],\n",
      "        [ 272,    7],\n",
      "        [ 279,   16],\n",
      "        [ 295,   26],\n",
      "        [ 321,    5],\n",
      "        [ 326,   16],\n",
      "        [ 342,   22],\n",
      "        [ 364,   10],\n",
      "        [ 374,    8],\n",
      "        [ 382,    8],\n",
      "        [ 390,    7],\n",
      "        [ 397,    8],\n",
      "        [ 405,    9],\n",
      "        [ 414,   29],\n",
      "        [ 443,    7],\n",
      "        [ 450,    6],\n",
      "        [ 456,    7],\n",
      "        [ 463,    8],\n",
      "        [ 471,   51],\n",
      "        [ 522,    7],\n",
      "        [ 529,    8],\n",
      "        [ 537,   22],\n",
      "        [ 559,   49],\n",
      "        [ 608,    7],\n",
      "        [ 615,   28],\n",
      "        [ 643,   13],\n",
      "        [ 656,    8],\n",
      "        [ 664,   26],\n",
      "        [ 690,    7],\n",
      "        [ 697,   10],\n",
      "        [ 707,    9],\n",
      "        [ 716,   76],\n",
      "        [ 792,    7],\n",
      "        [ 799,   28],\n",
      "        [ 827,    8],\n",
      "        [ 835,    8],\n",
      "        [ 843,    7],\n",
      "        [ 850,    8],\n",
      "        [ 858,   22],\n",
      "        [ 880,   18],\n",
      "        [ 898,   12],\n",
      "        [ 910,    8],\n",
      "        [ 918,    8],\n",
      "        [ 926,   36],\n",
      "        [ 962,    8],\n",
      "        [ 970,   22],\n",
      "        [ 992,   30],\n",
      "        [1022,   30],\n",
      "        [1052,    6]]), tensor([[   1,   58],\n",
      "        [  59,   16],\n",
      "        [  75,   52],\n",
      "        [ 127,   16],\n",
      "        [ 143,   72],\n",
      "        [ 215,   16],\n",
      "        [ 231,   24],\n",
      "        [ 255,   16],\n",
      "        [ 271,   32],\n",
      "        [ 303,   20],\n",
      "        [ 323,   16],\n",
      "        [ 339,   14],\n",
      "        [ 353,   72],\n",
      "        [ 425,   88],\n",
      "        [ 513,   10],\n",
      "        [ 523,   40],\n",
      "        [ 563,   14],\n",
      "        [ 577,   30],\n",
      "        [ 607,   58],\n",
      "        [ 665,    8],\n",
      "        [ 673,   34],\n",
      "        [ 707,   44],\n",
      "        [ 751,   18],\n",
      "        [ 769,   16],\n",
      "        [ 785,   16],\n",
      "        [ 801,   14],\n",
      "        [ 815,   16],\n",
      "        [ 831,   18],\n",
      "        [ 849,   66],\n",
      "        [ 915,   14],\n",
      "        [ 929,   12],\n",
      "        [ 941,   12],\n",
      "        [ 953,   14],\n",
      "        [ 967,  106],\n",
      "        [1073,   12],\n",
      "        [1085,   14],\n",
      "        [1099,   48],\n",
      "        [1147,  106],\n",
      "        [1253,   16],\n",
      "        [1269,   62],\n",
      "        [1331,   26],\n",
      "        [1357,   16],\n",
      "        [1373,   58],\n",
      "        [1431,   14],\n",
      "        [1445,   22],\n",
      "        [1467,   18],\n",
      "        [1485,  166],\n",
      "        [1651,   14],\n",
      "        [1665,   60],\n",
      "        [1725,   16],\n",
      "        [1741,   16],\n",
      "        [1757,   14],\n",
      "        [1771,   14],\n",
      "        [1785,   46],\n",
      "        [1831,   38],\n",
      "        [1869,   22],\n",
      "        [1891,   14],\n",
      "        [1905,   14],\n",
      "        [1919,   76],\n",
      "        [1995,   18],\n",
      "        [2013,   50],\n",
      "        [2063,   64],\n",
      "        [2127,   66],\n",
      "        [2193,   10]]), tensor([[   0,    0,    0,    0],\n",
      "        [   2,    0,    0,    0],\n",
      "        [   1,    3,    7,    0],\n",
      "        ...,\n",
      "        [1053, 1056,    0,    0],\n",
      "        [1055, 1057,    0,    0],\n",
      "        [1056,    0,    0,    0]])), [None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None], tensor([], size=(64, 0)), tensor([], size=(64, 0)))\n"
     ]
    }
   ],
   "source": [
    "for item in train_data:\n",
    "    print(item)\n",
    "    break"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/pj20/miniconda3/envs/kgc/lib/python3.8/site-packages/tqdm/auto.py:22: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
      "  from .autonotebook import tqdm as notebook_tqdm\n"
     ]
    }
   ],
   "source": [
    "import torch\n",
    "\n",
    "path = \"/data/pj20/grover/pretrain/grover_large.pt\"\n",
    "\n",
    "state = torch.load(path, map_location=lambda storage, loc: storage)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "args, loaded_state_dict = state['args'], state['state_dict']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Namespace(activation='PReLU', backbone='dualtrans', bias=False, bond_drop_rate=0, cuda=True, dense=False, depth=6, dist_coff=0.1, dropout=0.0, embedding_output_type='both', fine_tune_coff=1, hidden_size=1200, input_layer='fc', no_attach_fea=True, num_attn_head=4, num_mt_block=1, select_by_loss=True, self_attention=False, skip_epoch=0, undirected=False, weight_decay=2e-07)"
      ]
     },
     "execution_count": 13,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "args"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "from mgnn import build_model\n",
    "\n",
    "m_gnn = build_model(args=args)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "model_state_dict = m_gnn.state_dict()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "pretrained_state_dict = {}\n",
    "for param_name in loaded_state_dict.keys():\n",
    "    new_param_name = param_name\n",
    "    if new_param_name not in model_state_dict:\n",
    "        print(f'Pretrained parameter \"{param_name}\" cannot be found in model parameters.')\n",
    "    elif model_state_dict[new_param_name].shape != loaded_state_dict[param_name].shape:\n",
    "        print(f'Pretrained parameter \"{param_name}\" '\n",
    "                f'of shape {loaded_state_dict[param_name].shape} does not match corresponding '\n",
    "                f'model parameter of shape {model_state_dict[new_param_name].shape}.')\n",
    "    else:\n",
    "        print(f'Loading pretrained parameter \"{param_name}\".')\n",
    "        pretrained_state_dict[new_param_name] = loaded_state_dict[param_name]\n",
    "\n",
    "model_state_dict.update(pretrained_state_dict)\n",
    "m_gnn.load_state_dict(model_state_dict)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3.8.13 ('kgc')",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.13"
  },
  "orig_nbformat": 4,
  "vscode": {
   "interpreter": {
    "hash": "3d0509d9aa81f2882b18eeb72d4d23c32cae9029e9b99f63cde94ba86c35ac78"
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
