{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## ProToken 1.0 Single Chain Example"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "# ==============================================================================\n",
    "# Copyright 2024 Changping Laboratory & Peking University. All Rights Reserved.\n",
    "# Licensed under the Apache License, Version 2.0 (the “License”);\n",
    "# you may not use this file except in compliance with the License.\n",
    "# You may obtain a copy of the License at\n",
    "#     http://www.apache.org/licenses/LICENSE-2.0\n",
    "# Unless required by applicable law or agreed to in writing, software\n",
    "# distributed under the License is distributed on an “AS IS” BASIS,\n",
    "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
    "# See the License for the specific language governing permissions and limitations under the License.\n",
    "# =============================================================================="
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Load basic libraries"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os, jax\n",
    "import pickle as pkl\n",
    "import numpy as np\n",
    "import tensorflow as tf\n",
    "from data_process.preprocess import save_pdb_from_aux, protoken_encoder_preprocess, protoken_decoder_preprocess, init_protoken_model\n",
    "from data_process.preprocess import protoken_encoder_input_features, protoken_decoder_input_features"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### A. Single Chain Protein Structure Encoding and Decoding."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### 1. Prepare the task information"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "# single chain example\n",
    "\n",
    "task_mode = 'single' # 'single' or 'multi'\n",
    "task_name = 'single_example'\n",
    "pdb_input_dir = './examples/single'\n",
    "\n",
    "saving_dir = f'./results/{task_name}'\n",
    "os.makedirs(saving_dir, exist_ok=True)\n",
    "\n",
    "pdb_saving_path = os.path.join(saving_dir, 'reconstructed_protein.pdb')\n",
    "code_saving_path = os.path.join(saving_dir, 'protoken_index.pkl')\n",
    "\n",
    "# Notes:\n",
    "# We have 3 models for different sequence lengths range from 0-512, 512-1024, 1024-2048\n",
    "# You can choose the model based on the sequence length of your protein,\n",
    "# Once the sequence length is beyond the current model's range, you need to reinitialize the model.\n",
    "# Have fun!"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### 2. Prepare the encoder inputs"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Found 1 pdb files in the input directory.\n",
      "WARNING: The input pdb files should be single-chain pdb files.\n",
      "seq_mask (512,)\n",
      "residue_index (512,)\n",
      "backbone_atom_masks (512, 37)\n",
      "backbone_atom_positions (512, 37, 3)\n",
      "ca_pos (512, 3)\n",
      "backbone_affine_tensor (512, 7)\n",
      "torsion_angles_sin_cos (512, 6)\n",
      "torsion_angles_mask (512, 3)\n"
     ]
    }
   ],
   "source": [
    "encoder_inputs, encoder_aux, seq_len = protoken_encoder_preprocess(pdb_input_dir, task_mode=task_mode)\n",
    "for k, v in zip(protoken_encoder_input_features, encoder_inputs):\n",
    "    print(k, v.shape)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### 3. Warmup the encoder and decoder"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Did not find GPU, will use CPU for prediction\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "2024-04-23 20:48:00.354811: W tensorflow/core/common_runtime/gpu/gpu_device.cc:2251] Cannot dlopen some GPU libraries. Please make sure the missing libraries mentioned above are installed properly if you would like to use GPU. Follow the guide at https://www.tensorflow.org/install/gpu for how to download and setup the required libraries for your platform.\n",
      "Skipping registering GPU devices...\n",
      "WARNING:absl:Importing a function (__inference_internal_grad_fn_113922) with ops with unsaved custom gradients. Will likely fail if a gradient is requested.\n",
      "WARNING:absl:Importing a function (__inference_internal_grad_fn_195207) with ops with unsaved custom gradients. Will likely fail if a gradient is requested.\n"
     ]
    }
   ],
   "source": [
    "model = init_protoken_model(seq_len)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### 4. Encode the protein structure and get the ProToken Index"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "ProToken Index: (193,)\n",
      "[258 384 416 294 324 454 324 227 127 104 342 100 373 381  92 215 487 403\n",
      "  92 250 509 324 240 177 256 472 384  74 228 471  24 241 329 202 369 132\n",
      " 458 487  47 333 151 267 231 483 133  51  28 132  32   0 362  78 493 220\n",
      "  24  12 196 364 337 210 358 439 367 161 293 216 450 110 106 266 257 473\n",
      " 495 291  46 503  92 328 214  48 384 360 146 266 476  50 297 185 241  50\n",
      "  34 362 241 485 163 237 304  27 419 299  72  42 293 329 430  76 315 152\n",
      " 481 268 315 123 361  59 194 262 372 248 130 268 425 109 256 118 386 264\n",
      " 393 305 347 190 411 403 106 407 446  14  38 487 161 342 190 254  42 334\n",
      "  49 125 187 466 143 457 324 439 109 161 456 163  30 161 415 440 151 170\n",
      " 291 395 274  42 457 246  25  42 224 315 442 471 349 303 442 202 451 261\n",
      "  38 272 165 230 466 168 434 247 450 411  52  95 264]\n"
     ]
    }
   ],
   "source": [
    "encoder_results = model.encoder(*encoder_inputs)\n",
    "protoken_index = np.asarray([encoder_results[\"protoken_index\"][p] for p in range(encoder_aux['seq_mask'].shape[0]) \\\n",
    "                                if encoder_aux['seq_mask'][p]])\n",
    "print(f'ProToken Index: {protoken_index.shape}\\n{protoken_index}')\n",
    "with open(code_saving_path, 'wb') as f:\n",
    "    pkl.dump(protoken_index, f)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### 5. Prepare the decoder inputs"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "protoken_index: (512,)\n",
      "protoken_mask: (512,)\n",
      "residue_index: (512,)\n"
     ]
    }
   ],
   "source": [
    "# Single chain ProToken decoder's inputs should be a array of ProToken indexes in the np.ndarray format.\n",
    "decoder_inputs = protoken_decoder_preprocess(protoken_index, task_mode=task_mode)\n",
    "for k, v in zip(protoken_decoder_input_features, decoder_inputs):\n",
    "    print(f'{k}: {v.shape}')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### 6. Decode the ProToken Index and get the reconstructed protein structure"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "decoder_results = model.decoder(*decoder_inputs)\n",
    "reconstructed_atom_positions = np.asarray(decoder_results['reconstructed_atom_positions'])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### 7. Compare the original and reconstructed protein structures"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 28,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Average lDDT: 0.9680577494332036\n"
     ]
    }
   ],
   "source": [
    "from data_process.preprocess import lddt\n",
    "lDDT = lddt(reconstructed_atom_positions[None, ...][:,:,1,:], \n",
    "            encoder_aux['backbone_atom_positions'][None, ...][:,:,1,:],\n",
    "            encoder_aux['seq_mask'][None,...,None], per_residue=True)[0]\n",
    "print(f\"Average lDDT: {np.mean(lDDT[:np.sum(encoder_aux['seq_mask'])])}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### 8. Save the reconstructed protein structure"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 50,
   "metadata": {},
   "outputs": [],
   "source": [
    "partial_aux = {\"aatype\": encoder_aux[\"aatype\"].astype(np.int32),\n",
    "               \"residue_index\": decoder_inputs[-1].astype(np.int32)+1,\n",
    "               \"atom_positions\": reconstructed_atom_positions.astype(np.float32),\n",
    "               \"atom_mask\": encoder_aux[\"backbone_atom_masks\"].astype(np.float32),\n",
    "               \"plddt\": lDDT.astype(np.float32)}\n",
    "save_pdb_from_aux(partial_aux, pdb_saving_path)\n",
    "\n",
    "# if you want to save the protein without encoder_aux, \n",
    "# use the following code to save the protein\n",
    "# aatype_all_gly = np.asarray(decoder_inputs[1]).astype(np.int32)*7\n",
    "# backbone_atom_mask = np.repeat(np.asarray([1,1,1,0,1]+[0]*32)[None,...], aatype_all_gly.shape[0], axis=0).astype(np.float32)*decoder_inputs[1][..., None]\n",
    "# plddt = np.ones_like(aatype_all_gly).astype(np.float32)*99.99\n",
    "# partial_aux = {\"aatype\": aatype_all_gly,\n",
    "#                \"residue_index\": decoder_inputs[-1].astype(np.int32)+1,\n",
    "#                \"atom_positions\": reconstructed_atom_positions.astype(np.float32),\n",
    "#                \"atom_mask\": backbone_atom_mask,\n",
    "#                \"plddt\": plddt}\n",
    "# save_pdb_from_aux(partial_aux, pdb_saving_path)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 32,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "PDB saved at: ./results/single_example/reconstructed_protein.pdb\n",
      "ProTokens saved at: ./results/single_example/protoken_index.pkl\n",
      "Average lDDT: 0.968 Seq_Len: 193\n",
      "Job finished!\n",
      "\n"
     ]
    }
   ],
   "source": [
    "print(f'PDB saved at: {pdb_saving_path}')\n",
    "print(f'ProTokens saved at: {code_saving_path}')\n",
    "print('Average lDDT:', round(np.mean(lDDT[:np.sum(encoder_aux['seq_mask'])]), 3), 'Seq_Len:', seq_len)\n",
    "print(f'Job finished!\\n')"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "tensorflow_2.16",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.19"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
