{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## ProToken 1.0 Multimer Example\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "# ==============================================================================\n",
    "# Copyright 2024 Changping Laboratory & Peking University. All Rights Reserved.\n",
    "# Licensed under the Apache License, Version 2.0 (the “License”);\n",
    "# you may not use this file except in compliance with the License.\n",
    "# You may obtain a copy of the License at\n",
    "#     http://www.apache.org/licenses/LICENSE-2.0\n",
    "# Unless required by applicable law or agreed to in writing, software\n",
    "# distributed under the License is distributed on an “AS IS” BASIS,\n",
    "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
    "# See the License for the specific language governing permissions and limitations under the License.\n",
    "# =============================================================================="
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Load basic libraries"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "2024-04-23 21:18:38.715040: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Could not find TensorRT\n"
     ]
    }
   ],
   "source": [
    "import os, jax\n",
    "import pickle as pkl\n",
    "import numpy as np\n",
    "import tensorflow as tf\n",
    "from data_process.preprocess import save_pdb_from_aux, protoken_encoder_preprocess, protoken_decoder_preprocess, init_protoken_model\n",
    "from data_process.preprocess import protoken_encoder_input_features, protoken_decoder_input_features"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### B. Multimer Structures Encoding and Decoding."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### 1. Prepare the task information"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "# single chain example\n",
    "\n",
    "task_mode = 'multi' # 'single' or 'multi'\n",
    "task_name = 'multimer_example'\n",
    "pdb_input_dir = './examples/multimer'\n",
    "\n",
    "saving_dir = f'./results/{task_name}'\n",
    "os.makedirs(saving_dir, exist_ok=True)\n",
    "\n",
    "pdb_saving_path = os.path.join(saving_dir, 'reconstructed_protein.pdb')\n",
    "code_saving_path = os.path.join(saving_dir, 'protoken_index.pkl')\n",
    "\n",
    "# Notes:\n",
    "# We have 3 models for different sequence lengths range from 0-512, 512-1024, 1024-2048\n",
    "# You can choose the model based on the sequence length of your protein,\n",
    "# Once the sequence length is beyond the current model's range, you need to reinitialize the model.\n",
    "# Have fun!"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### 2. Prepare the encoder inputs"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Found 2 pdb files in the input directory.\n",
      "WARNING: The input pdb files should be single-chain pdb files.\n",
      "seq_mask (1024,)\n",
      "residue_index (1024,)\n",
      "backbone_atom_masks (1024, 37)\n",
      "backbone_atom_positions (1024, 37, 3)\n",
      "ca_pos (1024, 3)\n",
      "backbone_affine_tensor (1024, 7)\n",
      "torsion_angles_sin_cos (1024, 6)\n",
      "torsion_angles_mask (1024, 3)\n"
     ]
    }
   ],
   "source": [
    "encoder_inputs, encoder_aux, seq_len = protoken_encoder_preprocess(pdb_input_dir, task_mode=task_mode)\n",
    "for k, v in zip(protoken_encoder_input_features, encoder_inputs):\n",
    "    print(k, v.shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{0: {'pdb_name': '7W51_B.pdb', 'seq_len': 154, 'start_idx': 0},\n",
       " 1: {'pdb_name': '7W51_A.pdb', 'seq_len': 361, 'start_idx': 154}}"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# multimer auxiliary information\n",
    "encoder_aux['chain_length_info']"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### 3. Warmup the encoder and decoder"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "2024-04-23 21:19:30.098989: W tensorflow/core/common_runtime/gpu/gpu_device.cc:2251] Cannot dlopen some GPU libraries. Please make sure the missing libraries mentioned above are installed properly if you would like to use GPU. Follow the guide at https://www.tensorflow.org/install/gpu for how to download and setup the required libraries for your platform.\n",
      "Skipping registering GPU devices...\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Did not find GPU, will use CPU for prediction\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "WARNING:absl:Importing a function (__inference_internal_grad_fn_113922) with ops with unsaved custom gradients. Will likely fail if a gradient is requested.\n",
      "WARNING:absl:Importing a function (__inference_internal_grad_fn_195207) with ops with unsaved custom gradients. Will likely fail if a gradient is requested.\n",
      "2024-04-23 21:19:53.781014: W external/local_tsl/tsl/framework/cpu_allocator_impl.cc:83] Allocation of 4294967296 exceeds 10% of free system memory.\n",
      "2024-04-23 21:19:54.894360: W external/local_tsl/tsl/framework/cpu_allocator_impl.cc:83] Allocation of 4294967296 exceeds 10% of free system memory.\n",
      "2024-04-23 21:19:57.575232: W external/local_tsl/tsl/framework/cpu_allocator_impl.cc:83] Allocation of 4294967296 exceeds 10% of free system memory.\n",
      "2024-04-23 21:19:58.284182: W external/local_tsl/tsl/framework/cpu_allocator_impl.cc:83] Allocation of 2147483648 exceeds 10% of free system memory.\n",
      "2024-04-23 21:19:58.284224: W external/local_tsl/tsl/framework/cpu_allocator_impl.cc:83] Allocation of 2147483648 exceeds 10% of free system memory.\n"
     ]
    }
   ],
   "source": [
    "model = init_protoken_model(seq_len)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### 4. Encode the protein structure and get the ProToken Index"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "encoder_results = model.encoder(*encoder_inputs)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "PDB ID:  7W51_B.pdb\n",
      "Chain Length:  154\n",
      "ProToken Index: (154,)\n",
      "[ 65 471 223 368  25 167  29 336  87 155  49 109 125 268 486  30  97 324\n",
      " 303 212 393 415 411 369  38  38  50 385 487 383 346 463  47 337 346 466\n",
      " 254 415  12 227 454  18 161 346 268 333  80 277  26 174 111 301 354  82\n",
      " 495 249 297  46 291 313 227 134  72 307 132  64 235  71 422 184 485 226\n",
      "  34 272 507  30 509 507 361 261  60 257 240 210 466 293 354 459 391 230\n",
      " 256 267 227  51 315 389 495  25 405 343 314 219 337 305 268  97 212   5\n",
      " 254 367 449 215 131 294 106 319 249 461 262 347 368 369 354 471 315  47\n",
      " 440  65 232 113 177  87 159  58 118 254 167 327 439 191 440 215 469  33\n",
      " 206 379  25 507 504 216  15 140 202  91]\n",
      "PDB ID:  7W51_A.pdb\n",
      "Chain Length:  361\n",
      "ProToken Index: (361,)\n",
      "[307 154 463 345 496 481 146 361  54 176  16  72 328 259 483 395 479 346\n",
      " 385 342 504 224   7 215 348 113 478 470 346 414 346  50 449 200 300 141\n",
      " 367 139 274 492 265  58  26 240 257  95 369 472 216 241 350 459 427 293\n",
      "  51 389 442  38 404 109 393 391 229 329  49 324  51 416 461   6 315 383\n",
      " 240 430 414 216  34 231 470 383 329 343 196 269 176 304 261  63 384 504\n",
      " 231 142 191  79 174 423  45 387 393  14 469 337 407  42  45 113 104 110\n",
      " 201  82  38 110 252   9  30  92 321 161 318 493 266 486   9  65 324 120\n",
      " 315 179 503 118  70  96 161  41 227  14 164 240  25   1 134 439 395 220\n",
      "   9 395 457 469 256 160 411 446 415  34 218 248 240 174 164 164 461 262\n",
      " 118 328 438   0 393 349 140 293  78 201 254 130 302 109 264 483  95 284\n",
      "  94 132 427 476  16 101 395 152 411  38  51 329 346 291 123 283   3  91\n",
      " 243  50 174 231 306 200 414 395 425 392 163  94 122  49  50 274 487 136\n",
      " 233 264 358 395 196 411 315  47 330 216  28 241 452 202  34 330 190 287\n",
      " 167 405 152 333  72 110 393  51 123 109 167  51 228 248 262 336 257 476\n",
      " 130 333  99 196 257 442 152 389 130 231  28 177 256 197  59 212 270  87\n",
      "  49 279 256  14 215 509 231 327 118 210  25 116 384 255 237 131 313  93\n",
      " 325  67 149 175 256  51 469 336 368 267 349 439 471  65  13 114 167 194\n",
      " 367 392  87  25 399  76 331 486 160  84 315   1   0  76 500 427   0 231\n",
      "  94 275 404  46 159 329 213 351 159 224 297 319  63 278 177 463  38 391\n",
      " 202 119 201 415 142 136 453   4 428 291 194 167 361 176 483 384 162 368\n",
      " 479]\n"
     ]
    }
   ],
   "source": [
    "protoken_index_ = np.asarray([encoder_results[\"protoken_index\"][p] for p in range(encoder_aux['seq_mask'].shape[0]) \\\n",
    "                                if encoder_aux['seq_mask'][p]])\n",
    "protoken_index_multimer = [protoken_index_[v['start_idx']:v['start_idx']+v['seq_len']] for k, v in encoder_aux['chain_length_info'].items()]\n",
    "\n",
    "for k in encoder_aux['chain_length_info'].keys():\n",
    "    print('PDB ID: ', encoder_aux['chain_length_info'][k]['pdb_name'])\n",
    "    print(f'Chain Length: ', encoder_aux['chain_length_info'][k]['seq_len'])\n",
    "    print(f'ProToken Index: {protoken_index_multimer[k].shape}\\n{protoken_index_multimer[k]}')\n",
    "    encoder_aux['chain_length_info'][k]['protoken_index'] = protoken_index_multimer[k]\n",
    "with open(code_saving_path, 'wb') as f:\n",
    "    pkl.dump(encoder_aux['chain_length_info'], f)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### 5. Prepare the decoder inputs"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "protoken_index: (1024,)\n",
      "protoken_mask: (1024,)\n",
      "residue_index: (1024,)\n"
     ]
    }
   ],
   "source": [
    "# Multimer ProToken decoder's inputs should be a list of ProToken index and ProToken index should be in np.ndarray format.\n",
    "decoder_inputs = protoken_decoder_preprocess(protoken_index_multimer, task_mode=task_mode)\n",
    "for k, v in zip(protoken_decoder_input_features, decoder_inputs):\n",
    "    print(f'{k}: {v.shape}')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### 6. Decode the ProToken Index and get the reconstructed protein structure"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [],
   "source": [
    "decoder_results = model.decoder(*decoder_inputs)\n",
    "reconstructed_atom_positions = np.asarray(decoder_results['reconstructed_atom_positions'])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### 7. Compare the original and reconstructed protein structures"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Average lDDT: 0.9068252570445748\n"
     ]
    }
   ],
   "source": [
    "from data_process.preprocess import lddt\n",
    "lDDT = lddt(reconstructed_atom_positions[None, ...][:,:,1,:], \n",
    "            encoder_aux['backbone_atom_positions'][None, ...][:,:,1,:],\n",
    "            encoder_aux['seq_mask'][None,...,None], per_residue=True)[0]\n",
    "print(f\"Average lDDT: {np.mean(lDDT[:np.sum(encoder_aux['seq_mask'])])}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### 8. Save the reconstructed protein structure"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [],
   "source": [
    "partial_aux = {\"aatype\": encoder_aux[\"aatype\"].astype(np.int32),\n",
    "               \"residue_index\": decoder_inputs[-1].astype(np.int32)+1,\n",
    "               \"atom_positions\": reconstructed_atom_positions.astype(np.float32),\n",
    "               \"atom_mask\": encoder_aux[\"backbone_atom_masks\"].astype(np.float32),\n",
    "               \"plddt\": lDDT.astype(np.float32)}\n",
    "save_pdb_from_aux(partial_aux, pdb_saving_path)\n",
    "\n",
    "# if you want to save the protein without encoder_aux, \n",
    "# use the following code to save the protein\n",
    "# aatype_all_gly = np.asarray(decoder_inputs[1]).astype(np.int32)*7\n",
    "# backbone_atom_mask = np.repeat(np.asarray([1,1,1,0,1]+[0]*32)[None,...], aatype_all_gly.shape[0], axis=0).astype(np.float32)*decoder_inputs[1][..., None]\n",
    "# plddt = np.ones_like(aatype_all_gly).astype(np.float32)*99.99\n",
    "# partial_aux = {\"aatype\": aatype_all_gly,\n",
    "#                \"residue_index\": decoder_inputs[-1].astype(np.int32)+1,\n",
    "#                \"atom_positions\": reconstructed_atom_positions.astype(np.float32),\n",
    "#                \"atom_mask\": backbone_atom_mask,\n",
    "#                \"plddt\": plddt}\n",
    "# save_pdb_from_aux(partial_aux, pdb_saving_path)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "PDB saved at: ./results/multimer_example/reconstructed_protein.pdb\n",
      "ProTokens saved at: ./results/multimer_example/protoken_index.pkl\n",
      "Average lDDT: 0.907 Seq_Len: 515\n",
      "Job finished!\n",
      "\n"
     ]
    }
   ],
   "source": [
    "print(f'PDB saved at: {pdb_saving_path}')\n",
    "print(f'ProTokens saved at: {code_saving_path}')\n",
    "print('Average lDDT:', round(np.mean(lDDT[:np.sum(encoder_aux['seq_mask'])]), 3), 'Seq_Len:', seq_len)\n",
    "print(f'Job finished!\\n')"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "tensorflow_2.16",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.19"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
