{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from ProtMamba_ssm.core import *\n",
    "from ProtMamba_ssm.dataloaders import *\n",
    "from ProtMamba_ssm.utils import *\n",
    "from ProtMamba_ssm.modules import *\n",
    "import torch\n",
    "from matplotlib import pyplot as plt\n",
    "import numpy as np\n",
    "from tqdm import tqdm\n",
    "import pickle\n",
    "import pandas as pd\n",
    "import os"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from Bio import Align\n",
    "\n",
    "def delete_masks(seq):\n",
    "    masks = [\"<mask-1>\", \"<mask-2>\", \"<mask-3>\", \"<mask-4>\", \"<mask-5>\", \"<cls>\"]\n",
    "    for mask in masks:\n",
    "        seq = seq.replace(mask, \"\")\n",
    "    return seq\n",
    "\n",
    "aligner = Align.PairwiseAligner()\n",
    "aligner.mode = 'global'\n",
    "aligner.match_score = 1\n",
    "aligner.mismatch_score = -1\n",
    "aligner.open_gap_score = -1\n",
    "aligner.extend_gap_score = -1\n",
    "\n",
    "def align_sequences(ref_seq, query_seq, print_alignments=False):\n",
    "    def hamming_str(s1,s2):\n",
    "        assert len(s1) == len(s2)\n",
    "        return sum(np.array(list(s1)) != np.array(list(s2)))/len(s1)\n",
    "    alignments = aligner.align(ref_seq, query_seq)\n",
    "    if print_alignments:\n",
    "        print(\"Score = %.1f:\" % alignments[0].score)\n",
    "        print(alignments[0])\n",
    "    return hamming_str(alignments[0][0], alignments[0][1]), alignments[0][0], alignments[0][1]\n",
    "\n",
    "seq1 = \"ACDEFGHIKLMNPQRST\"\n",
    "seq2 = \"ACDEEGHKLMNQRSTVWY\"\n",
    "align_sequences(seq1, seq2), align_sequences(seq1, seq1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import string\n",
    "from Bio import SeqIO\n",
    "import pyhmmer\n",
    "\n",
    "alphabet = pyhmmer.easel.Alphabet.amino()\n",
    "\n",
    "# This is an efficient way to delete lowercase characters and insertion characters from a string\n",
    "deletekeys = dict.fromkeys(string.ascii_lowercase)\n",
    "deletekeys[\".\"] = None\n",
    "deletekeys[\"*\"] = None\n",
    "translation = str.maketrans(deletekeys)\n",
    "\n",
    "def remove_insertions(sequence: str) -> str:\n",
    "    \"\"\" Removes any insertions into the sequence. Needed to load aligned sequences in an MSA. \"\"\"\n",
    "    return sequence.translate(translation)\n",
    "\n",
    "def read_msa(filename: str):\n",
    "    \"\"\" Reads the sequences from an MSA file, automatically removes insertions.\"\"\"\n",
    "    return [(record.description, remove_insertions(str(record.seq))) for record in SeqIO.parse(filename, \"fasta\")]\n",
    "\n",
    "def read_msa_unaligned(filename: str):\n",
    "    \"\"\" Reads the sequences from an MSA file, removes only . - and * characters.\"\"\"\n",
    "    return [(record.description, str(record.seq).replace(\".\",\"\").replace(\"-\",\"\").replace(\"*\",\"\").upper()) for record in SeqIO.parse(filename, \"fasta\")]\n",
    "\n",
    "def check_msa(msa):\n",
    "    \"\"\" Checks if there are any repeated sequences in the MSA\"\"\"\n",
    "    seqs = set()\n",
    "    for el in msa:\n",
    "        seqs.add(el[1])\n",
    "    assert len(seqs) == len(msa), \"There are repeated sequences in the MSA\"\n",
    "    \n",
    "def make_hmm_from_a3m_msa(msa_filepath, hmm_filename=None):\n",
    "    # Load MSA from a3m\n",
    "    msa_tup = read_msa(msa_filepath)\n",
    "    # check_msa(msa_tup)\n",
    "    # Create digitized MSA block\n",
    "    all_seqs = [pyhmmer.easel.TextSequence(name=str(i).encode(\"utf-8\"), sequence=seq) for i, (idz, seq) in enumerate(msa_tup)]\n",
    "    msa  = pyhmmer.easel.TextMSA(name=b\"msa\", sequences=all_seqs)\n",
    "    msa = msa.digitize(alphabet)\n",
    "    # Fit HMM\n",
    "    builder = pyhmmer.plan7.Builder(alphabet)\n",
    "    background = pyhmmer.plan7.Background(alphabet)\n",
    "    hmm, _, _ = builder.build_msa(msa, background)\n",
    "    if hmm_filename is not None:\n",
    "        with open(f\"{hmm_filename}.hmm\", \"wb\") as output_file:\n",
    "            hmm.write(output_file)\n",
    "    return hmm\n",
    "\n",
    "def align_and_score_sequences_in_a3m_with_hmm(hmm, sequences_path=None, sequences_list=None):\n",
    "    if sequences_list is not None:\n",
    "        msa = sequences_list\n",
    "        all_seqs = [pyhmmer.easel.TextSequence(name=str(i).encode(\"utf-8\"), sequence=seq) for i, seq in enumerate(sequences_list)]\n",
    "    elif sequences_path is not None:\n",
    "        # Load sequences from a3m\n",
    "        msa = read_msa_unaligned(sequences_path)\n",
    "        all_seqs = [pyhmmer.easel.TextSequence(name=str(i).encode(\"utf-8\"), sequence=seq) for i, (idz, seq) in enumerate(msa)]\n",
    "    else:\n",
    "        raise NotImplementedError(\"Missing sequences to align/score\")\n",
    "    # Create digitized Sequence block\n",
    "    seq_block = pyhmmer.easel.TextSequenceBlock(all_seqs)\n",
    "    seq_block = seq_block.digitize(alphabet)\n",
    "    # Get all hits from the hmm\n",
    "    background = pyhmmer.plan7.Background(alphabet)\n",
    "    pipeline = pyhmmer.plan7.Pipeline(alphabet, background=background, bias_filter=False, F1=1.0, F2=1.0, F3=1.0)\n",
    "    hits = pipeline.search_hmm(hmm, seq_block)\n",
    "    if len(hits) != len(msa):\n",
    "        print(f\"Number of hits: {len(hits)} is different from the number of sequences in the MSA: {len(msa)}\")\n",
    "    # Extract hits\n",
    "    all_hits = {}\n",
    "    for hit in hits:\n",
    "        idz, score, evalue = hit.name, hit.score, hit.evalue\n",
    "        i = int(idz.decode(\"utf-8\"))\n",
    "        seq = msa[i][1] if sequences_path is not None else sequences_list[i]\n",
    "        all_hits[seq] = {\"score\": score, \"evalue\": evalue}\n",
    "    return all_hits"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from transformers import AutoTokenizer, EsmForProteinFolding\n",
    "from Bio.PDB import *\n",
    "import tmscoring\n",
    "\n",
    "pdb_parser = PDBParser()\n",
    "\n",
    "def align_structures_CEalign(path_ref, path_query, key=None):\n",
    "    if key:\n",
    "        key = \"-\"+key\n",
    "    ref_structure = pdb_parser.get_structure(\"reference\", path_ref)\n",
    "    query_structure = pdb_parser.get_structure(\"query\", path_query)\n",
    "\n",
    "    aligner = cealign.CEAligner()\n",
    "    aligner.set_reference(ref_structure[0])\n",
    "    aligner.align(query_structure[0])\n",
    "    rmsd = aligner.rms  \n",
    "    # Save new aligned structure\n",
    "    io=PDBIO()\n",
    "    io.set_structure(query_structure)\n",
    "    str_path = str(path_query).split(\".\")\n",
    "    io.save(str_path[0] + f\"_aligned\"+key+\".\"+str_path[1])\n",
    "    return rmsd\n",
    "\n",
    "def align_structures_TMscore(path_ref, path_query, key=None):\n",
    "    if key:\n",
    "        key = \"-\"+key\n",
    "    alignment = tmscoring.TMscoring(path_ref, path_query)\n",
    "    # Find the optimal alignment\n",
    "    alignment.optimise()\n",
    "    # Get the TM score:\n",
    "    tmscore = alignment.tmscore(**alignment.get_current_values())\n",
    "    # RMSD of the protein aligned according to TM score\n",
    "    rmsd = alignment.rmsd(**alignment.get_current_values())\n",
    "    str_path = str(path_query).split(\".\")\n",
    "    alignment.write(outputfile=str_path[0] + f\"_aligned\"+key+\".\"+str_path[1], appended=True)\n",
    "    return tmscore, rmsd\n",
    "\n",
    "def compute_structure(seq, model, struct_path, ref_struct_path, alignment_func=align_structures_TMscore):\n",
    "    def keep_sequence(seq, l):\n",
    "        if len(seq) > l:\n",
    "            return False\n",
    "        for mm in list(MASK_TO_ID.keys())+[\"<eos>\", \"<pad>\", \"<unk>\", \"<mask>\", \"<cls>\", \"<null_1>\", \".\" , \"-\"]:\n",
    "            if mm in seq:\n",
    "                return False\n",
    "        return True\n",
    "    keep = keep_sequence(seq, l=750)\n",
    "    if keep:\n",
    "        with torch.no_grad():\n",
    "            output = model.infer([seq])\n",
    "        pdb = model.output_to_pdb(output)\n",
    "        ptm = output[\"ptm\"].item()\n",
    "        pae = output[\"predicted_aligned_error\"].cpu().numpy()\n",
    "        mean_plddt = ((output[\"plddt\"] * output[\"atom37_atom_exists\"]).sum(dim=(1, 2)) / output[\"atom37_atom_exists\"].sum(dim=(1, 2))).item()\n",
    "        pos_plddt = ((output[\"plddt\"] * output[\"atom37_atom_exists\"]).sum(dim=(2,)) / output[\"atom37_atom_exists\"].sum(dim=(2,))).cpu().numpy()\n",
    "        with open(struct_path, \"w\") as f:\n",
    "            f.write(pdb[0])\n",
    "        tmscore, rmsd = alignment_func(ref_struct_path, struct_path, key=\"\")\n",
    "    else:\n",
    "        print(f\"Sequence {struct_path} is too long\")\n",
    "        ptm, pae, mean_plddt, pos_plddt, rmsd, tmscore = 0, 0, 0, 0, 0, 0\n",
    "    return ptm, pae, mean_plddt, pos_plddt, rmsd, tmscore"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Import generated sequences"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "name_data = \"check-131k(13-18)_gen_seqs_full\"\n",
    "with open(f\"figures/generated_sequences/{name_data}.pkl\", \"rb\") as f:\n",
    "    gen_seqs = pickle.load(f)\n",
    "fim_generation = True if input(\"Is it a FIM generated dataset? (y/n): \") == \"y\" else False\n",
    "\n",
    "is_fim = True\n",
    "dataset_name = \"encoded_MSAs_test.pkl\"\n",
    "fim_strategy = \"multiple_span\"\n",
    "num_natural = 100\n",
    "# Load the dataset used for training\n",
    "dataset = Uniclust30_Dataset(dataset_name,\n",
    "                             filepath=\"/data1/common/OpenProteinSet/\",\n",
    "                             sample=False,\n",
    "                             max_msa_len=-1,\n",
    "                             max_patches=5,\n",
    "                             mask_fraction=0.2,\n",
    "                             fim_strategy=fim_strategy,\n",
    "                             max_position_embeddings=2048,\n",
    "                             add_position_ids=\"1d\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Pairwise align generated sequences with natural sequences"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "if fim_generation:\n",
    "    for j in tqdm(gen_seqs.keys()):\n",
    "        for i, (key, dict_seqs) in enumerate(gen_seqs[j].items()):\n",
    "            for seq in dict_seqs.keys():\n",
    "                # Modified positions in fim generated parts\n",
    "                # print(gen_seqs[j][key][seq][\"generated_input_fim\"], gen_seqs[j][key][seq][\"original_input_fim\"])\n",
    "                new, orig = delete_masks(gen_seqs[j][key][seq][\"generated_input_fim\"]), delete_masks(gen_seqs[j][key][seq][\"original_input_fim\"])\n",
    "                gen_seqs[j][key][seq][\"fim_distance/fim_size\"] = (sum(np.array(list(new)) != np.array(list(orig))), len(new))\n",
    "else:\n",
    "    all_hamming_ctx = {}\n",
    "    for j in tqdm(gen_seqs.keys()):\n",
    "        # Select a sample of the dataset to be the input\n",
    "        data = dataset[j]\n",
    "        family_id = dataset.cluster_names[j]\n",
    "        tokens = data[\"input_ids\"][None,:].to(\"cuda\")\n",
    "        pos_ids = data[\"position_ids\"][None,:].to(\"cuda\")\n",
    "        # Find baseline hamming distances between natural sequences\n",
    "        all_context = decode_sequence(tokens[0].cpu().numpy())\n",
    "        list_sequences_msa = [reorder_masked_sequence(elem+\"<cls>\") for elem in all_context.split(\"<cls>\")[1:-1]]\n",
    "        all_hamming_ctx[family_id] = []\n",
    "        rd_idxs = np.random.choice(len(list_sequences_msa), num_natural)\n",
    "        for tmp_seq in [el for i, el in enumerate(list_sequences_msa) if i in rd_idxs]:\n",
    "            all_hamming = []\n",
    "            for ctx_seq in list_sequences_msa:\n",
    "                if ctx_seq == tmp_seq:\n",
    "                    continue\n",
    "                else:\n",
    "                    hamming, _, _ = align_sequences(ctx_seq, tmp_seq , print_alignments=False)\n",
    "                    all_hamming.append(hamming)\n",
    "            all_hamming_ctx[family_id].append(all_hamming)\n",
    "            \n",
    "        for i, (key, dict_seqs) in enumerate(gen_seqs[j].items()):\n",
    "            for seq in dict_seqs.keys():\n",
    "                # Hamming distances between generated sequences and natural sequences\n",
    "                all_hamming = []\n",
    "                for ctx_seq in list_sequences_msa:\n",
    "                    hamming, _, _ = align_sequences(ctx_seq, reorder_masked_sequence(seq), print_alignments=False)\n",
    "                    all_hamming.append(hamming)\n",
    "                gen_seqs[j][key][seq][\"hamming\"] = np.array(all_hamming)\n",
    "    with open(f\"figures/generated_sequences/all_hamming_ctx_{name_data}.pkl\", \"wb\") as f:\n",
    "        pickle.dump(all_hamming_ctx, f)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Make dataframe with all sequences"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "all_gen_seqs = {}\n",
    "for j in tqdm(gen_seqs.keys()):\n",
    "    for params in gen_seqs[j].keys():\n",
    "        dict_seqs = gen_seqs[j][params]\n",
    "        n_seqs_ctx , temperature, top_k, top_p = params\n",
    "        for seq, values in dict_seqs.items():\n",
    "            perplexity = values[\"perplexity\"]\n",
    "            all_gen_seqs[seq] = {\"family\": j, \"family_id\": dataset.cluster_names[j], \"perplexity\": perplexity,\n",
    "                             \"n_seqs_ctx\": n_seqs_ctx, \"temperature\": temperature, \"top_k\": top_k, \"top_p\": top_p}\n",
    "            if fim_generation:\n",
    "                all_gen_seqs[seq][\"original_input\"] = values[\"original_input\"]\n",
    "                all_gen_seqs[seq][\"original_input_fim\"] = values[\"original_input_fim\"]\n",
    "                all_gen_seqs[seq][\"generated_input_fim\"] = values[\"generated_input_fim\"]\n",
    "                all_gen_seqs[seq][\"fim_distance\"], all_gen_seqs[seq][\"fim_size\"] = values[\"fim_distance/fim_size\"]\n",
    "                all_gen_seqs[seq][\"original_sequence\"] = reorder_masked_sequence(values[\"original_input\"] + values[\"original_input_fim\"])\n",
    "                all_gen_seqs[seq][\"generated_sequence\"] = reorder_masked_sequence(values[\"original_input\"] + values[\"generated_input_fim\"])\n",
    "                assert len(all_gen_seqs[seq][\"original_sequence\"]) == len(all_gen_seqs[seq][\"generated_sequence\"])\n",
    "            else:\n",
    "                all_gen_seqs[seq][\"hamming\"] = values[\"hamming\"]\n",
    "                all_gen_seqs[seq][\"min_hamming\"] = np.min(values[\"hamming\"])\n",
    "                all_gen_seqs[seq][\"generated_sequence\"] = reorder_masked_sequence(seq)\n",
    "                all_gen_seqs[seq][\"sequence_length\"] = len(all_gen_seqs[seq][\"generated_sequence\"])\n",
    "\n",
    "df = pd.DataFrame.from_dict(all_gen_seqs, orient=\"index\")\n",
    "df.reset_index(inplace=True, drop=True)\n",
    "df.head()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "df.describe()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "df.to_pickle(f\"figures/generated_sequences/dataframe_{name_data}.pkl\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## HMMER scores"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "families = df[\"family_id\"].unique()\n",
    "all_scores_ctx = {}\n",
    "for family_id in tqdm(families):\n",
    "    msa_filepath = f\"figures/pdb_structures/msas/{family_id}.a3m\"\n",
    "    try:\n",
    "        hmm = make_hmm_from_a3m_msa(msa_filepath)\n",
    "    except:\n",
    "        raise Exception(f\"Missing MSA of family {family_id}\")\n",
    "    # find all df entries with the same family and align them\n",
    "    family_df = df[df[\"family_id\"] == family_id]\n",
    "    sequences = family_df[\"generated_sequence\"].values\n",
    "    scores = align_and_score_sequences_in_a3m_with_hmm(hmm, sequences_list=sequences)\n",
    "    # save the scores associated to each sequence in the main df in the columns \"score\" and \"evalue\"\n",
    "    for seq in sequences:\n",
    "        df.loc[df[\"generated_sequence\"] == seq, \"score_gen\"] = scores[seq][\"score\"] if seq in scores.keys() else 0\n",
    "        df.loc[df[\"generated_sequence\"] == seq, \"evalue_gen\"] = scores[seq][\"evalue\"] if seq in scores.keys() else 1\n",
    "    if fim_generation:\n",
    "        sequences = family_df[\"original_sequence\"].values\n",
    "        scores = align_and_score_sequences_in_a3m_with_hmm(hmm, sequences_list=sequences)\n",
    "        # save the scores associated to each sequence in the main df in the columns \"score\" and \"evalue\"\n",
    "        for seq in sequences:\n",
    "            df.loc[df[\"original_sequence\"] == seq, \"score_orig\"] = scores[seq][\"score\"] if seq in scores.keys() else 0\n",
    "            df.loc[df[\"original_sequence\"] == seq, \"evalue_orig\"] = scores[seq][\"evalue\"] if seq in scores.keys() else 1\n",
    "    else:\n",
    "        scores = align_and_score_sequences_in_a3m_with_hmm(hmm, sequences_path=msa_filepath)\n",
    "        all_scores_ctx[family_id] = {\"score\": [scores[seq][\"score\"] for seq in scores.keys()],\n",
    "                                     \"evalue\": [scores[seq][\"evalue\"] for seq in scores.keys()]}\n",
    "if not fim_generation:\n",
    "    with open(f\"figures/generated_sequences/all_hmmer_ctx_{name_data}.pkl\", \"wb\") as f:\n",
    "        pickle.dump(all_scores_ctx, f)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "df.to_pickle(f\"figures/generated_sequences/dataframe_{name_data}.pkl\")\n",
    "df.head()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Structure prediction (ESMFold)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Import the folding model\n",
    "esm_tokenizer = AutoTokenizer.from_pretrained(\"facebook/esmfold_v1\")\n",
    "model = EsmForProteinFolding.from_pretrained(\"facebook/esmfold_v1\", low_cpu_mem_usage=True)\n",
    "\n",
    "model = model.cuda(\"cuda:0\")\n",
    "model.esm = model.esm.half()\n",
    "torch.backends.cuda.matmul.allow_tf32 = True\n",
    "# model.trunk.set_chunk_size(64)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from pathlib import Path\n",
    "# Load refernce structures OpenFold\n",
    "structures_dir = Path(\"figures/pdb_structures/\")\n",
    "# list all pdb files in the directory\n",
    "structures_paths = {path.stem: path for path in structures_dir.glob(\"*.pdb\")}\n",
    "print(f\"Reference structures: \", *list(structures_paths.keys()))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Check if structures of representatives are already computed\n",
    "\n",
    "representatives_dir = Path(\"figures/pdb_structures/esmfold/representatives/\")\n",
    "representatives_paths = list((path.stem) for path in representatives_dir.glob(\"*.pdb\"))\n",
    "\n",
    "bool_var = False\n",
    "for name in structures_paths.keys():\n",
    "    if (name + \"_esmfold\" not in representatives_paths) or (name + \"_esmfold_aligned\" not in representatives_paths):\n",
    "        print(f\"Missing representative structure for {name}\")\n",
    "        bool_var = True\n",
    "\n",
    "if bool_var and input(\"Do you want to compute the representative structures using esmfold? (y/n): \") == \"y\":\n",
    "    ppb=PPBuilder()\n",
    "\n",
    "    openproteinset = {}\n",
    "    for name, ref_struct_path in structures_paths.items():\n",
    "        exp_struct = pdb_parser.get_structure(\"exp\", ref_struct_path)\n",
    "        seq = ppb.build_peptides(exp_struct[0][\"A\"])\n",
    "        sequence = \"\".join([str(sq.get_sequence()) for sq in seq])\n",
    "        plddt = [residue[\"CA\"].get_bfactor()/100 for residue in exp_struct[0][\"A\"]]\n",
    "        openproteinset[name] = {\"sequence\": sequence, \"plddt\": plddt}\n",
    "        assert len(openproteinset[name][\"sequence\"]) == len(openproteinset[name][\"plddt\"])\n",
    "    conf = {}\n",
    "    for name in openproteinset.keys():\n",
    "        seq = openproteinset[name][\"sequence\"]\n",
    "        ref_struct_path = str(structures_paths[name])\n",
    "        struct_path = f\"figures/pdb_structures/esmfold/representatives/{name}.pdb\"\n",
    "        ptm, pae, mean_plddt, pos_plddt, rmsd, tmscore = compute_structure(seq, model, struct_path, ref_struct_path)\n",
    "        conf[name] = {\"ptm\": ptm, \"pae\": pae, \"mean_plddt\": mean_plddt, \"pos_plddt\": pos_plddt, \"rmsd\": rmsd, \"tmscore\": tmscore}\n",
    "    with open(f\"figures/generated_sequences/all_structures_representatives.pkl\", \"wb\") as f:\n",
    "        pickle.dump(conf, f)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# df = pd.read_pickle(f\"figures/generated_sequences/dataframe_{name_data}.pkl\")\n",
    "# with open(f\"figures/generated_sequences/all_structures_ctx_{name_data}.pkl\", \"rb\") as f:\n",
    "#     all_structures_ctx = pickle.load(f)\n",
    "# families = df[\"family_id\"].unique()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "dir_path = f\"figures/pdb_structures/esmfold/generated_{name_data}/\"\n",
    "os.mkdir(dir_path)\n",
    "all_structures_ctx = {}\n",
    "\n",
    "for family_id in tqdm(families):\n",
    "    # find all df entries with the same family\n",
    "    print(f\"Family {family_id}\")\n",
    "    family_df = df[df[\"family_id\"] == family_id]\n",
    "    ref_struct_path = str(structures_paths[family_id])\n",
    "    sequences = family_df[\"generated_sequence\"].values\n",
    "    for seq in tqdm(sequences):\n",
    "        # get index of sequence in dataframe\n",
    "        indx = df[df[\"generated_sequence\"] == seq].index[0]\n",
    "        struct_path = dir_path+f\"{family_id}_{indx}_gen.pdb\"\n",
    "        # compute the structure\n",
    "        if \"ptm_gen\" not in df.columns:\n",
    "            # add column\n",
    "            df[\"ptm_gen\"] = np.nan\n",
    "        if not df[df[\"generated_sequence\"] == seq][\"ptm_gen\"].values[0] > 0:\n",
    "            ptm, pae, mean_plddt, pos_plddt, rmsd, tmscore = compute_structure(seq, model, struct_path, ref_struct_path)\n",
    "            df.loc[df[\"generated_sequence\"] == seq, [\"ptm_gen\",\n",
    "                                                    #  \"pae_gen\",\n",
    "                                                    \"mean_plddt_gen\",\n",
    "                                                    \"rmsd_gen\",\n",
    "                                                    \"tmscore_gen\"]] = ptm, mean_plddt, rmsd, tmscore\n",
    "        if fim_generation:\n",
    "            # get plddt values of masked positions\n",
    "            input_orig = df.loc[indx, \"original_input\"]\n",
    "            input_fim_gen = df.loc[indx, \"generated_input_fim\"]\n",
    "            full_gen, inds_masks = reorder_masked_sequence(input_orig+input_fim_gen, return_ids=True)\n",
    "            assert full_gen == seq\n",
    "            plddt_masked = 0\n",
    "            if isinstance(pos_plddt, np.ndarray):\n",
    "                assert pos_plddt.shape[1] == len(full_gen)\n",
    "                plddt_masked = [el for tup in inds_masks for el in pos_plddt[0,tup[0]:tup[1]]]\n",
    "                assert len(plddt_masked) == df.loc[indx, \"fim_size\"]\n",
    "                plddt_masked = np.mean(plddt_masked)\n",
    "            df.loc[df[\"generated]_sequence\"] == seq, [\"masked_plddt_gen\"]] = plddt_masked\n",
    "    if fim_generation:\n",
    "        sequences = family_df[\"original_sequence\"].values\n",
    "        for seq in tqdm(sequences):\n",
    "            # get index of sequence in dataframe\n",
    "            indx = df[df[\"original_sequence\"] == seq].index[0]\n",
    "            struct_path = dir_path+f\"{family_id}_{indx}_orig.pdb\"\n",
    "            # compute the structure\n",
    "            ptm, pae, mean_plddt, pos_plddt, rmsd, tmscore = compute_structure(seq, model, struct_path, ref_struct_path)\n",
    "            # get plddt values of masked positions\n",
    "            input_orig = df.loc[indx, \"original_input\"]\n",
    "            input_fim_orig = df.loc[indx, \"original_input_fim\"]\n",
    "            full_orig, inds_masks = reorder_masked_sequence(input_orig+input_fim_orig, return_ids=True)\n",
    "            assert full_orig == seq\n",
    "            plddt_masked = 0\n",
    "            if isinstance(pos_plddt, np.ndarray):\n",
    "                assert pos_plddt.shape[1] == len(full_orig)\n",
    "                plddt_masked = [el for tup in inds_masks for el in pos_plddt[0,tup[0]:tup[1]]]\n",
    "                assert len(plddt_masked) == df.loc[indx, \"fim_size\"]\n",
    "                plddt_masked = np.mean(plddt_masked)\n",
    "            df.loc[df[\"original_sequence\"] == seq, [\"ptm_orig\",\n",
    "                                                    # \"pae_orig\",\n",
    "                                                    \"mean_plddt_orig\",\n",
    "                                                    \"masked_plddt_orig\",\n",
    "                                                    \"rmsd_orig\",\n",
    "                                                    \"tmscore_orig\"]] = ptm, mean_plddt, plddt_masked, rmsd, tmscore\n",
    "            # compare the structure of the original sequence with the one of the generated fim sequence\n",
    "            tmscore, rmsd = 0, 0\n",
    "            if df[\"ptm_gen\"][indx] != 0:\n",
    "                struct_path_gen = dir_path+f\"{family_id}_{indx}_gen.pdb\"\n",
    "                tmscore, rmsd = align_structures_TMscore(struct_path, struct_path_gen, key=\"\")\n",
    "            df.loc[df[\"original_sequence\"] == seq, [\"tmscore_orig_gen\", \"rmsd_orig_gen\"]] = tmscore, rmsd\n",
    "    else:\n",
    "        msa = read_msa_unaligned(f\"figures/pdb_structures/msas/{family_id}.a3m\")\n",
    "        if family_id not in all_structures_ctx.keys():\n",
    "            all_structures_ctx[family_id] = {}\n",
    "            subset_seq_ids = np.random.choice(len(msa), num_natural, replace=False)\n",
    "            for i in subset_seq_ids:\n",
    "                _, seq = msa[i]\n",
    "                struct_path = f\"figures/pdb_structures/esmfold/natural/{family_id}_{i}.pdb\"\n",
    "                ptm, pae, mean_plddt, pos_plddt, rmsd, tmscore = compute_structure(seq, model, struct_path, ref_struct_path)\n",
    "                all_structures_ctx[family_id][i] = {\"ptm\": ptm, \"mean_plddt\": mean_plddt, \"rmsd\": rmsd, \"tmscore\": tmscore}\n",
    "    # save temporary dataframe\n",
    "    df.to_pickle(f\"figures/generated_sequences/dataframe_{name_data}.pkl\")        \n",
    "    if not fim_generation:   \n",
    "        with open(f\"figures/generated_sequences/all_structures_ctx_{name_data}.pkl\", \"wb\") as f:\n",
    "            pickle.dump(all_structures_ctx, f)  "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "df.to_pickle(f\"figures/generated_sequences/dataframe_{name_data}.pkl\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "df.head()"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "ProtMamba",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.18"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
