{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "886d609a",
   "metadata": {},
   "source": [
    "### Input data: \n",
    "\n",
    "1. Download all files present on this [link](https://figshare.com/articles/dataset/PPI_prediction_from_sequence_gold_standard_dataset/21591618/3) from Bernett et al., 2023"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "d6fb5725-0c5d-47a7-9736-214df755d04a",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os,sys,re\n",
    "import argparse, json\n",
    "import copy\n",
    "import random\n",
    "import pickle\n",
    "import math\n",
    "import torch\n",
    "from torch import nn\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "from torch.utils.data import Dataset\n",
    "from tqdm import tqdm\n",
    "#from tqdm.notebook import tqdm\n",
    "from Bio.PDB.PDBParser import PDBParser\n",
    "from Bio.PDB.Polypeptide import one_to_index\n",
    "from Bio.PDB import Selection\n",
    "from Bio import SeqIO\n",
    "from Bio.PDB.Residue import Residue\n",
    "from easydict import EasyDict\n",
    "import enum\n",
    "import gzip\n",
    "from Bio import SeqIO\n",
    "from collections import OrderedDict\n",
    "from sklearn.metrics import mean_squared_error\n",
    "import scipy.stats"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "bc243e1c-7cda-4c2e-8bc4-9372e4a6baa4",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "59260 out of 59260 data points found\n",
      "163019 out of 163192 data points found\n",
      "52048 out of 52048 data points found\n"
     ]
    }
   ],
   "source": [
    "fasta_dictionary = SeqIO.to_dict(SeqIO.parse('human_swissprot_oneliner.fasta', \"fasta\"))\n",
    "\n",
    "def make_new_dfs(data_name):\n",
    "    data_df_neg = pd.read_csv(f'{data_name}_neg_rr.txt', sep=' ', header=None)\n",
    "    data_df_pos = pd.read_csv(f'{data_name}_pos_rr.txt', sep=' ', header=None)\n",
    "    \n",
    "    data_df = pd.concat([data_df_neg, data_df_pos], ignore_index=True)\n",
    "    labels = [0]*len(data_df_neg) + [1]*len(data_df_pos)\n",
    "    \n",
    "    seq1 = []\n",
    "    seq2 = []\n",
    "    new_labels = []\n",
    "    for index, row in data_df.iterrows():\n",
    "        if row[0] in fasta_dictionary and row[1] in fasta_dictionary:\n",
    "            seq1.append(str(fasta_dictionary[row[0]].seq))\n",
    "            seq2.append(str(fasta_dictionary[row[1]].seq))\n",
    "            new_labels.append(labels[index])\n",
    "    seq_df = pd.DataFrame({'seq1': seq1, 'seq2': seq2, 'labels':new_labels})\n",
    "    seq_df.to_csv(f'{data_name}_seqs.csv', index=False)\n",
    "    print(f'{len(seq_df)} out of {len(data_df)} data points found')\n",
    "\n",
    "make_new_dfs('Intra0')\n",
    "make_new_dfs('Intra1')\n",
    "make_new_dfs('Intra2')"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.2"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
