{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import random\n",
    "import json\n",
    "import jsonlines\n",
    "import ray\n",
    "import mmap\n",
    "import time\n",
    "import sys\n",
    "from ray.util.queue import Queue\n",
    "sys.path.append(\".\")\n",
    "from utils.misc import execute\n",
    "from utils.ray_tools import ProgressBar\n",
    "from tqdm import tqdm\n",
    "import pathlib\n",
    "import glob\n",
    "import subprocess\n",
    "import os\n",
    "import traceback\n",
    "\n",
    "TMscore_threshold=0.4\n",
    "Match_rate_threshold=0.4\n",
    "\n",
    "PDBBind_dir='/path/to/dir'\n",
    "MSA_dir=\"/path/to/dir\"\n",
    "AF2DB_dir=\"/path/to/dir\"\n",
    "PDBBind_instance_dirs = glob.glob(PDBBind_dir + '*/')\n",
    "print('Number of PDBBind instances: {}'.format(len(PDBBind_instance_dirs)))\n",
    "uncompleted_jobs=[]\n",
    "\n",
    "# remove jobs which do not have pocket position/chain_pdb_file/MSA_file\n",
    "for PDBBind_instance_dir in PDBBind_instance_dirs:\n",
    "    pdb_id=PDBBind_instance_dir.split(\"/\")[-2]\n",
    "    fasta_dir=glob.glob(PDBBind_instance_dir + '/*.fasta')\n",
    "    if len(fasta_dir)==0:\n",
    "        continue\n",
    "    chain_id=fasta_dir[0].split(\"/\")[-1].split(\".\")[0][-1]\n",
    "\n",
    "    MSA_file=MSA_dir+f\"/{pdb_id}\"+f\"{chain_id}\"+\".fasta\"\n",
    "    if not os.path.exists(MSA_file):\n",
    "        continue\n",
    "    uncompleted_jobs.append(PDBBind_instance_dir)\n",
    "print(\"uncompleted jobs:\",len(uncompleted_jobs))\n",
    "\n",
    "uncompleted_jobs=uncompleted_jobs\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from Bio.PDB import *\n",
    "import numpy as np\n",
    "from rdkit import Chem\n",
    "\n",
    "def get_MSA_ids(PDBBind_instance_dir):\n",
    "    file_list=glob.glob(PDBBind_instance_dir + '/rotation_matrix/*_TMscore.txt')\n",
    "    MSA_ids=[]\n",
    "    for file in file_list:\n",
    "        MSA_ids.append(file.split(\"/\")[-1].split(\"_\")[0])\n",
    "    return MSA_ids\n",
    "\n",
    "def calc_match_rate(pocket_position,Aligned_seq):\n",
    "    total_cnt=0\n",
    "    match_cnt=0\n",
    "    for i in range(len(Aligned_seq)):\n",
    "        if pocket_position[i]!=\"-\":\n",
    "            total_cnt+=1\n",
    "            if pocket_position[i]==Aligned_seq[i]:\n",
    "                match_cnt+=1\n",
    "    return match_cnt/total_cnt\n",
    "\n",
    "def get_rotate_matrix(rotate_matrix_file):\n",
    "    with open(rotate_matrix_file,\"r\") as f:\n",
    "        data=f.readlines()\n",
    "    u=[]\n",
    "    t=[]\n",
    "    for i in range(2,5):\n",
    "        line=data[i].split(\" \")\n",
    "        line_float=[float(x) for x in line if x!=\"\"]\n",
    "        t.append(line_float[1])\n",
    "        u.append(line_float[2:])\n",
    "    u=np.array(u)\n",
    "    t=np.array(t)\n",
    "    return u,t\n",
    "    \n",
    "\n",
    "uncompleted_jobs=uncompleted_jobs\n",
    "# print(uncompleted_jobs)\n",
    "for PDBBind_instance_dir in uncompleted_jobs:\n",
    "    print(\"PDBBind_instance_dir:\",PDBBind_instance_dir)\n",
    "    pdb_id=PDBBind_instance_dir.split(\"/\")[-2]\n",
    "    fasta_dir=glob.glob(PDBBind_instance_dir + '/*.fasta')\n",
    "    chain_id=fasta_dir[0].split(\"/\")[-1].split(\".\")[0][-1]\n",
    "    pocket_position_file=PDBBind_instance_dir + pdb_id +chain_id+ '_pocket_position.txt'\n",
    "    chain_pdb_file=PDBBind_instance_dir + pdb_id + '_pocket_chain.pdb'\n",
    "    # print(\"chain_pdb_file:\",chain_pdb_file)\n",
    "    mol2_file=PDBBind_instance_dir + pdb_id + '_ligand.mol2'\n",
    "    pocket_position_file=PDBBind_instance_dir + pdb_id +chain_id+ '_pocket_position.txt'\n",
    "    with open(pocket_position_file,\"r\") as f:\n",
    "        data=f.readlines()\n",
    "    pocket_position=data[0].strip()\n",
    "\n",
    "    extend_dir=PDBBind_instance_dir+\"/extend/\"\n",
    "    if not os.path.exists(extend_dir):\n",
    "        os.mkdir(extend_dir)\n",
    "\n",
    "    # get_MSA_ids\n",
    "    MSA_ids=get_MSA_ids(PDBBind_instance_dir)\n",
    "    # print(\"MSA_ids:\",MSA_ids)\n",
    "\n",
    "    ligand = Chem.MolFromMol2File(mol2_file)\n",
    "    conf = ligand.GetConformer()\n",
    "    ligand_coords = conf.GetPositions()\n",
    "\n",
    "    for MSA_id in MSA_ids[:10]:\n",
    "        # get TMscore and Match_rate of the MSA\n",
    "        TMscore_file=PDBBind_instance_dir+f\"/rotation_matrix/{MSA_id}_TMscore.txt\"\n",
    "        with open(TMscore_file,\"r\") as f:\n",
    "            data=f.readlines()\n",
    "        TMscore=float(data[0].split(\":\")[-1])\n",
    "        Aligned_seq=data[4].strip()\n",
    "        Match_rate=calc_match_rate(pocket_position,Aligned_seq)\n",
    "\n",
    "\n",
    "        if TMscore>=TMscore_threshold and Match_rate>=Match_rate_threshold:\n",
    "            if TMscore<0.5 and Match_rate<0.5 :\n",
    "                print(MSA_id,\" TMscore:\",TMscore,\"Match_rate:\",Match_rate)\n",
    "            continue\n",
    "\n",
    "            # print(\"###########################################################################\")\n",
    "            print(\"MSA_id:\",MSA_id)\n",
    "            # create extend dir\n",
    "            extend_instance_dir=extend_dir+MSA_id+\"/\"\n",
    "            if not os.path.exists(extend_instance_dir):\n",
    "                os.mkdir(extend_instance_dir)\n",
    "            \n",
    "            # read ori MSA pdb file\n",
    "            MSA_pdb_file=AF2DB_dir+f\"/{MSA_id}\"+\".pdb\"\n",
    "            parser = PDBParser()\n",
    "            structure = parser.get_structure(MSA_id, MSA_pdb_file)\n",
    "            model = structure[0]\n",
    "            \n",
    "            # get chain\n",
    "            for chain in model:\n",
    "                MSA_chain_id=chain.id\n",
    "                break\n",
    "            MSA_chain=model[MSA_chain_id]\n",
    "\n",
    "            # get rotate_matrix\n",
    "            rotate_matrix_file=PDBBind_instance_dir+f\"/rotation_matrix/{MSA_id}.txt\"\n",
    "            rotation_matrix=get_rotate_matrix(rotate_matrix_file)\n",
    "            # print(\"rotation_matrix:\",rotation_matrix)\n",
    "\n",
    "            for residue in MSA_chain:\n",
    "                for atom in residue:\n",
    "                    coord=atom.get_coord()\n",
    "                    coord=np.array(coord)\n",
    "                    new_coord=np.dot(rotation_matrix[0],coord)+rotation_matrix[1]\n",
    "                    atom.set_coord(new_coord)\n",
    "            \n",
    "            # write new pdb file\n",
    "            io = PDBIO()\n",
    "            io.set_structure(structure)\n",
    "            io.save(extend_instance_dir+f\"{MSA_id}\"+\"_protein.pdb\")\n",
    "\n",
    "            # get pocket , which is in the 6A of ligand\n",
    "            MSA_pocket_file=extend_instance_dir+f\"{MSA_id}\"+\"_pocket.pdb\"\n",
    "            for residue in MSA_chain:\n",
    "                # print(\"-------------------------------------------------------------\")\n",
    "                remove_atom_ids=[]\n",
    "                for atom in residue:\n",
    "                    # print(\"atom: \",atom.id)\n",
    "                    coord=atom.get_coord()\n",
    "                    f=0\n",
    "                    for ligand_coord in ligand_coords:\n",
    "                        dis=np.linalg.norm(coord-ligand_coord)\n",
    "                        if np.linalg.norm(coord-ligand_coord)<=10:\n",
    "                            f=1\n",
    "                            break\n",
    "                    if f==0:\n",
    "                        remove_atom_ids.append(atom.id)\n",
    "                for atom_id in remove_atom_ids:\n",
    "                    residue.detach_child(atom_id)\n",
    "            io = PDBIO()\n",
    "            io.set_structure(structure)\n",
    "            io.save(MSA_pocket_file)\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# count\n",
    "from Bio.PDB import *\n",
    "import numpy as np\n",
    "from rdkit import Chem\n",
    "\n",
    "def get_MSA_ids(PDBBind_instance_dir):\n",
    "    file_list=glob.glob(PDBBind_instance_dir + '/rotation_matrix/*_TMscore.txt')\n",
    "    MSA_ids=[]\n",
    "    for file in file_list:\n",
    "        MSA_ids.append(file.split(\"/\")[-1].split(\"_\")[0])\n",
    "    return MSA_ids\n",
    "\n",
    "def calc_match_rate(pocket_position,Aligned_seq):\n",
    "    total_cnt=0\n",
    "    match_cnt=0\n",
    "    for i in range(len(Aligned_seq)):\n",
    "        if pocket_position[i]!=\"-\":\n",
    "            total_cnt+=1\n",
    "            if pocket_position[i]==Aligned_seq[i]:\n",
    "                match_cnt+=1\n",
    "    return match_cnt/total_cnt\n",
    "\n",
    "def get_rotate_matrix(rotate_matrix_file):\n",
    "    with open(rotate_matrix_file,\"r\") as f:\n",
    "        data=f.readlines()\n",
    "    u=[]\n",
    "    t=[]\n",
    "    for i in range(2,5):\n",
    "        line=data[i].split(\" \")\n",
    "        line_float=[float(x) for x in line if x!=\"\"]\n",
    "        t.append(line_float[1])\n",
    "        u.append(line_float[2:])\n",
    "    u=np.array(u)\n",
    "    t=np.array(t)\n",
    "    return u,t\n",
    "    \n",
    "total_cnt=0\n",
    "MSA_cnt=[]\n",
    "uncompleted_jobs=uncompleted_jobs\n",
    "print(uncompleted_jobs)\n",
    "for PDBBind_instance_dir in tqdm(uncompleted_jobs):\n",
    "    # print(\"PDBBind_instance_dir:\",PDBBind_instance_dir)\n",
    "    pdb_id=PDBBind_instance_dir.split(\"/\")[-2]\n",
    "    fasta_dir=glob.glob(PDBBind_instance_dir + '/*.fasta')\n",
    "    chain_id=fasta_dir[0].split(\"/\")[-1].split(\".\")[0][-1]\n",
    "    pocket_position_file=PDBBind_instance_dir + pdb_id +chain_id+ '_pocket_position.txt'\n",
    "    chain_pdb_file=PDBBind_instance_dir + pdb_id + '_pocket_chain.pdb'\n",
    "    # print(\"chain_pdb_file:\",chain_pdb_file)\n",
    "    mol2_file=PDBBind_instance_dir + pdb_id + '_ligand.mol2'\n",
    "    pocket_position_file=PDBBind_instance_dir + pdb_id +chain_id+ '_pocket_position.txt'\n",
    "\n",
    "    extend_dir=PDBBind_instance_dir+\"/extend/\"\n",
    "    extend_num=len(glob.glob(extend_dir+\"*\"))\n",
    "    total_cnt+=extend_num\n",
    "    MSA_cnt.append(extend_num)\n",
    "print(\"total_cnt:\",total_cnt)\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "gt0_cnt=0\n",
    "gt3_cnt=0\n",
    "gt5_cnt=0\n",
    "for i in MSA_cnt:\n",
    "    if i>0:\n",
    "        gt0_cnt+=1\n",
    "    if i>3:\n",
    "        gt3_cnt+=1\n",
    "    if i>5:\n",
    "        gt5_cnt+=1\n",
    "print(\"gt0_cnt:\",gt0_cnt)\n",
    "print(\"gt3_cnt:\",gt3_cnt)\n",
    "print(\"gt5_cnt:\",gt5_cnt)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "base",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.16"
  },
  "orig_nbformat": 4
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
