{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d10b6e04",
   "metadata": {},
   "outputs": [],
   "source": [
    "from concurrent.futures import ThreadPoolExecutor, as_completed\n",
    "from tqdm import tqdm\n",
    "import os\n",
    "import requests\n",
    "from io import StringIO\n",
    "from Bio.PDB import PDBParser, PDBIO, Select\n",
    "from tape.datasets import LMDBDataset\n",
    "from collections import Counter\n",
    "from functools import partial\n",
    "from Bio.PDB import PDBParser, DSSP\n",
    "from biotite.structure.io.pdb import PDBFile\n",
    "from pathlib import Path\n",
    "from foldingdiff.datasets import extract_pdb_code_and_chain\n",
    "os.chdir(Path.cwd().parents[0])\n",
    "\n",
    "class ChainSelect(Select):\n",
    "    \"\"\"\n",
    "    A custom selection class for PDBIO that only writes out the specified chain.\n",
    "    \"\"\"\n",
    "    def __init__(self, chain_id):\n",
    "        self.chain_id = chain_id\n",
    "\n",
    "    def accept_chain(self, chain_obj):\n",
    "        # Only accept the chain with id matching the desired chain.\n",
    "        if chain_obj.get_id() == self.chain_id:\n",
    "            return 1\n",
    "        else:\n",
    "            return 0\n",
    "\n",
    "def download_and_filter_pdb(pdb_code, chain, download_dir=\"pdb_files\"):\n",
    "    \"\"\"\n",
    "    Downloads the full PDB file for the given pdb_code, parses it with Biopython, and writes\n",
    "    out only the structure corresponding to the specified chain.\n",
    "    \"\"\"\n",
    "    os.makedirs(download_dir, exist_ok=True)\n",
    "    filename = os.path.join(download_dir, f\"{pdb_code}_{chain}.pdb\")\n",
    "    url = f\"https://files.rcsb.org/download/{pdb_code}.pdb\"\n",
    "    response = requests.get(url)\n",
    "    if response.status_code == 200:\n",
    "        pdb_text = response.text\n",
    "        # Parse the PDB content using a StringIO stream\n",
    "        parser = PDBParser(QUIET=True)\n",
    "        structure = parser.get_structure(pdb_code, StringIO(pdb_text))\n",
    "        \n",
    "        # Prepare the PDB writer with our custom ChainSelect\n",
    "        io = PDBIO()\n",
    "        io.set_structure(structure)        \n",
    "        io.save(filename, select=ChainSelect(chain))\n",
    "        # print(f\"Downloaded and filtered {pdb_code} chain {chain} successfully.\")\n",
    "    else:\n",
    "        print(f\"Error: Could not download {pdb_code}; status code {response.status_code}\")\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0b38e289",
   "metadata": {},
   "outputs": [],
   "source": [
    "def process_sample(sample, download_dir):\n",
    "    \"\"\"Process one dataset sample: extract id, get the PDB code and download the file.\"\"\"\n",
    "    dataset_id = sample['id']\n",
    "    # print(f\"Processing: {dataset_id}\")\n",
    "    pdb_code, chain = extract_pdb_code_and_chain(dataset_id)\n",
    "    download_and_filter_pdb(pdb_code, chain, download_dir)\n",
    "    return dataset_id  # You may return any result you need\n",
    "\n",
    "# Set an appropriate number of worker threads (adjust max_workers as needed)\n",
    "max_workers = 100\n",
    "\n",
    "def process_dataset(dataset_ids, download_dir):\n",
    "    # If LMDBDataset isn’t a list (and doesn’t have __len__), consider converting it to a list first \n",
    "    # so that tqdm knows the total number of items. For example:\n",
    "    dataset_ids = list(dataset_ids)\n",
    "    process_sample_partial = partial(process_sample, download_dir=download_dir)\n",
    "    with ThreadPoolExecutor(max_workers=max_workers) as executor:\n",
    "        # Use executor.map to apply process_sample to each sample in the dataset\n",
    "        # executor.map returns an iterator that produces results in order\n",
    "        results = list(tqdm(executor.map(process_sample_partial, dataset_ids), total=len(dataset_ids)))\n",
    "\n",
    "    # Optionally, process or log the collected results\n",
    "    print(\"Finished processing samples.\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "884155b9",
   "metadata": {},
   "outputs": [],
   "source": [
    "train = LMDBDataset(f'data/remote_homology/remote_homology_train.lmdb')\n",
    "counts = Counter((s['fold_label'] for s in train))\n",
    "keep = set([k for k in counts if counts[k]>50])\n",
    "print(len(keep))\n",
    "for suffix in ['train','valid','test_family_holdout','test_fold_holdout','test_superfamily_holdout']:\n",
    "    dataset_ids = LMDBDataset(f'data/remote_homology/remote_homology_{suffix}.lmdb')\n",
    "    dataset_ids = [s for s in dataset_ids if s['fold_label'] in keep]\n",
    "    process_dataset(dataset_ids, f'data/remote_homology/{suffix}_pdbs')\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8ee1c6a5",
   "metadata": {},
   "outputs": [],
   "source": [
    "bad = []\n",
    "for suffix in ['train','valid','test_family_holdout','test_fold_holdout','test_superfamily_holdout']:\n",
    "    folder = f'data/remote_homology/{suffix}_pdbs'\n",
    "    for f in tqdm(os.listdir(folder)):\n",
    "        fname = os.path.join(folder, f)\n",
    "        if '1JBA_A' not in f:\n",
    "            continue\n",
    "        with open(str(fname), \"rt\") as f:\n",
    "            source = PDBFile.read(f)\n",
    "        source_struct = source.get_structure(model=1)        \n",
    "        parser = PDBParser(QUIET=True)\n",
    "        structure = parser.get_structure(Path(fname).stem, fname)        \n",
    "        model = structure[0]  # assuming you want the first model\n",
    "        print(model)\n",
    "        # try:\n",
    "        #     dssp = DSSP(model, fname)        \n",
    "        # except:\n",
    "        #     print(fname)\n",
    "        #     bad.append(fname)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "966e3a04",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "abf45312",
   "metadata": {},
   "outputs": [],
   "source": [
    "from huggingface_hub import login\n",
    "from esm.models.esm3 import ESM3\n",
    "from esm.sdk.api import ESM3InferenceClient, ESMProtein, GenerationConfig\n",
    "\n",
    "# Will instruct you how to get an API key from huggingface hub, make one with \"Read\" permission.\n",
    "# login()\n",
    "\n",
    "# This will download the model weights and instantiate the model on your machine.\n",
    "model: ESM3InferenceClient = ESM3.from_pretrained(\"esm3-open\").to(\"cuda\") # or \"cpu\"\n",
    "\n",
    "# Generate a completion for a partial Carbonic Anhydrase (2vvb)\n",
    "prompt = \"___________________________________________________DQATSLRILNNGHAFNVEFDDSQDKAVLKGGPLDGTYRLIQFHFHWGSLDGQGSEHTVDKKKYAAELHLVHWNTKYGDFGKAVQQPDGLAVLGIFLKVGSAKPGLQKVVDVLDSIKTKGKSADFTNFDPRGLLPESLDYWTYPGSLTTPP___________________________________________________________\"\n",
    "protein = ESMProtein(sequence=prompt)\n",
    "# Generate the sequence, then the structure. This will iteratively unmask the sequence track.\n",
    "protein = model.generate(protein, GenerationConfig(track=\"sequence\", num_steps=8, temperature=0.7))\n",
    "# We can show the predicted structure for the generated sequence.\n",
    "protein = model.generate(protein, GenerationConfig(track=\"structure\", num_steps=8))\n",
    "protein.to_pdb(\"./generation.pdb\")\n",
    "# Then we can do a round trip design by inverse folding the sequence and recomputing the structure\n",
    "protein.sequence = None\n",
    "protein = model.generate(protein, GenerationConfig(track=\"sequence\", num_steps=8))\n",
    "protein.coordinates = None\n",
    "protein = model.generate(protein, GenerationConfig(track=\"structure\", num_steps=8))\n",
    "protein.to_pdb(\"./round_tripped.pdb\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0ce85ad3",
   "metadata": {},
   "outputs": [],
   "source": [
    "from esm.models.esmc import ESMC\n",
    "from esm.sdk.api import ESMProtein, LogitsConfig\n",
    "\n",
    "protein = ESMProtein(sequence=\"AAAAA\")\n",
    "client = ESMC.from_pretrained(\"esmc_300m\").to(\"cuda\") # or \"cpu\"\n",
    "protein_tensor = client.encode(protein)\n",
    "logits_output = client.logits(\n",
    "   protein_tensor, LogitsConfig(sequence=True, return_embeddings=True)\n",
    ")\n",
    "print(logits_output.logits, logits_output.embeddings)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "67abc820",
   "metadata": {},
   "outputs": [],
   "source": [
    "from foldingdiff.angles_and_coords import *\n",
    "\n",
    "canonical_distances_and_dihedrals('data/remote_homology/test_superfamily_holdout_pdbs/1UX8_A.pdb')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f081f10a",
   "metadata": {},
   "outputs": [],
   "source": [
    "import pickle\n",
    "from foldingdiff.bpe import MyDataset\n",
    "train, valid, test_datasets = pickle.load(open('homo_datasets_with_test.pkl', 'rb'))\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6efa60cd",
   "metadata": {},
   "outputs": [],
   "source": [
    "test_datasets = list(zip(('test_family', 'test_fold', 'test_superfamily'), test_datasets))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "378c78f1",
   "metadata": {},
   "outputs": [],
   "source": [
    "pickle.dump((train, valid,test_datasets), open('homo_datasets_with_test.pkl', 'wb+'))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "dfde6748",
   "metadata": {},
   "outputs": [],
   "source": [
    "import pickle \n",
    "\n",
    "for f in ['1749673452.5647202', '1749673452.650502', '1749673453.07734', '1749696458.6950917', '1749696458.6925066', '1749696458.6264317']:\n",
    "    i = 0\n",
    "    while True:\n",
    "        path = f'ckpts/{f}/feats_100_{i}.pkl'\n",
    "        if not os.path.exists(path):\n",
    "            break\n",
    "        feats = pickle.load(open(path, 'rb'))\n",
    "        for prot_id, feat in feats.items():\n",
    "            feat['fp'] = {key: val.cpu().numpy() for (key, val) in feat.items()}\n",
    "            if 'foldseek' in feat:\n",
    "                feat['foldseek'] = {key: val.cpu().numpy() for (key, val) in feat.items()}\n",
    "            feats[prot_id] = feat\n",
    "        pickle.dump(feats, open(path, \"wb+\"))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "97dbe7cd",
   "metadata": {},
   "outputs": [],
   "source": [
    "# create a csv\n",
    "import pickle\n",
    "import pandas as pd\n",
    "# for name in [\"conserved-site-prediction\",\"CatBio\",\"BindBio\",\"CatInt\",\"repeat-motif-prediction\"]:\n",
    "name = \"remote-homology-detection\"\n",
    "path = f\"data/struct_token_bench/processed_pickles/{name}.pkl\"\n",
    "train, val, test = pickle.load(open(path, 'rb'))\n",
    "if \"homo\" in name:\n",
    "    keys = [\"pdb_id\", \"chain_id\", \"fold_label\"]\n",
    "else:\n",
    "    keys = [\"pdb_id\", \"chain_id\", \"label_type\", \"residue_label\"]\n",
    "\n",
    "\n",
    "def fill_keys(item):\n",
    "    assert \"id\" in item\n",
    "    pdb_id, chain_id = extract_pdb_code_and_chain(item[\"id\"])\n",
    "    item[\"pdb_id\"] = pdb_id\n",
    "    item[\"chain_id\"] = chain_id    \n",
    "\n",
    "rows = []\n",
    "for i, item in enumerate(train):\n",
    "    if \"homo\" in name:\n",
    "        fill_keys(item)\n",
    "    rows.append({\"split\": \"train\",\n",
    "                \"idx\": i} | {key: item[key] for key in keys})\n",
    "for i, item in enumerate(val):\n",
    "    if \"homo\" in name:\n",
    "        fill_keys(item)\n",
    "    rows.append({\"split\": \"valid\",\n",
    "                \"idx\": i} | {key: item[key] for key in keys})    \n",
    "for k, test_k in test:\n",
    "    for i, item in enumerate(test_k):\n",
    "        if \"homo\" in name:\n",
    "            fill_keys(item)\n",
    "        rows.append({\"split\": k,\n",
    "                    \"idx\": i} | {key: item[key] for key in keys})                    \n",
    "df = pd.DataFrame(rows)\n",
    "df.to_csv(f\"data/struct_token_bench/processed_csvs/{name}.csv\", index=False)\n",
    "print(f\"saved {name}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b8fde623",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "language_info": {
   "name": "python"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
