{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "1dfc6462",
   "metadata": {},
   "source": [
    "### Input data:\n",
    "\n",
    "1. Download the .lmdb dataset from this [link](https://miladeepgraphlearningproteindata.s3.us-east-2.amazonaws.com/ppidata/yeast_ppi.zip). "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "160b3cdf-0d7f-44b8-a8f2-158690aedf93",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os,sys\n",
    "import argparse, json\n",
    "import copy\n",
    "import random\n",
    "import pickle\n",
    "import math\n",
    "import torch\n",
    "from torch import nn\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "from torch.utils.data import Dataset\n",
    "from tqdm import tqdm\n",
    "#from tqdm.notebook import tqdm\n",
    "from Bio.PDB.PDBParser import PDBParser\n",
    "from Bio.PDB.Polypeptide import one_to_index\n",
    "from Bio.PDB import Selection\n",
    "from Bio import SeqIO\n",
    "from Bio.PDB.Residue import Residue\n",
    "from easydict import EasyDict\n",
    "import enum\n",
    "import gzip\n",
    "from Bio import SeqIO\n",
    "from collections import OrderedDict\n",
    "from sklearn.metrics import mean_squared_error\n",
    "import scipy.stats\n",
    "from torch.utils import data as torch_data\n",
    "from collections import defaultdict \n",
    "import lmdb"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "5deb177f-67bb-41e4-b8f3-ffdc758c584c",
   "metadata": {},
   "outputs": [],
   "source": [
    "class PPI(Dataset):\n",
    "    \n",
    "    def __init__(self, path, name, target_field, split='train', verbose=1):\n",
    "\n",
    "        self.target_field = target_field\n",
    "        lmdb_file = os.path.join(path, f'{name}/{name}_{split}.lmdb/')\n",
    "        self.load_lmdbs([lmdb_file], sequence_field=[\"primary_1\", \"primary_2\"], target_field=self.target_field,\n",
    "                        verbose=verbose)\n",
    "\n",
    "    def load_lmdbs(self, lmdb_files, sequence_field=\"primary\", target_field=None, number_field=\"num_examples\",\n",
    "                   transform=None, lazy=False, verbose=0, **kwargs):\n",
    "\n",
    "\n",
    "        targets = []    \n",
    "        sequences = []\n",
    "        num_samples = []\n",
    "        for lmdb_file in lmdb_files:\n",
    "            env = lmdb.open(lmdb_file, readonly=True, lock=False, readahead=False, meminit=False)\n",
    "            with env.begin(write=False) as txn:\n",
    "                num_sample = pickle.loads(txn.get(number_field.encode()))\n",
    "                for i in range(num_sample):\n",
    "                    item = pickle.loads(txn.get(str(i).encode()))\n",
    "                    sequences.append([item[field] for field in sequence_field])\n",
    "                    target_value = item[target_field]\n",
    "                    if isinstance(target_value, np.ndarray) and value.size == 1:\n",
    "                        target_value = target_value.item()\n",
    "                    targets.append(target_value)\n",
    "                num_samples.append(num_sample)\n",
    "\n",
    "        assert num_samples[0] == len(targets)\n",
    "        self.sequences = sequences\n",
    "        self.targets = targets\n",
    "        self.num_samples = num_samples\n",
    "        \n",
    "\n",
    "    def __len__(self):\n",
    "        return self.num_samples[0]\n",
    "    \n",
    "    def __getitem__(self, index):   \n",
    "        return self.sequences[index][0], self.sequences[index][1], self.targets[index]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "dc897af9-b9a6-4054-966b-919848dbfe4a",
   "metadata": {},
   "outputs": [],
   "source": [
    "train = PPI(path='./dataset/', split='train', target_field ='interaction', name='yeast_ppi')\n",
    "valid = PPI(path='./dataset/', split='valid', target_field ='interaction', name='yeast_ppi')\n",
    "test = PPI(path='./dataset/', split='test', target_field ='interaction', name='yeast_ppi')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "5011d449-69b8-41ab-a522-123918f567f6",
   "metadata": {},
   "outputs": [],
   "source": [
    "def convert_to_csv(dataset, name):\n",
    "    seq1 = []\n",
    "    seq2 = []\n",
    "    targets = []\n",
    "    for i in range(len(dataset)):\n",
    "        s1, s2, t = dataset[i]\n",
    "        seq1.append(s1)\n",
    "        seq2.append(s2)\n",
    "        targets.append(t)\n",
    "    df = pd.DataFrame({'sequence_1': seq1, 'sequence_2': seq2, 'target': targets})\n",
    "    df.to_csv(f'./processed_data_{name}.csv')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "514e444d-4737-4e20-bc68-1d5e725e9a13",
   "metadata": {},
   "outputs": [],
   "source": [
    "convert_to_csv(train, 'train')\n",
    "convert_to_csv(valid, 'validation')\n",
    "convert_to_csv(test, 'test')"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.2"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
