{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from matplotlib import pyplot as plt\n",
    "import numpy as np\n",
    "from tqdm import tqdm\n",
    "import pickle\n",
    "import pandas as pd\n",
    "import os\n",
    "import seaborn as sns\n",
    "\n",
    "sns.set_theme(\n",
    "    context='notebook', style='ticks', palette='bright',\n",
    "    color_codes=True)  #other contexts: “paper”, “talk”, and “poster”,\n",
    "\n",
    "# Plotting settings\n",
    "SMALL_SIZE = 15\n",
    "MEDIUM_SIZE = 20\n",
    "BIGGER_SIZE = 30\n",
    "\n",
    "plt.rcParams.update({\n",
    "    \"text.usetex\": False,\n",
    "    \"font.family\": \"sans-serif\",\n",
    "    \"font.serif\": [\"Arial\"],\n",
    "    \"font.size\": MEDIUM_SIZE,\n",
    "    \"axes.titlesize\": MEDIUM_SIZE,\n",
    "    \"axes.labelsize\": MEDIUM_SIZE,\n",
    "    \"figure.labelsize\": MEDIUM_SIZE,\n",
    "    \"figure.titlesize\": MEDIUM_SIZE,\n",
    "    \"xtick.labelsize\": SMALL_SIZE,\n",
    "    \"ytick.labelsize\": SMALL_SIZE,\n",
    "    \"legend.fontsize\": MEDIUM_SIZE,\n",
    "})\n",
    "\n",
    "color_ = [(64, 83, 211), (0, 178, 93), (181, 29, 20), (221, 179, 16), (0, 190, 255), (251, 73, 176), (202, 202, 202)]\n",
    "color =[]\n",
    "for t in color_:\n",
    "    color.append(tuple(ti/255 for ti in t))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# import dataframe from pickle file\n",
    "name_data = \"check-131k_gen_seqs_full\" # check-131k\n",
    "fim_generation = True if input(\"Are sequences generated using FIM? (y/n)\") == \"y\" else False\n",
    "\n",
    "df = pd.read_pickle(f\"figures/generated_sequences/dataframe_{name_data}.pkl\")\n",
    "families = df[\"family_id\"].unique()\n",
    "\n",
    "with open(f\"figures/generated_sequences/all_structures_representatives.pkl\", \"rb\") as f:\n",
    "    structures_representatives = pickle.load(f)\n",
    "    \n",
    "if not fim_generation:\n",
    "    # import baselines\n",
    "    with open(f\"figures/generated_sequences/all_hamming_ctx_{name_data}.pkl\", \"rb\") as f:\n",
    "        all_hamming_ctx = pickle.load(f)\n",
    "    with open(f\"figures/generated_sequences/all_hmmer_ctx_{name_data}.pkl\", \"rb\") as f:\n",
    "        all_hmmer_ctx = pickle.load(f)\n",
    "    with open(f\"figures/generated_sequences/all_structures_ctx_{name_data}.pkl\", \"rb\") as f:\n",
    "        all_structures_ctx = pickle.load(f)\n",
    "\n",
    "assert all_structures_ctx.keys() == all_hamming_ctx.keys() == all_hmmer_ctx.keys()\n",
    "assert list(all_structures_ctx.keys()) == list(df[\"family_id\"].unique())\n",
    " \n",
    "df.head()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from Bio import SeqIO\n",
    "def read_msa_unaligned(filename: str):\n",
    "    \"\"\" Reads the sequences from an MSA file, removes only . - and * characters.\"\"\"\n",
    "    return [(record.description, str(record.seq).replace(\".\",\"\").replace(\"-\",\"\").replace(\"*\",\"\").upper()) for record in SeqIO.parse(filename, \"fasta\")]\n",
    "# compute lengths of natural sequences\n",
    "lengths_natural = {}\n",
    "for family_id in families:\n",
    "    msa_filepath = f\"figures/pdb_structures/msas/{family_id}.a3m\"\n",
    "    msa = read_msa_unaligned(msa_filepath)\n",
    "    lengths_natural[family_id] = [len(seq) for _, seq in msa]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Plotting functions"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def plot_hamming_hmmer(fig,\n",
    "                       axs,\n",
    "                       family_id,\n",
    "                       df_family,\n",
    "                       lengths_nat,\n",
    "                       hamming_ctx,\n",
    "                       hmmer_ctx,\n",
    "                       structures_ctx,\n",
    "                       last_row=False):\n",
    "    lst_all_ham = [el for ham in df_family[\"hamming\"] for el in ham]\n",
    "    lst_min_ham = [min(ham) for ham in df_family[\"hamming\"]]\n",
    "    lst_all_ham_ctx = [el for ham in hamming_ctx for el in ham]\n",
    "    lst_min_ham_ctx = [min(ham) for ham in hamming_ctx]\n",
    "    lst_seq_len = [len(seq) for seq in df_family[\"generated_sequence\"]]\n",
    "    plddts_nat = np.array([structures_ctx[k][\"mean_plddt\"] for k in structures_ctx.keys()])\n",
    "    \n",
    "    # Sequence lengths\n",
    "    # axs[0].hist(df_family[\"sequence_length\"].to_list(), bins=30, color=color[0], alpha=0.7, density=True, label=\"Generated\")\n",
    "    # axs[0].hist(lengths_nat, bins=30, color=color[1], alpha=0.5, density=True, label=\"Natural\")\n",
    "    # # axs[0].hist([], bins=30, color=color[2], alpha=0.5, density=True, label=\"Generated (top 10%)\")\n",
    "    # axs[0].set_xlabel(\"Sequence length\")\n",
    "    axs[0].set_ylabel(family_id+\"\\nDensity\")#\"Family ID: \"+\n",
    "    # axs[0].set_title(\"Sequence lengths\")\n",
    "    # axs[0].legend(frameon=False)\n",
    "    \n",
    "    # Hamming distances\n",
    "    axs[0].hist(lst_all_ham, bins = 40, alpha=0.7, density=True, color=color[0], label=\"Generated\")\n",
    "    axs[0].hist(lst_all_ham_ctx, bins = 40, alpha=0.5, density=True, color=color[1], label=\"Natural\")\n",
    "    if last_row:\n",
    "        axs[0].set_xlabel(\"Hamming distance\")\n",
    "    # axs[1].set_ylabel(\"Density\")\n",
    "    # axs[1].set_title(\"Hamming distances\")\n",
    "    # axs[1].legend()\n",
    "    \n",
    "    # HMMER scores\n",
    "    axs[1].hist(df_family[\"score_gen\"], bins=40, alpha=0.7, density=True, color=color[0], label=\"Generated\")\n",
    "    axs[1].hist(hmmer_ctx[\"score\"], bins=40, alpha=0.5, density=True, color=color[1], label=\"Natural\")\n",
    "    if last_row:\n",
    "        axs[1].set_xlabel(\"HMMER score\")\n",
    "    # axs[2].set_ylabel(\"Density\")\n",
    "    # axs[2].set_title(\"HMMER scores\")\n",
    "    # axs[2].legend()\n",
    "    \n",
    "    # pLDDT\n",
    "    perplexities, plddts = df_family[\"perplexity\"].to_numpy(), df_family[\"mean_plddt_gen\"].to_numpy() \n",
    "    plddts, perplexities, lens = plddts[plddts>0], perplexities[plddts>0], np.array(lst_seq_len)[plddts>0]\n",
    "    mins, maxs = min(plddts.min(), min(plddts_nat)), max(plddts.max(), max(plddts_nat))\n",
    "    bins = np.linspace(mins, maxs, 40)\n",
    "    corr = np.corrcoef(perplexities,plddts)[0,1]\n",
    "    ind = np.argsort(perplexities)[:len(plddts)//10]\n",
    "    plddts_new, perplexities_new = plddts[ind], perplexities[ind]\n",
    "    axs[2].hist(df_family[\"mean_plddt_gen\"], bins=bins, alpha=0.7, density=True, color=color[0], label=\"Generated\")\n",
    "    axs[2].hist(plddts_new, bins=bins, alpha=0.5, density=True, color=color[2], label=\"Generated (top 10%)\")\n",
    "    axs[2].hist(plddts_nat, bins=bins, alpha=0.5, density=True, color=color[1], label=\"Natural\")\n",
    "    if last_row:\n",
    "        axs[2].set_xlabel(\"Mean pLDDT\")\n",
    "    # axs[3].set_ylabel(\"Density\")\n",
    "    # axs[3].set_title(\"Mean pLDDT\")\n",
    "    # axs[3].legend()\n",
    "    \n",
    "    # Perplexity vs min Hamming\n",
    "    im = axs[3].scatter(lst_min_ham, df_family[\"perplexity\"].to_list(), c=lst_seq_len, cmap=\"viridis\", alpha=0.8)\n",
    "    axs[3].axvline(np.median(lst_min_ham_ctx), color=color[1], linestyle=\"--\", linewidth=1.5)\n",
    "    axs[3].axvline(np.median(lst_min_ham), color=color[0], linestyle=\"--\", linewidth=1.5)\n",
    "    if last_row:\n",
    "        axs[3].set_xlabel(\"Min Hamming\")#distance to\\nclosest natural sequence\n",
    "    axs[3].set_ylabel(\"Perplexity\")\n",
    "    # axs[4].set_title(\"Perplexity vs Hamming\")\n",
    "    \n",
    "    # Perplexity vs HMMER score\n",
    "    im = axs[4].scatter(df_family[\"score_gen\"].to_list(), df_family[\"perplexity\"].to_list(), c=lst_seq_len, cmap=\"viridis\", alpha=0.8)\n",
    "    axs[4].axvline(np.median(hmmer_ctx[\"score\"]), color=color[1], linestyle=\"--\", linewidth=1.5)\n",
    "    axs[4].axvline(np.median(df_family[\"score_gen\"]), color=color[0], linestyle=\"--\", linewidth=1.5)\n",
    "    if last_row:\n",
    "        axs[4].set_xlabel(\"HMMER score\")\n",
    "    # axs[5].set_ylabel(\"Perplexity\")\n",
    "    # axs[5].set_title(\"Perplexity vs HMMER score\")\n",
    "    \n",
    "    # Perplexity vs pLDDT   \n",
    "    im = axs[5].scatter(plddts,perplexities, c=lens, cmap=\"viridis\", alpha=0.8)\n",
    "    axs[5].axvline(np.median(plddts_nat), color=color[1], linestyle=\"--\", linewidth=1.5)\n",
    "    axs[5].axvline(np.median(df_family[\"mean_plddt_gen\"]), color=color[0], linestyle=\"--\", linewidth=1.5)\n",
    "    axs[5].axvline(np.median(plddts_new), color=color[2], linestyle=\"--\", linewidth=1.5)\n",
    "    # axs[5].text(0.05, 0.08, f\"Correlation: {corr:.2f}\", transform=axs[5].transAxes, fontsize=MEDIUM_SIZE, verticalalignment='top')\n",
    "    if last_row:\n",
    "        axs[5].set_xlabel(\"Mean pLDDT\")\n",
    "    # axs[6].set_ylabel(\"Perplexity\")\n",
    "    # axs[6].set_title(\"Perplexity vs pLDDT\")\n",
    "\n",
    "    handles, labels = axs[2].get_legend_handles_labels()\n",
    "    fig.legend(handles, labels, loc='upper center', bbox_to_anchor=(0.5, 1.04), ncol=3, fancybox=False)\n",
    "    fig.colorbar(im, ax=axs[5], label=\"Sequence length\")\n",
    "    return fig, axs\n",
    "\n",
    "def compare_generation_parameters(fig,\n",
    "                                  axs,\n",
    "                                  family_id,\n",
    "                                  df_family,\n",
    "                                  lengths_nat,\n",
    "                                  hamming_ctx,\n",
    "                                  hmmer_ctx):\n",
    "    df_family[\"generation_parameters\"] = list(zip(df_family[\"temperature\"].to_list(),\n",
    "                                                  df_family[\"top_k\"].to_list(),\n",
    "                                                  df_family[\"top_p\"].to_list()))\n",
    "    gen_par = df_family[\"generation_parameters\"].unique()\n",
    "    color_p = color if len(color)>=len(gen_par) else sns.color_palette(\"viridis\", len(gen_par))\n",
    "    ctx_len = df_family[\"n_seqs_ctx\"].unique()\n",
    "    lst_all_ham_ctx = [el for ham in hamming_ctx for el in ham]\n",
    "    lst_min_ham_ctx = [min(ham) for ham in hamming_ctx]\n",
    "    for j, ctx_l in enumerate(ctx_len):\n",
    "        df_ctx = df_family[df_family[\"n_seqs_ctx\"] == ctx_l]\n",
    "        ax = axs[j]\n",
    "        for i, gen_p in enumerate(gen_par):\n",
    "            df_gen_p = df_ctx[df_ctx[\"generation_parameters\"] == gen_p]\n",
    "            lst_all_ham = [el for ham in df_gen_p[\"hamming\"] for el in ham]\n",
    "            lst_min_ham = [min(ham) for ham in df_gen_p[\"hamming\"]]\n",
    "            lst_seq_len = [len(seq) for seq in df_gen_p[\"generated_sequence\"]]\n",
    "            # Perplexity vs Sequence lengths\n",
    "            im0 = ax[0].scatter(lst_seq_len, df_gen_p[\"perplexity\"].to_list(), color=color_p[i], label=gen_p)\n",
    "            # Perplexity vs min Hamming\n",
    "            im1 = ax[1].scatter(lst_min_ham, df_gen_p[\"perplexity\"].to_list(), color=color_p[i], label=gen_p)\n",
    "            # Perplexity vs HMMER score\n",
    "            im2 = ax[2].scatter(df_gen_p[\"score_gen\"], df_gen_p[\"perplexity\"].to_list(), color=color_p[i], label=gen_p)\n",
    "\n",
    "        ax[0].axvline(np.median(lengths_nat), color=\"k\", linestyle=\"--\", linewidth=1.5, label=\"Natural\")\n",
    "        ax[0].set_ylabel(f\"Seqs in ctx: {ctx_l}\")\n",
    "        ax[1].axvline(np.median(lst_min_ham_ctx), color=\"k\", linestyle=\"--\", linewidth=1.5)\n",
    "        ax[2].axvline(np.median(hmmer_ctx[\"score\"]), color=\"k\", linestyle=\"--\", linewidth=1.5)\n",
    "    axs[0, 0].set_title(\"Perplexity vs Sequence length\")\n",
    "    axs[0, 1].set_title(\"Perplexity vs Hamming\")\n",
    "    axs[0, 2].set_title(\"Perplexity vs HMMER score\")\n",
    "    axs[-1,0].set_xlabel(\"Sequence length\")\n",
    "    axs[-1,1].set_xlabel(\"Hamming distance to\\nclosest natural sequence\")\n",
    "    axs[-1,2].set_xlabel(\"HMMER score\")\n",
    "    \n",
    "    # Unique legend for the entire figure on the top of all subplots\n",
    "    handles, labels = axs[0, 0].get_legend_handles_labels()\n",
    "    fig.legend(handles, labels, loc='upper center', bbox_to_anchor=(0.5, 1.1), ncol=5, fancybox=False)\n",
    "    fig.supylabel(\"Perplexity\")\n",
    "    return fig, axs\n",
    "\n",
    "def plot_comparison(fig,\n",
    "                       axs,\n",
    "                       family_id,\n",
    "                       df_family,\n",
    "                       lengths_nat,\n",
    "                       hamming_ctx,\n",
    "                       hmmer_ctx,\n",
    "                       structures_ctx):\n",
    "    def make_formula(val,std):\n",
    "        return \"$\"+f\"{round(val,2)} \\pm {round(std,2)}\"+\"$\"\n",
    "    \n",
    "    # lst_all_ham = [el for ham in df_family[\"hamming\"] for el in ham]\n",
    "    # lst_min_ham = [min(ham) for ham in df_family[\"hamming\"]]\n",
    "    # lst_all_ham_ctx = [el for ham in hamming_ctx for el in ham]\n",
    "    lst_min_ham_ctx = [min(ham) for ham in hamming_ctx]\n",
    "    lst_seq_len = [len(seq) for seq in df_family[\"generated_sequence\"]]\n",
    "\n",
    "    perplexities = df_family[\"perplexity\"].to_numpy()\n",
    "    plddt_ref = structures_representatives[family_id][\"mean_plddt\"]\n",
    "    ptm_ref = structures_representatives[family_id][\"ptm\"]\n",
    "    plddts_gen = df_family[\"mean_plddt_gen\"].to_numpy()\n",
    "    ptm_gen = df_family[\"ptm_gen\"].to_numpy()\n",
    "    plddts_nat = np.array([structures_ctx[k][\"mean_plddt\"] for k in structures_ctx.keys()])\n",
    "    ptm_nat = np.array([structures_ctx[k][\"ptm\"] for k in structures_ctx.keys()])\n",
    "\n",
    "    corr2 = np.corrcoef(perplexities,df_family[\"score_gen\"].to_numpy())[0,1]\n",
    "    corr3 = np.corrcoef(perplexities,df_family[\"min_hamming\"].to_numpy())[0,1]\n",
    "    corr4 = np.corrcoef(perplexities,lst_seq_len)[0,1]    \n",
    "    plddts_gen, ptm_gen, perplexities = plddts_gen[plddts_gen>0], ptm_gen[plddts_gen>0], perplexities[plddts_gen>0]\n",
    "    corr0 = np.corrcoef(perplexities,plddts_gen)[0,1]\n",
    "    corr1 = np.corrcoef(perplexities,ptm_gen)[0,1]\n",
    "    ind = np.argsort(perplexities)[:100]#len(plddts_gen)//10\n",
    "    ind1 = np.argsort(perplexities)[:len(hmmer_ctx[\"score\"])]\n",
    "    plddts_gen, ptm_gen, perplexities = plddts_gen[ind], ptm_gen[ind], perplexities[ind]\n",
    "    hmmer_gen, ham_gen, lens_gen = df_family[\"score_gen\"].to_numpy()[ind1], df_family[\"min_hamming\"].to_numpy()[ind], np.array(lst_seq_len)\n",
    "    \n",
    "    # print(family_id, f\" & ${round(ptm_ref, 2)}$ & \",  make_formula(np.median(ptm_gen),np.std(ptm_gen)), \" & \", \\\n",
    "    #                                                     make_formula(np.median(ptm_nat),np.std(ptm_nat)), \" & \", \\\n",
    "    #                  f\" & ${round(plddt_ref, 2)}$ & \", \\\n",
    "    #                                                     make_formula(np.median(plddts_gen),np.std(plddts_gen)), \" & \", \\\n",
    "    #                                                     make_formula(np.median(plddts_nat),np.std(plddts_nat)), \" & \", \\\n",
    "    #                                                     f\"${round(corr0,2)}$ \\\\\\\\\")\n",
    "    print(family_id, f\" & ${round(corr3,2)}$ & \", f\"${round(corr2,2)}$ & \", f\"${round(corr0,2)}$ & \", f\"${round(corr1,2)}$ \\\\\\\\\")\n",
    "    corrs = [corr3, corr2, corr0, corr1]\n",
    "    # Sequence lengths\n",
    "    axs[0].errorbar(np.median(lens_gen), np.median(lengths_nat), xerr=np.std(lens_gen), yerr=np.std(lengths_nat), fmt='o', color=color[0], label=family_id, capsize=2, elinewidth=0.8, markersize=4)\n",
    "    # axs[0].plot([0, 1], [0, 1], transform=axs[0].transAxes, linestyle=\"--\", color=\"k\")    \n",
    "    axs[0].set_title(\"Sequence length\")\n",
    "    \n",
    "    # Min Hamming distances\n",
    "    axs[1].errorbar(np.median(ham_gen), np.median(lst_min_ham_ctx), xerr=np.std(ham_gen), yerr=np.std(lst_min_ham_ctx), fmt='o', color=color[1], label=family_id, capsize=2, elinewidth=0.8, markersize=4)\n",
    "    # axs[1].plot([0, 1], [0, 1], transform=axs[1].transAxes, linestyle=\"--\", color=\"k\")\n",
    "    axs[1].set_title(\"Min Hamming\")\n",
    "    \n",
    "    # HMMER scores\n",
    "    mmin, mmax = min(min(hmmer_gen), min(hmmer_ctx[\"score\"])), max(max(hmmer_gen), max(hmmer_ctx[\"score\"]))\n",
    "    rescaled_hmmer_gen = (hmmer_gen - mmin) / (mmax - mmin)\n",
    "    rescaled_hmmer_ctx = (np.array(hmmer_ctx[\"score\"]) - mmin) / (mmax - mmin)\n",
    "    axs[2].errorbar(np.median(rescaled_hmmer_gen), np.median(rescaled_hmmer_ctx), xerr=np.std(rescaled_hmmer_gen), yerr=np.std(rescaled_hmmer_ctx), fmt='o', color=color[2], label=family_id, capsize=2, elinewidth=0.8, markersize=4)\n",
    "    # axs[2].plot([0, 1], [0, 1], transform=axs[2].transAxes, linestyle=\"--\", color=\"k\")\n",
    "    axs[2].set_title(\"HMMER score\")\n",
    "    \n",
    "    # pLDDTs\n",
    "    axs[3].errorbar(np.median(plddts_gen), np.median(plddts_nat), xerr=np.std(plddts_gen), yerr=np.std(plddts_nat), fmt='o', color=color[3], label=family_id, capsize=2, elinewidth=0.8, markersize=4)\n",
    "    # axs[1].text(np.median(plddts_gen), np.median(plddts_nat), family_id, fontsize=SMALL_SIZE)\n",
    "    # axs[3].plot([0, 1], [0, 1], transform=axs[3].transAxes, linestyle=\"--\", color=\"k\")\n",
    "    axs[3].set_title(\"pLDDT\")\n",
    "    \n",
    "    # pTMs\n",
    "    axs[4].errorbar(np.median(ptm_gen), np.median(ptm_nat), xerr=np.std(ptm_gen), yerr=np.std(ptm_nat), fmt='o', color=color[4], label=family_id, capsize=2, elinewidth=0.8, markersize=4)\n",
    "    # axs[4].plot([0, 1], [0, 1], transform=axs[4].transAxes, linestyle=\"--\", color=\"k\")\n",
    "    axs[4].set_title(\"pTM\")\n",
    "    \n",
    "    # axs[0].set_xlabel(\"generated\")\n",
    "    # axs[0].set_ylabel(\"natural\")   \n",
    "    fig.supxlabel(\"Generated\")\n",
    "    fig.supylabel(\"Natural\")\n",
    "    return fig, axs, corrs"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def plot_scores_fim(fig,\n",
    "                    axs,\n",
    "                    family_id,\n",
    "                    df_family):\n",
    "    # pLDDT masked parts\n",
    "    im = axs[0].scatter(df_family[\"masked_plddt_orig\"].to_list(), df_family[\"masked_plddt_gen\"].to_list(), c=df_family[\"perplexity\"].to_list(), cmap=\"viridis\")\n",
    "    xvals, yvals = df_family[\"masked_plddt_orig\"].to_list(), df_family[\"masked_plddt_gen\"].to_list()\n",
    "    mins, maxs = min(min(xvals),min(yvals)), max(max(xvals),max(yvals))\n",
    "    axs[0].plot([mins,maxs],[mins,maxs], \"k--\")\n",
    "    axs[0].set_xlabel(\"Masked pLDDT original\")\n",
    "    axs[0].set_ylabel(\"Family ID: \"+family_id+\"\\nMasked pLDDT generated\")\n",
    "    \n",
    "    fig.colorbar(im, ax=axs[0], label=\"Perplexity\\n\")\n",
    "    # col_par = (df_family[\"fim_distance\"]).to_list()\n",
    "    col_par = df_family[\"perplexity\"].to_list()\n",
    "    # HMMER scores\n",
    "    im = axs[1].scatter(df_family[\"fim_size\"].to_list(),(df_family[\"score_gen\"]-df_family[\"score_orig\"]).to_list(), c=col_par, cmap=\"viridis\")\n",
    "    axs[1].set_ylabel(\"HMMER score difference\")\n",
    "    axs[1].set_xlabel(\"FIM size\")\n",
    "    \n",
    "    # pTM scores\n",
    "    im = axs[2].scatter(df_family[\"fim_size\"].to_list(),(df_family[\"ptm_gen\"]-df_family[\"ptm_orig\"]).to_list(), c=col_par, cmap=\"viridis\")\n",
    "    axs[2].set_ylabel(\"pTM difference\")\n",
    "    axs[2].set_xlabel(\"FIM size\")\n",
    "    \n",
    "    # Perplexity vs FIM distance relative to original\n",
    "    im = axs[3].scatter(df_family[\"fim_size\"].to_list(), df_family[\"perplexity\"].to_list(), c=col_par, cmap=\"viridis\")\n",
    "    axs[3].set_xlabel(\"FIM size\")\n",
    "    axs[3].set_ylabel(\"Perplexity\")\n",
    "    \n",
    "    # Perplexity vs TMscore\n",
    "    im = axs[4].scatter(df_family[\"fim_size\"].to_list(), df_family[\"tmscore_orig_gen\"].to_list(), c=col_par, cmap=\"viridis\")\n",
    "    axs[4].set_ylabel(\"TMscore between orig. and gen.\")\n",
    "    axs[4].set_xlabel(\"FIM size\")\n",
    "       \n",
    "    fig.colorbar(im, ax=axs[4], label=\"Perplexity\")\n",
    "    return fig, axs\n",
    "\n",
    "def plot_scores_hist_fim(fig,\n",
    "                        axs,\n",
    "                        family_id,\n",
    "                        df_family):\n",
    "    perplexity = df_family[\"perplexity\"].to_numpy()\n",
    "    inds = np.argsort(perplexity)[:len(perplexity)//5]\n",
    "    inds1 = np.argwhere(df_family[\"fim_distance\"].to_numpy()>0)\n",
    "    # take all indices that are both in inds and inds1\n",
    "    inds = np.intersect1d(inds, inds1)\n",
    "    bins = np.concatenate([-np.linspace(0,1,20)[::-1],np.linspace(0,1,20)[1:]])\n",
    "    # pLDDT masked parts\n",
    "    full_max = max(-(df_family[\"masked_plddt_gen\"]-df_family[\"masked_plddt_orig\"]).min(), (df_family[\"masked_plddt_gen\"]-df_family[\"masked_plddt_orig\"]).max())\n",
    "    im = axs[0].hist((df_family[\"masked_plddt_gen\"]-df_family[\"masked_plddt_orig\"]).to_numpy(), bins=bins*full_max)\n",
    "    im = axs[0].hist((df_family[\"masked_plddt_gen\"]-df_family[\"masked_plddt_orig\"]).to_numpy()[inds], bins=im[1])\n",
    "    f0 = round(np.sum((df_family['masked_plddt_gen']-df_family['masked_plddt_orig']).to_numpy()>0)/len(df_family),2)\n",
    "    f1 = round(np.sum((df_family['masked_plddt_gen']-df_family['masked_plddt_orig']).to_numpy()[inds]>0)/len(inds),2)\n",
    "    axs[0].text(0.6, 0.8, f\"F0: {f0}\\nF1: {f1}\", transform=axs[0].transAxes)\n",
    "    axs[0].set_xlabel(\"Masked pLDDT difference\")\n",
    "    axs[0].set_ylabel(\"Family ID: \"+family_id)\n",
    "    \n",
    "    # HMMER scores\n",
    "    full_max = max(-(df_family[\"score_gen\"]-df_family[\"score_orig\"]).min(), (df_family[\"score_gen\"]-df_family[\"score_orig\"]).max())\n",
    "    im = axs[1].hist((df_family[\"score_gen\"]-df_family[\"score_orig\"]).to_numpy(), bins=bins*full_max)\n",
    "    im = axs[1].hist((df_family[\"score_gen\"]-df_family[\"score_orig\"]).to_numpy()[inds], bins=im[1])\n",
    "    f0 = round(np.sum((df_family['score_gen']-df_family['score_orig']).to_numpy()>0)/len(df_family),2)\n",
    "    f1 = round(np.sum((df_family['score_gen']-df_family['score_orig']).to_numpy()[inds]>0)/len(inds),2)\n",
    "    axs[1].text(0.6, 0.8, f\"F0: {f0}\\nF1: {f1}\", transform=axs[1].transAxes)\n",
    "    axs[1].set_xlabel(\"HMMER score difference\")\n",
    "    \n",
    "    # pTM scores\n",
    "    full_max = max(-(df_family[\"ptm_gen\"]-df_family[\"ptm_orig\"]).min(), (df_family[\"ptm_gen\"]-df_family[\"ptm_orig\"]).max())\n",
    "    im = axs[2].hist((df_family[\"ptm_gen\"]-df_family[\"ptm_orig\"]).to_numpy(), bins=bins*full_max)\n",
    "    im = axs[2].hist((df_family[\"ptm_gen\"]-df_family[\"ptm_orig\"]).to_numpy()[inds], bins=im[1])\n",
    "    f0 = round(np.sum((df_family['ptm_gen']-df_family['ptm_orig']).to_numpy()>0)/len(df_family),2)\n",
    "    f1 = round(np.sum((df_family['ptm_gen']-df_family['ptm_orig']).to_numpy()[inds]>0)/len(inds),2)\n",
    "    axs[2].text(0.6, 0.8, f\"F0: {f0}\\nF1: {f1}\", transform=axs[2].transAxes)\n",
    "    axs[2].set_xlabel(\"pTM difference\")\n",
    "    \n",
    "    # Perplexity vs FIM distance relative to original\n",
    "    im = axs[3].hist(df_family[\"fim_size\"].to_numpy(), bins=40)\n",
    "    im = axs[3].hist(df_family[\"fim_size\"].to_numpy()[inds], bins=im[1])\n",
    "    axs[3].set_xlabel(\"FIM size\")\n",
    "    \n",
    "    # Perplexity vs TMscore\n",
    "    im = axs[4].hist(df_family[\"tmscore_orig_gen\"].to_numpy(), bins=40)\n",
    "    im = axs[4].hist(df_family[\"tmscore_orig_gen\"].to_numpy()[inds], bins=im[1])\n",
    "    axs[4].set_xlabel(\"TMscore between orig. and gen.\")\n",
    "\n",
    "    return fig, axs\n",
    "\n",
    "def compare_generation_parameters_fim(fig,\n",
    "                                      axs,\n",
    "                                      family_id,\n",
    "                                      df_family):\n",
    "    df_family[\"generation_parameters\"] = list(zip(df_family[\"temperature\"].to_list(),\n",
    "                                                  df_family[\"top_k\"].to_list(),\n",
    "                                                  df_family[\"top_p\"].to_list()))\n",
    "    gen_par = df_family[\"generation_parameters\"].unique()\n",
    "    color_p = sns.color_palette(\"viridis\", len(gen_par)) #color if len(color)>=len(gen_par) else \n",
    "    ctx_len = df_family[\"n_seqs_ctx\"].unique()\n",
    "    for j, ctx_l in enumerate(ctx_len):\n",
    "        df_ctx = df_family[df_family[\"n_seqs_ctx\"] == ctx_l]\n",
    "        ax = axs[j]\n",
    "        for i, gen_p in enumerate(gen_par):\n",
    "            df_gen_p = df_ctx[df_ctx[\"generation_parameters\"] == gen_p]\n",
    "            # pLDDT masked parts\n",
    "            im0 = ax[0].scatter(df_gen_p[\"masked_plddt_orig\"].to_list(), df_gen_p[\"masked_plddt_gen\"].to_list(), color=color_p[i], label=gen_p, alpha=0.8)\n",
    "            xvals, yvals = df_gen_p[\"masked_plddt_orig\"].to_list(), df_gen_p[\"masked_plddt_gen\"].to_list()\n",
    "            mins, maxs = min(min(xvals),min(yvals)), max(max(xvals),max(yvals))\n",
    "            ax[0].plot([mins,maxs],[mins,maxs], \"k--\")\n",
    "            # HMMER scores\n",
    "            im1 = ax[1].scatter((df_gen_p[\"score_gen\"]-df_gen_p[\"score_orig\"]).to_list(), df_gen_p[\"perplexity\"].to_list(), color=color_p[i], label=gen_p, alpha=0.8)\n",
    "            # pTM scores\n",
    "            im2 = ax[2].scatter((df_gen_p[\"ptm_gen\"]-df_gen_p[\"ptm_orig\"]).to_list(), df_gen_p[\"perplexity\"].to_list(), color=color_p[i], label=gen_p, alpha=0.8)\n",
    "            # Perplexity vs FIM distance relative to original\n",
    "            im3 = ax[3].scatter(df_gen_p[\"fim_distance\"].to_list(), df_gen_p[\"perplexity\"].to_list(), color=color_p[i], label=gen_p, alpha=0.8)\n",
    "            # Perplexity vs TMscore\n",
    "            im4 = ax[4].scatter(df_gen_p[\"tmscore_orig_gen\"].to_list(), df_gen_p[\"perplexity\"].to_list(), color=color_p[i], label=gen_p, alpha=0.8)\n",
    "        ax[0].set_ylabel(f\"Seqs in ctx: {ctx_l}\"+\"\\nMasked pLDDT generated\")\n",
    "        ax[1].set_ylabel(\"Perplexity\")\n",
    "    axs[-1,0].set_xlabel(\"Masked pLDDT original\")\n",
    "    axs[-1,1].set_xlabel(\"HMMER score difference\")\n",
    "    axs[-1,2].set_xlabel(\"pTM difference\")\n",
    "    axs[-1,3].set_xlabel(\"FIM distance relative to original\")\n",
    "    axs[-1,4].set_xlabel(\"TMscore between orig. and gen.\")\n",
    "    \n",
    "    # Unique legend for the entire figure on the top of all subplots\n",
    "    handles, labels = axs[0, 0].get_legend_handles_labels()\n",
    "    fig.legend(handles, labels, loc='upper center', bbox_to_anchor=(0.5, 1.1), ncol=6, fancybox=False)\n",
    "    return fig, axs"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## AR generated"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "fig, axs = plt.subplots(1, 5, figsize=(13, 3), constrained_layout=True)\n",
    "for ax in axs:\n",
    "    ax.set_box_aspect(1)\n",
    "all_corr = []\n",
    "for i, family_id in enumerate(families):\n",
    "    df_family = df[df[\"family_id\"] == family_id]\n",
    "    lengths_nat = lengths_natural[family_id]\n",
    "    hamming_ctx = all_hamming_ctx[family_id]\n",
    "    hmmer_ctx = all_hmmer_ctx[family_id]\n",
    "    structures_ctx = all_structures_ctx[family_id]\n",
    "    fig, tmp_axs, corrs = plot_comparison(fig, axs, family_id, df_family, lengths_nat, hamming_ctx, hmmer_ctx, structures_ctx)\n",
    "    all_corr += [np.array(corrs)]\n",
    "all_corr = np.array(all_corr)\n",
    "\n",
    "for i in range(len(axs)):\n",
    "    ymin,ymax = tmp_axs[i].get_ylim()\n",
    "    xmin, xmax = tmp_axs[i].get_xlim()\n",
    "    mmin, mmax = min(xmin,ymin), max(xmax,ymax)\n",
    "    tmp_axs[i].set_xlim(mmin, mmax)\n",
    "    tmp_axs[i].set_ylim(mmin, mmax)\n",
    "    tmp_axs[i].set_aspect('equal', adjustable='box')\n",
    "\n",
    "tmp_axs[0].plot([0,500],[0,500], \"k--\")\n",
    "tmp_axs[1].plot([0.2,0.8],[0.2,0.8], \"k--\")\n",
    "tmp_axs[2].plot([0,0.8],[0,0.8], \"k--\")\n",
    "tmp_axs[3].plot([0.5,1],[0.5,1], \"k--\")\n",
    "tmp_axs[4].plot([0.,1],[0.,1], \"k--\")\n",
    "# print(\"Mean\", f\" & ${round(all_corr[0],2)}$ & \", f\"${round(all_corr[1],2)}$ & \", f\"${round(all_corr[2],2)}$ & \", f\"${round(all_corr[3],2)}$ \\\\\\\\\")\n",
    "plt.show()\n",
    "fig.savefig(f\"figures/generated_sequences/comparison_{name_data}.pdf\", bbox_inches='tight')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "inds = np.argsort(abs(all_corr).mean(1))\n",
    "\n",
    "fig,axs = plt.subplots(1,1,figsize=(5,8))\n",
    "plt.plot(all_corr[inds,0], np.arange(len(families)), \"s\", color=color[1], label=\"Hamming\", markersize=8)\n",
    "plt.plot(-all_corr[inds,1], np.arange(len(families)), \"^\", color=color[2], label=\"HMMER\", markersize=8)\n",
    "plt.plot(-all_corr[inds,2], np.arange(len(families)), \"v\", color=color[3], label=\"pLDDT\", markersize=8)\n",
    "plt.plot(-all_corr[inds,3], np.arange(len(families)), \"o\", color=color[4], label=\"pTM\", markersize=8)\n",
    "plt.plot(abs(all_corr[inds,:]).mean(1), np.arange(len(families)), \"*\", color=\"k\", label=\"Mean\", markersize=15)\n",
    "plt.yticks(np.arange(len(families)), [families[i] for i in inds])\n",
    "plt.xlabel(\"Pearson Correlation\")\n",
    "plt.legend(frameon=False)\n",
    "fig.savefig(f\"figures/generated_sequences/comparison_{name_data}_correlation.pdf\", bbox_inches='tight')\n",
    "plt.show()\n",
    "\n",
    "# invert x and y axes\n",
    "fig, axs = plt.subplots(1, 1, figsize=(12, 4), constrained_layout=True)\n",
    "plt.plot(np.arange(len(families)), all_corr[inds,0], \"s\", color=color[1], label=\"Hamming\", markersize=8)\n",
    "plt.plot(np.arange(len(families)), -all_corr[inds,1], \"^\", color=color[2], label=\"HMMER\", markersize=8)\n",
    "plt.plot(np.arange(len(families)), -all_corr[inds,2], \"v\", color=color[3], label=\"pLDDT\", markersize=8)\n",
    "plt.plot(np.arange(len(families)), -all_corr[inds,3], \"o\", color=color[4], label=\"pTM\", markersize=8)\n",
    "plt.plot(np.arange(len(families)), abs(all_corr[inds,:]).mean(1), \"*\", color=\"k\", label=\"Mean\", markersize=15)\n",
    "plt.xticks(np.arange(len(families)), [families[i] for i in inds], rotation=90)\n",
    "plt.ylabel(\"Pearson Correlation\")\n",
    "plt.legend(frameon=False, bbox_to_anchor=(1.02, -0.05), loc='lower right', ncol=3)\n",
    "# fig.savefig(f\"figures/generated_sequences/comparison_{name_data}_correlation.pdf\", bbox_inches='tight')\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "fig, axs = plt.subplots(len(families), 6, figsize=(18, 2.5*len(families)), constrained_layout=True)\n",
    "\n",
    "for i, family_id in enumerate(families):\n",
    "    df_family = df[df[\"family_id\"] == family_id]\n",
    "    lengths_nat = lengths_natural[family_id]\n",
    "    hamming_ctx = all_hamming_ctx[family_id]\n",
    "    hmmer_ctx = all_hmmer_ctx[family_id]\n",
    "    structures_ctx = all_structures_ctx[family_id]\n",
    "\n",
    "    tmp_axs = axs[i]\n",
    "    fig, tmp_axs = plot_hamming_hmmer(fig, tmp_axs, family_id, df_family, lengths_nat, hamming_ctx, hmmer_ctx, structures_ctx,last_row=(True if i==len(families)-1 else False))\n",
    "plt.show()\n",
    "fig.savefig(f\"figures/generated_sequences/scatter_{name_data}.pdf\", bbox_inches='tight')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "ctx_len = df[\"n_seqs_ctx\"].unique()\n",
    "for i, family_id in enumerate(families):\n",
    "    fig, axs = plt.subplots(len(ctx_len), 3, figsize=(15, 5*len(ctx_len)), sharex=\"col\", sharey=True, constrained_layout=True)\n",
    "    df_family = df[df[\"family_id\"] == family_id]\n",
    "    lengths_nat = lengths_natural[family_id]\n",
    "    hamming_ctx = all_hamming_ctx[family_id]\n",
    "    hmmer_ctx = all_hmmer_ctx[family_id]\n",
    "    fig, axs = compare_generation_parameters(fig, axs, family_id, df_family, lengths_nat, hamming_ctx, hmmer_ctx)\n",
    "    plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## FIM generated"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "fig, axs = plt.subplots(len(families), 5, figsize=(25, 5*len(families)), constrained_layout=True)\n",
    "\n",
    "for i, family_id in enumerate(families):\n",
    "    df_family = df[df[\"family_id\"] == family_id]\n",
    "\n",
    "    tmp_axs = axs[i]\n",
    "    fig, tmp_axs = plot_scores_fim(fig, tmp_axs, family_id, df_family)\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "fig, axs = plt.subplots(len(families), 5, figsize=(25, 5*len(families)), constrained_layout=True)\n",
    "\n",
    "for i, family_id in enumerate(families):\n",
    "    df_family = df[df[\"family_id\"] == family_id]\n",
    "\n",
    "    tmp_axs = axs[i]\n",
    "    fig, tmp_axs = plot_scores_hist_fim(fig, tmp_axs, family_id, df_family)\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "ctx_l = df[\"n_seqs_ctx\"].unique()\n",
    "for i, family_id in enumerate(families):\n",
    "    fig, axs = plt.subplots(len(ctx_l), 5, figsize=(25, 5*len(ctx_l)), sharex=\"col\", sharey=\"col\", constrained_layout=True)\n",
    "    df_family = df[df[\"family_id\"] == family_id]\n",
    "    fig, axs = compare_generation_parameters_fim(fig, axs, family_id, df_family)\n",
    "    plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "ProtMamba",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
