{
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "4eem1ns5n7Eq"
      },
      "source": [
        "# Uni-Mol Binding Pose Prediction Colab\n",
        "\n",
        "This Colab notebook provides an online runnable version of [Uni-Mol](https://github.com/dptech-corp/Uni-Mol/) binding pose prediction (short for \"docking\" in the following) with custom settings.\n",
        "Uni-Mol docking is very fast in dozens of seconds. \n",
        "\n",
        "Please note that this Colab notebook is not a finished product and is provided as an early-access prototype. It is provided for theoretical modeling only and caution should be exercised in its use. \n",
        "\n",
        "**Licenses**\n",
        "\n",
        "This Colab uses the [Uni-Mol model parameters](https://github.com/dptech-corp/Uni-Mol/LICENSE) and its outputs are under the terms of the Creative Commons Attribution 4.0 International (CC BY 4.0) license. You can find details at: https://creativecommons.org/licenses/by/4.0/legalcode. The Colab is provided under the [Apache 2.0 license](https://www.apache.org/licenses/LICENSE-2.0).\n",
        "\n",
        "\n",
        "**Citations**\n",
        "\n",
        "Please cite the following papers if you use this notebook:\n",
        " \n",
        "*   Gengmo Zhou, Zhifeng Gao, Qiankun Ding, Hang Zheng, Hongteng Xu, Zhewei Wei, Linfeng Zhang, Guolin Ke. \"[Uni-Mol: A Universal 3D Molecular Representation Learning Framework.](https://chemrxiv.org/engage/chemrxiv/article-details/6318b529bada388485bc8361)\n",
        "\" ChemRxiv (2022)\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "cellView": "form",
        "id": "p6uWJIpRQR6y"
      },
      "outputs": [],
      "source": [
        "%%bash\n",
        "#@title Install dependencies\n",
        "\n",
        "GIT_REPO='https://github.com/dptech-corp/Uni-Mol'\n",
        "UNICORE_URL='https://github.com/dptech-corp/Uni-Core/releases/download/0.0.2/unicore-0.0.1+cu116torch1.13.1-cp39-cp39-linux_x86_64.whl'\n",
        "DOCKING_DATA_URL='https://github.com/dptech-corp/Uni-Mol/releases/download/v0.1/CASF-2016.tar.gz'\n",
        "DOCKING_WEIGHT_URL='https://github.com/dptech-corp/Uni-Mol/releases/download/v0.1/binding_pose_220908.pt'\n",
        "if [ ! -f UNIMOL_READY ]; then\n",
        "  wget -q ${UNICORE_URL} \n",
        "  pip3 -q install \"unicore-0.0.1+cu116torch1.13.1-cp39-cp39-linux_x86_64.whl\" \n",
        "  rm -rf ./Uni-Mol\n",
        "  git clone -b main ${GIT_REPO}\n",
        "  pip3 install -q ./Uni-Mol/unimol\n",
        "  pip install -q rdkit\n",
        "  pip install -q biopandas\n",
        "  wget -q ${DOCKING_DATA_URL}\n",
        "  tar -xzf \"CASF-2016.tar.gz\"\n",
        "  wget -q ${DOCKING_WEIGHT_URL}\n",
        "  pip install -q py3Dmol\n",
        "  touch UNIMOL_READY\n",
        "fi\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "cellView": "form",
        "id": "242RLQ0JVWns"
      },
      "outputs": [],
      "source": [
        "import os\n",
        "import sys\n",
        "import numpy as np\n",
        "import pandas as pd\n",
        "import biopandas\n",
        "import lmdb\n",
        "from biopandas.pdb import PandasPdb\n",
        "from rdkit import Chem\n",
        "from rdkit.Chem import AllChem\n",
        "from sklearn.cluster import KMeans\n",
        "from rdkit.Chem import rdMolTransforms\n",
        "from rdkit.Chem.rdMolAlign import AlignMolConformers\n",
        "from unimol.utils.docking_utils import docking_data_pre, ensemble_iterations\n",
        "from tqdm import tqdm\n",
        "import pickle\n",
        "import re\n",
        "import json\n",
        "import copy\n",
        "\n",
        "CASF_PATH = \"CASF-2016\"\n",
        "main_atoms = [\"N\", \"CA\", \"C\", \"O\", \"H\"]\n",
        "\n",
        "\n",
        "def load_from_CASF(pdb_id):\n",
        "    try:\n",
        "        pdb_path = os.path.join(CASF_PATH, \"casf2016\", pdb_id + \"_protein.pdb\")\n",
        "        pmol = PandasPdb().read_pdb(pdb_path)\n",
        "        pocket_residues = json.load(\n",
        "            open(os.path.join(CASF_PATH, \"casf2016.pocket.json\"))\n",
        "        )[pdb_id]\n",
        "        return pmol, pocket_residues\n",
        "    except:\n",
        "        print(\"Currently not support parsing pdb and pocket info from local files.\")\n",
        "\n",
        "\n",
        "def normalize_atoms(atom):\n",
        "    return re.sub(\"\\d+\", \"\", atom)\n",
        "\n",
        "\n",
        "def single_conf_gen(tgt_mol, num_confs=1000, seed=42, removeHs=True):\n",
        "    mol = copy.deepcopy(tgt_mol)\n",
        "    mol = Chem.AddHs(mol)\n",
        "    allconformers = AllChem.EmbedMultipleConfs(\n",
        "        mol, numConfs=num_confs, randomSeed=seed, clearConfs=True\n",
        "    )\n",
        "    sz = len(allconformers)\n",
        "    for i in range(sz):\n",
        "        try:\n",
        "            AllChem.MMFFOptimizeMolecule(mol, confId=i)\n",
        "        except:\n",
        "            continue\n",
        "    if removeHs:\n",
        "        mol = Chem.RemoveHs(mol)\n",
        "    return mol\n",
        "\n",
        "\n",
        "def clustering_coords(mol, M=1000, N=100, seed=42, removeHs=True):\n",
        "    rdkit_coords_list = []\n",
        "    rdkit_mol = single_conf_gen(mol, num_confs=M, seed=seed, removeHs=removeHs)\n",
        "    noHsIds = [\n",
        "        rdkit_mol.GetAtoms()[i].GetIdx()\n",
        "        for i in range(len(rdkit_mol.GetAtoms()))\n",
        "        if rdkit_mol.GetAtoms()[i].GetAtomicNum() != 1\n",
        "    ]\n",
        "    ### exclude hydrogens for aligning\n",
        "    AlignMolConformers(rdkit_mol, atomIds=noHsIds)\n",
        "    sz = len(rdkit_mol.GetConformers())\n",
        "    for i in range(sz):\n",
        "        _coords = rdkit_mol.GetConformers()[i].GetPositions().astype(np.float32)\n",
        "        rdkit_coords_list.append(_coords)\n",
        "\n",
        "    ### exclude hydrogens for clustering\n",
        "    rdkit_coords_flatten = np.array(rdkit_coords_list)[:, noHsIds].reshape(sz, -1)\n",
        "    ids = (\n",
        "        KMeans(n_clusters=N, random_state=seed)\n",
        "        .fit_predict(rdkit_coords_flatten)\n",
        "        .tolist()\n",
        "    )\n",
        "    coords_list = [rdkit_coords_list[ids.index(i)] for i in range(N)]\n",
        "    return coords_list\n",
        "\n",
        "\n",
        "def parser(pdb_id, smiles, seed=42):\n",
        "    pmol, pocket_residues = load_from_CASF(pdb_id)\n",
        "    pname = pdb_id\n",
        "    pro_atom = pmol.df[\"ATOM\"]\n",
        "    pro_hetatm = pmol.df[\"HETATM\"]\n",
        "\n",
        "    pro_atom[\"ID\"] = pro_atom[\"chain_id\"].astype(str) + pro_atom[\n",
        "        \"residue_number\"\n",
        "    ].astype(str)\n",
        "    pro_hetatm[\"ID\"] = pro_hetatm[\"chain_id\"].astype(str) + pro_hetatm[\n",
        "        \"residue_number\"\n",
        "    ].astype(str)\n",
        "\n",
        "    pocket = pd.concat(\n",
        "        [\n",
        "            pro_atom[pro_atom[\"ID\"].isin(pocket_residues)],\n",
        "            pro_hetatm[pro_hetatm[\"ID\"].isin(pocket_residues)],\n",
        "        ],\n",
        "        axis=0,\n",
        "        ignore_index=True,\n",
        "    )\n",
        "\n",
        "    pocket[\"normalize_atom\"] = pocket[\"atom_name\"].map(normalize_atoms)\n",
        "    pocket = pocket[pocket[\"normalize_atom\"] != \"\"]\n",
        "    patoms = pocket[\"atom_name\"].apply(normalize_atoms).values.tolist()\n",
        "    pcoords = [pocket[[\"x_coord\", \"y_coord\", \"z_coord\"]].values]\n",
        "    side = [0 if a in main_atoms else 1 for a in patoms]\n",
        "    residues = (\n",
        "        pocket[\"chain_id\"].astype(str) + pocket[\"residue_number\"].astype(str)\n",
        "    ).values.tolist()\n",
        "\n",
        "    # generate ligand conformation\n",
        "    M, N = 100, 10\n",
        "    mol = Chem.MolFromSmiles(smiles)\n",
        "    mol = Chem.AddHs(mol)\n",
        "    AllChem.EmbedMolecule(mol, randomSeed=seed)\n",
        "    latoms = [atom.GetSymbol() for atom in mol.GetAtoms()]\n",
        "    holo_coordinates = [mol.GetConformer().GetPositions().astype(np.float32)]\n",
        "    holo_mol = mol\n",
        "    coordinate_list = clustering_coords(mol, M=M, N=N, seed=seed, removeHs=False)\n",
        "    mol_list = [mol] * N\n",
        "\n",
        "    return pickle.dumps(\n",
        "        {\n",
        "            \"atoms\": latoms,\n",
        "            \"coordinates\": coordinate_list,\n",
        "            \"mol_list\": mol_list,\n",
        "            \"pocket_atoms\": patoms,\n",
        "            \"pocket_coordinates\": pcoords,\n",
        "            \"side\": side,\n",
        "            \"residue\": residues,\n",
        "            \"holo_coordinates\": holo_coordinates,\n",
        "            \"holo_mol\": holo_mol,\n",
        "            \"holo_pocket_coordinates\": pcoords,\n",
        "            \"smi\": smiles,\n",
        "            \"pocket\": pname,\n",
        "        },\n",
        "        protocol=-1,\n",
        "    )\n",
        "\n",
        "\n",
        "def write_lmdb(pdb_id, smiles_list, seed=42, result_dir=\"./results\"):\n",
        "    os.makedirs(result_dir, exist_ok=True)\n",
        "    outputfilename = os.path.join(result_dir, pdb_id + \".lmdb\")\n",
        "    try:\n",
        "        os.remove(outputfilename)\n",
        "    except:\n",
        "        pass\n",
        "    env_new = lmdb.open(\n",
        "        outputfilename,\n",
        "        subdir=False,\n",
        "        readonly=False,\n",
        "        lock=False,\n",
        "        readahead=False,\n",
        "        meminit=False,\n",
        "        max_readers=1,\n",
        "        map_size=int(10e9),\n",
        "    )\n",
        "    for i, smiles in enumerate(smiles_list):\n",
        "        inner_output = parser(pdb_id, smiles, seed=seed)\n",
        "        txn_write = env_new.begin(write=True)\n",
        "        txn_write.put(f\"{i}\".encode(\"ascii\"), inner_output)\n",
        "    txn_write.commit()\n",
        "    env_new.close()\n",
        "\n",
        "\n",
        "# @title Run Uni-Mol Binding Pose Prediction\n",
        "\n",
        "# @markdown Currently this scripts only support CASF-2016 dataset with given pockets residues.\n",
        "\n",
        "# @markdown You can input multiple SMILES, split by ','.\n",
        "\n",
        "# @markdown If SMILES is not given, the default one in the complex will be used.\n",
        "\n",
        "pdb_id = \"4ty7\"  # @param {type:\"string\"}\n",
        "pdb_id = pdb_id.lower()\n",
        "casf_collect = os.listdir(os.path.join(CASF_PATH, \"casf2016\"))\n",
        "casf_collect = list(set([item[:4] for item in casf_collect]))\n",
        "if pdb_id not in casf_collect:\n",
        "  warning_str = \"{} is not int CASF-2016 dataset, Please select from \\n\".format(pdb_id)\n",
        "  for i in range(15):\n",
        "    warning_str += \"{}\\n\".format(','.join(casf_collect[20*i:20*(i+1)]))\n",
        "  raise Exception(warning_str)\n",
        "supp = Chem.SDMolSupplier(os.path.join(CASF_PATH, \"casf2016\", pdb_id + \"_ligand.sdf\"))\n",
        "mol = [mol for mol in supp if mol][0]\n",
        "ori_smiles = Chem.MolToSmiles(mol)\n",
        "smiles = \"\"  # @param {type:\"string\"}\n",
        "seed = 42  # @param {type:\"number\"}\n",
        "data_path = \"./CASF-2016\"\n",
        "results_path = \"./results/\"\n",
        "weight_path = \"/content/binding_pose_220908.pt\"\n",
        "batch_size = 8\n",
        "dist_threshold = 8.0\n",
        "recycling = 3\n",
        "if smiles.split(\",\") == 0 or smiles == \"\":\n",
        "    print(\"No other smiles inputs\")\n",
        "    smiles_list = [ori_smiles]\n",
        "else:\n",
        "    print(\"Docking with smiles: {}\".format(smiles))\n",
        "    smiles_list = smiles.split(\",\")\n",
        "\n",
        "write_lmdb(pdb_id, smiles_list, seed=seed, result_dir=data_path)\n",
        "\n",
        "!python ./Uni-Mol/unimol/unimol/infer.py --user-dir ./Uni-Mol/unimol/unimol $data_path --valid-subset $pdb_id \\\n",
        "       --results-path $results_path \\\n",
        "       --num-workers 8 --ddp-backend=c10d --batch-size $batch_size \\\n",
        "       --task docking_pose --loss docking_pose --arch docking_pose \\\n",
        "       --path $weight_path \\\n",
        "       --fp16 --fp16-init-scale 4 --fp16-scale-window 256 \\\n",
        "       --dist-threshold $dist_threshold --recycling $recycling \\\n",
        "       --log-interval 50 --log-format simple\n",
        "\n",
        "def generate_docking_input(\n",
        "    predict_file, reference_file, tta_times=10, output_dir=\"./results\"\n",
        "):\n",
        "    (\n",
        "        mol_list,\n",
        "        smi_list,\n",
        "        pocket_list,\n",
        "        pocket_coords_list,\n",
        "        distance_predict_list,\n",
        "        holo_distance_predict_list,\n",
        "        holo_coords_list,\n",
        "        holo_center_coords_list,\n",
        "    ) = docking_data_pre(reference_file, predict_file)\n",
        "    iter = ensemble_iterations(\n",
        "        mol_list,\n",
        "        smi_list,\n",
        "        pocket_list,\n",
        "        pocket_coords_list,\n",
        "        distance_predict_list,\n",
        "        holo_distance_predict_list,\n",
        "        holo_coords_list,\n",
        "        holo_center_coords_list,\n",
        "        tta_times=tta_times,\n",
        "    )\n",
        "    for i, content in enumerate(iter):\n",
        "        pocket = content[3]\n",
        "        output_name = os.path.join(output_dir, \"{}.{}.pkl\".format(pocket, i))\n",
        "        try:\n",
        "            os.remove(output_name)\n",
        "        except:\n",
        "            pass\n",
        "        pd.to_pickle(content, output_name)\n",
        "\n",
        "\n",
        "predict_file = os.path.join(results_path, \"content_\" + pdb_id + \".out.pkl\")\n",
        "reference_file = os.path.join(data_path, pdb_id + \".lmdb\")\n",
        "generate_docking_input(\n",
        "    predict_file, reference_file, tta_times=10, output_dir=results_path\n",
        ")\n",
        "for i, smiles in enumerate(smiles_list):\n",
        "    print(\"Docking {}\".format(smiles))\n",
        "    input_path = os.path.join(results_path, \"{}.{}.pkl\".format(pdb_id, i))\n",
        "    ligand_path = os.path.join(results_path, \"docking.{}.{}.sdf\".format(pdb_id, i))\n",
        "    cmd = \"python ./Uni-Mol/unimol/unimol/utils/coordinate_model.py --input {} --output-ligand {}\".format(\n",
        "        input_path, ligand_path\n",
        "    )\n",
        "    os.system(cmd)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "cellView": "form",
        "id": "ilKA-z6h_2oO"
      },
      "outputs": [],
      "source": [
        "#@title Visualization\n",
        "\n",
        "#@markdown Note: The first figure shows the result of the Uni-Mol prediction, and the second one shows the difference between the Uni-Mol prediction and the ground-truth ligand in the complex.\n",
        "\n",
        "#@markdown Note: We only visualize the first ligand when multiple SMILES are provided.\n",
        "\n",
        "import py3Dmol\n",
        "import matplotlib.pyplot as plt\n",
        "pdb_path = os.path.join(CASF_PATH, 'casf2016', pdb_id+'_protein.pdb')\n",
        "ligand_path = os.path.join(results_path, \"docking.{}.{}.sdf\".format(pdb_id,0))\n",
        "gt_ligand_path = os.path.join(CASF_PATH,'casf2016',pdb_id+'_ligand.sdf')\n",
        "view = py3Dmol.view()\n",
        "view.removeAllModels()\n",
        "pdb_path = os.path.join(CASF_PATH, 'casf2016', pdb_id+'_protein.pdb')\n",
        "view.addModel(open(pdb_path,'r').read(),format='pdb')\n",
        "view.setStyle({'cartoon': {'arrows':True, 'tubes':False, 'style':'oval', 'color':'white'}})\n",
        "view.addSurface(py3Dmol.VDW,{'opacity':0.5,'color':'white'})\n",
        "\n",
        "view.addModel(open(ligand_path,'r').read(),format='sdf')\n",
        "ref_m = view.getModel()\n",
        "ref_m.setStyle({},{'stick':{'colorscheme':'greenCarbon','radius':0.2}})\n",
        "\n",
        "view.zoomTo(viewer=(100,0))\n",
        "view.show()\n",
        "\n",
        "view.removeAllModels()\n",
        "view.addModel(open(ligand_path,'r').read(),format='sdf')\n",
        "ref_m = view.getModel()\n",
        "ref_m.setStyle({},{'stick':{'colorscheme':'greenCarbon','radius':0.2}})\n",
        "\n",
        "view.addModel(open(gt_ligand_path,'r').read(),format='sdf')\n",
        "ref_m = view.getModel()\n",
        "ref_m.setStyle({},{'stick':{'colorscheme':'redCarbon','radius':0.2}})\n",
        "\n",
        "view.zoomTo(viewer=(100,0))\n",
        "view.show()"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "cellView": "form",
        "id": "BetxYhrqB1SD"
      },
      "outputs": [],
      "source": [
        "#@title Download the prediction\n",
        "#@markdown **The content of zip file**:\n",
        "#@markdown 1. PDB formatted structures\n",
        "#@markdown 2. Docking ligand SDF files\n",
        "#@markdown 3. Target ligand SDF files.\n",
        "\n",
        "from google.colab import files\n",
        "file_lists = []\n",
        "pdb_path = os.path.join(CASF_PATH, 'casf2016', pdb_id+'_protein.pdb')\n",
        "file_lists.append(pdb_path)\n",
        "for i in range(len(smiles_list)):\n",
        "  ligand_path = os.path.join(results_path, \"docking.{}.{}.sdf\".format(pdb_id,i))\n",
        "  file_lists.append(ligand_path)\n",
        "gt_ligand_path = os.path.join(CASF_PATH,'casf2016',pdb_id+'_ligand.sdf')\n",
        "file_lists.append(gt_ligand_path)\n",
        "\n",
        "!zip -j {\"unimol.docking.\"+pdb_id}.zip {\" \".join(file_lists)}\n",
        "files.download(f'{\"unimol.docking.\"+pdb_id}.zip')"
      ]
    }
  ],
  "metadata": {
    "accelerator": "GPU",
    "colab": {
      "collapsed_sections": [],
      "provenance": []
    },
    "gpuClass": "standard",
    "kernelspec": {
      "display_name": "Python 3.8.10 64-bit",
      "language": "python",
      "name": "python3"
    },
    "language_info": {
      "codemirror_mode": {
        "name": "ipython",
        "version": 3
      },
      "file_extension": ".py",
      "mimetype": "text/x-python",
      "name": "python",
      "nbconvert_exporter": "python",
      "pygments_lexer": "ipython3",
      "version": "3.8.10"
    },
    "vscode": {
      "interpreter": {
        "hash": "916dbcbb3f70747c44a77c7bcd40155683ae19c65e1c03b4aa3499c5328201f1"
      }
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}
