{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "3ddae686",
   "metadata": {},
   "source": [
    "### Input data\n",
    "\n",
    "1. Create an account on the [PDB-Bind portal](https://www.pdbbind-plus.org.cn) and download the protein-protein binding affinity dataset. \n",
    "2. After extracting the downloaded folder, the file called `PP/index/INDEX_general_PP.2020` will be used as the input to this notebook. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "7b770612-301b-4341-af44-42e62d208f04",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os,sys,re\n",
    "import argparse, json\n",
    "import copy\n",
    "import random\n",
    "import pickle\n",
    "import math\n",
    "import torch\n",
    "from torch import nn\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "from torch.utils.data import Dataset\n",
    "from tqdm import tqdm\n",
    "#from tqdm.notebook import tqdm\n",
    "from Bio.PDB.PDBParser import PDBParser\n",
    "from Bio.PDB.Polypeptide import one_to_index\n",
    "from Bio.PDB import Selection\n",
    "from Bio import SeqIO\n",
    "from Bio.PDB.Residue import Residue\n",
    "from easydict import EasyDict\n",
    "import enum\n",
    "import gzip\n",
    "from Bio import SeqIO\n",
    "\n",
    "from collections import OrderedDict\n",
    "from sklearn.metrics import mean_squared_error\n",
    "import scipy.stats"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "id": "1533028d-772f-4414-9498-6de9912e114c",
   "metadata": {},
   "outputs": [],
   "source": [
    "complexes = []\n",
    "resolution = []\n",
    "units = []\n",
    "log_kds_M = []\n",
    "\n",
    "dissociation_units_to_moles = {\n",
    "    'fM': 1e-15,  # femtomolar to molar\n",
    "    'pM': 1e-12,  # picomolar to molar\n",
    "    'nM': 1e-9,   # nanomolar to molar\n",
    "    'uM': 1e-6,   # micromolar to molar\n",
    "    'mM': 1e-3    # millimolar to molar\n",
    "}\n",
    "\n",
    "with open('PP/index/INDEX_general_PP.2020') as file:\n",
    "    for i,line in enumerate(file):\n",
    "        if line.startswith('#'):\n",
    "            continue\n",
    "        splitted = line.split('  ')\n",
    "\n",
    "        def remove(s, sub_strs):\n",
    "            for sub_str in sub_strs:\n",
    "                s = s.replace(sub_str, '')\n",
    "            return s \n",
    "\n",
    "        if 'Ki' in splitted[3]:\n",
    "            kd_str = remove(splitted[3], ['Ki=', 'Ki~', 'Ki<'])\n",
    "        if 'Kd' in splitted[3]:\n",
    "            kd_str = remove(splitted[3], ['Kd=', 'Kd~', 'Kd<', 'Kd>'])\n",
    "        if 'IC50' in splitted[3]:\n",
    "            kd_str = remove(splitted[3], ['IC50=', 'IC50~', 'IC50<'])\n",
    "        val = kd_str[:-2]\n",
    "        unit = kd_str[len(kd_str)-2:]\n",
    "\n",
    "        if val == '':\n",
    "            print(line)\n",
    "            print(kd_str)\n",
    "            print(unit)\n",
    "        \n",
    "        if unit == '':\n",
    "            break\n",
    "        else:\n",
    "            kd = float(val) * dissociation_units_to_moles[unit]\n",
    "            log_kd = -1*math.log10(kd)\n",
    "            complexes.append(splitted[0])\n",
    "            log_kds_M.append(log_kd)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "id": "ec2a3424-899b-4aef-9576-e3e8ccb19ba5",
   "metadata": {},
   "outputs": [],
   "source": [
    "extracted_df = pd.DataFrame({'complex': complexes, '-logKd': log_kds_M})\n",
    "extracted_df.to_csv('extracted_affinities.csv', index=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "d6d42116-8623-41cb-86ea-1659ed0e4b31",
   "metadata": {},
   "outputs": [],
   "source": [
    "non_standard_residue_substitutions = {\n",
    "    '2AS':'ASP', '3AH':'HIS', '5HP':'GLU', 'ACL':'ARG', 'AGM':'ARG', 'AIB':'ALA', 'ALM':'ALA', 'ALO':'THR', 'ALY':'LYS', 'ARM':'ARG',\n",
    "    'ASA':'ASP', 'ASB':'ASP', 'ASK':'ASP', 'ASL':'ASP', 'ASQ':'ASP', 'AYA':'ALA', 'BCS':'CYS', 'BHD':'ASP', 'BMT':'THR', 'BNN':'ALA',\n",
    "    'BUC':'CYS', 'BUG':'LEU', 'C5C':'CYS', 'C6C':'CYS', 'CAS':'CYS', 'CCS':'CYS', 'CEA':'CYS', 'CGU':'GLU', 'CHG':'ALA', 'CLE':'LEU', 'CME':'CYS',\n",
    "    'CSD':'ALA', 'CSO':'CYS', 'CSP':'CYS', 'CSS':'CYS', 'CSW':'CYS', 'CSX':'CYS', 'CXM':'MET', 'CY1':'CYS', 'CY3':'CYS', 'CYG':'CYS',\n",
    "    'CYM':'CYS', 'CYQ':'CYS', 'DAH':'PHE', 'DAL':'ALA', 'DAR':'ARG', 'DAS':'ASP', 'DCY':'CYS', 'DGL':'GLU', 'DGN':'GLN', 'DHA':'ALA',\n",
    "    'DHI':'HIS', 'DIL':'ILE', 'DIV':'VAL', 'DLE':'LEU', 'DLY':'LYS', 'DNP':'ALA', 'DPN':'PHE', 'DPR':'PRO', 'DSN':'SER', 'DSP':'ASP',\n",
    "    'DTH':'THR', 'DTR':'TRP', 'DTY':'TYR', 'DVA':'VAL', 'EFC':'CYS', 'FLA':'ALA', 'FME':'MET', 'GGL':'GLU', 'GL3':'GLY', 'GLZ':'GLY',\n",
    "    'GMA':'GLU', 'GSC':'GLY', 'HAC':'ALA', 'HAR':'ARG', 'HIC':'HIS', 'HIP':'HIS', 'HMR':'ARG', 'HPQ':'PHE', 'HTR':'TRP', 'HYP':'PRO',\n",
    "    'IAS':'ASP', 'IIL':'ILE', 'IYR':'TYR', 'KCX':'LYS', 'LLP':'LYS', 'LLY':'LYS', 'LTR':'TRP', 'LYM':'LYS', 'LYZ':'LYS', 'MAA':'ALA', 'MEN':'ASN',\n",
    "    'MHS':'HIS', 'MIS':'SER', 'MLE':'LEU', 'MPQ':'GLY', 'MSA':'GLY', 'MSE':'MET', 'MVA':'VAL', 'NEM':'HIS', 'NEP':'HIS', 'NLE':'LEU',\n",
    "    'NLN':'LEU', 'NLP':'LEU', 'NMC':'GLY', 'OAS':'SER', 'OCS':'CYS', 'OMT':'MET', 'PAQ':'TYR', 'PCA':'GLU', 'PEC':'CYS', 'PHI':'PHE',\n",
    "    'PHL':'PHE', 'PR3':'CYS', 'PRR':'ALA', 'PTR':'TYR', 'PYX':'CYS', 'SAC':'SER', 'SAR':'GLY', 'SCH':'CYS', 'SCS':'CYS', 'SCY':'CYS',\n",
    "    'SEL':'SER', 'SEP':'SER', 'SET':'SER', 'SHC':'CYS', 'SHR':'LYS', 'SMC':'CYS', 'SOC':'CYS', 'STY':'TYR', 'SVA':'SER', 'TIH':'ALA',\n",
    "    'TPL':'TRP', 'TPO':'THR', 'TPQ':'ALA', 'TRG':'LYS', 'TRO':'TRP', 'TYB':'TYR', 'TYI':'TYR', 'TYQ':'TYR', 'TYS':'TYR', 'TYY':'TYR'\n",
    "}\n",
    "\n",
    "\n",
    "max_num_heavyatoms = 15\n",
    "max_num_hydrogens = 16\n",
    "max_num_allatoms = max_num_heavyatoms + max_num_hydrogens\n",
    "\n",
    "ressymb_to_resindex = {\n",
    "    'A': 0, 'C': 1, 'D': 2, 'E': 3, 'F': 4,\n",
    "    'G': 5, 'H': 6, 'I': 7, 'K': 8, 'L': 9,\n",
    "    'M': 10, 'N': 11, 'P': 12, 'Q': 13, 'R': 14,\n",
    "    'S': 15, 'T': 16, 'V': 17, 'W': 18, 'Y': 19,\n",
    "    'X': 20,\n",
    "}\n",
    "\n",
    "resindex_to_ressymb = {v: k for k, v in ressymb_to_resindex.items()}\n",
    "\n",
    "class BBHeavyAtom(enum.IntEnum):\n",
    "    N = 0; CA = 1; C = 2; O = 3; CB = 4; OXT=14;\n",
    "\n",
    "def _get_residue_heavyatom_info(res: Residue):\n",
    "    pos_heavyatom = torch.zeros([max_num_heavyatoms, 3], dtype=torch.float)\n",
    "    mask_heavyatom = torch.zeros([max_num_heavyatoms, ], dtype=torch.bool)\n",
    "    bfactor_heavyatom = torch.zeros([max_num_heavyatoms, ], dtype=torch.float)\n",
    "    restype = AA(res.get_resname())\n",
    "    for idx, atom_name in enumerate(restype_to_heavyatom_names[restype]):\n",
    "        if atom_name == '': continue\n",
    "        if atom_name in res:\n",
    "            pos_heavyatom[idx] = torch.tensor(res[atom_name].get_coord().tolist(), dtype=pos_heavyatom.dtype)\n",
    "            mask_heavyatom[idx] = True\n",
    "            bfactor_heavyatom[idx] = res[atom_name].get_bfactor()\n",
    "    return pos_heavyatom, mask_heavyatom, bfactor_heavyatom\n",
    "\n",
    "class AA(enum.IntEnum):\n",
    "    ALA = 0; CYS = 1; ASP = 2; GLU = 3; PHE = 4\n",
    "    GLY = 5; HIS = 6; ILE = 7; LYS = 8; LEU = 9\n",
    "    MET = 10; ASN = 11; PRO = 12; GLN = 13; ARG = 14\n",
    "    SER = 15; THR = 16; VAL = 17; TRP = 18; TYR = 19\n",
    "    UNK = 20\n",
    "\n",
    "    @classmethod\n",
    "    def _missing_(cls, value):\n",
    "        if isinstance(value, str) and len(value) == 3:      # three representation\n",
    "            if value in non_standard_residue_substitutions:\n",
    "                value = non_standard_residue_substitutions[value]\n",
    "            if value in cls._member_names_:\n",
    "                return getattr(cls, value)\n",
    "        elif isinstance(value, str) and len(value) == 1:    # one representation\n",
    "            if value in ressymb_to_resindex:\n",
    "                return cls(ressymb_to_resindex[value])\n",
    "\n",
    "        return super()._missing_(value)\n",
    "\n",
    "    def __str__(self):\n",
    "        return self.name\n",
    "\n",
    "    @classmethod\n",
    "    def is_aa(cls, value):\n",
    "        return (value in ressymb_to_resindex) or \\\n",
    "            (value in non_standard_residue_substitutions) or \\\n",
    "            (value in cls._member_names_)\n",
    "\n",
    "restype_to_heavyatom_names = {\n",
    "    AA.ALA: ['N', 'CA', 'C', 'O', 'CB', '',    '',    '',    '',    '',    '',    '',    '',    '', 'OXT'],\n",
    "    AA.ARG: ['N', 'CA', 'C', 'O', 'CB', 'CG',  'CD',  'NE',  'CZ',  'NH1', 'NH2', '',    '',    '', 'OXT'],\n",
    "    AA.ASN: ['N', 'CA', 'C', 'O', 'CB', 'CG',  'OD1', 'ND2', '',    '',    '',    '',    '',    '', 'OXT'],\n",
    "    AA.ASP: ['N', 'CA', 'C', 'O', 'CB', 'CG',  'OD1', 'OD2', '',    '',    '',    '',    '',    '', 'OXT'],\n",
    "    AA.CYS: ['N', 'CA', 'C', 'O', 'CB', 'SG',  '',    '',    '',    '',    '',    '',    '',    '', 'OXT'],\n",
    "    AA.GLN: ['N', 'CA', 'C', 'O', 'CB', 'CG',  'CD',  'OE1', 'NE2', '',    '',    '',    '',    '', 'OXT'],\n",
    "    AA.GLU: ['N', 'CA', 'C', 'O', 'CB', 'CG',  'CD',  'OE1', 'OE2', '',    '',    '',    '',    '', 'OXT'],\n",
    "    AA.GLY: ['N', 'CA', 'C', 'O', '',   '',    '',    '',    '',    '',    '',    '',    '',    '', 'OXT'],\n",
    "    AA.HIS: ['N', 'CA', 'C', 'O', 'CB', 'CG',  'ND1', 'CD2', 'CE1', 'NE2', '',    '',    '',    '', 'OXT'],\n",
    "    AA.ILE: ['N', 'CA', 'C', 'O', 'CB', 'CG1', 'CG2', 'CD1', '',    '',    '',    '',    '',    '', 'OXT'],\n",
    "    AA.LEU: ['N', 'CA', 'C', 'O', 'CB', 'CG',  'CD1', 'CD2', '',    '',    '',    '',    '',    '', 'OXT'],\n",
    "    AA.LYS: ['N', 'CA', 'C', 'O', 'CB', 'CG',  'CD',  'CE',  'NZ',  '',    '',    '',    '',    '', 'OXT'],\n",
    "    AA.MET: ['N', 'CA', 'C', 'O', 'CB', 'CG',  'SD',  'CE',  '',    '',    '',    '',    '',    '', 'OXT'],\n",
    "    AA.PHE: ['N', 'CA', 'C', 'O', 'CB', 'CG',  'CD1', 'CD2', 'CE1', 'CE2', 'CZ',  '',    '',    '', 'OXT'],\n",
    "    AA.PRO: ['N', 'CA', 'C', 'O', 'CB', 'CG',  'CD',  '',    '',    '',    '',    '',    '',    '', 'OXT'],\n",
    "    AA.SER: ['N', 'CA', 'C', 'O', 'CB', 'OG',  '',    '',    '',    '',    '',    '',    '',    '', 'OXT'],\n",
    "    AA.THR: ['N', 'CA', 'C', 'O', 'CB', 'OG1', 'CG2', '',    '',    '',    '',    '',    '',    '', 'OXT'],\n",
    "    AA.TRP: ['N', 'CA', 'C', 'O', 'CB', 'CG',  'CD1', 'CD2', 'NE1', 'CE2', 'CE3', 'CZ2', 'CZ3', 'CH2', 'OXT'],\n",
    "    AA.TYR: ['N', 'CA', 'C', 'O', 'CB', 'CG',  'CD1', 'CD2', 'CE1', 'CE2', 'CZ',  'OH',  '',    '', 'OXT'],\n",
    "    AA.VAL: ['N', 'CA', 'C', 'O', 'CB', 'CG1', 'CG2', '',    '',    '',    '',    '',    '',    '', 'OXT'],\n",
    "    AA.UNK: ['',  '',   '',  '',  '',   '',    '',    '',    '',    '',    '',    '',    '',    '',    ''],\n",
    "}\n",
    "for names in restype_to_heavyatom_names.values(): assert len(names) == max_num_heavyatoms\n",
    "\n",
    "amino_acids = {\n",
    "    'ALA': 'A', 'ARG': 'R', 'ASN': 'N', 'ASP': 'D',\n",
    "    'CYS': 'C', 'GLU': 'E', 'GLN': 'Q', 'GLY': 'G',\n",
    "    'HIS': 'H', 'ILE': 'I', 'LEU': 'L', 'LYS': 'K',\n",
    "    'MET': 'M', 'PHE': 'F', 'PRO': 'P', 'SER': 'S',\n",
    "    'THR': 'T', 'TRP': 'W', 'TYR': 'Y', 'VAL': 'V'\n",
    "}\n",
    "\n",
    "\n",
    "\n",
    "def parse_biopython_structure(entity, unknown_threshold=1.0):\n",
    "    chains = Selection.unfold_entities(entity, 'C')\n",
    "    chains.sort(key=lambda c: c.get_id())\n",
    "    data = EasyDict({\n",
    "        'chain_id': [], 'chain_nb': [], 'seq': []\n",
    "    })\n",
    "    \n",
    "    tensor_types = {\n",
    "        'chain_nb': torch.LongTensor\n",
    "    }\n",
    "    count_aa, count_unk = 0, 0\n",
    "\n",
    "    for i, chain in enumerate(chains):\n",
    "        chain.atom_to_internal_coordinates()\n",
    "        seq_this = 0   # Renumbering residues\n",
    "        residues = Selection.unfold_entities(chain, 'R')\n",
    "        residues.sort(key=lambda res: (res.get_id()[1], res.get_id()[2]))   # Sort residues by resseq-icode\n",
    "        for _, res in enumerate(residues):\n",
    "            resname = res.get_resname()\n",
    "\n",
    "            if not AA.is_aa(resname): continue\n",
    "            if not (res.has_id('CA') and res.has_id('C') and res.has_id('N')): continue\n",
    "            restype = AA(resname)\n",
    "            \n",
    "            count_aa += 1\n",
    "            if restype == AA.UNK: \n",
    "                count_unk += 1\n",
    "                continue\n",
    "\n",
    "            # Chain info\n",
    "            data.chain_id.append(chain.get_id())\n",
    "            data.chain_nb.append(i)\n",
    "\n",
    "            try:\n",
    "                data['seq'] += amino_acids[resname]\n",
    "            except:\n",
    "                data['seq'] += amino_acids[non_standard_residue_substitutions[resname]]\n",
    "\n",
    "    return data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "e105ca50-4355-47f7-815a-91f6885c5271",
   "metadata": {},
   "outputs": [],
   "source": [
    "def standardize(x):\n",
    "    return (x - x.mean(axis=0))/(x.std(axis=0))\n",
    "\n",
    "def get_sequences_by_chain(chain_ids, amino_acids):\n",
    "    # Initialize list of lists for chains\n",
    "    chains = []\n",
    "    for i in range(torch.max(chain_ids)+1):\n",
    "        chains.append([])\n",
    "    \n",
    "    # Iterate over both chain ids and amino acids\n",
    "    for chain_id, amino_acid in zip(chain_ids, amino_acids):\n",
    "        # Append amino acid to the corresponding chain\n",
    "        chains[chain_id].append(amino_acid)\n",
    "    \n",
    "    # Convert each sublist into a string to get the sequence of each chain\n",
    "    return [''.join(chain) for chain in chains]\n",
    "\n",
    "class PDBBindDataset(Dataset):\n",
    "\n",
    "    def __init__(\n",
    "        self, \n",
    "        csv_path, \n",
    "        pdb_dir, \n",
    "        cache_dir,\n",
    "        cvfold_index=0, \n",
    "        num_cvfolds=3, \n",
    "        split='train', \n",
    "        split_seed=2023,\n",
    "        reset=False\n",
    "    ):\n",
    "        super().__init__()\n",
    "        self.csv_path = csv_path\n",
    "        self.pdb_dir = pdb_dir\n",
    "        self.cache_dir = cache_dir\n",
    "        os.makedirs(cache_dir, exist_ok=True)\n",
    "        self.cvfold_index = cvfold_index\n",
    "        self.num_cvfolds = num_cvfolds\n",
    "        assert split in ('train', 'val')\n",
    "        self.split = split\n",
    "        self.split_seed = split_seed\n",
    "\n",
    "        self.entries_cache = os.path.join(cache_dir, 'entries.pkl')\n",
    "        self.entries = None\n",
    "        self.entries_full = None\n",
    "        self._load_entries(reset)\n",
    "\n",
    "        self.structures_cache = os.path.join(cache_dir, 'structures.pkl')\n",
    "        self.structures = None\n",
    "        self._load_structures(reset)\n",
    "\n",
    "    def _load_entries(self, reset):\n",
    "        if not os.path.exists(self.entries_cache) or reset:\n",
    "            self.entries_full = self._preprocess_entries()\n",
    "        else:\n",
    "            with open(self.entries_cache, 'rb') as f:\n",
    "                self.entries_full = pickle.load(f)\n",
    "\n",
    "        complex_to_entries = {}\n",
    "        for index, e in self.entries_full.iterrows():\n",
    "            if e['complex'] not in complex_to_entries:\n",
    "                complex_to_entries[e['complex']] = []\n",
    "            complex_to_entries[e['complex']].append(e)\n",
    "\n",
    "        complex_list = sorted(complex_to_entries.keys())\n",
    "        random.Random(self.split_seed).shuffle(complex_list)\n",
    "\n",
    "        split_size = math.ceil(len(complex_list) / self.num_cvfolds)\n",
    "        complex_splits = [\n",
    "            complex_list[i*split_size : (i+1)*split_size] \n",
    "            for i in range(self.num_cvfolds)\n",
    "        ]\n",
    "\n",
    "        val_split = complex_splits.pop(self.cvfold_index)\n",
    "        train_split = sum(complex_splits, [])\n",
    "        if self.split == 'val':\n",
    "            complexes_this = val_split\n",
    "        else:\n",
    "            complexes_this = train_split\n",
    "\n",
    "        entries = []\n",
    "        for cplx in complexes_this:\n",
    "            entries += complex_to_entries[cplx]\n",
    "        \n",
    "        entries_val = [float(e['-logKd']) for e in entries]\n",
    "        \n",
    "        self.entries_val = standardize(np.array(entries_val))\n",
    "        self.entries_complex = [e['complex'] for e in entries]\n",
    "        \n",
    "    def _preprocess_entries(self):\n",
    "        entries = pd.read_csv(filepath_or_buffer=self.csv_path)\n",
    "        with open(self.entries_cache, 'wb') as f:\n",
    "            pickle.dump(entries, f)\n",
    "        return entries\n",
    "\n",
    "    def _load_structures(self, reset):\n",
    "        if not os.path.exists(self.structures_cache) or reset:\n",
    "            self.structures = self._preprocess_structures()\n",
    "        else:\n",
    "            with open(self.structures_cache, 'rb') as f:\n",
    "                self.structures = pickle.load(f)\n",
    "\n",
    "    def _preprocess_structures(self):\n",
    "        structures = {}\n",
    "        pdbcodes = list(set([e['complex'] for index, e in self.entries_full.iterrows()]))\n",
    "\n",
    "        missing_num = 0\n",
    "        for pdbcode in tqdm(pdbcodes, desc='Structures'):\n",
    "            parser = PDBParser(QUIET=True)\n",
    "            pdb_path = os.path.join(self.pdb_dir, '{}.ent.pdb'.format(pdbcode.lower()))\n",
    "            model = parser.get_structure(None, pdb_path)[0]\n",
    "            data = parse_biopython_structure(model)\n",
    "            structures[pdbcode] = data\n",
    "\n",
    "        with open(self.structures_cache, 'wb') as f:\n",
    "            pickle.dump(structures, f)\n",
    "        return structures\n",
    "\n",
    "    def __len__(self):\n",
    "        return len(self.entries_val)\n",
    "\n",
    "    def __getitem__(self, index):\n",
    "\n",
    "        ddG = self.entries_val[index]\n",
    "        complex = self.entries_complex[index]\n",
    "        data = copy.deepcopy(self.structures[complex])\n",
    "\n",
    "        data['seq'] = ''.join(data['seq'])\n",
    "        data['chain_nb'] = torch.tensor(data['chain_nb'])\n",
    "        data['ddG'] = ddG\n",
    "\n",
    "        del data['chain_id']\n",
    "        \n",
    "        return data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "id": "0f1d73e8-e77f-4209-9cb1-4b59ce681b2d",
   "metadata": {},
   "outputs": [],
   "source": [
    "entries = pd.read_csv('extracted_affinities.csv')\n",
    "pdbcodes = entries['complex'].unique().tolist()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "id": "442a65e0-7bd9-43f9-b5ea-b39b0de0e041",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Structures:   4%|█                           | 109/2852 [00:50<31:30,  1.45it/s]/data/cb/scratch/varun/miniconda3/envs/esmfold/lib/python3.7/site-packages/Bio/PDB/internal_coords.py:757: RuntimeWarning: invalid value encountered in true_divide\n",
      "  numpy.arccos(((a0a1 * a0a1) + sqr_a1a2 - (a0a2 * a0a2)) / (2 * a0a1 * a1a2))\n",
      "/data/cb/scratch/varun/miniconda3/envs/esmfold/lib/python3.7/site-packages/Bio/PDB/internal_coords.py:761: RuntimeWarning: invalid value encountered in true_divide\n",
      "  numpy.arccos((sqr_a1a2 + (a2a3 * a2a3) - (a1a3 * a1a3)) / (2 * a1a2 * a2a3))\n",
      "Structures: 100%|███████████████████████████| 2852/2852 [38:35<00:00,  1.23it/s]\n"
     ]
    }
   ],
   "source": [
    "sequences = {}\n",
    "missing_num = 0\n",
    "for pdbcode in tqdm(pdbcodes, desc='Structures'):\n",
    "    parser = PDBParser(QUIET=True)\n",
    "    pdb_path = os.path.join('PP/', '{}.ent.pdb'.format(pdbcode.lower()))\n",
    "    model = parser.get_structure(None, pdb_path)[0]\n",
    "    data = parse_biopython_structure(model)\n",
    "    sequences[pdbcode] = data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 60,
   "id": "5df539cb-9029-403a-a71c-3c9d5e45b99b",
   "metadata": {},
   "outputs": [],
   "source": [
    "seqs = []\n",
    "all_chain_ids = []\n",
    "targets = []\n",
    "\n",
    "for i,row in entries.iterrows():\n",
    "    seq_data = sequences [row['complex']]\n",
    "\n",
    "    if np.max(seq_data['chain_nb']) > 5:\n",
    "        continue\n",
    "    \n",
    "    chain_ids = ''.join([str(i) for i in seq_data['chain_nb']])\n",
    "    seq = ''.join(seq_data['seq'])\n",
    "\n",
    "    assert len(seq) == len(chain_ids)\n",
    "\n",
    "    seqs.append(seq)\n",
    "    all_chain_ids.append(chain_ids)\n",
    "    targets.append(row['-logKd'])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 61,
   "id": "4924b530-9b78-4246-8369-a119c810de75",
   "metadata": {},
   "outputs": [],
   "source": [
    "df = pd.DataFrame({'seq': seqs, 'chain_ids': all_chain_ids, 'target': targets})"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 62,
   "id": "8530a10e-cb1a-46a7-bf89-cb0cbb7e26f6",
   "metadata": {},
   "outputs": [],
   "source": [
    "df.to_csv('./processed_data.csv')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "599a669b-6c6d-4faf-afc2-61bb8e1d2a02",
   "metadata": {},
   "outputs": [],
   "source": [
    "x = pd.read_csv('./processed_data.csv')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "c8dc317b-cb85-46e6-97aa-f618620abf76",
   "metadata": {},
   "outputs": [],
   "source": [
    "x = x.drop_duplicates(subset=['seq'])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "de2e212f-d4a6-4340-aed9-77bc7aa3e8fa",
   "metadata": {},
   "outputs": [],
   "source": [
    "x.to_csv('./processed_data.csv')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f12e5e5a-fe74-4cb6-9ac7-2d97af6dff01",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.2"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
