{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7fea02ca",
   "metadata": {},
   "outputs": [],
   "source": [
    "import json\n",
    "import pandas as pd\n",
    "import faiss\n",
    "import numpy as np\n",
    "from tqdm import tqdm\n",
    "\n",
    "import sys\n",
    "sys.path.append(\"igloo\")\n",
    "from evals.metrics import dihedral_distance\n",
    "from evals.align_loops import kabsch_numpy\n",
    "\n",
    "EMBEDDING_DIR = \"benchmarking_data/paratope_binning/\"\n",
    "LOOP_TYPE = \"H3\""
   ]
  },
  {
   "cell_type": "markdown",
   "id": "bb7d9fbf",
   "metadata": {},
   "source": [
    "### Save Loop Type Sequences"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4e9c9644",
   "metadata": {},
   "outputs": [],
   "source": [
    "# select_loops = []\n",
    "# with open(f\"data/test_loop_len_all_seed_42.jsonl\", \"r\") as f:\n",
    "#     for line in f:\n",
    "#         item = json.loads(line)\n",
    "#         if item['loop_id'].endswith(LOOP_TYPE):\n",
    "#             select_loops.append(item)\n",
    "# select_loops_ids = [int(item['loop_id'].split(\"_\")[0]) for item in select_loops]\n",
    "# print(f\"Number of test {LOOP_TYPE} loops: {len(select_loops_ids)}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "66b4885d",
   "metadata": {},
   "outputs": [],
   "source": [
    "# # for the files with resolution 3.5 suffix, the sabdab_id is wrong since it was numberd after filtering out for resolution\n",
    "# # fix it by using ab_fname for id mapping\n",
    "# raw_df = pd.read_parquet(\"preprocessed_data/sabdab_2025-05-06-paired.parquet\")\n",
    "# raw_df['sabdab_id'] = range(len(raw_df))\n",
    "# ab_fname_to_id = {fname: sabdab_id for fname, sabdab_id in zip(raw_df['ab_fname'], raw_df['sabdab_id'])}\n",
    "\n",
    "# data_df = pd.read_parquet(\"preprocessed_data/sabdab_2025-05-06-paired_chains_resolution_3.5.parquet\")\n",
    "# data_df['sabdab_id'] = data_df['ab_fname'].map(ab_fname_to_id)\n",
    "# data_df = data_df[data_df['chain_id'] == LOOP_TYPE[0]]\n",
    "# data_df = data_df[['sabdab_id', f'CDR{LOOP_TYPE[1]}_start', f'CDR{LOOP_TYPE[1]}_end', 'sequence']]\n",
    "# data_df = data_df.rename(columns={f'CDR{LOOP_TYPE[1]}_start': 'start', f'CDR{LOOP_TYPE[1]}_end': 'end', 'sabdab_id': 'loop_id'})\n",
    "# data_df['loop_sequence'] = data_df.apply(lambda row: row['sequence'][row['start']:row['end']], axis=1)\n",
    "# data_df['test'] = data_df['loop_id'].isin(select_loops_ids)\n",
    "\n",
    "# data_df.to_csv(f\"data/sabdab_{LOOP_TYPE}_loops.csv\", index=False)\n",
    "# print(f\"Saved {LOOP_TYPE} loops ({len(data_df)}) to data/sabdab_{LOOP_TYPE}_loops.csv\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "b1ba8759",
   "metadata": {},
   "source": [
    "### Add 3di information"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ddb126d4",
   "metadata": {},
   "outputs": [],
   "source": [
    "# fasta_3di = \"sabdab/sabdab_db_ss.fasta\"\n",
    "# seqeunces_3di = {}\n",
    "# with open(fasta_3di, \"r\") as f:\n",
    "#     for line in f:\n",
    "#         if line.startswith(\">\"):\n",
    "#             fname = line.strip().replace(\">\", \"\")\n",
    "#             seqeunces_3di[fname] = \"\"\n",
    "#         else:\n",
    "#             seqeunces_3di[fname] += line.strip()\n",
    "\n",
    "# raw_df = pd.read_parquet(\"preprocessed_data/sabdab_2025-05-06-paired.parquet\")\n",
    "# raw_df['sabdab_id'] = range(len(raw_df))\n",
    "# ab_fname_to_id = {fname: sabdab_id for fname, sabdab_id in zip(raw_df['ab_fname'], raw_df['sabdab_id'])}\n",
    "# id_to_ab_fname = {sabdab_id: fname for fname, sabdab_id in ab_fname_to_id.items()}\n",
    "\n",
    "# data_df = pd.read_csv(f\"data/sabdab_{LOOP_TYPE}_loops.csv\")\n",
    "# data_df['ab_fname'] = data_df['loop_id'].map(id_to_ab_fname)\n",
    "# data_df['ab_fname_chain'] = data_df['ab_fname'].str.replace(\".pdb\", f\"_{LOOP_TYPE[0]}\")\n",
    "# data_df.loc[~data_df['ab_fname_chain'].isin(seqeunces_3di), 'ab_fname_chain'] = data_df['ab_fname'].str.replace(\".pdb\", \"\")\n",
    "# assert data_df['ab_fname_chain'].isin(seqeunces_3di).all(), \"Some ab_fname_chain are not in the 3di sequences\"\n",
    "# data_df['3di_sequence'] = data_df['ab_fname_chain'].map(seqeunces_3di)\n",
    "# data_df.to_csv(f\"data/sabdab_{LOOP_TYPE}_loops_with_3di.csv\", index=False)\n",
    "# print(f\"Saved {LOOP_TYPE} loops with 3di sequences ({len(data_df)}) to data/sabdab_{LOOP_TYPE}_loops_with_3di.csv\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "c5925f54",
   "metadata": {},
   "source": [
    "### Add angle information"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ccd388e7",
   "metadata": {},
   "outputs": [],
   "source": [
    "# loop_df = pd.read_parquet(\"preprocessed_data/sabdab_2025-05-06-paired_loops.parquet\")\n",
    "# loop_df = loop_df[loop_df['loop_type'] == LOOP_TYPE]\n",
    "\n",
    "# data_df = pd.read_csv(f\"data/sabdab_{LOOP_TYPE}_loops.csv\")\n",
    "# data_df_with_angles = data_df.merge(loop_df[['sabdab_id', 'phi', 'psi', 'omega', 'c_alpha_atoms', 'stem_c_alpha_atoms']], left_on='loop_id', right_on='sabdab_id', how='inner')\n",
    "# data_df_with_angles.rename(columns={'c_alpha_atoms': 'loop_c_alpha_atoms'}, inplace=True)\n",
    "# data_df_with_angles[['loop_id', 'loop_sequence', 'phi', 'psi', 'omega', 'loop_c_alpha_atoms', 'stem_c_alpha_atoms']].to_parquet(f\"data/sabdab_{LOOP_TYPE}_loops_with_angles.parquet\", index=False)\n",
    "# print(f\"Saved {LOOP_TYPE} loops with angles ({len(data_df_with_angles)}) to data/sabdab_{LOOP_TYPE}_loops_with_angles.parquet\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a597df1d",
   "metadata": {},
   "source": [
    "# Retrieval set up"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "16c965fd",
   "metadata": {},
   "outputs": [],
   "source": [
    "loop_df = pd.read_parquet(\"preprocessed_data/sabdab_2025-05-06-paired_loops.parquet\")\n",
    "loop_df = loop_df[loop_df['loop_type'] == LOOP_TYPE]\n",
    "\n",
    "loop_df['angles'] = loop_df.apply(lambda row: np.stack([row['phi'], row['psi'], row['omega']], axis=1), axis=1)\n",
    "loop_id_to_angles = {row['sabdab_id']: row['angles'] for _, row in loop_df.iterrows()}\n",
    "loop_id_to_calpha = {row['sabdab_id']: np.stack(row['c_alpha_atoms'].tolist() + row['c_atoms'].tolist() + row['n_atoms'].tolist()) for _, row in loop_df.iterrows()}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "711b1a75",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>loop_id</th>\n",
       "      <th>start</th>\n",
       "      <th>end</th>\n",
       "      <th>sequence</th>\n",
       "      <th>loop_sequence</th>\n",
       "      <th>test</th>\n",
       "      <th>loop_len</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>0</td>\n",
       "      <td>96</td>\n",
       "      <td>105</td>\n",
       "      <td>EVQLQQPGPELVKPGASVKVSCKASGYSFTDHNMYWVKQSHGKSLE...</td>\n",
       "      <td>YIGSFYFVY</td>\n",
       "      <td>False</td>\n",
       "      <td>9</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>1</td>\n",
       "      <td>96</td>\n",
       "      <td>105</td>\n",
       "      <td>EVQLQQPGPELVKPGASVKVSCKASGYSFTDHNMYWVKQSHGKSLE...</td>\n",
       "      <td>YIGSFYFVY</td>\n",
       "      <td>False</td>\n",
       "      <td>9</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>2</td>\n",
       "      <td>96</td>\n",
       "      <td>107</td>\n",
       "      <td>QVQLVQSGGGLVQPGGSLRLSCAASGFTFSSYAMSWVRQAPGKGLE...</td>\n",
       "      <td>ARDSGSGRFDP</td>\n",
       "      <td>False</td>\n",
       "      <td>11</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>3</td>\n",
       "      <td>96</td>\n",
       "      <td>107</td>\n",
       "      <td>QVQLVQSGGGLVQPGGSLRLSCAASGFTFSSYAMSWVRQAPGKGLE...</td>\n",
       "      <td>ARDSGSGRFDP</td>\n",
       "      <td>False</td>\n",
       "      <td>11</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>4</td>\n",
       "      <td>96</td>\n",
       "      <td>107</td>\n",
       "      <td>QVQLVQSGGGLVQPGGSLRLSCAASGFTFSSYAMSWVRQAPGKGLE...</td>\n",
       "      <td>ARDSGSGRFDP</td>\n",
       "      <td>False</td>\n",
       "      <td>11</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>...</th>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>15282</th>\n",
       "      <td>18823</td>\n",
       "      <td>95</td>\n",
       "      <td>98</td>\n",
       "      <td>QVQLRESGPSLVKPSQTLSLTCTASGLSLSDKAVGWVRRAPTKALE...</td>\n",
       "      <td>ATV</td>\n",
       "      <td>False</td>\n",
       "      <td>3</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>15283</th>\n",
       "      <td>18827</td>\n",
       "      <td>94</td>\n",
       "      <td>106</td>\n",
       "      <td>VQLVESGGGLVQPGGSLRLSCAASEFIVSANYMSWVRQAPGKGLEW...</td>\n",
       "      <td>ARFLPTYDYFDY</td>\n",
       "      <td>False</td>\n",
       "      <td>12</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>15284</th>\n",
       "      <td>18829</td>\n",
       "      <td>96</td>\n",
       "      <td>107</td>\n",
       "      <td>QVQFQQSGAELVKPGASVKLSCKASGYTFTSYLMHWIKQRPGRGLE...</td>\n",
       "      <td>ARYAYCRPMDY</td>\n",
       "      <td>False</td>\n",
       "      <td>11</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>15285</th>\n",
       "      <td>18830</td>\n",
       "      <td>96</td>\n",
       "      <td>107</td>\n",
       "      <td>QVQFQQSGAELVKPGASVKLSCKASGYTFTSYLMHWIKQRPGRGLE...</td>\n",
       "      <td>ARYAYCRPMDY</td>\n",
       "      <td>False</td>\n",
       "      <td>11</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>15286</th>\n",
       "      <td>18831</td>\n",
       "      <td>96</td>\n",
       "      <td>119</td>\n",
       "      <td>EVQVVESGGGVVQPGRSLRLSCTASGFTFSNFAMGWVRQAPGKGLE...</td>\n",
       "      <td>AKDVGDYKSPIQDPRAMVGAFDL</td>\n",
       "      <td>False</td>\n",
       "      <td>23</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "<p>15287 rows × 7 columns</p>\n",
       "</div>"
      ],
      "text/plain": [
       "       loop_id  start  end                                           sequence  \\\n",
       "0            0     96  105  EVQLQQPGPELVKPGASVKVSCKASGYSFTDHNMYWVKQSHGKSLE...   \n",
       "1            1     96  105  EVQLQQPGPELVKPGASVKVSCKASGYSFTDHNMYWVKQSHGKSLE...   \n",
       "2            2     96  107  QVQLVQSGGGLVQPGGSLRLSCAASGFTFSSYAMSWVRQAPGKGLE...   \n",
       "3            3     96  107  QVQLVQSGGGLVQPGGSLRLSCAASGFTFSSYAMSWVRQAPGKGLE...   \n",
       "4            4     96  107  QVQLVQSGGGLVQPGGSLRLSCAASGFTFSSYAMSWVRQAPGKGLE...   \n",
       "...        ...    ...  ...                                                ...   \n",
       "15282    18823     95   98  QVQLRESGPSLVKPSQTLSLTCTASGLSLSDKAVGWVRRAPTKALE...   \n",
       "15283    18827     94  106  VQLVESGGGLVQPGGSLRLSCAASEFIVSANYMSWVRQAPGKGLEW...   \n",
       "15284    18829     96  107  QVQFQQSGAELVKPGASVKLSCKASGYTFTSYLMHWIKQRPGRGLE...   \n",
       "15285    18830     96  107  QVQFQQSGAELVKPGASVKLSCKASGYTFTSYLMHWIKQRPGRGLE...   \n",
       "15286    18831     96  119  EVQVVESGGGVVQPGRSLRLSCTASGFTFSNFAMGWVRQAPGKGLE...   \n",
       "\n",
       "                 loop_sequence   test  loop_len  \n",
       "0                    YIGSFYFVY  False         9  \n",
       "1                    YIGSFYFVY  False         9  \n",
       "2                  ARDSGSGRFDP  False        11  \n",
       "3                  ARDSGSGRFDP  False        11  \n",
       "4                  ARDSGSGRFDP  False        11  \n",
       "...                        ...    ...       ...  \n",
       "15282                      ATV  False         3  \n",
       "15283             ARFLPTYDYFDY  False        12  \n",
       "15284              ARYAYCRPMDY  False        11  \n",
       "15285              ARYAYCRPMDY  False        11  \n",
       "15286  AKDVGDYKSPIQDPRAMVGAFDL  False        23  \n",
       "\n",
       "[15287 rows x 7 columns]"
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "data_df = pd.read_csv(f\"data/sabdab_{LOOP_TYPE}_loops.csv\")\n",
    "data_df['loop_len'] = data_df['end'] - data_df['start']\n",
    "data_df"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "99610a2b",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "15281"
      ]
     },
     "execution_count": 8,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "valid_loop_ids = []\n",
    "for _, row in data_df.iterrows():\n",
    "    if not row['loop_id'] in loop_id_to_angles:\n",
    "        continue\n",
    "    if row['loop_len'] == loop_id_to_angles[row['loop_id']].shape[0]:\n",
    "        valid_loop_ids.append(row['loop_id'])\n",
    "data_df['valid_loop'] = data_df['loop_id'].isin(valid_loop_ids)\n",
    "len(valid_loop_ids)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "8c81315d",
   "metadata": {},
   "outputs": [],
   "source": [
    "import h5py\n",
    "\n",
    "prostt5_embeddings = []\n",
    "with h5py.File(f\"{EMBEDDING_DIR}/prostt5_embeddings_{LOOP_TYPE}.h5\", \"r\") as f:\n",
    "    for loop_id in data_df['loop_id'].values:\n",
    "        if str(loop_id) not in f.keys():\n",
    "            print(f\"Loop ID {loop_id} not found in embeddings file.\")\n",
    "            prostt5_embeddings.append(np.zeros((1024,)))\n",
    "        else:\n",
    "            prostt5_embeddings.append(f[str(loop_id)][:])\n",
    "prostt5_embeddings = np.array(prostt5_embeddings)\n",
    "\n",
    "prostt5_3di_embeddings = []\n",
    "with h5py.File(f\"{EMBEDDING_DIR}/prostt5_3di_embeddings_{LOOP_TYPE}.h5\", \"r\") as f:\n",
    "    for loop_id in data_df['loop_id'].values:\n",
    "        if str(loop_id) not in f.keys():\n",
    "            print(f\"Loop ID {loop_id} not found in embeddings file.\")\n",
    "            prostt5_3di_embeddings.append(np.zeros((1024,)))\n",
    "        else:\n",
    "            prostt5_3di_embeddings.append(f[str(loop_id)][:])\n",
    "prostt5_3di_embeddings = np.array(prostt5_3di_embeddings)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "fa916dc4",
   "metadata": {},
   "outputs": [],
   "source": [
    "igloo_embeddings = np.load(f\"{EMBEDDING_DIR}/IgLoo_sabdab_{LOOP_TYPE}.npy\")\n",
    "ablang2_embeddings = np.load(f\"{EMBEDDING_DIR}/ablang2_embeddings_{LOOP_TYPE}.npy\")\n",
    "esm2_embeddings = np.load(f\"{EMBEDDING_DIR}/esm2_embeddings_{LOOP_TYPE}.npy\")\n",
    "emc_embeddings = np.load(f\"{EMBEDDING_DIR}/esmc_embeddings_{LOOP_TYPE}.npy\")\n",
    "igbert_embeddings = np.load(f\"{EMBEDDING_DIR}/igbert_embeddings_{LOOP_TYPE}.npy\")\n",
    "saprot_embeddings = np.load(f\"{EMBEDDING_DIR}/saprot_3di_embeddings_{LOOP_TYPE}.npy\")\n",
    "foldseek3di_embeddings = np.load(f\"{EMBEDDING_DIR}/foldseek3di_embeddings_{LOOP_TYPE}.npy\")\n",
    "aminoaseed_embeddings = np.load(f\"{EMBEDDING_DIR}/aminoaseed_embeddings_{LOOP_TYPE}.npy\")\n",
    "proteinmpnn_embeddings = np.load(f\"{EMBEDDING_DIR}/proteinmpnn_embeddings_{LOOP_TYPE}.npy\")\n",
    "mif_embeddings = np.load(f\"{EMBEDDING_DIR}/mif_embeddings_{LOOP_TYPE}.npy\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "acf0e0d0",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Missing igloo angle embedding for loop_id: 15905\n",
      "Missing igloo angle embedding for loop_id: 17275\n",
      "Missing igloo angle embedding for loop_id: 15905\n",
      "Missing igloo angle embedding for loop_id: 17275\n",
      "Missing igloo angle embedding for loop_id: 15905\n",
      "Missing igloo angle embedding for loop_id: 17275\n",
      "Missing igloo angle embedding for loop_id: 15905\n",
      "Missing igloo angle embedding for loop_id: 17275\n",
      "Missing igloo angle embedding for loop_id: 15905\n",
      "Missing igloo angle embedding for loop_id: 17275\n",
      "Missing igloo angle embedding for loop_id: 15905\n",
      "Missing igloo angle embedding for loop_id: 17275\n"
     ]
    }
   ],
   "source": [
    "# load igloo with angle embeddings, some of the angles are missing so we have to remap them\n",
    "\n",
    "def get_igloo_angle_embeddings(fname):\n",
    "    if fname.endswith(\".jsonl\"):\n",
    "        igloo_angle_embeddings_raw = {}\n",
    "        with open(fname, \"r\") as f:\n",
    "            for line in f:\n",
    "                item = json.loads(line)\n",
    "                igloo_angle_embeddings_raw[item['id']] = item['encoded']\n",
    "    else:\n",
    "        igloo_embeddings_df = pd.read_parquet(fname)\n",
    "        igloo_angle_embeddings_raw = {row['loop_id']: row['encoded'] for _, row in igloo_embeddings_df.iterrows()}\n",
    "\n",
    "    igloo_angle_embeddings = []\n",
    "    for loop_id in data_df['loop_id']:\n",
    "        if loop_id in igloo_angle_embeddings_raw:\n",
    "            igloo_angle_embeddings.append(np.array(igloo_angle_embeddings_raw[loop_id]))\n",
    "        else:\n",
    "            igloo_angle_embeddings.append(np.zeros(128))\n",
    "            print(f\"Missing igloo angle embedding for loop_id: {loop_id}\")\n",
    "    igloo_angle_embeddings = np.stack(igloo_angle_embeddings)\n",
    "    return igloo_angle_embeddings\n",
    "\n",
    "igloo_angle_embeddings = get_igloo_angle_embeddings(f\"benchmarking_data/paratope_binning/IgLoo_sabdab_{LOOP_TYPE}_with_angles.jsonl\")\n",
    "igloo_no_dihedral_loss_embeddings = get_igloo_angle_embeddings(f\"benchmarking_data/paratope_binning/Igloo_ablation_no_dihedral_loss_sabdab_{LOOP_TYPE}.parquet\")\n",
    "igloo_no_sequence_embeddings = get_igloo_angle_embeddings(f\"benchmarking_data/paratope_binning/Igloo_ablation_no_sequence_sabdab_{LOOP_TYPE}.parquet\")\n",
    "igloo_no_dihedrals_embeddings = get_igloo_angle_embeddings(f\"benchmarking_data/paratope_binning/Igloo_ablation_no_dihedrals_sabdab_{LOOP_TYPE}.parquet\")\n",
    "igloo_tol1_embeddings = get_igloo_angle_embeddings(f\"benchmarking_data/paratope_binning/Igloo_ablation_tol1_v222_epoch30_sabdab_{LOOP_TYPE}.parquet\")\n",
    "igloo_no_dihedral_threshold_embeddings = get_igloo_angle_embeddings(f\"benchmarking_data/paratope_binning/Igloo_ablation_no_dihedral_threshold_sabdab_{LOOP_TYPE}.parquet\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "be86af28",
   "metadata": {},
   "outputs": [],
   "source": [
    "embeddings = {\n",
    "    'igloo_angle': igloo_angle_embeddings,\n",
    "    'igloo_no_dihedral_loss_embeddings': igloo_no_dihedral_loss_embeddings,\n",
    "    'igloo_no_sequence_embeddings': igloo_no_sequence_embeddings,\n",
    "    'igloo_no_dihedrals_embeddings': igloo_no_dihedrals_embeddings,\n",
    "    'igloo_tol1_embeddings': igloo_tol1_embeddings,\n",
    "    'igloo_no_dihedral_threshold_embeddings': igloo_no_dihedral_threshold_embeddings,\n",
    "    'ablang2': ablang2_embeddings,\n",
    "    'esm2': esm2_embeddings,\n",
    "    'emc': emc_embeddings,\n",
    "    'igbert': igbert_embeddings,\n",
    "    'prostt5': prostt5_embeddings,\n",
    "    'prostt5_3di': prostt5_3di_embeddings,\n",
    "    'saprot_3di': saprot_embeddings,\n",
    "    'foldseek': foldseek3di_embeddings,\n",
    "    'aminoaseed': aminoaseed_embeddings,\n",
    "    'proteinmpnn': proteinmpnn_embeddings,\n",
    "    'mif': mif_embeddings,\n",
    "}"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "27132287",
   "metadata": {},
   "source": [
    "# Run retrieval"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "id": "6d441ce9",
   "metadata": {},
   "outputs": [],
   "source": [
    "results = []\n",
    "\n",
    "for knn in [20]: # 1, 5, 10, \n",
    "    for LOOP_LEN in data_df[data_df['test']]['loop_len'].value_counts().index.tolist():\n",
    "        test_mask = (data_df['loop_len'] == LOOP_LEN) & (data_df['test']) & (data_df['valid_loop'])\n",
    "        dataset_mask = (data_df['loop_len'] == LOOP_LEN) & (~data_df['test']) & (data_df['valid_loop'])\n",
    "        original_indices = np.where(dataset_mask)[0]\n",
    "        test_indices = np.where(test_mask)[0]\n",
    "\n",
    "        if np.sum(original_indices) <= knn:\n",
    "            print(f\"Not enough loops for knn={knn} and loop_len={LOOP_LEN}. Skipping...\")\n",
    "            continue\n",
    "\n",
    "        test_angles = [loop_id_to_angles[data_df['loop_id'][test_indices[i]]] for i in range(len(test_indices))]\n",
    "        test_angles = np.stack(test_angles, axis=0)\n",
    "\n",
    "        dataset_angles = [loop_id_to_angles[data_df['loop_id'][original_indices[i]]] for i in range(len(original_indices))]\n",
    "        dataset_angles = np.stack(dataset_angles, axis=0)\n",
    "\n",
    "        D_all = dihedral_distance(test_angles, dataset_angles)\n",
    "        D_all_bin = D_all < 0.47\n",
    "\n",
    "        valid_values = (D_all_bin.sum(axis=1) != 0) # filter out cases where no loops with similar dihedrals are found \n",
    "        for embedding_name, embeddings_dataset in embeddings.items():\n",
    "            query_embeddings = embeddings_dataset[test_mask]\n",
    "            embeddings_dataset_ = embeddings_dataset[dataset_mask]\n",
    "            \n",
    "            # Use cosine similarity\n",
    "            query_embeddings = query_embeddings / np.linalg.norm(query_embeddings, axis=1, keepdims=True)\n",
    "            embeddings_dataset_ = embeddings_dataset_ / np.linalg.norm(embeddings_dataset_, axis=1, keepdims=True)\n",
    "            index = faiss.IndexFlatIP(embeddings_dataset_.shape[1])  # Use inner product for cosine similarity\n",
    "            index.add(embeddings_dataset_)\n",
    "            _, retrieved_indices = index.search(query_embeddings, knn)\n",
    "\n",
    "            # Use L2\n",
    "            # index = faiss.IndexFlatL2(embeddings_dataset_.shape[1])\n",
    "            # index.add(embeddings_dataset_)\n",
    "            # _, retrieved_indices = index.search(query_embeddings, knn)\n",
    "\n",
    "            precision = D_all_bin[np.arange(D_all_bin.shape[0])[:, None], retrieved_indices][valid_values].mean()\n",
    "            recall = np.mean(D_all_bin[np.arange(D_all_bin.shape[0])[:, None], retrieved_indices].sum(axis=1)[valid_values] / D_all_bin.sum(axis=1)[valid_values])\n",
    "            hits = np.mean((D_all_bin[np.arange(D_all_bin.shape[0])[:, None], retrieved_indices].sum(axis=1) > 0)[valid_values])\n",
    "\n",
    "            # RMSD\n",
    "            rmsd_values = []\n",
    "            for i in range(len(test_indices)):\n",
    "                test_calpha = loop_id_to_calpha[data_df['loop_id'][test_indices[i]]]\n",
    "                centroid1 = np.mean(test_calpha, axis=0)\n",
    "                test_calpha_centered = test_calpha - centroid1\n",
    "\n",
    "                for j in range(len(retrieved_indices[i])):\n",
    "                    retrieved_calpha = loop_id_to_calpha[data_df['loop_id'][original_indices[retrieved_indices[i][j]]]]\n",
    "                    centroid2 = np.mean(retrieved_calpha, axis=0)\n",
    "                    retrieved_calpha_centered = retrieved_calpha - centroid2\n",
    "\n",
    "                    _, _, rmsd = kabsch_numpy(test_calpha_centered, retrieved_calpha_centered)\n",
    "                    rmsd_values.append(rmsd)\n",
    "\n",
    "            results.append({\n",
    "                'embedding': embedding_name,\n",
    "                'precision': precision,\n",
    "                'recall': recall,\n",
    "                'hits': hits,\n",
    "                'knn': knn,\n",
    "                'loop_len': LOOP_LEN,\n",
    "                'rmsd': np.mean(rmsd_values),\n",
    "                'rmsd_precision': np.mean(np.array(rmsd_values) < 1.0),  # consider RMSD < 2.0 as a hit\n",
    "            })\n",
    "results = pd.DataFrame(results)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "id": "2f719c05",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th>precision</th>\n",
       "      <th>recall</th>\n",
       "      <th>hits</th>\n",
       "      <th>loop_len</th>\n",
       "      <th>rmsd</th>\n",
       "      <th>rmsd_precision</th>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>knn</th>\n",
       "      <th>embedding</th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th rowspan=\"17\" valign=\"top\">20</th>\n",
       "      <th>ablang2</th>\n",
       "      <td>0.222329</td>\n",
       "      <td>0.223650</td>\n",
       "      <td>0.697239</td>\n",
       "      <td>15.5</td>\n",
       "      <td>2.676202</td>\n",
       "      <td>0.173290</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>aminoaseed</th>\n",
       "      <td>0.379036</td>\n",
       "      <td>0.283963</td>\n",
       "      <td>0.844974</td>\n",
       "      <td>15.5</td>\n",
       "      <td>2.243870</td>\n",
       "      <td>0.292175</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>emc</th>\n",
       "      <td>0.208441</td>\n",
       "      <td>0.207271</td>\n",
       "      <td>0.677094</td>\n",
       "      <td>15.5</td>\n",
       "      <td>2.650888</td>\n",
       "      <td>0.189960</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>esm2</th>\n",
       "      <td>0.236538</td>\n",
       "      <td>0.181721</td>\n",
       "      <td>0.645844</td>\n",
       "      <td>15.5</td>\n",
       "      <td>2.616186</td>\n",
       "      <td>0.206248</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>foldseek</th>\n",
       "      <td>0.361770</td>\n",
       "      <td>0.313618</td>\n",
       "      <td>0.857018</td>\n",
       "      <td>15.5</td>\n",
       "      <td>2.267806</td>\n",
       "      <td>0.281100</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>igbert</th>\n",
       "      <td>0.215860</td>\n",
       "      <td>0.218832</td>\n",
       "      <td>0.704361</td>\n",
       "      <td>15.5</td>\n",
       "      <td>2.639347</td>\n",
       "      <td>0.181773</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>igloo_angle</th>\n",
       "      <td>0.401525</td>\n",
       "      <td>0.336070</td>\n",
       "      <td>0.883748</td>\n",
       "      <td>15.5</td>\n",
       "      <td>2.450542</td>\n",
       "      <td>0.277825</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>igloo_no_dihedral_loss_embeddings</th>\n",
       "      <td>0.334637</td>\n",
       "      <td>0.208014</td>\n",
       "      <td>0.788578</td>\n",
       "      <td>15.5</td>\n",
       "      <td>2.558965</td>\n",
       "      <td>0.241513</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>igloo_no_dihedral_threshold_embeddings</th>\n",
       "      <td>0.416734</td>\n",
       "      <td>0.294020</td>\n",
       "      <td>0.856581</td>\n",
       "      <td>15.5</td>\n",
       "      <td>2.463137</td>\n",
       "      <td>0.279402</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>igloo_no_dihedrals_embeddings</th>\n",
       "      <td>0.216782</td>\n",
       "      <td>0.194406</td>\n",
       "      <td>0.665094</td>\n",
       "      <td>15.5</td>\n",
       "      <td>2.787291</td>\n",
       "      <td>0.193400</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>igloo_no_sequence_embeddings</th>\n",
       "      <td>0.356375</td>\n",
       "      <td>0.283101</td>\n",
       "      <td>0.851591</td>\n",
       "      <td>15.5</td>\n",
       "      <td>2.618416</td>\n",
       "      <td>0.245031</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>igloo_tol1_embeddings</th>\n",
       "      <td>0.408048</td>\n",
       "      <td>0.306059</td>\n",
       "      <td>0.852995</td>\n",
       "      <td>15.5</td>\n",
       "      <td>2.441876</td>\n",
       "      <td>0.280181</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>mif</th>\n",
       "      <td>0.297578</td>\n",
       "      <td>0.282813</td>\n",
       "      <td>0.795659</td>\n",
       "      <td>15.5</td>\n",
       "      <td>2.593146</td>\n",
       "      <td>0.231096</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>prostt5</th>\n",
       "      <td>0.232774</td>\n",
       "      <td>0.253838</td>\n",
       "      <td>0.734838</td>\n",
       "      <td>15.5</td>\n",
       "      <td>2.656494</td>\n",
       "      <td>0.200443</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>prostt5_3di</th>\n",
       "      <td>0.359122</td>\n",
       "      <td>0.279899</td>\n",
       "      <td>0.846808</td>\n",
       "      <td>15.5</td>\n",
       "      <td>2.293359</td>\n",
       "      <td>0.275989</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>proteinmpnn</th>\n",
       "      <td>0.371910</td>\n",
       "      <td>0.277319</td>\n",
       "      <td>0.851433</td>\n",
       "      <td>15.5</td>\n",
       "      <td>2.343534</td>\n",
       "      <td>0.285716</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>saprot_3di</th>\n",
       "      <td>0.247545</td>\n",
       "      <td>0.252519</td>\n",
       "      <td>0.790519</td>\n",
       "      <td>15.5</td>\n",
       "      <td>2.603662</td>\n",
       "      <td>0.218167</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "                                            precision    recall      hits  \\\n",
       "knn embedding                                                               \n",
       "20  ablang2                                  0.222329  0.223650  0.697239   \n",
       "    aminoaseed                               0.379036  0.283963  0.844974   \n",
       "    emc                                      0.208441  0.207271  0.677094   \n",
       "    esm2                                     0.236538  0.181721  0.645844   \n",
       "    foldseek                                 0.361770  0.313618  0.857018   \n",
       "    igbert                                   0.215860  0.218832  0.704361   \n",
       "    igloo_angle                              0.401525  0.336070  0.883748   \n",
       "    igloo_no_dihedral_loss_embeddings        0.334637  0.208014  0.788578   \n",
       "    igloo_no_dihedral_threshold_embeddings   0.416734  0.294020  0.856581   \n",
       "    igloo_no_dihedrals_embeddings            0.216782  0.194406  0.665094   \n",
       "    igloo_no_sequence_embeddings             0.356375  0.283101  0.851591   \n",
       "    igloo_tol1_embeddings                    0.408048  0.306059  0.852995   \n",
       "    mif                                      0.297578  0.282813  0.795659   \n",
       "    prostt5                                  0.232774  0.253838  0.734838   \n",
       "    prostt5_3di                              0.359122  0.279899  0.846808   \n",
       "    proteinmpnn                              0.371910  0.277319  0.851433   \n",
       "    saprot_3di                               0.247545  0.252519  0.790519   \n",
       "\n",
       "                                            loop_len      rmsd  rmsd_precision  \n",
       "knn embedding                                                                   \n",
       "20  ablang2                                     15.5  2.676202        0.173290  \n",
       "    aminoaseed                                  15.5  2.243870        0.292175  \n",
       "    emc                                         15.5  2.650888        0.189960  \n",
       "    esm2                                        15.5  2.616186        0.206248  \n",
       "    foldseek                                    15.5  2.267806        0.281100  \n",
       "    igbert                                      15.5  2.639347        0.181773  \n",
       "    igloo_angle                                 15.5  2.450542        0.277825  \n",
       "    igloo_no_dihedral_loss_embeddings           15.5  2.558965        0.241513  \n",
       "    igloo_no_dihedral_threshold_embeddings      15.5  2.463137        0.279402  \n",
       "    igloo_no_dihedrals_embeddings               15.5  2.787291        0.193400  \n",
       "    igloo_no_sequence_embeddings                15.5  2.618416        0.245031  \n",
       "    igloo_tol1_embeddings                       15.5  2.441876        0.280181  \n",
       "    mif                                         15.5  2.593146        0.231096  \n",
       "    prostt5                                     15.5  2.656494        0.200443  \n",
       "    prostt5_3di                                 15.5  2.293359        0.275989  \n",
       "    proteinmpnn                                 15.5  2.343534        0.285716  \n",
       "    saprot_3di                                  15.5  2.603662        0.218167  "
      ]
     },
     "execution_count": 14,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "results.groupby(['knn', 'embedding']).mean()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a15bf400",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
