{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "160b3cdf-0d7f-44b8-a8f2-158690aedf93",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os,sys,re\n",
    "import argparse, json\n",
    "import copy\n",
    "import random\n",
    "import pickle\n",
    "import math\n",
    "import torch\n",
    "from torch import nn\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "from torch.utils.data import Dataset\n",
    "from tqdm import tqdm\n",
    "from Bio.PDB.PDBParser import PDBParser\n",
    "from Bio.PDB.Polypeptide import one_to_index\n",
    "from Bio.PDB import Selection\n",
    "from Bio import SeqIO\n",
    "from Bio.PDB.Residue import Residue\n",
    "from easydict import EasyDict\n",
    "import enum\n",
    "from collections import OrderedDict\n",
    "from sklearn.metrics import mean_squared_error\n",
    "import scipy.stats\n",
    "from torch.utils import data as torch_data\n",
    "from collections import defaultdict \n",
    "import lmdb"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "id": "2205dcd4-57c7-42d9-a9b8-1e24be7b3bbf",
   "metadata": {},
   "outputs": [],
   "source": [
    "def download(url, path, save_file=None, md5=None):\n",
    "\n",
    "    from six.moves.urllib.request import urlretrieve\n",
    "\n",
    "    if save_file is None:\n",
    "        save_file = os.path.basename(url)\n",
    "        if \"?\" in save_file:\n",
    "            save_file = save_file[:save_file.find(\"?\")]\n",
    "    save_file = os.path.join(path, save_file)\n",
    "\n",
    "    if not os.path.exists(save_file) or compute_md5(save_file) != md5:\n",
    "        urlretrieve(url, save_file)\n",
    "    return save_file\n",
    "\n",
    "def compute_md5(file_name, chunk_size=65536):\n",
    "    import hashlib\n",
    "\n",
    "    md5 = hashlib.md5()\n",
    "    with open(file_name, \"rb\") as fin:\n",
    "        chunk = fin.read(chunk_size)\n",
    "        while chunk:\n",
    "            md5.update(chunk)\n",
    "            chunk = fin.read(chunk_size)\n",
    "    return md5.hexdigest()\n",
    "\n",
    "def extract(zip_file, member=None):\n",
    "\n",
    "    import gzip\n",
    "    import shutil\n",
    "    import zipfile\n",
    "    import tarfile\n",
    "\n",
    "    zip_name, extension = os.path.splitext(zip_file)\n",
    "    if zip_name.endswith(\".tar\"):\n",
    "        extension = \".tar\" + extension\n",
    "        zip_name = zip_name[:-4]\n",
    "    save_path = os.path.dirname(zip_file)\n",
    "\n",
    "    if extension == \".gz\":\n",
    "        member = os.path.basename(zip_name)\n",
    "        members = [member]\n",
    "        save_files = [os.path.join(save_path, member)]\n",
    "        for _member, save_file in zip(members, save_files):\n",
    "            with open(zip_file, \"rb\") as fin:\n",
    "                fin.seek(-4, 2)\n",
    "                file_size = struct.unpack(\"<I\", fin.read())[0]\n",
    "            with gzip.open(zip_file, \"rb\") as fin:\n",
    "                if not os.path.exists(save_file) or file_size != os.path.getsize(save_file):\n",
    "                    logger.info(\"Extracting %s to %s\" % (zip_file, save_file))\n",
    "                    with open(save_file, \"wb\") as fout:\n",
    "                        shutil.copyfileobj(fin, fout)\n",
    "    elif extension in [\".tar.gz\", \".tgz\", \".tar\"]:\n",
    "        tar = tarfile.open(zip_file, \"r\")\n",
    "        if member is not None:\n",
    "            members = [member]\n",
    "            save_files = [os.path.join(save_path, os.path.basename(member))]\n",
    "        else:\n",
    "            members = tar.getnames()\n",
    "            save_files = [os.path.join(save_path, _member) for _member in members]\n",
    "        for _member, save_file in zip(members, save_files):\n",
    "            if tar.getmember(_member).isdir():\n",
    "                os.makedirs(save_file, exist_ok=True)\n",
    "                continue\n",
    "            os.makedirs(os.path.dirname(save_file), exist_ok=True)\n",
    "            if not os.path.exists(save_file) or tar.getmember(_member).size != os.path.getsize(save_file):\n",
    "                with tar.extractfile(_member) as fin, open(save_file, \"wb\") as fout:\n",
    "                    shutil.copyfileobj(fin, fout)\n",
    "    elif extension == \".zip\":\n",
    "        zipped = zipfile.ZipFile(zip_file)\n",
    "        if member is not None:\n",
    "            members = [member]\n",
    "            save_files = [os.path.join(save_path, os.path.basename(member))]\n",
    "        else:\n",
    "            members = zipped.namelist()\n",
    "            save_files = [os.path.join(save_path, _member) for _member in members]\n",
    "        for _member, save_file in zip(members, save_files):\n",
    "            if zipped.getinfo(_member).is_dir():\n",
    "                os.makedirs(save_file, exist_ok=True)\n",
    "                continue\n",
    "            os.makedirs(os.path.dirname(save_file), exist_ok=True)\n",
    "            if not os.path.exists(save_file) or zipped.getinfo(_member).file_size != os.path.getsize(save_file):\n",
    "                with zipped.open(_member, \"r\") as fin, open(save_file, \"wb\") as fout:\n",
    "                    shutil.copyfileobj(fin, fout)\n",
    "    else:\n",
    "        raise ValueError(\"Unknown file extension `%s`\" % extension)\n",
    "\n",
    "    if len(save_files) == 1:\n",
    "        return save_files[0]\n",
    "    else:\n",
    "        return save_path"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "5deb177f-67bb-41e4-b8f3-ffdc758c584c",
   "metadata": {},
   "outputs": [],
   "source": [
    "class HumanPPI(Dataset):\n",
    "    \n",
    "    url = \"https://miladeepgraphlearningproteindata.s3.us-east-2.amazonaws.com/ppidata/human_ppi.zip\"\n",
    "    md5 = \"89885545ebc2c11d774c342910230e20\"\n",
    "    \n",
    "    splits = [\"train\", \"valid\", \"test\", \"cross_species_test\"]\n",
    "    target_fields = [\"interaction\"]\n",
    "\n",
    "    def __init__(self, path, split='train', verbose=1):\n",
    "        \n",
    "        lmdb_file = os.path.join(path, f'HumanPPI/normal/{split}/')\n",
    "        self.load_lmdb(lmdb_file, sequence_field=[\"primary_1\", \"primary_2\"], target_fields=self.target_fields,\n",
    "                        verbose=verbose)\n",
    "\n",
    "    def load_lmdb(self, lmdb_file, sequence_field=\"primary\", target_fields=None, number_field=\"num_examples\",\n",
    "                   transform=None, lazy=False, verbose=0, **kwargs):\n",
    "  \n",
    "        \n",
    "        target_fields = set(target_fields)\n",
    "    \n",
    "        sequences = []\n",
    "        num_samples = 0\n",
    "        targets = defaultdict(list)\n",
    "        \n",
    "        self.env = lmdb.open(lmdb_file, lock=False, map_size=10995116277760)\n",
    "        self.operator = self.env.begin()\n",
    "\n",
    "    def _get(self, key: str or int):\n",
    "        value = self.operator.get(str(key).encode())\n",
    "        if value is not None:\n",
    "            value = value.decode()\n",
    "        return value\n",
    "\n",
    "    def __len__(self):\n",
    "        return int(self._get(\"length\"))\n",
    "    \n",
    "    def __getitem__(self, index):   \n",
    "        entry = json.loads(self._get(index))\n",
    "        seq_1, seq_2 = entry['seq_1'], entry['seq_2']\n",
    "        return seq_1, seq_2, int(entry[\"label\"])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "9d1a0822-a126-4f69-b035-9d823378defc",
   "metadata": {},
   "outputs": [],
   "source": [
    "train = HumanPPI(path='./dataset/', split='train')\n",
    "val = HumanPPI(path='./dataset/', split='valid')\n",
    "test = HumanPPI(path='./dataset/', split='test')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "7adbeedf-15df-43a1-8834-ae5dd662e175",
   "metadata": {},
   "outputs": [],
   "source": [
    "def convert_to_csv(dataset, name):\n",
    "    seq1 = []\n",
    "    seq2 = []\n",
    "    targets = []\n",
    "    for i in range(len(dataset)):\n",
    "        s1, s2, t = dataset[i]\n",
    "        seq1.append(s1)\n",
    "        seq2.append(s2)\n",
    "        targets.append(t)\n",
    "    df = pd.DataFrame({'sequence_1': seq1, 'sequence_2': seq2, 'target': targets})\n",
    "    df.to_csv(f'./processed_data_{name}.csv')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "6e6a467a-d116-4d72-a9c5-d45cb137680f",
   "metadata": {},
   "outputs": [],
   "source": [
    "convert_to_csv(train, 'train')\n",
    "convert_to_csv(val, 'validation')\n",
    "convert_to_csv(test, 'test')"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.2"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
