{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Please note before running:\n",
    "Commit: b962451 (b962451a019e15363bd34b3af9d3a3cd02330947)\n",
    "\n",
    "Workspace path: Uni-Mol\n",
    "\n",
    "Notebook path: Uni-Mol/unimol_posebuster_demo.ipynb"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Import modules"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import pickle\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "from rdkit import Chem, RDLogger\n",
    "from rdkit.Chem import AllChem\n",
    "from tqdm import tqdm\n",
    "RDLogger.DisableLog('rdApp.*')  \n",
    "import warnings\n",
    "warnings.filterwarnings(action='ignore')\n",
    "from multiprocessing import Pool\n",
    "import copy\n",
    "import lmdb\n",
    "from biopandas.pdb import PandasPdb\n",
    "from sklearn.cluster import KMeans\n",
    "from rdkit.Chem.rdMolAlign  import AlignMolConformers"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Preprocess func for generating the LMDB file"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# allowed atom types \n",
    "main_atoms = ['N', 'CA', 'C', 'O', 'H']\n",
    "allow_pocket_atoms = ['C', 'H', 'N', 'O', 'S']\n",
    "\n",
    "def cal_configs(coords):\n",
    "    \"\"\"Calculate pocket configs\"\"\"\n",
    "\n",
    "    centerx,centery,centerz = list((np.max(coords,axis=0)+np.min(coords,axis=0))/2)\n",
    "    sizex,sizey,sizez = list(np.max(coords,axis=0)-np.mean(coords,axis=0))\n",
    "    config = {'cx':centerx,'cy':centery,'cz':centerz,\n",
    "                'sx':sizex,'sy':sizey,'sz':sizez}\n",
    "            \n",
    "    return config,centerx,centery,centerz,sizex,sizey,sizez\n",
    "\n",
    "\n",
    "def filter_pocketatoms(atom):\n",
    "    if atom[:2] in ['Cd','Cs', 'Cn', 'Ce', 'Cm', 'Cf', 'Cl', 'Ca', \\\n",
    "                    'Cr', 'Co', 'Cu', 'Nh', 'Nd', 'Np', 'No', 'Ne', 'Na',\\\n",
    "                     'Ni','Nb', 'Os', 'Og', 'Hf', 'Hg', 'Hs', 'Ho', 'He',\\\n",
    "                     'Sr', 'Sn', 'Sb', 'Sg', 'Sm', 'Si', 'Sc', 'Se']:\n",
    "        return None\n",
    "    if atom[0] >= '0' and atom[0] <= '9':\n",
    "        return filter_pocketatoms(atom[1:])\n",
    "    if atom[0] in ['Z','M','P','D','F','K','I','B']:\n",
    "        return None\n",
    "    if atom[0] in allow_pocket_atoms:\n",
    "        return atom\n",
    "    return atom\n",
    "\n",
    "\n",
    "def single_conf_gen(tgt_mol, num_confs=1000, seed=42, removeHs=True):\n",
    "    mol = copy.deepcopy(tgt_mol)\n",
    "    mol = Chem.AddHs(mol)\n",
    "    allconformers = AllChem.EmbedMultipleConfs(mol, numConfs=num_confs, randomSeed=seed, clearConfs=True)\n",
    "    sz = len(allconformers)\n",
    "    for i in range(sz):\n",
    "        try:\n",
    "            AllChem.MMFFOptimizeMolecule(mol, confId=i)\n",
    "        except:\n",
    "            continue\n",
    "    if removeHs:\n",
    "        mol = Chem.RemoveHs(mol)\n",
    "    return mol\n",
    "\n",
    "\n",
    "def clustering_coords(mol, M=1000, N=100, seed=42, removeHs=True, method='bonds'):\n",
    "    rdkit_coords_list = []\n",
    "    if method == 'rdkit_MMFF':\n",
    "        rdkit_mol = single_conf_gen(mol, num_confs=M, seed=seed, removeHs=removeHs)\n",
    "    else:\n",
    "        print('no conformer generation methods:{}'.format(method))\n",
    "        raise \n",
    "    noHsIds = [rdkit_mol.GetAtoms()[i].GetIdx() for i in range(len(rdkit_mol.GetAtoms())) if rdkit_mol.GetAtoms()[i].GetAtomicNum()!=1]\n",
    "    ### exclude hydrogens for aligning\n",
    "    AlignMolConformers(rdkit_mol, atomIds=noHsIds)\n",
    "    sz = len(rdkit_mol.GetConformers())\n",
    "    for i in range(sz):\n",
    "        _coords = rdkit_mol.GetConformers()[i].GetPositions().astype(np.float32)\n",
    "        rdkit_coords_list.append(_coords)\n",
    "\n",
    "    ### exclude hydrogens for clustering\n",
    "    rdkit_coords_flatten = np.array(rdkit_coords_list)[:, noHsIds].reshape(sz,-1)\n",
    "    ids = KMeans(n_clusters=N, random_state=seed).fit_predict(rdkit_coords_flatten).tolist()\n",
    "    coords_list = [rdkit_coords_list[ids.index(i)] for i in range(N)]\n",
    "    return coords_list\n",
    "\n",
    "\n",
    "def extract_pose_posebuster(content):\n",
    "\n",
    "    pdbid, ligid, protein_path, ligand_path, index = content\n",
    "\n",
    "    def read_pdb(path, pdbid):\n",
    "        #### protein preparation\n",
    "        pfile = os.path.join(path, pdbid+'.pdb')\n",
    "        pmol = PandasPdb().read_pdb(pfile)\n",
    "        \n",
    "        return pmol\n",
    "\n",
    "    ### totally posebuster data\n",
    "    def read_mol(path, pdbid, ligid):\n",
    "        lsdf = os.path.join(path, f'{pdbid}_{ligid}.sdf')\n",
    "        supp = Chem.SDMolSupplier(lsdf)\n",
    "        mols = [mol for mol in supp if mol]\n",
    "        if len(mols) == 0:\n",
    "            print(lsdf)\n",
    "        mol = mols[0]\n",
    "        return mol\n",
    "\n",
    "    # influence pocket size\n",
    "    dist_thres=6\n",
    "    if pdbid == 'index' or pdbid == 'readme':\n",
    "        return None\n",
    "\n",
    "    pmol = read_pdb(protein_path, pdbid)\n",
    "    pname = pdbid\n",
    "    mol = read_mol(ligand_path, pdbid, ligid)\n",
    "    mol = Chem.RemoveHs(mol)\n",
    "    lcoords = mol.GetConformer().GetPositions().astype(np.float32)\n",
    "        \n",
    "    pdf = pmol.df['ATOM']\n",
    "    filter_std = []\n",
    "    for lcoord in lcoords:\n",
    "        pdf['dist'] = pmol.distance(xyz=list(lcoord), records=('ATOM'))\n",
    "        df = pdf[(pdf.dist <= dist_thres) & (pdf.element_symbol != 'H')][['chain_id', 'residue_number']]\n",
    "        filter_std += list(zip(df.chain_id.tolist(), df.residue_number.tolist()))\n",
    "\n",
    "    filter_std = set(filter_std)\n",
    "    patoms, pcoords, residues = [], np.empty((0,3)), []\n",
    "    for id,res in filter_std:\n",
    "        df = pdf[(pdf.chain_id == id) & (pdf.residue_number == res)]\n",
    "        patoms += df['atom_name'].tolist()\n",
    "        pcoords = np.concatenate((pcoords, df[['x_coord','y_coord','z_coord']].to_numpy()), axis=0)\n",
    "        residues += [str(id)+str(res)]*len(df)\n",
    "\n",
    "    if len(pcoords)==0:\n",
    "        print('empty pocket:', pdbid)\n",
    "        return None\n",
    "    config,centerx,centery,centerz,sizex,sizey,sizez = cal_configs(pcoords)\n",
    "\n",
    "    # filter unnormal atoms, include metal\n",
    "    atoms, index, residues_tmp = [], [], []\n",
    "    for i,a in enumerate(patoms):\n",
    "        output = filter_pocketatoms(a)\n",
    "        if output is not None:\n",
    "            index.append(True)\n",
    "            atoms.append(output)\n",
    "            residues_tmp.append(residues[i])\n",
    "        else:\n",
    "            index.append(False)\n",
    "    coordinates = pcoords[index].astype(np.float32)\n",
    "    residues = residues_tmp\n",
    "\n",
    "    assert len(atoms) == len(residues)\n",
    "    assert len(atoms) == coordinates.shape[0]\n",
    "\n",
    "    if len(atoms) != coordinates.shape[0]:\n",
    "        print(pname)\n",
    "        return None\n",
    "    patoms = atoms\n",
    "    pcoords = [coordinates]\n",
    "    side = [0 if a in main_atoms else 1 for a in patoms]\n",
    "\n",
    "    smiles = Chem.MolToSmiles(mol)\n",
    "    mol = AllChem.AddHs(mol, addCoords=True)\n",
    "    latoms = [atom.GetSymbol() for atom in mol.GetAtoms()]\n",
    "    holo_coordinates = [mol.GetConformer().GetPositions().astype(np.float32)]\n",
    "    holo_mol = mol\n",
    "    \n",
    "    M, N = 100, 10\n",
    "    coordinate_list = clustering_coords(mol, M=M, N=N, seed=42, removeHs=False, method='rdkit_MMFF')\n",
    "    mol_list = [mol]*N\n",
    "    ligand = [latoms, coordinate_list, holo_coordinates, smiles, mol_list, holo_mol]\n",
    "\n",
    "    return pname, patoms, pcoords, side, residues, config, ligand\n",
    "\n",
    "\n",
    "def parser(content):\n",
    "    pname, patoms, pcoords, side, residues, config, ligand = extract_pose_posebuster(content)\n",
    "    latoms, coordinate_list, holo_coordinates, smiles, mol_list, holo_mol = ligand\n",
    "    pickle.dumps({})\n",
    "    return pickle.dumps(\n",
    "        {\n",
    "            \"atoms\": latoms,\n",
    "            \"coordinates\": coordinate_list,\n",
    "            \"mol_list\": mol_list,\n",
    "            \"pocket_atoms\": patoms,\n",
    "            \"pocket_coordinates\": pcoords,\n",
    "            \"side\": side,\n",
    "            \"residue\": residues,\n",
    "            \"config\": config,\n",
    "            \"holo_coordinates\": holo_coordinates,\n",
    "            \"holo_mol\": holo_mol,\n",
    "            \"holo_pocket_coordinates\": pcoords,\n",
    "            \"smi\": smiles,\n",
    "            'pocket':pname,\n",
    "            'scaffold':pname,\n",
    "        },\n",
    "        protocol=-1,\n",
    "    )\n",
    "\n",
    "\n",
    "def write_lmdb(protein_path, ligand_path, outpath, meta_info_file, lmdb_name, nthreads=8):\n",
    "    os.makedirs(outpath, exist_ok=True)\n",
    "    df = pd.read_csv(meta_info_file)\n",
    "    pdb_ids = list(df['pdb_code'].values)\n",
    "    lig_ids = list(df['lig_code'].values)\n",
    "    content_list = list(zip(pdb_ids, lig_ids, [protein_path]*len(pdb_ids), [ligand_path]*len(pdb_ids), range(len(pdb_ids))))\n",
    "    outputfilename = os.path.join(outpath, lmdb_name +'.lmdb')\n",
    "    try:\n",
    "        os.remove(outputfilename)\n",
    "    except:\n",
    "        pass\n",
    "    env_new = lmdb.open(\n",
    "        outputfilename,\n",
    "        subdir=False,\n",
    "        readonly=False,\n",
    "        lock=False,\n",
    "        readahead=False,\n",
    "        meminit=False,\n",
    "        max_readers=1,\n",
    "        map_size=int(100e9),\n",
    "    )\n",
    "    txn_write = env_new.begin(write=True)\n",
    "    print(\"Start preprocessing data...\")\n",
    "    print(f'Number of systems: {len(pdb_ids)}')\n",
    "    with Pool(nthreads) as pool:\n",
    "        i = 0\n",
    "        failed_num = 0\n",
    "        for inner_output in tqdm(pool.imap(parser, content_list)):\n",
    "            if inner_output is not None:\n",
    "                txn_write.put(f\"{i}\".encode(\"ascii\"), inner_output)\n",
    "                i+=1\n",
    "            elif inner_output is None: \n",
    "                failed_num += 1\n",
    "        txn_write.commit()\n",
    "        env_new.close()\n",
    "    print(f'Total num: {len(pdb_ids)}, Success: {i}, Failed: {failed_num}')\n",
    "    print(\"Done!\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Generate `lmdb` from `pdb` and `sdf`"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "protein_path = 'eval_sets/posebusters/proteins'\n",
    "ligand_path = 'eval_sets/posebusters/ligands'\n",
    "outpath = 'posebuster_test'\n",
    "meta_info_file = 'eval_sets/posebusters/posebuster_set_meta.csv'\n",
    "lmdb_name = 'posebuster_428'\n",
    "nthreads = 8\n",
    "\n",
    "write_lmdb(protein_path, ligand_path, outpath, meta_info_file, lmdb_name, nthreads=nthreads)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Infer with public ckp\n",
    "The script is the same as it is in the [Readme](https://github.com/dptech-corp/Uni-Mol/tree/main/unimol#protein-ligand-binding-pose-prediction)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "data_path=outpath\n",
    "results_path=\"./infer_pose\"  # replace to your results path\n",
    "weight_path=\"./ckp/binding_pose_220908.pt\"\n",
    "batch_size=8\n",
    "dist_threshold=8.0\n",
    "recycling=3\n",
    "valid_subset=lmdb_name\n",
    "mol_dict_name='dict_mol.txt'\n",
    "pocket_dict_name='dict_pkt.txt'\n",
    "\n",
    "!cp ./example_data/molecule/dict.txt $data_path/$mol_dict_name\n",
    "!cp ./example_data/pocket/dict_coarse.txt $data_path/$pocket_dict_name\n",
    "!python ./unimol/infer.py --user-dir ./unimol $data_path --valid-subset $valid_subset \\\n",
    "       --results-path $results_path \\\n",
    "       --num-workers 8 --ddp-backend=c10d --batch-size $batch_size \\\n",
    "       --task docking_pose --loss docking_pose --arch docking_pose \\\n",
    "       --path $weight_path \\\n",
    "       --fp16 --fp16-init-scale 4 --fp16-scale-window 256 \\\n",
    "       --dist-threshold $dist_threshold --recycling $recycling \\\n",
    "       --log-interval 50 --log-format simple"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Docking and cal metrics:\n",
    "The script is the same as it is in the [Readme](https://github.com/dptech-corp/Uni-Mol/tree/main/unimol#protein-ligand-binding-pose-prediction)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "nthreads=8  # Num of threads\n",
    "predict_file=f\"{results_path}/ckp_{lmdb_name}.out.pkl\"  # Your inference file dir\n",
    "reference_file=f\"{outpath}/{lmdb_name}.lmdb\"  # Your reference file dir\n",
    "output_path=\"./unimol_repro_posebuster428\"  # Docking results path\n",
    "\n",
    "!python ./unimol/utils/docking.py --nthreads $nthreads --predict-file $predict_file --reference-file $reference_file --output-path $output_path"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "base",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.18"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
