{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "8b7e9e33",
   "metadata": {},
   "source": [
    "## Input data: \n",
    "\n",
    "1. Download the PDB cleaned files and the csv file from [SKEMPI v2](https://life.bsc.es/pid/skempi2/database/index).  "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "4ec23f6e-3704-4511-a42c-28baef0f2244",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os,sys\n",
    "import argparse, json\n",
    "import copy\n",
    "import random\n",
    "import pickle\n",
    "import math\n",
    "import torch\n",
    "from torch import nn\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "from torch.utils.data import Dataset\n",
    "from tqdm import tqdm\n",
    "#from tqdm.notebook import tqdm\n",
    "from Bio.PDB.PDBParser import PDBParser\n",
    "from Bio.PDB.Polypeptide import one_to_index\n",
    "from Bio.PDB import Selection\n",
    "from Bio import SeqIO\n",
    "from Bio.PDB.Residue import Residue\n",
    "from easydict import EasyDict\n",
    "import enum\n",
    "\n",
    "from Bio import SeqIO\n",
    "from sklearn.metrics import mean_squared_error\n",
    "import scipy.stats"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "e6cafe25-c1ce-4135-bdc8-2441aa31a97c",
   "metadata": {},
   "outputs": [],
   "source": [
    "non_standard_residue_substitutions = {\n",
    "    '2AS':'ASP', '3AH':'HIS', '5HP':'GLU', 'ACL':'ARG', 'AGM':'ARG', 'AIB':'ALA', 'ALM':'ALA', 'ALO':'THR', 'ALY':'LYS', 'ARM':'ARG',\n",
    "    'ASA':'ASP', 'ASB':'ASP', 'ASK':'ASP', 'ASL':'ASP', 'ASQ':'ASP', 'AYA':'ALA', 'BCS':'CYS', 'BHD':'ASP', 'BMT':'THR', 'BNN':'ALA',\n",
    "    'BUC':'CYS', 'BUG':'LEU', 'C5C':'CYS', 'C6C':'CYS', 'CAS':'CYS', 'CCS':'CYS', 'CEA':'CYS', 'CGU':'GLU', 'CHG':'ALA', 'CLE':'LEU', 'CME':'CYS',\n",
    "    'CSD':'ALA', 'CSO':'CYS', 'CSP':'CYS', 'CSS':'CYS', 'CSW':'CYS', 'CSX':'CYS', 'CXM':'MET', 'CY1':'CYS', 'CY3':'CYS', 'CYG':'CYS',\n",
    "    'CYM':'CYS', 'CYQ':'CYS', 'DAH':'PHE', 'DAL':'ALA', 'DAR':'ARG', 'DAS':'ASP', 'DCY':'CYS', 'DGL':'GLU', 'DGN':'GLN', 'DHA':'ALA',\n",
    "    'DHI':'HIS', 'DIL':'ILE', 'DIV':'VAL', 'DLE':'LEU', 'DLY':'LYS', 'DNP':'ALA', 'DPN':'PHE', 'DPR':'PRO', 'DSN':'SER', 'DSP':'ASP',\n",
    "    'DTH':'THR', 'DTR':'TRP', 'DTY':'TYR', 'DVA':'VAL', 'EFC':'CYS', 'FLA':'ALA', 'FME':'MET', 'GGL':'GLU', 'GL3':'GLY', 'GLZ':'GLY',\n",
    "    'GMA':'GLU', 'GSC':'GLY', 'HAC':'ALA', 'HAR':'ARG', 'HIC':'HIS', 'HIP':'HIS', 'HMR':'ARG', 'HPQ':'PHE', 'HTR':'TRP', 'HYP':'PRO',\n",
    "    'IAS':'ASP', 'IIL':'ILE', 'IYR':'TYR', 'KCX':'LYS', 'LLP':'LYS', 'LLY':'LYS', 'LTR':'TRP', 'LYM':'LYS', 'LYZ':'LYS', 'MAA':'ALA', 'MEN':'ASN',\n",
    "    'MHS':'HIS', 'MIS':'SER', 'MLE':'LEU', 'MPQ':'GLY', 'MSA':'GLY', 'MSE':'MET', 'MVA':'VAL', 'NEM':'HIS', 'NEP':'HIS', 'NLE':'LEU',\n",
    "    'NLN':'LEU', 'NLP':'LEU', 'NMC':'GLY', 'OAS':'SER', 'OCS':'CYS', 'OMT':'MET', 'PAQ':'TYR', 'PCA':'GLU', 'PEC':'CYS', 'PHI':'PHE',\n",
    "    'PHL':'PHE', 'PR3':'CYS', 'PRR':'ALA', 'PTR':'TYR', 'PYX':'CYS', 'SAC':'SER', 'SAR':'GLY', 'SCH':'CYS', 'SCS':'CYS', 'SCY':'CYS',\n",
    "    'SEL':'SER', 'SEP':'SER', 'SET':'SER', 'SHC':'CYS', 'SHR':'LYS', 'SMC':'CYS', 'SOC':'CYS', 'STY':'TYR', 'SVA':'SER', 'TIH':'ALA',\n",
    "    'TPL':'TRP', 'TPO':'THR', 'TPQ':'ALA', 'TRG':'LYS', 'TRO':'TRP', 'TYB':'TYR', 'TYI':'TYR', 'TYQ':'TYR', 'TYS':'TYR', 'TYY':'TYR'\n",
    "}\n",
    "\n",
    "\n",
    "max_num_heavyatoms = 15\n",
    "max_num_hydrogens = 16\n",
    "max_num_allatoms = max_num_heavyatoms + max_num_hydrogens\n",
    "\n",
    "ressymb_to_resindex = {\n",
    "    'A': 0, 'C': 1, 'D': 2, 'E': 3, 'F': 4,\n",
    "    'G': 5, 'H': 6, 'I': 7, 'K': 8, 'L': 9,\n",
    "    'M': 10, 'N': 11, 'P': 12, 'Q': 13, 'R': 14,\n",
    "    'S': 15, 'T': 16, 'V': 17, 'W': 18, 'Y': 19,\n",
    "    'X': 20,\n",
    "}\n",
    "\n",
    "resindex_to_ressymb = {v: k for k, v in ressymb_to_resindex.items()}\n",
    "\n",
    "class BBHeavyAtom(enum.IntEnum):\n",
    "    N = 0; CA = 1; C = 2; O = 3; CB = 4; OXT=14;\n",
    "\n",
    "def _get_residue_heavyatom_info(res: Residue):\n",
    "    pos_heavyatom = torch.zeros([max_num_heavyatoms, 3], dtype=torch.float)\n",
    "    mask_heavyatom = torch.zeros([max_num_heavyatoms, ], dtype=torch.bool)\n",
    "    bfactor_heavyatom = torch.zeros([max_num_heavyatoms, ], dtype=torch.float)\n",
    "    restype = AA(res.get_resname())\n",
    "    for idx, atom_name in enumerate(restype_to_heavyatom_names[restype]):\n",
    "        if atom_name == '': continue\n",
    "        if atom_name in res:\n",
    "            pos_heavyatom[idx] = torch.tensor(res[atom_name].get_coord().tolist(), dtype=pos_heavyatom.dtype)\n",
    "            mask_heavyatom[idx] = True\n",
    "            bfactor_heavyatom[idx] = res[atom_name].get_bfactor()\n",
    "    return pos_heavyatom, mask_heavyatom, bfactor_heavyatom\n",
    "\n",
    "class AA(enum.IntEnum):\n",
    "    ALA = 0; CYS = 1; ASP = 2; GLU = 3; PHE = 4\n",
    "    GLY = 5; HIS = 6; ILE = 7; LYS = 8; LEU = 9\n",
    "    MET = 10; ASN = 11; PRO = 12; GLN = 13; ARG = 14\n",
    "    SER = 15; THR = 16; VAL = 17; TRP = 18; TYR = 19\n",
    "    UNK = 20\n",
    "\n",
    "    @classmethod\n",
    "    def _missing_(cls, value):\n",
    "        if isinstance(value, str) and len(value) == 3:      # three representation\n",
    "            if value in non_standard_residue_substitutions:\n",
    "                value = non_standard_residue_substitutions[value]\n",
    "            if value in cls._member_names_:\n",
    "                return getattr(cls, value)\n",
    "        elif isinstance(value, str) and len(value) == 1:    # one representation\n",
    "            if value in ressymb_to_resindex:\n",
    "                return cls(ressymb_to_resindex[value])\n",
    "\n",
    "        return super()._missing_(value)\n",
    "\n",
    "    def __str__(self):\n",
    "        return self.name\n",
    "\n",
    "    @classmethod\n",
    "    def is_aa(cls, value):\n",
    "        return (value in ressymb_to_resindex) or \\\n",
    "            (value in non_standard_residue_substitutions) or \\\n",
    "            (value in cls._member_names_)\n",
    "\n",
    "restype_to_heavyatom_names = {\n",
    "    AA.ALA: ['N', 'CA', 'C', 'O', 'CB', '',    '',    '',    '',    '',    '',    '',    '',    '', 'OXT'],\n",
    "    AA.ARG: ['N', 'CA', 'C', 'O', 'CB', 'CG',  'CD',  'NE',  'CZ',  'NH1', 'NH2', '',    '',    '', 'OXT'],\n",
    "    AA.ASN: ['N', 'CA', 'C', 'O', 'CB', 'CG',  'OD1', 'ND2', '',    '',    '',    '',    '',    '', 'OXT'],\n",
    "    AA.ASP: ['N', 'CA', 'C', 'O', 'CB', 'CG',  'OD1', 'OD2', '',    '',    '',    '',    '',    '', 'OXT'],\n",
    "    AA.CYS: ['N', 'CA', 'C', 'O', 'CB', 'SG',  '',    '',    '',    '',    '',    '',    '',    '', 'OXT'],\n",
    "    AA.GLN: ['N', 'CA', 'C', 'O', 'CB', 'CG',  'CD',  'OE1', 'NE2', '',    '',    '',    '',    '', 'OXT'],\n",
    "    AA.GLU: ['N', 'CA', 'C', 'O', 'CB', 'CG',  'CD',  'OE1', 'OE2', '',    '',    '',    '',    '', 'OXT'],\n",
    "    AA.GLY: ['N', 'CA', 'C', 'O', '',   '',    '',    '',    '',    '',    '',    '',    '',    '', 'OXT'],\n",
    "    AA.HIS: ['N', 'CA', 'C', 'O', 'CB', 'CG',  'ND1', 'CD2', 'CE1', 'NE2', '',    '',    '',    '', 'OXT'],\n",
    "    AA.ILE: ['N', 'CA', 'C', 'O', 'CB', 'CG1', 'CG2', 'CD1', '',    '',    '',    '',    '',    '', 'OXT'],\n",
    "    AA.LEU: ['N', 'CA', 'C', 'O', 'CB', 'CG',  'CD1', 'CD2', '',    '',    '',    '',    '',    '', 'OXT'],\n",
    "    AA.LYS: ['N', 'CA', 'C', 'O', 'CB', 'CG',  'CD',  'CE',  'NZ',  '',    '',    '',    '',    '', 'OXT'],\n",
    "    AA.MET: ['N', 'CA', 'C', 'O', 'CB', 'CG',  'SD',  'CE',  '',    '',    '',    '',    '',    '', 'OXT'],\n",
    "    AA.PHE: ['N', 'CA', 'C', 'O', 'CB', 'CG',  'CD1', 'CD2', 'CE1', 'CE2', 'CZ',  '',    '',    '', 'OXT'],\n",
    "    AA.PRO: ['N', 'CA', 'C', 'O', 'CB', 'CG',  'CD',  '',    '',    '',    '',    '',    '',    '', 'OXT'],\n",
    "    AA.SER: ['N', 'CA', 'C', 'O', 'CB', 'OG',  '',    '',    '',    '',    '',    '',    '',    '', 'OXT'],\n",
    "    AA.THR: ['N', 'CA', 'C', 'O', 'CB', 'OG1', 'CG2', '',    '',    '',    '',    '',    '',    '', 'OXT'],\n",
    "    AA.TRP: ['N', 'CA', 'C', 'O', 'CB', 'CG',  'CD1', 'CD2', 'NE1', 'CE2', 'CE3', 'CZ2', 'CZ3', 'CH2', 'OXT'],\n",
    "    AA.TYR: ['N', 'CA', 'C', 'O', 'CB', 'CG',  'CD1', 'CD2', 'CE1', 'CE2', 'CZ',  'OH',  '',    '', 'OXT'],\n",
    "    AA.VAL: ['N', 'CA', 'C', 'O', 'CB', 'CG1', 'CG2', '',    '',    '',    '',    '',    '',    '', 'OXT'],\n",
    "    AA.UNK: ['',  '',   '',  '',  '',   '',    '',    '',    '',    '',    '',    '',    '',    '',    ''],\n",
    "}\n",
    "for names in restype_to_heavyatom_names.values(): assert len(names) == max_num_heavyatoms"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "9658b54a-842d-4875-949a-d69c1dc63f01",
   "metadata": {},
   "outputs": [],
   "source": [
    "amino_acids = {\n",
    "    'ALA': 'A', 'ARG': 'R', 'ASN': 'N', 'ASP': 'D',\n",
    "    'CYS': 'C', 'GLU': 'E', 'GLN': 'Q', 'GLY': 'G',\n",
    "    'HIS': 'H', 'ILE': 'I', 'LEU': 'L', 'LYS': 'K',\n",
    "    'MET': 'M', 'PHE': 'F', 'PRO': 'P', 'SER': 'S',\n",
    "    'THR': 'T', 'TRP': 'W', 'TYR': 'Y', 'VAL': 'V'\n",
    "}\n",
    "\n",
    "\n",
    "def parse_biopython_structure(entity, unknown_threshold=1.0):\n",
    "    chains = Selection.unfold_entities(entity, 'C')\n",
    "    chains.sort(key=lambda c: c.get_id())\n",
    "    data = EasyDict({\n",
    "        'chain_id': [], 'chain_nb': [],\n",
    "        'resseq': [], 'icode': [], 'res_nb': [],\n",
    "        'aa': [],\n",
    "        'pos_heavyatom': [], 'mask_heavyatom': [],\n",
    "        'bfactor_heavyatom': [], 'seq': []\n",
    "    })\n",
    "    \n",
    "    tensor_types = {\n",
    "        'chain_nb': torch.LongTensor,\n",
    "        'resseq': torch.LongTensor,\n",
    "        'res_nb': torch.LongTensor,\n",
    "        'aa': torch.LongTensor,\n",
    "        'pos_heavyatom': torch.stack,\n",
    "        'mask_heavyatom': torch.stack,\n",
    "        'bfactor_heavyatom': torch.stack,\n",
    "    }\n",
    "\n",
    "    count_aa, count_unk = 0, 0\n",
    "\n",
    "    for i, chain in enumerate(chains):\n",
    "        chain.atom_to_internal_coordinates()\n",
    "        seq_this = 0   # Renumbering residues\n",
    "        residues = Selection.unfold_entities(chain, 'R')\n",
    "        residues.sort(key=lambda res: (res.get_id()[1], res.get_id()[2]))   # Sort residues by resseq-icode\n",
    "        for _, res in enumerate(residues):\n",
    "            resname = res.get_resname()\n",
    "\n",
    "            \n",
    "            if not AA.is_aa(resname): continue\n",
    "            if not (res.has_id('CA') and res.has_id('C') and res.has_id('N')): continue\n",
    "            restype = AA(resname)\n",
    "            \n",
    "            count_aa += 1\n",
    "            if restype == AA.UNK: \n",
    "                count_unk += 1\n",
    "                continue\n",
    "\n",
    "            # Chain info\n",
    "            data.chain_id.append(chain.get_id())\n",
    "            data.chain_nb.append(i)\n",
    "\n",
    "            try:\n",
    "                data['seq'] += amino_acids[resname]\n",
    "            except:\n",
    "                data['seq'] += amino_acids[non_standard_residue_substitutions[resname]]\n",
    "\n",
    "            # Residue types\n",
    "            data.aa.append(restype) # Will be automatically cast to torch.long\n",
    "\n",
    "            # Heavy atoms\n",
    "            pos_heavyatom, mask_heavyatom, bfactor_heavyatom = _get_residue_heavyatom_info(res)\n",
    "            data.pos_heavyatom.append(pos_heavyatom)\n",
    "            data.mask_heavyatom.append(mask_heavyatom)\n",
    "            data.bfactor_heavyatom.append(bfactor_heavyatom)\n",
    "\n",
    "            # Sequential number\n",
    "            resseq_this = int(res.get_id()[1])\n",
    "            icode_this = res.get_id()[2]\n",
    "            if seq_this == 0:\n",
    "                seq_this = 1\n",
    "            else:\n",
    "                d_CA_CA = torch.linalg.norm(data.pos_heavyatom[-2][BBHeavyAtom.CA] - data.pos_heavyatom[-1][BBHeavyAtom.CA], ord=2).item()\n",
    "                if d_CA_CA <= 4.0:\n",
    "                    seq_this += 1\n",
    "                else:\n",
    "                    d_resseq = resseq_this - data.resseq[-1]\n",
    "                    seq_this += max(2, d_resseq)\n",
    "\n",
    "            data.resseq.append(resseq_this)\n",
    "            data.icode.append(icode_this)\n",
    "            data.res_nb.append(seq_this)\n",
    "\n",
    "    if len(data.aa) == 0:\n",
    "        return None, None\n",
    "\n",
    "    if (count_unk / count_aa) >= unknown_threshold:\n",
    "        return None, None\n",
    "\n",
    "    seq_map = {}\n",
    "    for i, (chain_id, resseq, icode) in enumerate(zip(data.chain_id, data.resseq, data.icode)):\n",
    "        seq_map[(chain_id, resseq, icode)] = i\n",
    "\n",
    "    for key, convert_fn in tensor_types.items():\n",
    "        data[key] = convert_fn(data[key])\n",
    "\n",
    "    return data, seq_map"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "a2dce6f2-f9a6-4c82-a8b0-4ef4bc89b51f",
   "metadata": {},
   "outputs": [],
   "source": [
    "def load_skempi_entries(csv_path, pdb_dir, block_list={'1KBH'}):\n",
    "    df = pd.read_csv(csv_path, sep=';')\n",
    "    df['dG_wt'] =  (8.314/4184)*(273.15 + 25.0) * np.log(df['Affinity_wt_parsed'])\n",
    "    df['dG_mut'] =  (8.314/4184)*(273.15 + 25.0) * np.log(df['Affinity_mut_parsed'])\n",
    "    df['ddG'] = df['dG_mut'] - df['dG_wt']\n",
    "\n",
    "    def _parse_mut(mut_name):\n",
    "        wt_type, mutchain, mt_type = mut_name[0], mut_name[1], mut_name[-1]\n",
    "        mutseq = int(mut_name[2:-1])\n",
    "        return {\n",
    "            'wt': wt_type,\n",
    "            'mt': mt_type,\n",
    "            'chain': mutchain,\n",
    "            'resseq': mutseq,\n",
    "            'icode': ' ',\n",
    "            'name': mut_name\n",
    "        }\n",
    "\n",
    "    entries = []\n",
    "    for i, row in df.iterrows():\n",
    "        pdbcode, group1, group2 = row['#Pdb'].split('_')\n",
    "        if pdbcode in block_list:\n",
    "            continue\n",
    "        mut_str = row['Mutation(s)_cleaned']\n",
    "        muts = list(map(_parse_mut, row['Mutation(s)_cleaned'].split(',')))\n",
    "        if muts[0]['chain'] in group1:\n",
    "            group_ligand, group_receptor = group1, group2\n",
    "        else:\n",
    "            group_ligand, group_receptor = group2, group1\n",
    "\n",
    "        pdb_path = os.path.join(pdb_dir, '{}.pdb'.format(pdbcode.upper()))\n",
    "        if not os.path.exists(pdb_path):\n",
    "            continue\n",
    "\n",
    "        if not np.isfinite(row['ddG']):\n",
    "            continue\n",
    "\n",
    "        entry = {\n",
    "            'id': i,\n",
    "            'complex': row['#Pdb'],\n",
    "            'mutstr': mut_str,\n",
    "            'num_muts': len(muts),\n",
    "            'pdbcode': pdbcode,\n",
    "            'group_ligand': list(group_ligand),\n",
    "            'group_receptor': list(group_receptor),\n",
    "            'mutations': muts,\n",
    "            'ddG': np.float32(row['ddG']),\n",
    "            'pdb_path': pdb_path,\n",
    "        }\n",
    "        entries.append(entry)\n",
    "\n",
    "    return entries\n",
    "\n",
    "\n",
    "class SkempiDataset(Dataset):\n",
    "\n",
    "    def __init__(\n",
    "        self, \n",
    "        csv_path, \n",
    "        pdb_dir, \n",
    "        cache_dir,\n",
    "        cvfold_index=0, \n",
    "        num_cvfolds=3, \n",
    "        split='train', \n",
    "        split_seed=2023,\n",
    "        transform=None, \n",
    "        blocklist=frozenset({'1KBH'}), \n",
    "        reset=False\n",
    "    ):\n",
    "        super().__init__()\n",
    "        self.csv_path = csv_path\n",
    "        self.pdb_dir = pdb_dir\n",
    "        self.cache_dir = cache_dir\n",
    "        os.makedirs(cache_dir, exist_ok=True)\n",
    "        self.blocklist = blocklist\n",
    "        self.transform = transform\n",
    "        self.cvfold_index = cvfold_index\n",
    "        self.num_cvfolds = num_cvfolds\n",
    "        assert split in ('train', 'val')\n",
    "        self.split = split\n",
    "        self.split_seed = split_seed\n",
    "\n",
    "        self.entries_cache = os.path.join(cache_dir, 'entries.pkl')\n",
    "        self.entries = None\n",
    "        self.entries_full = None\n",
    "        self._load_entries(reset)\n",
    "\n",
    "        self.structures_cache = os.path.join(cache_dir, 'structures.pkl')\n",
    "        self.structures = None\n",
    "        self._load_structures(reset)\n",
    "\n",
    "    def _load_entries(self, reset):\n",
    "        if not os.path.exists(self.entries_cache) or reset:\n",
    "            self.entries_full = self._preprocess_entries()\n",
    "        else:\n",
    "            with open(self.entries_cache, 'rb') as f:\n",
    "                self.entries_full = pickle.load(f)\n",
    "\n",
    "        complex_to_entries = {}\n",
    "        for e in self.entries_full:\n",
    "            if e['complex'] not in complex_to_entries:\n",
    "                complex_to_entries[e['complex']] = []\n",
    "            complex_to_entries[e['complex']].append(e)\n",
    "\n",
    "        complex_list = sorted(complex_to_entries.keys())\n",
    "        random.Random(self.split_seed).shuffle(complex_list)\n",
    "\n",
    "        self.complex_list = complex_list\n",
    "\n",
    "        split_size = math.ceil(len(complex_list) / self.num_cvfolds)\n",
    "        complex_splits = [\n",
    "            complex_list[i*split_size : (i+1)*split_size] \n",
    "            for i in range(self.num_cvfolds)\n",
    "        ]\n",
    "\n",
    "        val_split = complex_splits.pop(self.cvfold_index)\n",
    "        train_split = sum(complex_splits, [])\n",
    "        if self.split == 'val':\n",
    "            complexes_this = val_split\n",
    "        else:\n",
    "            complexes_this = train_split\n",
    "\n",
    "        entries = []\n",
    "        for cplx in complexes_this:\n",
    "            entries += complex_to_entries[cplx]\n",
    "        self.entries = entries\n",
    "        \n",
    "    def _preprocess_entries(self):\n",
    "        entries = load_skempi_entries(self.csv_path, self.pdb_dir, self.blocklist)\n",
    "        with open(self.entries_cache, 'wb') as f:\n",
    "            pickle.dump(entries, f)\n",
    "        return entries\n",
    "\n",
    "    def _load_structures(self, reset):\n",
    "        if not os.path.exists(self.structures_cache) or reset:\n",
    "            self.structures = self._preprocess_structures()\n",
    "        else:\n",
    "            with open(self.structures_cache, 'rb') as f:\n",
    "                self.structures = pickle.load(f)\n",
    "\n",
    "    def _preprocess_structures(self):\n",
    "        structures = {}\n",
    "        pdbcodes = list(set([e['pdbcode'] for e in self.entries_full]))\n",
    "        for pdbcode in tqdm(pdbcodes, desc='Structures'):\n",
    "            parser = PDBParser(QUIET=True)\n",
    "            pdb_path = os.path.join(self.pdb_dir, '{}.pdb'.format(pdbcode.upper()))\n",
    "            model = parser.get_structure(None, pdb_path)[0]\n",
    "            data, seq_map = parse_biopython_structure(model)\n",
    "            structures[pdbcode] = (data, seq_map)\n",
    "        with open(self.structures_cache, 'wb') as f:\n",
    "            pickle.dump(structures, f)\n",
    "        return structures\n",
    "\n",
    "    def __len__(self):\n",
    "        return len(self.entries)\n",
    "\n",
    "    def __getitem__(self, index):\n",
    "        entry = self.entries[index]\n",
    "        data, seq_map = copy.deepcopy( self.structures[entry['pdbcode']] )\n",
    "        \n",
    "        keys = {'id', 'complex', 'mutstr', 'num_muts', 'pdbcode', 'ddG'}\n",
    "        for k in keys:\n",
    "            data[k] = entry[k]\n",
    "\n",
    "        group_id = []\n",
    "        for ch in data['chain_id']:\n",
    "            if ch in entry['group_ligand']:\n",
    "                group_id.append(1)\n",
    "            elif ch in entry['group_receptor']:\n",
    "                group_id.append(2)\n",
    "            else:\n",
    "                group_id.append(0)\n",
    "        data['group_id'] = torch.LongTensor(group_id)\n",
    "\n",
    "        aa_mut = data['aa'].clone()\n",
    "        for mut in entry['mutations']:\n",
    "            ch_rs_ic = (mut['chain'], mut['resseq'], mut['icode'])\n",
    "            if ch_rs_ic not in seq_map: continue\n",
    "            aa_mut[seq_map[ch_rs_ic]] = one_to_index(mut['mt'])\n",
    "        data['aa_mut'] = aa_mut\n",
    "        data['mut_flag'] = (data['aa'] != data['aa_mut'])\n",
    "        \n",
    "        final_data = {}\n",
    "\n",
    "        assert True in data['mut_flag']\n",
    "\n",
    "        final_data['aa_seq'] = ''.join([resindex_to_ressymb[num.item()] for num in data['aa']])\n",
    "        final_data['aa_mut_seq'] = ''.join([resindex_to_ressymb[num.item()] for num in data['aa_mut']])\n",
    "        final_data['group_id'] = data['group_id'] - 1\n",
    "        final_data['chain_nb'] = data['chain_nb']\n",
    "        final_data['ddG'] = data['ddG']\n",
    "        final_data['complex'] = data['complex']\n",
    "\n",
    "        return final_data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "2e3e880b-9824-4654-b5f1-fa8bf3cdeb36",
   "metadata": {},
   "outputs": [],
   "source": [
    "train_dataset = SkempiDataset('skempi_v2.csv', 'PDBs', 'cache/')\n",
    "val_dataset = SkempiDataset('skempi_v2.csv', 'PDBs', 'cache/', split='val')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "449cd8dd-b537-41ea-b61b-7326a18d8f00",
   "metadata": {},
   "outputs": [],
   "source": [
    "seq1 = []\n",
    "seq_mut1 = []\n",
    "seq2 = []\n",
    "seq_mut2 = []\n",
    "targets = []\n",
    "complexes = []\n",
    "\n",
    "for x in train_dataset:\n",
    "    chain_ids = x['group_id'].tolist()\n",
    "    seq = x['aa_seq']\n",
    "    seq_mut = x['aa_mut_seq']\n",
    "    seq1.append(''.join([s for c,s in zip(chain_ids, seq) if c == 0]))\n",
    "    seq2.append(''.join([s for c,s in zip(chain_ids, seq) if c == 1]))\n",
    "    seq_mut1.append(''.join([s for c,s in zip(chain_ids, seq_mut) if c == 0]))\n",
    "    seq_mut2.append(''.join([s for c,s in zip(chain_ids, seq_mut) if c == 1]))\n",
    "    targets.append(x['ddG'])\n",
    "    complexes.append(x['complex'])\n",
    "\n",
    "    if seq1 == seq2:\n",
    "        if seq_mut1 == seq_mut2:\n",
    "            print('no')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "435a729c-3609-4c4e-aad9-d3a3351c754f",
   "metadata": {},
   "outputs": [],
   "source": [
    "for x in val_dataset:\n",
    "    chain_ids = x['group_id'].tolist()\n",
    "    seq = x['aa_seq']\n",
    "    seq_mut = x['aa_mut_seq']\n",
    "    seq1.append(''.join([s for c,s in zip(chain_ids, seq) if c == 0]))\n",
    "    seq2.append(''.join([s for c,s in zip(chain_ids, seq) if c == 1]))\n",
    "    seq_mut1.append(''.join([s for c,s in zip(chain_ids, seq_mut) if c == 0]))\n",
    "    seq_mut2.append(''.join([s for c,s in zip(chain_ids, seq_mut) if c == 1]))\n",
    "    targets.append(x['ddG'])\n",
    "    complexes.append(x['complex'])\n",
    "\n",
    "    if seq1 == seq2:\n",
    "        if seq_mut1 == seq_mut2:\n",
    "            print('no')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "ac275b9f-8742-4468-ab9c-c5f61b21b291",
   "metadata": {},
   "outputs": [],
   "source": [
    "df = pd.DataFrame({'seq1': seq1, \n",
    "                   'seq2': seq2, \n",
    "                   'seq1_mut': seq_mut1, \n",
    "                   'seq2_mut': seq_mut2, \n",
    "                   'target': targets, \n",
    "                   'complex': complexes})"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "825f62cc-447a-4905-93c3-7f2ddfe0336c",
   "metadata": {},
   "outputs": [],
   "source": [
    "unique_complexes = df['complex'].unique()\n",
    "\n",
    "# Step 2: Shuffle the unique complexes\n",
    "np.random.shuffle(unique_complexes)\n",
    "\n",
    "# Step 3: Split into 3 folds\n",
    "folds = np.array_split(unique_complexes, 3)\n",
    "\n",
    "for i in range(3):\n",
    "\n",
    "    all_i = [0,1,2]\n",
    "    test_i = i\n",
    "\n",
    "    train_i = all_i[:test_i] + all_i[test_i+1 :]\n",
    "    \n",
    "    train_complexes = np.concatenate([folds[train_i[0]], folds[train_i[1]]])  \n",
    "    test_complexes = folds[test_i]  \n",
    "    \n",
    "    # Step 5: Create a new column 'split' to assign 'train' or 'test'\n",
    "    df[f'split_{i}'] = np.where(df['complex'].isin(train_complexes), 'train', 'test')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "id": "364b5082-a74c-441f-afcf-6cf0ea3d1996",
   "metadata": {},
   "outputs": [],
   "source": [
    "df.to_csv('./processed_data.csv')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "37b25ec5-2ebb-4760-bc54-f42325d4a300",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.2"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
