{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "08061c02",
   "metadata": {},
   "outputs": [],
   "source": [
    "from collections import defaultdict\n",
    "import torch\n",
    "from torch.utils.data import DataLoader\n",
    "from tqdm import tqdm\n",
    "import matplotlib.pyplot as plt\n",
    "import pandas as pd\n",
    "import os\n",
    "from sklearn.metrics import normalized_mutual_info_score, adjusted_rand_score, silhouette_score\n",
    "from collections import Counter\n",
    "import numpy as np\n",
    "\n",
    "import sys\n",
    "igloo_path = os.path.abspath(os.path.join(os.getcwd(), \"..\", \"..\"))\n",
    "print(f\"Adding {igloo_path} to sys.path\")\n",
    "sys.path.append(igloo_path)\n",
    "from model.vqvae import VQVAE\n",
    "from dataset import LoopSequenceDataset\n",
    "from dataset import Alphabet, proteinseq_toks\n",
    "from evals.metrics import eval_clusters, dihedral_distance_pairwise, eval_clusters_length_independent#"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "648002b7",
   "metadata": {},
   "outputs": [],
   "source": [
    "loop_df = pd.read_parquet(\"preprocessed_data/sabdab_2025-05-06-paired_loops_with_sequence_id.parquet\")\n",
    "loop_df['loop_id'] = loop_df.apply(lambda x: f\"{x['sabdab_id']}_{x['loop_type']}\", axis=1)\n",
    "loop_to_canonical = loop_df.set_index('loop_id')['assigned_cluster'].to_dict()\n",
    "loop_to_canonical_strict = loop_df.set_index('loop_id')['assigned_cluster_D=0.1'].to_dict()\n",
    "loop_to_canonical_ssc_comparison = loop_df.set_index('loop_id')['assigned_cluster_D=0.61'].to_dict()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "42cda11b",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Codebook size: 8192\n",
      "Number of parameters: 1932703\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "VQVAE(\n",
       "  (encoder): LoopTransformer(\n",
       "    (embed_tokens): Embedding(25, 128, padding_idx=3)\n",
       "    (dihedral_projection): Linear(in_features=6, out_features=128, bias=True)\n",
       "    (layers): ModuleList(\n",
       "      (0-3): 4 x TransformerLayer(\n",
       "        (self_attn): MultiheadAttention(\n",
       "          (k_proj): Linear(in_features=128, out_features=128, bias=True)\n",
       "          (v_proj): Linear(in_features=128, out_features=128, bias=True)\n",
       "          (q_proj): Linear(in_features=128, out_features=128, bias=True)\n",
       "          (out_proj): Linear(in_features=128, out_features=128, bias=True)\n",
       "          (rot_emb): RotaryEmbedding()\n",
       "        )\n",
       "        (self_attn_layer_norm): LayerNorm((128,), eps=1e-05, elementwise_affine=True)\n",
       "        (fc1): Linear(in_features=128, out_features=512, bias=True)\n",
       "        (fc2): Linear(in_features=512, out_features=128, bias=True)\n",
       "        (final_layer_norm): LayerNorm((128,), eps=1e-05, elementwise_affine=True)\n",
       "      )\n",
       "    )\n",
       "    (emb_layer_norm_after): LayerNorm((128,), eps=1e-05, elementwise_affine=True)\n",
       "    (lm_head): RobertaLMHead(\n",
       "      (dense): Linear(in_features=128, out_features=128, bias=True)\n",
       "      (layer_norm): LayerNorm((128,), eps=1e-05, elementwise_affine=True)\n",
       "    )\n",
       "    (dihedral_decoder): Sequential(\n",
       "      (0): Linear(in_features=128, out_features=128, bias=True)\n",
       "      (1): ReLU()\n",
       "      (2): LayerNorm((128,), eps=1e-05, elementwise_affine=True)\n",
       "      (3): Linear(in_features=128, out_features=6, bias=True)\n",
       "    )\n",
       "    (context_encoder): Linear(in_features=408, out_features=128, bias=True)\n",
       "  )\n",
       "  (quantizer): VectorQuantize(\n",
       "    (project_in): Identity()\n",
       "    (project_out): Identity()\n",
       "    (_codebook): EuclideanCodebook()\n",
       "  )\n",
       ")"
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "model_ckpt = \"../../checkpoints/igloo_weights.pt\" \n",
    "model_config = \"../../checkpoints/igloo_config.json\"\n",
    "\n",
    "model = VQVAE.load_from_config_and_weights(model_config, model_ckpt, strict=False)\n",
    "print(\"Codebook size:\", model.codebook_size)\n",
    "print(\"Number of parameters:\", sum(p.numel() for p in model.parameters() if p.requires_grad))\n",
    "model.eval()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ccb2a53b",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Dataset size: test=7442, train=108167, val=7090\n",
      "Number of unique sequences: test=2041, train=17820, val=1899\n",
      "Test, val overlap: 0 sequences\n",
      "Test, train overlap: 0 sequences\n",
      "Train, val overlap: 0 sequences\n"
     ]
    }
   ],
   "source": [
    "USE_CONTEXT = False\n",
    "USE_H3 = True\n",
    "\n",
    "dataset_path_suffix = \".jsonl\" if USE_H3 else \"_no_H3.jsonl\"\n",
    "context_path = \"preprocessed_data/sabdab_2025-05-06-paired_chains_lobster_24M_representations.parquet\"\n",
    "test_dataset = LoopSequenceDataset(f\"data/test_loop_len_all_seed_42{dataset_path_suffix}\", max_length=36, context_path=context_path if USE_CONTEXT else None)\n",
    "test_dataloader = DataLoader(test_dataset, batch_size=32, shuffle=False, collate_fn=test_dataset.collate_fn)\n",
    "train_dataset = LoopSequenceDataset(f\"data/train_loop_len_all_seed_42{dataset_path_suffix}\", max_length=36, context_path=context_path if USE_CONTEXT else None)\n",
    "train_dataloader = DataLoader(train_dataset, batch_size=32, shuffle=False, collate_fn=train_dataset.collate_fn)\n",
    "val_dataset = LoopSequenceDataset(f\"data/val_loop_len_all_seed_42{dataset_path_suffix}\", max_length=36, context_path=context_path if USE_CONTEXT else None)\n",
    "val_dataloader = DataLoader(val_dataset, batch_size=32, shuffle=False, collate_fn=val_dataset.collate_fn)\n",
    "print(f\"Dataset size: test={len(test_dataset)}, train={len(train_dataset)}, val={len(val_dataset)}\")\n",
    "\n",
    "alphabet = Alphabet(standard_toks=proteinseq_toks)\n",
    "\n",
    "test_sequences = set([x['loop_sequence'] for x in test_dataset.data])\n",
    "train_sequences = set([x['loop_sequence'] for x in train_dataset.data])\n",
    "val_sequences = set([x['loop_sequence'] for x in val_dataset.data])\n",
    "print(f\"Number of unique sequences: test={len(test_sequences)}, train={len(train_sequences)}, val={len(val_sequences)}\")\n",
    "print(f\"Test, val overlap: {len(test_sequences.intersection(val_sequences))} sequences\")\n",
    "print(f\"Test, train overlap: {len(test_sequences.intersection(train_sequences))} sequences\")\n",
    "print(f\"Train, val overlap: {len(train_sequences.intersection(val_sequences))} sequences\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "73072643",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 233/233 [00:17<00:00, 13.38it/s]\n"
     ]
    }
   ],
   "source": [
    "dataset = test_dataset\n",
    "dataloader = test_dataloader\n",
    "dataset.inference = True\n",
    "\n",
    "loss_fn = torch.nn.MSELoss()\n",
    "all_quantized_indices = []\n",
    "device = 'cuda' if torch.cuda.is_available() else 'cpu'\n",
    "model = model.to(device)\n",
    "model.eval()\n",
    "with torch.no_grad():\n",
    "    for batch in tqdm(dataloader, total=len(dataloader)):\n",
    "        for key in batch:\n",
    "            if isinstance(batch[key], torch.Tensor):\n",
    "                if key == 'id':\n",
    "                    continue\n",
    "                batch[key] = batch[key].to(device)\n",
    "        output = model(batch, val=True)\n",
    "        all_quantized_indices.append(output.quantized_indices.detach().cpu())\n",
    "all_quantized_indices = torch.cat(all_quantized_indices, dim=0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "6a62656b",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "  2%|▏         | 4/233 [00:00<00:29,  7.87it/s]"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 233/233 [00:21<00:00, 10.96it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Masked AA Recovery: 0.5307\n",
      "Masked AA Recovery for loop type L1: 0.6749\n",
      "Masked AA Recovery for loop type L2: 0.5632\n",
      "Masked AA Recovery for loop type L3: 0.4989\n",
      "Masked AA Recovery for loop type L4: 0.6187\n",
      "Masked AA Recovery for loop type H1: 0.6249\n",
      "Masked AA Recovery for loop type H2: 0.5025\n",
      "Masked AA Recovery for loop type H3: 0.4059\n",
      "Masked AA Recovery for loop type H4: 0.5176\n"
     ]
    },
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAjcAAAHHCAYAAABDUnkqAAAAOnRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjEwLjMsIGh0dHBzOi8vbWF0cGxvdGxpYi5vcmcvZiW1igAAAAlwSFlzAAAPYQAAD2EBqD+naQAASntJREFUeJzt3XlcFWX///H3AeQAoriguCEImrtiGG6pmZRllraYW6Gk5q1ZGnWXlGmahpZbd3pr7pWaWHbb4p1lFC1qmkumd6ZfTcUNxVRwKVC4fn/08+QJUA6CB6fX8/GYx8O55pqZz5wDnLcz18yxGWOMAAAALMLD3QUAAAAUJcINAACwFMINAACwFMINAACwFMINAACwFMINAACwFMINAACwFMINAACwFMINAACwFMINkI/k5GTZbDa99957xbqf0NBQ9evXr1j3AffZt2+fbDabJk2a5O5SgL8Nwg1KpIULF8pms8lms+nbb7/NtdwYo+DgYNlsNnXp0sUNFbrPqVOn5OPjI5vNph07dlyx/44dO2Sz2eTj46NTp04VeD8vvvii4z2w2WwqVaqUQkND9cQTT7i0HVwbF39nNm7c6O5SCuTSn63LTcnJye4uFdchL3cXAFyOj4+PlixZoptvvtmp/auvvtLBgwdlt9vdVJn7vPvuu7LZbKpSpYoWL16scePGXbb/okWLVKVKFZ08eVLvvfeeBgwY4NL+Zs6cKX9/f509e1ZJSUl6/fXXtXnz5jxDJ1BQb7/9ttP8W2+9pdWrV+dqr1+//rUsCxZBuEGJ1rlzZ7377rv617/+JS+vP39clyxZosjISB0/ftyN1bnHokWL1LlzZ4WEhGjJkiWXDTfGGC1ZskS9e/fW3r17tXjxYpfDzQMPPKDAwEBJ0qBBg9SzZ08lJiZqw4YNioqKuqpjKekuXLignJwceXt7u7sUy3nooYec5r/77jutXr06VztQGFyWQonWq1cv/frrr1q9erWjLSsrS++995569+6d5zqTJk1S69atVbFiRfn6+ioyMjLPcTOrV6/WzTffrHLlysnf319169bVc889d9l6MjMz1aVLFwUEBGjt2rWSpJycHE2bNk0NGzaUj4+PgoKCNGjQIJ08edJpXWOMxo0bpxo1asjPz08dOnTQ//73P5dej5SUFH3zzTfq2bOnevbsqb179zrqyMuaNWu0b98+R/+vv/5aBw8edGmff9W2bVtJ0p49e5za169frzvuuEMBAQHy8/NT+/bttWbNmlzrHzp0SP3791e1atVkt9tVq1YtDR48WFlZWY4+v/zyi7p3764KFSrIz89PLVu21MqVKx3Ljx49Ki8vL40ZMybX9nfu3Cmbzabp06c72k6dOqXhw4crODhYdrtdtWvX1sSJE5WTk+Poc+nYmGnTpik8PFx2u10bNmxQ6dKlNWzYsFz7OnjwoDw9PZWQkFCg127q1KkKCQmRr6+v2rdvr+3btzuWLViwQDabTVu2bMm13ssvvyxPT08dOnSoQPu5nC1btujOO+9U2bJl5e/vr44dO+q7777L1e9K74H057i0xMREPffcc6pSpYpKly6te+65RwcOHLiqOvv27avAwECdP38+17Lbb79ddevWdczbbDYNHTpUixcvVt26deXj46PIyEh9/fXXudY9dOiQHnnkEQUFBclut6thw4aaP3/+VdWKEsgAJdCCBQuMJPP999+b1q1bm4cfftixbMWKFcbDw8McOnTIhISEmLvuustp3Ro1apghQ4aY6dOnmylTppioqCgjyXz88ceOPtu3bzfe3t6mefPm5rXXXjOzZs0yTz/9tGnXrp2jz5dffmkkmXfffdcYY8y5c+fMbbfdZsqXL282bNjg6DdgwADj5eVlBg4caGbNmmWeffZZU7p0aXPTTTeZrKwsR7+RI0caSaZz585m+vTp5pFHHjHVqlUzgYGBpm/fvgV6XSZMmGD8/f3NuXPnjDHGhIeHmyFDhuTb/x//+IcJDw931O/v729eeeWVAu1r9OjRRpJJS0tzan/66aeNJPPJJ5842pKSkoy3t7dp1aqVmTx5spk6dapp0qSJ8fb2NuvXr3f0O3TokKlWrZrx8/Mzw4cPN7NmzTIvvPCCqV+/vjl58qQxxpjU1FQTFBRkypQpY55//nkzZcoU07RpU+Ph4WHef/99x7ZuvfVW06BBg1x1jxkzxnh6eprU1FRjjDFnz541TZo0MRUrVjTPPfecmTVrlomJiTE2m80MGzbMsd7evXuNJNOgQQMTFhZmJkyYYKZOnWr2799v+vTpY4KCgsyFCxec9vXKK68Ym81m9u/fn+/reHG7jRs3NqGhoWbixIlmzJgxpkKFCqZSpUqOOjMyMoyvr6956qmncm2jQYMG5tZbb813H8Y4/87kZ/v27aZ06dKmatWq5qWXXjITJkwwtWrVMna73Xz33XeOfgV9Dy7+jjRu3Ng0adLETJkyxYwYMcL4+PiYG264wfFzWhCPPfaYufQjafXq1UaS+eijj5z6HTlyxHh6epqxY8c62iSZRo0amcDAQDN27FgzceJEExISYnx9fc22bducjqtGjRomODjYjB071sycOdPcc889RpKZOnVqgWtFyUe4QYl06R/q6dOnmzJlyjj+UHbv3t106NDBGGPyDDd//YOalZVlGjVq5PThMHXq1Dw/uC91abg5ffq0ad++vQkMDDRbtmxx9Pnmm2+MJLN48WKndVetWuXUfuzYMePt7W3uuusuk5OT4+j33HPPGUkFDjeNGzc2ffr0cVo/MDDQnD9/PlffrKwsU7FiRfP888872nr37m2aNm1aoH1dDDc7d+40aWlpZt++fWb+/PnG19fXVKpUyZw9e9YYY0xOTo6pU6eO6dSpk9OxnTt3ztSqVcvcdtttjraYmBjj4eGR5wfwxXWHDx9uJJlvvvnGsez06dOmVq1aJjQ01GRnZxtjjHnjjTeMJKcPL2NyB4GXXnrJlC5d2uzatcup34gRI4ynp6dJSUkxxvwZQsqWLWuOHTvm1PfTTz/NFeiMMaZJkyamffv2l30dL27X19fXHDx40NG+fv16I8k8+eSTjrZevXqZatWqOY7RGGM2b95sJJkFCxZcdj8FCTfdunUz3t7eZs+ePY62w4cPmzJlyjgF+4K+Bxd/R6pXr24yMjIcfZctW2Ykmddee+2yNV/qr+EmOzvb1KhRw/To0cOp35QpU4zNZjO//PKLo02SkWQ2btzoaNu/f7/x8fEx9957r6Otf//+pmrVqub48eNO2+zZs6cJCAhwKYyhZCPcoES69A/1sWPHjJeXl1m2bJnjf7dz5swxxuQdbi514sQJk5aWZgYPHmzKlSuXa/tz5851+iC51MU/3HPnzjWtWrUyQUFBZvv27U59nnjiCRMQEGCOHTtm0tLSnCZ/f38zYMAAY4wxS5YsMZLMqlWrnNY/duxYgcPN1q1bc52B2rZtW662iz744AMjyanmjz76KFdbfi6Gm79OjRs3dvoQufjh++abb+Z6DQYMGGDsdrvJzs422dnZpmzZsqZr166X3e8NN9xgoqKicrUnJCQ4hZm0tDTj5eVlRo4cmev1eOONNxxtTZo0MXfccUeu2j7//HMjySxatMgY82cIiY2NzbXv7OxsU61aNfPQQw/l2tfFn8X8XNxur169ci1r0aKFqVu3rmP+k08+MZLM559/7mh76qmnjK+vr1N4yMuVws2FCxeMn5+fefDBB3MtGzRokPHw8DDp6enGmIK/Bxd/R+Lj45365eTkmKpVq5pOnTpdtuZL/TXcGGPMs88+m+vYIyMjTZs2bZz6STKtWrXKtc0ePXoYPz8/c+HCBZOTk2PKlStnHn300Vw/Cxdfu2+//bbA9aJkY8wNSrxKlSopOjpaS5Ys0fvvv6/s7Gw98MAD+fb/+OOP1bJlS/n4+KhChQqqVKmSZs6cqfT0dEefHj16qE2bNhowYICCgoLUs2dPLVu2zGkMxkXDhw/X999/r88//1wNGzZ0WvZ///d/Sk9PV+XKlVWpUiWn6cyZMzp27Jgkaf/+/ZKkOnXq5Dq28uXLF+h1WLRokUqXLq2wsDDt3r1bu3fvlo+Pj0JDQ7V48eI8+9eqVUt2u93RPzw8XH5+fnn2z8/y5cu1evVqLVmyRC1bttSxY8fk6+vr9BpIf4yR+OtrMHfuXGVmZio9PV1paWnKyMhQo0aNLru//fv3O42nuOjiXTMXX8vAwEB17NhRy5Ytc/RJTEyUl5eX7rvvPqf6Vq1alau26OhoSXK8RxfVqlUr1749PDzUp08frVixQufOnZMkLV68WD4+Purevftlj+eiv773knTDDTdo3759jvnbbrtNVatWdbw/OTk5euedd9S1a1eVKVOmQPvJT1pams6dO5fva5uTk+MYJ1PQ9+Civx6bzWZT7dq1nY6tMGJiYvTbb7/pP//5j6Q/xlNt2rRJDz/8cK6++b2+586dU1pamtLS0nTq1CnNnj07189CbGyspNw/C7h+cbcUrgu9e/fWwIEDlZqaqjvvvFPlypXLs98333yje+65R+3atdO///1vVa1aVaVKldKCBQu0ZMkSRz9fX199/fXX+vLLL7Vy5UqtWrVKiYmJuvXWW/XZZ5/J09PT0bdr165aunSpJkyYoLfeekseHn/+nyAnJ0eVK1fONyxUqlSpSI7fGKN33nlHZ8+eVYMGDXItP3bsmM6cOSN/f39JUkZGhj766CP9/vvvef7RX7JkicaPHy+bzXbFfbdr185xt9Tdd9+txo0bq0+fPtq0aZM8PDwcgfDVV19VREREntvw9/fXiRMnCnq4BdazZ0/Fxsbqhx9+UEREhJYtW6aOHTs66pX+eI9uu+02PfPMM3lu44YbbnCavzS4XSomJkavvvqqVqxYoV69emnJkiWOweVFxdPTU71799acOXP073//W2vWrNHhw4f/tncQNWjQQJGRkVq0aJFiYmK0aNEieXt768EHH3R5Wxd/Th966CH17ds3zz5NmjS5qnpRchBucF249957NWjQIH333XdKTEzMt9/y5cvl4+OjTz/91OkZOAsWLMjV18PDQx07dlTHjh01ZcoUvfzyy3r++ef15ZdfOv5XL0ndunXT7bffrn79+qlMmTKaOXOmY1l4eLg+//xztWnTJt8PRUkKCQmR9MdZhLCwMEd7Wlparruq8nLxuT5jx47N9dyPkydP6tFHH9WKFSscH4Lvv/++fv/9d82cOdPpg17643+/I0eO1Jo1a3I9P+hK/P39NXr0aMXGxmrZsmXq2bOnwsPDJUlly5Z1et3+qlKlSipbtqzTHUJ5CQkJ0c6dO3O1//zzz47lF3Xr1k2DBg1y/Ezs2rVL8fHxTuuFh4frzJkzl62tIBo1aqRmzZpp8eLFqlGjhlJSUvT6668XeP2LZ7gutWvXLoWGhjq1xcTEaPLkyfroo4/0ySefqFKlSurUqdNV1S798fr7+fnl+9p6eHgoODhYkmvvgZT72Iwx2r17d5GEhZiYGMXFxenIkSNasmSJ7rrrrjzPdub3+vr5+Tn+k1GmTBllZ2df9c8CrgPuvi4G5CWv8QMLFy40L774otOgv7+OuYmLizN+fn6Owa7G/DHmwc/Pz+l6/q+//pprnytXrnQav/LXu6Vef/11I8k888wzjnWSk5PzHHNgjDHnz5933AF07NgxU6pUqUIPKO7fv78pXbq0+e233/JcXqdOHXPHHXc45jt27GjCwsLy7Pv7778bf39/849//OOy+8zvbqmsrCxTo0YNExERYYz5YzxKeHi4qVOnjjl9+nSu7Vw6ONeVAcVr1651LDtz5owJCwtzGsx60d13323CwsLMs88+a7y9vR2v+UUvvvhinuOdjDHm5MmTjsHYF8fGvPrqq/m+JlOmTDFeXl7m3nvvNRUrVnS6Gy4/VxpQPHz48FzrNGnSxNx+++2mbNmy5vHHH7/iPowp+IBiu91u9u7d62hLTU01ZcuWzXNA8ZXegysNKJ42bVqBajcm7zE3xhjHmLvu3bsbSWb58uW5+uj/jwfbtGmToy0lJcX4+PiYbt26Odr69etnvL29cw1Cv7gfWAfhBiVSQf5QG5M73CQlJRlJpm3btmbmzJlmzJgxpnLlyqZJkyZOfziHDRtmmjVrZkaOHGnmzJljxo8fb6pXr25q1KhhTp06ZYzJHW6MMWb8+PFGkhk/fryjbdCgQUaSufPOO83UqVPN9OnTzbBhw0y1atWc1o2Pj3e6Fbx///4FuhX8999/N+XKlXP6I/1XTz31lPHy8jJHjx41hw4dMh4eHnl+aF50//33X/HDOb9wY4wxr776qtPdQ19++aXx8fExNWvWNKNHjzazZ882o0ePNu3atTNdunRxrHfw4EFTpUoVx63gb7zxhnnxxRdNw4YNc90KHhAQYF544QUzdepUExERYWw2m9NtyBctWrTISDJlypQxd999d67lZ8+eNTfeeKPx8vIyAwYMMDNnzjSTJk0yffv2NaVLl3YcX0HCTWpqqvHy8jKSzODBg/Ptd6m8bgUfO3asqVChgqlYsaI5fPhwrnUmTZrk+MC+9Fb6y7n4OzN48GDz0ksv5ZoyMjIct4JXr17djB8/3kycONGEhYXleyv4ld6Dv94KPnXqVMet4LVr13b6T8aV5BdujDGmS5cuRpIpV66c+f3333MtVz63gvv4+JitW7c6HVdISIjx8/Mzw4YNM2+88YZJSEgw3bt3N+XLly9wrSj5CDcokQobbowxZt68eaZOnTrGbrebevXqmQULFjg+qC9KSkoyXbt2NdWqVTPe3t6mWrVqplevXk63C+cVbowx5plnnjGSzPTp0x1ts2fPNpGRkcbX19eUKVPGNG7c2DzzzDNOH1zZ2dlmzJgxpmrVqsbX19fccsstZvv27SYkJOSy4Wb58uVGkpk3b16+fS6eQXrttdfM5MmTjSSTlJSUb/+FCxcaSeaDDz7It8/lwk16eroJCAhwug16y5Yt5r777jMVK1Y0drvdhISEmAcffDBXHfv37zcxMTGmUqVKxm63m7CwMPPYY4+ZzMxMR589e/aYBx54wJQrV874+PiYqKioPO8IM+bP58Pokjuf/ur06dMmPj7e1K5d23h7e5vAwEDTunVrM2nSJEfAK0i4McaYzp075zqrcTmXbnfy5MkmODjY2O1207ZtW6cP3ktdfJbLDTfcUKB9GPPn70x+04EDB4wxf9zd1qlTJ+Pv72/8/PxMhw4d8jyWgrwHF39H3nnnHRMfH28qV65sfH19zV133XXZZ//k5XLh5uKZoEcffTTP5ZLMY489ZhYtWuT43W/WrJn58ssvc/U9evSoeeyxx0xwcLApVaqUqVKliunYsaOZPXu2S/WiZLMZY8xVX9sCgL+Je++9V9u2bdPu3buLbR/Hjx9X1apVNWrUKL3wwgvFtp+rlZycrA4dOujdd9+97B2MV+uDDz5Qt27d9PXXXzuekH0pm82mxx57zOmp1Ph741ZwACigI0eOaOXKlXneilyUFi5cqOzs7GLfz/Vizpw5CgsLc3kAPP6+uFsKAK5g7969WrNmjebOnatSpUpp0KBBxbKfL774Qj/99JPGjx+vbt265bqT6u9m6dKl+vHHH7Vy5Uq99tprBXp0ASARbgDgir766ivFxsaqZs2aevPNN1WlSpVi2c/YsWO1du1atWnTxqXbzK2qV69e8vf3V//+/TVkyBB3l4PrCGNuAACApTDmBgAAWArhBgAAWMrfbsxNTk6ODh8+rDJlyjA4DQCA64QxRqdPn1a1atWcvuMvv85uN336dBMSEmLsdruJioq67BM527dvn+cDqjp37lygfR04cOCyD7piYmJiYmJiKrnTxQdSXo7bz9wkJiYqLi5Os2bNUosWLTRt2jR16tRJO3fuVOXKlXP1f//995WVleWY//XXX9W0aVN17969QPsrU6aMJOnAgQMqW7Zs0RwEAAAoVhkZGQoODnZ8jl+O2++WatGihW666SbHkyVzcnIUHBysxx9/XCNGjLji+tOmTdOoUaN05MgRlS5d+or9MzIyFBAQoPT0dMINAADXCVc+v906oDgrK0ubNm1y+vp5Dw8PRUdHa926dQXaxrx589SzZ88CBRsAAGB9br0sdfz4cWVnZysoKMipPSgoSD///PMV19+wYYO2b9+uefPm5dsnMzNTmZmZjvmMjIzCFwwAAEq86/pW8Hnz5qlx48aKiorKt09CQoICAgIcU3Bw8DWsEAAAXGtuDTeBgYHy9PTU0aNHndqPHj16xcebnz17VkuXLlX//v0v2y8+Pl7p6emO6cCBA1ddNwAAKLncGm68vb0VGRmppKQkR1tOTo6SkpLUqlWry6777rvvKjMzUw899NBl+9ntdpUtW9ZpAgAA1uX2W8Hj4uLUt29fNW/eXFFRUZo2bZrOnj2r2NhYSVJMTIyqV6+uhIQEp/XmzZunbt26qWLFiu4oGwAAlFBuDzc9evRQWlqaRo0apdTUVEVERGjVqlWOQcYpKSm5nkS4c+dOffvtt/rss8/cUTIAACjB3P6cm2uN59wAAHD9uW6ecwMAAFDUCDcAAMBSCDcAAMBSCDcAAMBSCDcAAMBSCDcAAMBSCDcAAMBS3P4QP6sJHbHS3SUUyr4Jd7m7BAAAigRnbgAAgKUQbgAAgKUQbgAAgKUQbgAAgKUQbgAAgKUQbgAAgKUQbgAAgKUQbgAAgKUQbgAAgKUQbgAAgKUQbgAAgKUQbgAAgKUQbgAAgKUQbgAAgKUQbgAAgKUQbgAAgKUQbgAAgKUQbgAAgKUQbgAAgKUQbgAAgKUQbgAAgKUQbgAAgKUQbgAAgKUQbgAAgKUQbgAAgKV4ubsAALgWQkesdHcJhbJvwl3uLgG47nDmBgAAWArhBgAAWAqXpYC/OS7XALAaztwAAABLIdwAAABLIdwAAABLIdwAAABLIdwAAABLcXu4mTFjhkJDQ+Xj46MWLVpow4YNl+1/6tQpPfbYY6patarsdrtuuOEG/fe//71G1QIAgJLOrbeCJyYmKi4uTrNmzVKLFi00bdo0derUSTt37lTlypVz9c/KytJtt92mypUr67333lP16tW1f/9+lStX7toXDwAASiS3hpspU6Zo4MCBio2NlSTNmjVLK1eu1Pz58zVixIhc/efPn68TJ05o7dq1KlWqlCQpNDT0WpYMAABKOLddlsrKytKmTZsUHR39ZzEeHoqOjta6devyXOfDDz9Uq1at9NhjjykoKEiNGjXSyy+/rOzs7Hz3k5mZqYyMDKcJAABYl9vCzfHjx5Wdna2goCCn9qCgIKWmpua5zi+//KL33ntP2dnZ+u9//6sXXnhBkydP1rhx4/LdT0JCggICAhxTcHBwkR4HAAAoWdw+oNgVOTk5qly5smbPnq3IyEj16NFDzz//vGbNmpXvOvHx8UpPT3dMBw4cuIYVAwCAa81tY24CAwPl6empo0ePOrUfPXpUVapUyXOdqlWrqlSpUvL09HS01a9fX6mpqcrKypK3t3eudex2u+x2e9EWDwAASiy3nbnx9vZWZGSkkpKSHG05OTlKSkpSq1at8lynTZs22r17t3Jychxtu3btUtWqVfMMNgAA4O/HrZel4uLiNGfOHL355pvasWOHBg8erLNnzzrunoqJiVF8fLyj/+DBg3XixAkNGzZMu3bt0sqVK/Xyyy/rsccec9chAACAEsatt4L36NFDaWlpGjVqlFJTUxUREaFVq1Y5BhmnpKTIw+PP/BUcHKxPP/1UTz75pJo0aaLq1atr2LBhevbZZ911CAAAoIRxa7iRpKFDh2ro0KF5LktOTs7V1qpVK3333XfFXBUAALheXVd3SwEAAFwJ4QYAAFgK4QYAAFgK4QYAAFgK4QYAAFgK4QYAAFgK4QYAAFgK4QYAAFgK4QYAAFgK4QYAAFgK4QYAAFgK4QYAAFiK2784E9en0BEr3V1CoeybcJe7SwAAFDPO3AAAAEsh3AAAAEsh3AAAAEsh3AAAAEsh3AAAAEsh3AAAAEsh3AAAAEsh3AAAAEsh3AAAAEsh3AAAAEsh3AAAAEsh3AAAAEsh3AAAAEsh3AAAAEsh3AAAAEsh3AAAAEsh3AAAAEsh3AAAAEsh3AAAAEsh3AAAAEsh3AAAAEsh3AAAAEsh3AAAAEvxcncBAAAgt9ARK91dQqHtm3CXW/fPmRsAAGAphBsAAGAphBsAAGAphBsAAGApJSLczJgxQ6GhofLx8VGLFi20YcOGfPsuXLhQNpvNafLx8bmG1QIAgJLM7eEmMTFRcXFxGj16tDZv3qymTZuqU6dOOnbsWL7rlC1bVkeOHHFM+/fvv4YVAwCAkszt4WbKlCkaOHCgYmNj1aBBA82aNUt+fn6aP39+vuvYbDZVqVLFMQUFBV3DigEAQEnm1ufcZGVladOmTYqPj3e0eXh4KDo6WuvWrct3vTNnzigkJEQ5OTm68cYb9fLLL6thw4Z59s3MzFRmZqZjPiMjo+gOAABKmOv12Sjufi4KrMWtZ26OHz+u7OzsXGdegoKClJqamuc6devW1fz58/XBBx9o0aJFysnJUevWrXXw4ME8+yckJCggIMAxBQcHF/lxAACAkuO6e0Jxq1at1KpVK8d869atVb9+fb3xxht66aWXcvWPj49XXFycYz4jI4OAgwK5Xv8HLPG/YAB/b24NN4GBgfL09NTRo0ed2o8ePaoqVaoUaBulSpVSs2bNtHv37jyX2+122e32q64VAABcH9x6Wcrb21uRkZFKSkpytOXk5CgpKcnp7MzlZGdna9u2bapatWpxlQkAAK4jbr8sFRcXp759+6p58+aKiorStGnTdPbsWcXGxkqSYmJiVL16dSUkJEiSxo4dq5YtW6p27do6deqUXn31Ve3fv18DBgxw52EAAIASwuVw07dvX/Xv31/t2rUrkgJ69OihtLQ0jRo1SqmpqYqIiNCqVascg4xTUlLk4fHnCaaTJ09q4MCBSk1NVfny5RUZGam1a9eqQYMGRVIPAAC4vrkcbtLT0xUdHa2QkBDFxsaqb9++ql69+lUVMXToUA0dOjTPZcnJyU7zU6dO1dSpU69qfwAAwLpcHnOzYsUKHTp0SIMHD1ZiYqJCQ0N155136r333tP58+eLo0YAAIACK9SA4kqVKikuLk5bt27V+vXrVbt2bT388MOqVq2annzySf3f//1fUdcJAABQIFd1t9SRI0e0evVqrV69Wp6enurcubO2bdumBg0acOkIAAC4hcvh5vz581q+fLm6dOmikJAQvfvuuxo+fLgOHz6sN998U59//rmWLVumsWPHFke9AAAAl+XygOKqVasqJydHvXr10oYNGxQREZGrT4cOHVSuXLkiKA8AAMA1LoebqVOnqnv37vLx8cm3T7ly5bR3796rKgwAAKAwXLosdf78ecXGxub7VQcAAADu5lK4KVWqlGrWrKns7OziqgcAAOCquDyg+Pnnn9dzzz2nEydOFEc9AAAAV8XlMTfTp0/X7t27Va1aNYWEhKh06dJOyzdv3lxkxQEAALjK5XDTrVu3YigDAACgaLgcbkaPHl0cdQAAABSJQj2h+NSpU5o7d67i4+MdY282b96sQ4cOFWlxAAAArnL5zM2PP/6o6OhoBQQEaN++fRo4cKAqVKig999/XykpKXrrrbeKo04AAIACcfnMTVxcnPr166f/+7//c3qQX+fOnfX1118XaXEAAACucjncfP/99xo0aFCu9urVqys1NbVIigIAACgsl8ON3W5XRkZGrvZdu3apUqVKRVIUAABAYbkcbu655x6NHTtW58+flyTZbDalpKTo2Wef1f3331/kBQIAALjC5XAzefJknTlzRpUrV9Zvv/2m9u3bq3bt2ipTpozGjx9fHDUCAAAUmMt3SwUEBGj16tX69ttv9eOPP+rMmTO68cYbFR0dXRz1AQAAuMTlcHPgwAEFBwfr5ptv1s0331wcNQEAABSay5elQkND1b59e82ZM0cnT54sjpoAAAAKzeVws3HjRkVFRWns2LGqWrWqunXrpvfee0+ZmZnFUR8AAIBLXA43zZo106uvvqqUlBR98sknqlSpkh599FEFBQXpkUceKY4aAQAACqxQ3y0l/XELeIcOHTRnzhx9/vnnqlWrlt58882irA0AAMBlhQ43Bw8e1CuvvKKIiAhFRUXJ399fM2bMKMraAAAAXOby3VJvvPGGlixZojVr1qhevXrq06ePPvjgA4WEhBRHfQAAAC5xOdyMGzdOvXr10r/+9S81bdq0OGoCAAAoNJfDTUpKimw2W3HUAgAAcNVcDjc2m02nTp3SvHnztGPHDklSgwYN1L9/fwUEBBR5gQAAAK4o1HNuwsPDNXXqVJ04cUInTpzQ1KlTFR4ers2bNxdHjQAAAAXm8pmbJ598Uvfcc4/mzJkjL68/Vr9w4YIGDBig4cOH6+uvvy7yIgEAAArK5XCzceNGp2AjSV5eXnrmmWfUvHnzIi0OAADAVS5flipbtqxSUlJytR84cEBlypQpkqIAAAAKy+Vw06NHD/Xv31+JiYk6cOCADhw4oKVLl2rAgAHq1atXcdQIAABQYC5flpo0aZJsNptiYmJ04cIFSVKpUqU0ePBgTZgwocgLBAAAcIXL4cbb21uvvfaaEhIStGfPHklSeHi4/Pz8irw4AAAAV7kcbtLT05Wdna0KFSqocePGjvYTJ07Iy8tLZcuWLdICAQAAXOHymJuePXtq6dKludqXLVumnj17FklRAAAAheVyuFm/fr06dOiQq/2WW27R+vXri6QoAACAwnI53GRmZjoGEl/q/Pnz+u2334qkKAAAgMJyOdxERUVp9uzZudpnzZqlyMjIQhUxY8YMhYaGysfHRy1atNCGDRsKtN7SpUtls9nUrVu3Qu0XAABYj8sDiseNG6fo6Ght3bpVHTt2lCQlJSXp+++/12effeZyAYmJiYqLi9OsWbPUokULTZs2TZ06ddLOnTtVuXLlfNfbt2+fnn76abVt29blfQIArm+hI1a6u4RC2TfhLneX8Lfg8pmbNm3aaN26dapRo4aWLVumjz76SLVr19aPP/5YqKAxZcoUDRw4ULGxsWrQoIFmzZolPz8/zZ8/P991srOz1adPH40ZM0ZhYWEu7xMAAFiXy2duJCkiIkJLliy56p1nZWVp06ZNio+Pd7R5eHgoOjpa69aty3e9sWPHqnLlyurfv7+++eaby+4jMzNTmZmZjvmMjIyrrhsAAJRcLp+5kaQ9e/Zo5MiR6t27t44dOyZJ+uSTT/S///3Ppe0cP35c2dnZCgoKcmoPCgpSampqnut8++23mjdvnubMmVOgfSQkJCggIMAxBQcHu1QjAAC4vrgcbr766is1btxY69ev1/Lly3XmzBlJ0tatWzV69OgiL/BSp0+f1sMPP6w5c+YoMDCwQOvEx8crPT3dMR04cKBYawQAAO7l8mWpESNGaNy4cYqLi3P6FvBbb71V06dPd2lbgYGB8vT01NGjR53ajx49qipVquTqv2fPHu3bt0933323oy0nJ0eS5OXlpZ07dyo8PNxpHbvdLrvd7lJdAADg+uXymZtt27bp3nvvzdVeuXJlHT9+3KVteXt7KzIyUklJSY62nJwcJSUlqVWrVrn616tXT9u2bdMPP/zgmO655x516NBBP/zwA5ecAACA62duypUrpyNHjqhWrVpO7Vu2bFH16tVdLiAuLk59+/ZV8+bNFRUVpWnTpuns2bOKjY2VJMXExKh69epKSEiQj4+PGjVqlKseSbnaAQDA35PL4aZnz5569tln9e6778pmsyknJ0dr1qzR008/rZiYGJcL6NGjh9LS0jRq1CilpqYqIiJCq1atcgwyTklJkYdHocY9AwCAvyGXw83LL7+sxx57TMHBwcrOzlaDBg2UnZ2t3r176/nnny9UEUOHDtXQoUPzXJacnHzZdRcuXFiofQIAAGtyOdx4e3trzpw5GjVqlLZt26YzZ86oWbNmqlOnTnHUBwAA4JJCPcRPkoKDg50G8L7//vt68cUX9eOPPxZJYQAAAIXh0mCWN954Qw888IB69+6t9evXS5K++OILNWvWTA8//LDatGlTLEUCAAAUVIHDzYQJE/T4449r3759+vDDD3Xrrbfq5ZdfVp8+fdSjRw8dPHhQM2fOLM5aAQAArqjAl6UWLFigOXPmqG/fvvrmm2/Uvn17rV27Vrt371bp0qWLs0YAAIACK/CZm5SUFN16662SpLZt26pUqVIaM2YMwQYAAJQoBQ43mZmZ8vHxccx7e3urQoUKxVIUAABAYbl0t9QLL7wgPz8/SVJWVpbGjRungIAApz5TpkwpuuoAAABcVOBw065dO+3cudMx37p1a/3yyy9OfWw2W9FVBgAAUAgFDjdXelIwAABAScCXNgEAAEsh3AAAAEsh3AAAAEsh3AAAAEsp0nCzffv2otwcAACAy6463Jw+fVqzZ89WVFSUmjZtWhQ1AQAAFFqhw83XX3+tvn37qmrVqpo0aZJuvfVWfffdd0VZGwAAgMtcekJxamqqFi5cqHnz5ikjI0MPPvigMjMztWLFCjVo0KC4agQAACiwAp+5ufvuu1W3bl39+OOPmjZtmg4fPqzXX3+9OGsDAABwWYHP3HzyySd64oknNHjwYNWpU6c4awIAACi0Ap+5+fbbb3X69GlFRkaqRYsWmj59uo4fP16ctQEAALiswOGmZcuWmjNnjo4cOaJBgwZp6dKlqlatmnJycrR69WqdPn26OOsEAAAoEJfvlipdurQeeeQRffvtt9q2bZueeuopTZgwQZUrV9Y999xTHDUCAAAU2FU956Zu3bp65ZVXdPDgQb3zzjtFVRMAAEChuXQreH48PT11zz33yMurSDYHAABQaFedRnbv3q358+dr4cKFSktL0/nz54uiLgAAgEIp1GWp3377TW+99ZbatWununXrau3atRo1apQOHjxY1PUBAAC4xKUzN99//73mzp2rpUuXKjw8XH369NHatWv173//mycUAwCAEqHA4aZJkybKyMhQ7969tXbtWjVs2FCSNGLEiGIrDgAAwFUFviy1c+dOtWvXTh06dOAsDQAAKLEKHG5++eUX1a1bV4MHD1aNGjX09NNPa8uWLbLZbMVZHwAAgEsKHG6qV6+u559/Xrt379bbb7+t1NRUtWnTRhcuXNDChQu1a9eu4qwTAACgQAp1t9Stt96qRYsW6ciRI5o+fbq++OIL1atXT02aNCnq+gAAAFxyVU8oDggI0JAhQ7Rx40Zt3rxZt9xySxGVBQAAUDhXFW4uFRERoX/9619FtTkAAIBCKbJwAwAAUBIQbgAAgKUQbgAAgKUQbgAAgKUU6OsXXBko/MQTTxS6GAAAgKtVoHAzdepUp/m0tDSdO3dO5cqVkySdOnVKfn5+qly5MuEGAAC4VYEuS+3du9cxjR8/XhEREdqxY4dOnDihEydOaMeOHbrxxhv10ksvFaqIGTNmKDQ0VD4+PmrRooU2bNiQb9/3339fzZs3V7ly5VS6dGlFRETo7bffLtR+AQCA9bg85uaFF17Q66+/rrp16zra6tatq6lTp2rkyJEuF5CYmKi4uDiNHj1amzdvVtOmTdWpUycdO3Ysz/4VKlTQ888/r3Xr1unHH39UbGysYmNj9emnn7q8bwAAYD0uh5sjR47owoULudqzs7N19OhRlwuYMmWKBg4cqNjYWDVo0ECzZs2Sn5+f5s+fn2f/W265Rffee6/q16+v8PBwDRs2TE2aNNG3337r8r4BAID1uBxuOnbsqEGDBmnz5s2Otk2bNmnw4MGKjo52aVtZWVnatGmT03oeHh6Kjo7WunXrrri+MUZJSUnauXOn2rVrl2efzMxMZWRkOE0AAMC6XA438+fPV5UqVdS8eXPZ7XbZ7XZFRUUpKChIc+fOdWlbx48fV3Z2toKCgpzag4KClJqamu966enp8vf3l7e3t+666y69/vrruu222/Lsm5CQoICAAMcUHBzsUo0AAOD6UqC7pS5VqVIl/fe//9WuXbv0888/S5Lq1aunG264ociLy0+ZMmX0ww8/6MyZM0pKSlJcXJzCwsLy/OLO+Ph4xcXFOeYzMjIIOAAAWJjL4eai0NBQGWMUHh4uL6/CbSYwMFCenp65xuocPXpUVapUyXc9Dw8P1a5dW5Icd24lJCTkGW4unl0CAAB/Dy5fljp37pz69+8vPz8/NWzYUCkpKZKkxx9/XBMmTHBpW97e3oqMjFRSUpKjLScnR0lJSWrVqlWBt5OTk6PMzEyX9g0AAKzJ5XATHx+vrVu3Kjk5WT4+Po726OhoJSYmulxAXFyc5syZozfffFM7duzQ4MGDdfbsWcXGxkqSYmJiFB8f7+ifkJCg1atX65dfftGOHTs0efJkvf3223rooYdc3jcAALAel68nrVixQomJiWrZsqVsNpujvWHDhtqzZ4/LBfTo0UNpaWkaNWqUUlNTFRERoVWrVjkGGaekpMjD488MdvbsWQ0ZMkQHDx6Ur6+v6tWrp0WLFqlHjx4u7xsAAFiPy+EmLS1NlStXztV+9uxZp7DjiqFDh2ro0KF5LktOTnaaHzdunMaNG1eo/QAAAOtz+bJU8+bNtXLlSsf8xUAzd+5cl8bJAAAAFAeXz9y8/PLLuvPOO/XTTz/pwoULeu211/TTTz9p7dq1+uqrr4qjRgAAgAJz+czNzTffrB9++EEXLlxQ48aN9dlnn6ly5cpat26dIiMji6NGAACAAnP5zM327dvVqFEjzZkzJ9eyFStWqFu3bkVRFwAAQKG4fOamU6dO2rt3b6725cuXq0+fPkVSFAAAQGG5HG4GDBig6Ohop+9+SkxMVExMjBYuXFiUtQEAALjM5ctSY8aM0YkTJxQdHa2vv/5aq1at0oABA/T222/r/vvvL44aAQAACqxQXwr1+uuvq0+fPmrZsqUOHTqkd955R127di3q2gAAAFxWoHDz4Ycf5mq777779M0336hXr16y2WyOPvfcc0/RVggAAOCCAoWby90BNX/+fM2fP1/SHw/0y87OLpLCAAAACqNA4SYnJ6e46wAAACgSLt8tlZdTp04VxWYAAACumsvhZuLEiUpMTHTMd+/eXRUqVFD16tW1devWIi0OAADAVS6Hm1mzZik4OFiStHr1an3++edatWqV7rzzTv3zn/8s8gIBAABc4fKt4KmpqY5w8/HHH+vBBx/U7bffrtDQULVo0aLICwQAAHCFy2duypcvrwMHDkiSVq1apejoaEmSMYY7pQAAgNu5fObmvvvuU+/evVWnTh39+uuvuvPOOyVJW7ZsUe3atYu8QAAAAFe4HG6mTp2q0NBQHThwQK+88or8/f0lSUeOHNGQIUOKvEAAAABXuBxuSpUqpaeffjpX+5NPPlkkBQEAAFyNQn23lCT99NNPSklJUVZWllM7X78AAADcyeVw88svv+jee+/Vtm3bZLPZZIyR9MdXL0hiUDEAAHArl++WGjZsmGrVqqVjx47Jz89P//vf//T111+refPmSk5OLoYSAQAACs7lMzfr1q3TF198ocDAQHl4eMjDw0M333yzEhIS9MQTT2jLli3FUScAAECBuHzmJjs7W2XKlJEkBQYG6vDhw5KkkJAQ7dy5s2irAwAAcJHLZ24aNWqkrVu3qlatWmrRooVeeeUVeXt7a/bs2QoLCyuOGgEAAArM5XAzcuRInT17VpI0duxYdenSRW3btlXFihWdvlATAADAHVwON506dXL8u3bt2vr555914sQJlS9f3nHHFAAAgLsU+jk3l6pQoUJRbAYAAOCqFTjcPPLIIwXqN3/+/EIXAwAAcLUKHG4WLlyokJAQNWvWzPHgPgAAgJKmwOFm8ODBeuedd7R3717FxsbqoYce4nIUAAAocQr8nJsZM2boyJEjeuaZZ/TRRx8pODhYDz74oD799FPO5AAAgBLDpYf42e129erVS6tXr9ZPP/2khg0basiQIQoNDdWZM2eKq0YAAIACc/kJxY4VPTwcX5zJl2UCAICSwqVwk5mZqXfeeUe33XabbrjhBm3btk3Tp09XSkqK/P39i6tGAACAAivwgOIhQ4Zo6dKlCg4O1iOPPKJ33nlHgYGBxVkbAACAywocbmbNmqWaNWsqLCxMX331lb766qs8+73//vtFVhwAAICrChxuYmJi+HoFAABQ4rn0ED8AAICSrtB3SwEAAJREhBsAAGApJSLczJgxQ6GhofLx8VGLFi20YcOGfPvOmTNHbdu2Vfny5VW+fHlFR0dftj8AAPh7cXu4SUxMVFxcnEaPHq3NmzeradOm6tSpk44dO5Zn/+TkZPXq1Utffvml1q1bp+DgYN1+++06dOjQNa4cAACURG4PN1OmTNHAgQMVGxurBg0aaNasWfLz89P8+fPz7L948WINGTJEERERqlevnubOnaucnBwlJSVd48oBAEBJ5NZwk5WVpU2bNik6OtrR5uHhoejoaK1bt65A2zh37pzOnz+f7zeUZ2ZmKiMjw2kCAADW5dZwc/z4cWVnZysoKMipPSgoSKmpqQXaxrPPPqtq1ao5BaRLJSQkKCAgwDEFBwdfdd0AAKDkcvtlqasxYcIELV26VP/5z3/k4+OTZ5/4+Hilp6c7pgMHDlzjKgEAwLVU4If4FYfAwEB5enrq6NGjTu1Hjx5VlSpVLrvupEmTNGHCBH3++edq0qRJvv3sdrvsdnuR1AsAAEo+t5658fb2VmRkpNNg4IuDg1u1apXveq+88opeeuklrVq1Ss2bN78WpQIAgOuEW8/cSFJcXJz69u2r5s2bKyoqStOmTdPZs2cVGxsr6Y/vtKpevboSEhIkSRMnTtSoUaO0ZMkShYaGOsbm+Pv7y9/f323HAQAASga3h5sePXooLS1No0aNUmpqqiIiIrRq1SrHIOOUlBR5ePx5gmnmzJnKysrSAw884LSd0aNH68UXX7yWpQMAgBLI7eFGkoYOHaqhQ4fmuSw5Odlpft++fcVfEAAAuG5d13dLAQAA/BXhBgAAWArhBgAAWArhBgAAWArhBgAAWArhBgAAWArhBgAAWArhBgAAWArhBgAAWArhBgAAWArhBgAAWArhBgAAWArhBgAAWArhBgAAWArhBgAAWArhBgAAWArhBgAAWArhBgAAWArhBgAAWArhBgAAWArhBgAAWArhBgAAWArhBgAAWArhBgAAWArhBgAAWArhBgAAWArhBgAAWArhBgAAWArhBgAAWArhBgAAWArhBgAAWArhBgAAWArhBgAAWArhBgAAWArhBgAAWArhBgAAWArhBgAAWArhBgAAWArhBgAAWArhBgAAWArhBgAAWIrbw82MGTMUGhoqHx8ftWjRQhs2bMi37//+9z/df//9Cg0Nlc1m07Rp065doQAA4Lrg1nCTmJiouLg4jR49Wps3b1bTpk3VqVMnHTt2LM/+586dU1hYmCZMmKAqVapc42oBAMD1wK3hZsqUKRo4cKBiY2PVoEEDzZo1S35+fpo/f36e/W+66Sa9+uqr6tmzp+x2+zWuFgAAXA/cFm6ysrK0adMmRUdH/1mMh4eio6O1bt06d5UFAACuc17u2vHx48eVnZ2toKAgp/agoCD9/PPPRbafzMxMZWZmOuYzMjKKbNsAAKDkcfuA4uKWkJCggIAAxxQcHOzukgAAQDFyW7gJDAyUp6enjh496tR+9OjRIh0sHB8fr/T0dMd04MCBIts2AAAoedwWbry9vRUZGamkpCRHW05OjpKSktSqVasi24/dblfZsmWdJgAAYF1uG3MjSXFxcerbt6+aN2+uqKgoTZs2TWfPnlVsbKwkKSYmRtWrV1dCQoKkPwYh//TTT45/Hzp0SD/88IP8/f1Vu3Zttx0HAAAoOdwabnr06KG0tDSNGjVKqampioiI0KpVqxyDjFNSUuTh8efJpcOHD6tZs2aO+UmTJmnSpElq3769kpOTr3X5AACgBHJruJGkoUOHaujQoXku+2tgCQ0NlTHmGlQFAACuV5a/WwoAAPy9EG4AAIClEG4AAIClEG4AAIClEG4AAIClEG4AAIClEG4AAIClEG4AAIClEG4AAIClEG4AAIClEG4AAIClEG4AAIClEG4AAIClEG4AAIClEG4AAIClEG4AAIClEG4AAIClEG4AAIClEG4AAIClEG4AAIClEG4AAIClEG4AAIClEG4AAIClEG4AAIClEG4AAIClEG4AAIClEG4AAIClEG4AAIClEG4AAIClEG4AAIClEG4AAIClEG4AAIClEG4AAIClEG4AAIClEG4AAIClEG4AAIClEG4AAIClEG4AAIClEG4AAIClEG4AAIClEG4AAICllIhwM2PGDIWGhsrHx0ctWrTQhg0bLtv/3XffVb169eTj46PGjRvrv//97zWqFAAAlHRuDzeJiYmKi4vT6NGjtXnzZjVt2lSdOnXSsWPH8uy/du1a9erVS/3799eWLVvUrVs3devWTdu3b7/GlQMAgJLI7eFmypQpGjhwoGJjY9WgQQPNmjVLfn5+mj9/fp79X3vtNd1xxx365z//qfr16+ull17SjTfeqOnTp1/jygEAQEnk1nCTlZWlTZs2KTo62tHm4eGh6OhorVu3Ls911q1b59Rfkjp16pRvfwAA8Pfi5c6dHz9+XNnZ2QoKCnJqDwoK0s8//5znOqmpqXn2T01NzbN/ZmamMjMzHfPp6emSpIyMjKspPV85meeKZbvFzdXX4+9wnNfrMUp/j+PkZzZvHGfJ9nf43ZSK5zP24jaNMVfs69Zwcy0kJCRozJgxudqDg4PdUE3JFTDN3RVcGxyndfwdjlHiOK2G47x6p0+fVkBAwGX7uDXcBAYGytPTU0ePHnVqP3r0qKpUqZLnOlWqVHGpf3x8vOLi4hzzOTk5OnHihCpWrCibzXaVR3DtZGRkKDg4WAcOHFDZsmXdXU6x4Tit4+9wjBLHaTUcZ8lljNHp06dVrVq1K/Z1a7jx9vZWZGSkkpKS1K1bN0l/hI+kpCQNHTo0z3VatWqlpKQkDR8+3NG2evVqtWrVKs/+drtddrvdqa1cuXJFUb5blC1b9rr5QbwaHKd1/B2OUeI4rYbjLJmudMbmIrdfloqLi1Pfvn3VvHlzRUVFadq0aTp79qxiY2MlSTExMapevboSEhIkScOGDVP79u01efJk3XXXXVq6dKk2btyo2bNnu/MwAABACeH2cNOjRw+lpaVp1KhRSk1NVUREhFatWuUYNJySkiIPjz9v6mrdurWWLFmikSNH6rnnnlOdOnW0YsUKNWrUyF2HAAAAShC3hxtJGjp0aL6XoZKTk3O1de/eXd27dy/mqkoWu92u0aNH57rEZjUcp3X8HY5R4jithuO0BpspyD1VAAAA1wm3P6EYAACgKBFuAACApRBuAACApRBuAACApRBuSph+/fo5Hmj4V7Nnz9Ytt9yismXLymaz6dSpU9e0tqKU33GeOHFCjz/+uOrWrStfX1/VrFlTTzzxhOM7wa43l3s/Bw0apPDwcPn6+qpSpUrq2rVrvt+pVpJd7hgvMsbozjvvlM1m04oVK65JXUUpv2NMTk52/C7+/vvv6tevnxo3biwvL68rviYlUUGOMzk5WV27dlXVqlVVunRpRUREaPHixde+2KtQkOPcuXOnOnTooKCgIPn4+CgsLEwjR47U+fPnr33BhVCQY7zU7t27VaZMmev6IbeXItxcR86dO6c77rhDzz33nLtLKTaHDx/W4cOHNWnSJG3fvl0LFy7UqlWr1L9/f3eXVuQiIyO1YMEC7dixQ59++qmMMbr99tuVnZ3t7tKK3LRp066rrzspjOzsbPn6+uqJJ55QdHS0u8spNmvXrlWTJk20fPly/fjjj4qNjVVMTIw+/vhjd5dWpEqVKqWYmBh99tln2rlzp6ZNm6Y5c+Zo9OjR7i6tyJ0/f169evVS27Zt3V1KkSkRz7lBwVz8yom8nv1jFY0aNdLy5csd8+Hh4Ro/frweeughXbhwQV5e1vmRffTRRx3/Dg0N1bhx49S0aVPt27dP4eHhbqysaP3www+aPHmyNm7cqKpVq7q7nGJTunRpzZw5U5K0Zs2a6/rM6uX89T9Xw4YN02effab3339fXbp0cVNVRS8sLExhYWGO+ZCQECUnJ+ubb75xY1XFY+TIkapXr546duyotWvXurucIsGZG5R46enpKlu2rKWCzV+dPXtWCxYsUK1atSz1jfXnzp1T7969NWPGjHy/3BbXv/T0dFWoUMHdZRSr3bt3a9WqVWrfvr27SylSX3zxhd59913NmDHD3aUUKet+WsASjh8/rpdeesnpLIeV/Pvf/9Yzzzyjs2fPqm7dulq9erW8vb3dXVaRefLJJ9W6dWt17drV3aVctY8//lj+/v5ObVa8hOjqcS5btkzff/+93njjjeIurUgV9Dhbt26tzZs3KzMzU48++qjGjh17rUq8alc6xl9//VX9+vXTokWLrqsvzywIwg1KrIyMDN11111q0KCBXnzxRXeXUyz69Omj2267TUeOHNGkSZP04IMPas2aNfLx8XF3aVftww8/1BdffKEtW7a4u5Qi0aFDB8dlp4vWr1+vhx56yE0VFQ9XjvPLL79UbGys5syZo4YNG16rEotEQY8zMTFRp0+f1tatW/XPf/5TkyZN0jPPPHMtSy20Kx3jwIED1bt3b7Vr184d5RUrwg1KpNOnT+uOO+5QmTJl9J///EelSpVyd0nFIiAgQAEBAapTp45atmyp8uXL6z//+Y969erl7tKu2hdffKE9e/bkuvvi/vvvV9u2ba+7sWOlS5dW7dq1ndoOHjzopmqKT0GP86uvvtLdd9+tqVOnKiYm5lqVV2QKepwXLxM3aNBA2dnZevTRR/XUU0/J09PzmtR5Na50jF988YU+/PBDTZo0SdIfdzXm5OTIy8tLs2fP1iOPPHJN6y1KhBuUOBkZGerUqZPsdrs+/PBDS5zFKAhjjIwxyszMdHcpRWLEiBEaMGCAU1vjxo01depU3X333W6qCkUhOTlZXbp00cSJEy17yTgvOTk5On/+vHJycq6LcHMl69atc7pM9cEHH2jixIlau3atqlev7sbKrh7hpgRKT0/XDz/84NRWsWJFlSpVSqmpqdq9e7ckadu2bSpTpoxq1qx5XQ7my+s4y5cvrx49eujcuXNatGiRMjIylJGRIUmqVKnSdfkHJa/jTE9P19q1a3X77berUqVKOnjwoCZMmCBfX1917tzZPYVehfx+Zhs1apSrb82aNVWrVq1rVNm19dNPPykrK0snTpzQ6dOnHa9JRESEW+sqSl9++aW6dOmiYcOG6f7771dqaqokydvb+7r8O5SfxYsXq1SpUmrcuLHsdrs2btyo+Ph49ejRwzJnkuvXr+80v3HjRnl4eOT5e3u9IdyUQMnJyWrWrJlTW//+/VWjRg2NGTPG0XbxOumCBQvUr1+/a1likcjrOMPDw7Vnzx5JynU6de/evQoNDb1W5RWZvI4zNjZWqampmjZtmk6ePKmgoCC1a9dOa9euVeXKld1UaeHl9zM7d+5cN1XkHp07d9b+/fsd8xdfE2OMu0oqcm+++abOnTunhIQEJSQkONrbt29/3V1qvBwvLy9NnDhRu3btkjFGISEhGjp0qJ588kl3l4YCsBkr/dYBAIC/PZ5zAwAALIVwAwAALIVwAwAALIVwAwAALIVwAwAALIVwAwAALIVwAwAALIVwAwAALIVwA6BI9OvXT926dXN3GQ7Jycmy2WyXnaz0RF0Af+LrFwBYUuvWrXXkyBHH/LBhw5SRkaEFCxY42qz0XUgA/sSZGwDXxFdffaWoqCjZ7XZVrVpVI0aM0IULFxzLMzMz9cQTT6hy5cry8fHRzTffrO+//96x/OKZmJUrV6pJkyby8fFRy5YttX379jz35+3trSpVqjgmX19f2e12ValSRbt27VJwcLBOnDjhtM7w4cPVtm1bSdLChQtVrlw5rVixQnXq1JGPj486deqkAwcOOK3zwQcf6MYbb5SPj4/CwsI0ZswYp+MCcO0RbgAUu0OHDqlz58666aabtHXrVs2cOVPz5s3TuHHjHH2eeeYZLV++XG+++aY2b96s2rVrq1OnTrkCyD//+U9NnjxZ33//vSpVqqS7775b58+fd6medu3aKSwsTG+//baj7fz581q8eLEeeeQRR9u5c+c0fvx4vfXWW1qzZo1OnTqlnj17OpZ/8803iomJ0bBhw/TTTz/pjTfe0MKFCzV+/HhXXyIARckAQBHo27ev6dq1a57LnnvuOVO3bl2Tk5PjaJsxY4bx9/c32dnZ5syZM6ZUqVJm8eLFjuVZWVmmWrVq5pVXXjHGGPPll18aSWbp0qWOPr/++qvx9fU1iYmJLtc3ceJEU79+fcf88uXLjb+/vzlz5owxxpgFCxYYSea7775z9NmxY4eRZNavX2+MMaZjx47m5ZdfdtrP22+/bapWrXrFegAUH87cACh2O3bsUKtWrWSz2Rxtbdq00ZkzZ3Tw4EHt2bNH58+fV5s2bRzLS5UqpaioKO3YscNpW61atXL8u0KFCqpbt26uPgXRr18/7d69W999952kPy5DPfjggypdurSjj5eXl2666SbHfL169VSuXDnH/rZu3aqxY8fK39/fMQ0cOFBHjhzRuXPnXK4JQNFgQDGAv6XKlSvr7rvv1oIFC1SrVi198sknLt89debMGY0ZM0b33XdfrmU+Pj5FVCkAVxFuABS7+vXra/ny5TLGOM7erFmzRmXKlFGNGjVUsWJFeXt7a82aNQoJCZH0xxiY77//XsOHD3fa1nfffaeaNWtKkk6ePKldu3apfv36haprwIAB6tWrl2rUqKHw8HCnM0eSdOHCBW3cuFFRUVGSpJ07d+rUqVOO/d14443auXOnateuXaj9AygehBsARSY9PV0//PCDU1vFihU1ZMgQTZs2TY8//riGDh2qnTt3avTo0YqLi5OHh4dKly6twYMH65///KcqVKigmjVr6pVXXtG5c+fUv39/p+2NHTtWFStWVFBQkJ5//nkFBgYW+vk6nTp1UtmyZTVu3DiNHTs21/JSpUrp8ccf17/+9S95eXlp6NChatmypSPsjBo1Sl26dFHNmjX1wAMPyMPDQ1u3btX27dudBksDuMbcPegHgDX07dvXSMo19e/f3xhjTHJysrnpppuMt7e3qVKlinn22WfN+fPnHev/9ttv5vHHHzeBgYHGbrebNm3amA0bNjiWXxxQ/NFHH5mGDRsab29vExUVZbZu3Vrg+vIa8PzCCy8YT09Pc/jwYaf2BQsWmICAALN8+XITFhZm7Ha7iY6ONvv373fqt2rVKtO6dWvj6+trypYta6Kioszs2bML+rIBKAY2Y4xxZ7gCgIJITk5Whw4ddPLkSZUrV67Ittu/f3+lpaXpww8/dGpfuHChhg8frlOnThXZvgBcG1yWAvC3lJ6erm3btmnJkiW5gg2A6xvhBsDfUteuXbVhwwb94x//0G233ebucgAUIS5LAQAAS+EhfgAAwFIINwAAwFIINwAAwFIINwAAwFIINwAAwFIINwAAwFIINwAAwFIINwAAwFIINwAAwFL+HwAozFy11XNpAAAAAElFTkSuQmCC",
      "text/plain": [
       "<Figure size 640x480 with 1 Axes>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 233/233 [00:21<00:00, 10.92it/s]\n"
     ]
    }
   ],
   "source": [
    "# separate by loop type\n",
    "# also becareful about masked aas, they need to be discarded \n",
    "model.to('cpu')\n",
    "\n",
    "# repeat but with no sequence information\n",
    "loss_fn = torch.nn.MSELoss()\n",
    "gt_aa_loop_types = defaultdict(list)\n",
    "pred_aa_loop_types = defaultdict(list)\n",
    "gt_aa_list = []\n",
    "pred_aa_list = []\n",
    "all_quantized_indices_dihedral_only = []\n",
    "with torch.no_grad():\n",
    "    for batch in tqdm(dataloader, total=len(dataloader)):\n",
    "        # mask out sequence information\n",
    "        batch_mask = torch.zeros_like(batch['sequence'], dtype=torch.bool)\n",
    "        for tok in dataset.alphabet.standard_toks:\n",
    "            tok_idx = dataset.alphabet.get_idx(tok)\n",
    "            batch_mask |= (batch['sequence'] == tok_idx)\n",
    "        batch['sequence'][batch_mask] = dataset.alphabet.mask_idx\n",
    "        output = model(batch, val=True)\n",
    "\n",
    "        # get masked aa recovery by loop type\n",
    "        curr_idx = 0\n",
    "        for i in range(batch_mask.shape[0]):\n",
    "            loop_type = batch['id'][i].split('_')[1]\n",
    "            num_masked = batch_mask[i].sum().item()\n",
    "            gt_aa_loop_types[loop_type].append(output.true_aa[curr_idx:curr_idx + num_masked].numpy())\n",
    "            pred_aa_loop_types[loop_type].append(output.pred_aa[curr_idx:curr_idx + num_masked].argmax(dim=-1).numpy())\n",
    "            curr_idx += num_masked\n",
    "\n",
    "        gt_aa_list.append(output.true_aa)\n",
    "        pred_aa = output.pred_aa.argmax(dim=-1).numpy()\n",
    "        pred_aa_list.append(pred_aa)\n",
    "        all_quantized_indices_dihedral_only.append(output.quantized_indices)\n",
    "all_quantized_indices_dihedral_only = torch.cat(all_quantized_indices_dihedral_only, dim=0)\n",
    "\n",
    "masked_aa_recovery = np.mean(np.concatenate(pred_aa_list) == np.concatenate(gt_aa_list))\n",
    "print(f\"Masked AA Recovery: {masked_aa_recovery:.4g}\")\n",
    "\n",
    "x = ['L1', 'L2', 'L3', 'L4', 'H1', 'H2', 'H3', 'H4']\n",
    "y = []\n",
    "for loop_type in x:\n",
    "    if loop_type not in gt_aa_loop_types:\n",
    "        y.append(0.0)\n",
    "        print(f\"Loop type {loop_type} not found in dataset, skipping.\")\n",
    "        continue\n",
    "    gt_aa_loop_types[loop_type] = np.concatenate(gt_aa_loop_types[loop_type])\n",
    "    pred_aa_loop_types[loop_type] = np.concatenate(pred_aa_loop_types[loop_type])\n",
    "    masked_aa_recovery = np.mean(pred_aa_loop_types[loop_type] == gt_aa_loop_types[loop_type])\n",
    "    y.append(masked_aa_recovery)\n",
    "    print(f\"Masked AA Recovery for loop type {loop_type}: {masked_aa_recovery:.4g}\")\n",
    "\n",
    "plt.bar(x, y)\n",
    "plt.xlabel(\"Loop Type\")\n",
    "plt.ylabel(\"Masked AA Recovery\")\n",
    "plt.title(\"Masked AA Recovery by Loop Type\")\n",
    "plt.show()\n",
    "\n",
    "# repeat but with no dihedral information\n",
    "loss_fn = torch.nn.MSELoss()\n",
    "all_quantized_indices_sequence_only = []\n",
    "with torch.no_grad():\n",
    "    for batch in tqdm(dataloader, total=len(dataloader)):\n",
    "        # mask out dihedral information\n",
    "        batch_mask = torch.zeros_like(batch['sequence'], dtype=torch.bool)\n",
    "        for tok in dataset.alphabet.standard_toks:\n",
    "            tok_idx = dataset.alphabet.get_idx(tok)\n",
    "            batch_mask |= (batch['sequence'] == tok_idx)\n",
    "        batch['angles_mask'] = batch_mask # True = masked out\n",
    "        output = model(batch, val=True)\n",
    "        all_quantized_indices_sequence_only.append(output.quantized_indices)\n",
    "all_quantized_indices_sequence_only = torch.cat(all_quantized_indices_sequence_only, dim=0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "e2f87924",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Codebook size: 8192 num used: 790\n"
     ]
    }
   ],
   "source": [
    "unique, counts = np.unique(all_quantized_indices.numpy(), return_counts=True)\n",
    "print(\"Codebook size:\", model.codebook_size, \"num used:\", len(unique))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "c1d3ab9b",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Number of canonical clusters: 134\n",
      "Number of VQVAE clusters: 790\n"
     ]
    }
   ],
   "source": [
    "ground_truth = [loop_to_canonical[x['loop_id']] for x in dataset.data]\n",
    "ground_truth_strict = [loop_to_canonical_strict[x['loop_id']] for x in dataset.data]\n",
    "ground_truth_ssc_comparison = [loop_to_canonical_ssc_comparison[x['loop_id']] for x in dataset.data]\n",
    "clusters_to_indices = {}\n",
    "for i, cluster in enumerate(set(ground_truth)):\n",
    "    clusters_to_indices[cluster] = i\n",
    "\n",
    "results_df = pd.DataFrame({\n",
    "    'loop_id': [x['loop_id'] for x in dataset.data],\n",
    "    'loop_type': [x['loop_id'].split('_')[1] for x in dataset.data],\n",
    "    'loop_length': [len(x['loop_sequence']) for x in dataset.data],\n",
    "    'canonical_cluster': ground_truth,\n",
    "    'canonical_cluster_strict': ground_truth_strict,\n",
    "    'canonical_cluster_ssc_comparison': ground_truth_ssc_comparison,\n",
    "    'quantized_index': all_quantized_indices.tolist(),\n",
    "    'quantized_index_dihedral_only': all_quantized_indices_dihedral_only.tolist(),\n",
    "    'quantized_index_sequence_only': all_quantized_indices_sequence_only.tolist(),\n",
    "})\n",
    "\n",
    "print(\"Number of canonical clusters:\", len(set(ground_truth)))\n",
    "print(\"Number of VQVAE clusters:\", len(set(all_quantized_indices.numpy())))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "232a8366",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Weighted loop type purity: 0.9807\n",
      "Weighted loop length purity: 0.9604\n",
      "Number of clusters with multiple loop lengths: 128 out of 790\n",
      "Clusters with multiple loop lengths that have H3: 101 out of 346\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>quantized_index</th>\n",
       "      <th>cluster_size</th>\n",
       "      <th>loop_type</th>\n",
       "      <th>loop_length</th>\n",
       "      <th>num_loop_lengths</th>\n",
       "      <th>has_H3</th>\n",
       "      <th>loop_type_purity</th>\n",
       "      <th>loop_length_purity</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>198</th>\n",
       "      <td>2044</td>\n",
       "      <td>26</td>\n",
       "      <td>{'H3': 26}</td>\n",
       "      <td>{16: 8, 19: 1, 14: 9, 20: 1, 12: 2, 10: 2, 9: 3}</td>\n",
       "      <td>7</td>\n",
       "      <td>True</td>\n",
       "      <td>1.000000</td>\n",
       "      <td>0.346154</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>522</th>\n",
       "      <td>5363</td>\n",
       "      <td>8</td>\n",
       "      <td>{'H3': 4, 'H2': 3, 'L3': 1}</td>\n",
       "      <td>{13: 3, 10: 2, 16: 1, 14: 1, 9: 1}</td>\n",
       "      <td>5</td>\n",
       "      <td>True</td>\n",
       "      <td>0.500000</td>\n",
       "      <td>0.375000</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>755</th>\n",
       "      <td>7809</td>\n",
       "      <td>11</td>\n",
       "      <td>{'L3': 1, 'H3': 10}</td>\n",
       "      <td>{11: 1, 16: 8, 12: 1, 19: 1}</td>\n",
       "      <td>4</td>\n",
       "      <td>True</td>\n",
       "      <td>0.909091</td>\n",
       "      <td>0.727273</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>102</th>\n",
       "      <td>1031</td>\n",
       "      <td>36</td>\n",
       "      <td>{'H3': 36}</td>\n",
       "      <td>{20: 1, 14: 8, 15: 25, 16: 2}</td>\n",
       "      <td>4</td>\n",
       "      <td>True</td>\n",
       "      <td>1.000000</td>\n",
       "      <td>0.694444</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>521</th>\n",
       "      <td>5351</td>\n",
       "      <td>8</td>\n",
       "      <td>{'L1': 8}</td>\n",
       "      <td>{12: 2, 13: 1, 15: 3, 11: 2}</td>\n",
       "      <td>4</td>\n",
       "      <td>False</td>\n",
       "      <td>1.000000</td>\n",
       "      <td>0.375000</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>259</th>\n",
       "      <td>2687</td>\n",
       "      <td>5</td>\n",
       "      <td>{'H3': 5}</td>\n",
       "      <td>{14: 1, 11: 2, 15: 1, 12: 1}</td>\n",
       "      <td>4</td>\n",
       "      <td>True</td>\n",
       "      <td>1.000000</td>\n",
       "      <td>0.400000</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>166</th>\n",
       "      <td>1719</td>\n",
       "      <td>7</td>\n",
       "      <td>{'L1': 1, 'H3': 4, 'L3': 2}</td>\n",
       "      <td>{14: 3, 10: 1, 13: 2, 11: 1}</td>\n",
       "      <td>4</td>\n",
       "      <td>True</td>\n",
       "      <td>0.571429</td>\n",
       "      <td>0.428571</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>620</th>\n",
       "      <td>6260</td>\n",
       "      <td>4</td>\n",
       "      <td>{'H3': 3, 'L3': 1}</td>\n",
       "      <td>{19: 1, 12: 1, 18: 1, 11: 1}</td>\n",
       "      <td>4</td>\n",
       "      <td>True</td>\n",
       "      <td>0.750000</td>\n",
       "      <td>0.250000</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>38</th>\n",
       "      <td>364</td>\n",
       "      <td>4</td>\n",
       "      <td>{'L1': 2, 'L3': 2}</td>\n",
       "      <td>{13: 1, 10: 1, 14: 1, 12: 1}</td>\n",
       "      <td>4</td>\n",
       "      <td>False</td>\n",
       "      <td>0.500000</td>\n",
       "      <td>0.250000</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>271</th>\n",
       "      <td>2824</td>\n",
       "      <td>13</td>\n",
       "      <td>{'H3': 13}</td>\n",
       "      <td>{19: 1, 13: 2, 15: 8, 16: 2}</td>\n",
       "      <td>4</td>\n",
       "      <td>True</td>\n",
       "      <td>1.000000</td>\n",
       "      <td>0.615385</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>538</th>\n",
       "      <td>5523</td>\n",
       "      <td>14</td>\n",
       "      <td>{'H3': 10, 'H2': 4}</td>\n",
       "      <td>{14: 7, 15: 2, 16: 1, 11: 4}</td>\n",
       "      <td>4</td>\n",
       "      <td>True</td>\n",
       "      <td>0.714286</td>\n",
       "      <td>0.500000</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>132</th>\n",
       "      <td>1238</td>\n",
       "      <td>9</td>\n",
       "      <td>{'H3': 3, 'L3': 6}</td>\n",
       "      <td>{18: 1, 12: 6, 19: 2}</td>\n",
       "      <td>3</td>\n",
       "      <td>True</td>\n",
       "      <td>0.666667</td>\n",
       "      <td>0.666667</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>377</th>\n",
       "      <td>3818</td>\n",
       "      <td>9</td>\n",
       "      <td>{'H3': 9}</td>\n",
       "      <td>{15: 5, 14: 3, 16: 1}</td>\n",
       "      <td>3</td>\n",
       "      <td>True</td>\n",
       "      <td>1.000000</td>\n",
       "      <td>0.555556</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>758</th>\n",
       "      <td>7829</td>\n",
       "      <td>9</td>\n",
       "      <td>{'L3': 9}</td>\n",
       "      <td>{10: 2, 9: 6, 11: 1}</td>\n",
       "      <td>3</td>\n",
       "      <td>False</td>\n",
       "      <td>1.000000</td>\n",
       "      <td>0.666667</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>623</th>\n",
       "      <td>6264</td>\n",
       "      <td>3</td>\n",
       "      <td>{'H3': 3}</td>\n",
       "      <td>{10: 1, 13: 1, 16: 1}</td>\n",
       "      <td>3</td>\n",
       "      <td>True</td>\n",
       "      <td>1.000000</td>\n",
       "      <td>0.333333</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>469</th>\n",
       "      <td>4835</td>\n",
       "      <td>9</td>\n",
       "      <td>{'H3': 9}</td>\n",
       "      <td>{20: 7, 16: 1, 23: 1}</td>\n",
       "      <td>3</td>\n",
       "      <td>True</td>\n",
       "      <td>1.000000</td>\n",
       "      <td>0.777778</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>147</th>\n",
       "      <td>1436</td>\n",
       "      <td>3</td>\n",
       "      <td>{'H3': 3}</td>\n",
       "      <td>{15: 1, 19: 1, 14: 1}</td>\n",
       "      <td>3</td>\n",
       "      <td>True</td>\n",
       "      <td>1.000000</td>\n",
       "      <td>0.333333</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>383</th>\n",
       "      <td>3869</td>\n",
       "      <td>3</td>\n",
       "      <td>{'H3': 2, 'H4': 1}</td>\n",
       "      <td>{9: 1, 8: 1, 7: 1}</td>\n",
       "      <td>3</td>\n",
       "      <td>True</td>\n",
       "      <td>0.666667</td>\n",
       "      <td>0.333333</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>677</th>\n",
       "      <td>6905</td>\n",
       "      <td>14</td>\n",
       "      <td>{'H3': 14}</td>\n",
       "      <td>{15: 9, 14: 4, 16: 1}</td>\n",
       "      <td>3</td>\n",
       "      <td>True</td>\n",
       "      <td>1.000000</td>\n",
       "      <td>0.642857</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>95</th>\n",
       "      <td>978</td>\n",
       "      <td>5</td>\n",
       "      <td>{'H3': 3, 'L3': 2}</td>\n",
       "      <td>{14: 2, 13: 2, 12: 1}</td>\n",
       "      <td>3</td>\n",
       "      <td>True</td>\n",
       "      <td>0.600000</td>\n",
       "      <td>0.400000</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "     quantized_index  cluster_size                    loop_type  \\\n",
       "198             2044            26                   {'H3': 26}   \n",
       "522             5363             8  {'H3': 4, 'H2': 3, 'L3': 1}   \n",
       "755             7809            11          {'L3': 1, 'H3': 10}   \n",
       "102             1031            36                   {'H3': 36}   \n",
       "521             5351             8                    {'L1': 8}   \n",
       "259             2687             5                    {'H3': 5}   \n",
       "166             1719             7  {'L1': 1, 'H3': 4, 'L3': 2}   \n",
       "620             6260             4           {'H3': 3, 'L3': 1}   \n",
       "38               364             4           {'L1': 2, 'L3': 2}   \n",
       "271             2824            13                   {'H3': 13}   \n",
       "538             5523            14          {'H3': 10, 'H2': 4}   \n",
       "132             1238             9           {'H3': 3, 'L3': 6}   \n",
       "377             3818             9                    {'H3': 9}   \n",
       "758             7829             9                    {'L3': 9}   \n",
       "623             6264             3                    {'H3': 3}   \n",
       "469             4835             9                    {'H3': 9}   \n",
       "147             1436             3                    {'H3': 3}   \n",
       "383             3869             3           {'H3': 2, 'H4': 1}   \n",
       "677             6905            14                   {'H3': 14}   \n",
       "95               978             5           {'H3': 3, 'L3': 2}   \n",
       "\n",
       "                                          loop_length  num_loop_lengths  \\\n",
       "198  {16: 8, 19: 1, 14: 9, 20: 1, 12: 2, 10: 2, 9: 3}                 7   \n",
       "522                {13: 3, 10: 2, 16: 1, 14: 1, 9: 1}                 5   \n",
       "755                      {11: 1, 16: 8, 12: 1, 19: 1}                 4   \n",
       "102                     {20: 1, 14: 8, 15: 25, 16: 2}                 4   \n",
       "521                      {12: 2, 13: 1, 15: 3, 11: 2}                 4   \n",
       "259                      {14: 1, 11: 2, 15: 1, 12: 1}                 4   \n",
       "166                      {14: 3, 10: 1, 13: 2, 11: 1}                 4   \n",
       "620                      {19: 1, 12: 1, 18: 1, 11: 1}                 4   \n",
       "38                       {13: 1, 10: 1, 14: 1, 12: 1}                 4   \n",
       "271                      {19: 1, 13: 2, 15: 8, 16: 2}                 4   \n",
       "538                      {14: 7, 15: 2, 16: 1, 11: 4}                 4   \n",
       "132                             {18: 1, 12: 6, 19: 2}                 3   \n",
       "377                             {15: 5, 14: 3, 16: 1}                 3   \n",
       "758                              {10: 2, 9: 6, 11: 1}                 3   \n",
       "623                             {10: 1, 13: 1, 16: 1}                 3   \n",
       "469                             {20: 7, 16: 1, 23: 1}                 3   \n",
       "147                             {15: 1, 19: 1, 14: 1}                 3   \n",
       "383                                {9: 1, 8: 1, 7: 1}                 3   \n",
       "677                             {15: 9, 14: 4, 16: 1}                 3   \n",
       "95                              {14: 2, 13: 2, 12: 1}                 3   \n",
       "\n",
       "     has_H3  loop_type_purity  loop_length_purity  \n",
       "198    True          1.000000            0.346154  \n",
       "522    True          0.500000            0.375000  \n",
       "755    True          0.909091            0.727273  \n",
       "102    True          1.000000            0.694444  \n",
       "521   False          1.000000            0.375000  \n",
       "259    True          1.000000            0.400000  \n",
       "166    True          0.571429            0.428571  \n",
       "620    True          0.750000            0.250000  \n",
       "38    False          0.500000            0.250000  \n",
       "271    True          1.000000            0.615385  \n",
       "538    True          0.714286            0.500000  \n",
       "132    True          0.666667            0.666667  \n",
       "377    True          1.000000            0.555556  \n",
       "758   False          1.000000            0.666667  \n",
       "623    True          1.000000            0.333333  \n",
       "469    True          1.000000            0.777778  \n",
       "147    True          1.000000            0.333333  \n",
       "383    True          0.666667            0.333333  \n",
       "677    True          1.000000            0.642857  \n",
       "95     True          0.600000            0.400000  "
      ]
     },
     "execution_count": 11,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "lengths_per_cluster = results_df.groupby('quantized_index').agg(cluster_size=('loop_length', 'count'), loop_type=('loop_type', Counter), loop_length=('loop_length', Counter), num_loop_lengths=('loop_length', 'nunique')).reset_index().sort_values('num_loop_lengths', ascending=False)\n",
    "lengths_per_cluster['has_H3'] = lengths_per_cluster['loop_type'].apply(lambda x: 'H3' in x)\n",
    "\n",
    "lengths_per_cluster['loop_type_purity'] = lengths_per_cluster['loop_type'].apply(lambda x: max(x.values()) / sum(x.values()))\n",
    "loop_type_purity = (lengths_per_cluster['loop_type_purity'] * lengths_per_cluster['cluster_size']).sum() / lengths_per_cluster['cluster_size'].sum()\n",
    "print(f\"Weighted loop type purity: {loop_type_purity:.4g}\")\n",
    "\n",
    "lengths_per_cluster['loop_length_purity'] = lengths_per_cluster['loop_length'].apply(lambda x: max(x.values()) / sum(x.values()))\n",
    "loop_length_purity = (lengths_per_cluster['loop_length_purity'] * lengths_per_cluster['cluster_size']).sum() / lengths_per_cluster['cluster_size'].sum()\n",
    "print(f\"Weighted loop length purity: {loop_length_purity:.4g}\")\n",
    "print(\"Number of clusters with multiple loop lengths:\", len(lengths_per_cluster[lengths_per_cluster['num_loop_lengths'] > 1]), \"out of\", len(lengths_per_cluster))\n",
    "\n",
    "clusters_of_diff_length_with_H3 = len(lengths_per_cluster[(lengths_per_cluster['num_loop_lengths'] > 1) & (lengths_per_cluster['has_H3'])])\n",
    "print(f\"Clusters with multiple loop lengths that have H3: {clusters_of_diff_length_with_H3} out of {len(lengths_per_cluster[lengths_per_cluster['has_H3']])}\")\n",
    "\n",
    "lengths_per_cluster.head(20)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "54f1fad5",
   "metadata": {},
   "outputs": [],
   "source": [
    "# ground_truth_indices = []\n",
    "# test_indicies = []\n",
    "# dataset_indices = []\n",
    "# for i, (gt, test) in enumerate(zip(ground_truth, all_quantized_indices)):\n",
    "#     if gt.endswith(\"-*\"):\n",
    "#         continue\n",
    "#     ground_truth_indices.append(clusters_to_indices[gt])\n",
    "#     test_indicies.append(test.item())\n",
    "#     dataset_indices.append(i)\n",
    "\n",
    "# nmi = normalized_mutual_info_score(ground_truth_indices, test_indicies)\n",
    "# print(f\"NMI: {nmi:.3f}\")\n",
    "# ari = adjusted_rand_score(ground_truth_indices, test_indicies)\n",
    "# print(f\"ARI: {ari:.3f}\")\n",
    "\n",
    "# all_angles = np.array([dataset[i]['angles'] for i in range(len(dataset))])\n",
    "# all_tokens = np.array([dataset[i]['sequence'] for i in range(len(dataset))])\n",
    "# all_loop_coords = np.array([dataset[i]['loop_c_alpha_coords'] for i in range(len(dataset))])\n",
    "# all_stem_coords = np.array([dataset[i]['stem_c_alpha_coords'] for i in range(len(dataset))])\n",
    "\n",
    "# special_tokens_mask = (\n",
    "#     (all_tokens == alphabet.cls_idx) | (all_tokens == alphabet.eos_idx) | (all_tokens == alphabet.padding_idx)\n",
    "# )\n",
    "\n",
    "# all_angles_with_canonical = all_angles[dataset_indices]\n",
    "\n",
    "# print(\"\\nFor Kelow Clusters\")\n",
    "# correct, angle1, angle2, angle3 = eval_clusters_length_independent(\n",
    "#     all_angles[dataset_indices], all_loop_coords[dataset_indices], all_stem_coords[dataset_indices],\n",
    "#     np.array(ground_truth_indices), ~special_tokens_mask[dataset_indices])\n",
    "# print(f\"Proportion of pairs in a cluster that are within 0.47 radians of each other: {correct:4f}\")\n",
    "# print(f\"Angle variance of the clusters: {angle1:4f}, {angle2:4f}, {angle3:4f}\")\n",
    "\n",
    "# print(\"\\nFor VQVAE Clusters\")\n",
    "# correct, angle1, angle2, angle3 = eval_clusters_length_independent(\n",
    "#     all_angles[dataset_indices], all_loop_coords[dataset_indices], all_stem_coords[dataset_indices],\n",
    "#     np.array(test_indicies), ~special_tokens_mask[dataset_indices])\n",
    "# print(f\"Proportion of pairs in a cluster that are within 0.47 radians of each other: {correct:4f}\")\n",
    "# print(f\"Angle variance of the clusters: {angle1:4f}, {angle2:4f}, {angle3:4f}\")\n",
    "\n",
    "# print(\"\\nFor VQVAE Clusters (including noise)\")\n",
    "# correct, angle1, angle2, angle3 = eval_clusters_length_independent(\n",
    "#     all_angles, all_loop_coords, all_stem_coords,\n",
    "#     all_quantized_indices.numpy(), ~special_tokens_mask)\n",
    "# print(f\"Proportion of pairs in a cluster that are within 0.47 radians of each other: {correct:4f}\")\n",
    "# print(f\"Angle variance of the clusters: {angle1:4f}, {angle2:4f}, {angle3:4f}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7417e54e",
   "metadata": {},
   "outputs": [],
   "source": [
    "# this is quite intensive to run, can crash kernel\n",
    "# dataset_dihedral_distance = dihedral_distance_pairwise(all_angles, mask=~special_tokens_mask)\n",
    "# sil_score = silhouette_score(dataset_dihedral_distance, all_quantized_indices, metric='precomputed')\n",
    "# print(f\"Silhouette Score: {sil_score:.3f}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fa6856de",
   "metadata": {},
   "outputs": [],
   "source": [
    "# proportion_within_threshold_kelow, circular_variance_kelow = eval_clusters(np.array([dataset[i]['angles'] for i in dataset_indices]), np.array(ground_truth_indices), return_mean=False)\n",
    "# proportion_within_threshold, circular_variance = eval_clusters(np.array([dataset[i]['angles'] for i in range(len(dataset))]), all_quantized_indices.numpy(), return_mean=False)\n",
    "# clusters, counts = np.unique(all_quantized_indices.numpy(), return_counts=True)\n",
    "# clusters_var_df = pd.DataFrame({\n",
    "#     'cluster': clusters,\n",
    "#     'cluster_size': counts,\n",
    "#     'proportion_within_threshold': proportion_within_threshold,\n",
    "#     'phi': circular_variance[:, 0],\n",
    "#     'psi': circular_variance[:, 1],\n",
    "#     'omega': circular_variance[:, 2],\n",
    "# })\n",
    "# clusters_var_df.sort_values(by='proportion_within_threshold', ascending=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d1f576d4",
   "metadata": {},
   "outputs": [],
   "source": [
    "# fig, axes = plt.subplots(nrows=1, ncols=5, figsize=(16, 4))\n",
    "# bins = np.linspace(0, 1, 50)\n",
    "# axes[0].hist(proportion_within_threshold_kelow, bins=bins, alpha=0.5, label='Kelow Clusters', density=True)\n",
    "# axes[0].hist(proportion_within_threshold, bins=bins, alpha=0.5, label='VQVAE Clusters', density=True)\n",
    "# axes[0].set_xlabel('Proportion of pairs within threshold')\n",
    "# axes[0].set_ylabel('Frequency')\n",
    "# axes[0].set_title('Distribution of Proportion of Pairs \\n Within Threshold across Clusters')\n",
    "# axes[0].legend()\n",
    "\n",
    "# angle_bins = np.linspace(0, 0.5, 50)\n",
    "# axes[1].hist(circular_variance_kelow[:, 0], bins=angle_bins, alpha=0.5, label='Kelow Clusters', density=True)\n",
    "# axes[1].hist(circular_variance[:, 0], bins=angle_bins, alpha=0.5, label='VQVAE Clusters', density=True)\n",
    "# axes[1].set_xlabel('Circular Variance (radians)')\n",
    "# axes[1].set_ylabel('Frequency')\n",
    "# axes[1].set_title('Distribution of Circular Variance \\n (phi) across Clusters')\n",
    "# axes[2].hist(circular_variance_kelow[:, 1], bins=angle_bins, alpha=0.5, label='Kelow Clusters', density=True)\n",
    "# axes[2].hist(circular_variance[:, 1], bins=angle_bins, alpha=0.5, label='VQVAE Clusters', density=True)\n",
    "# axes[2].set_xlabel('Circular Variance (radians)')\n",
    "# axes[2].set_ylabel('Frequency')\n",
    "# axes[2].set_title('Distribution of Circular Variance \\n (psi) across Clusters')\n",
    "# axes[3].hist(circular_variance_kelow[:, 2], bins=angle_bins, alpha=0.5, label='Kelow Clusters', density=True)\n",
    "# axes[3].hist(circular_variance[:, 2], bins=angle_bins, alpha=0.5, label='VQVAE Clusters', density=True)\n",
    "# axes[3].set_xlabel('Circular Variance (radians)')\n",
    "# axes[3].set_ylabel('Frequency')\n",
    "# axes[3].set_title('Distribution of Circular Variance \\n (omega) across Clusters')\n",
    "\n",
    "\n",
    "# _, cluster_counts = np.unique(np.array(ground_truth_indices, dtype=np.int64, copy=True), return_counts=True)\n",
    "# _, vqvae_counts = np.unique(np.array(all_quantized_indices, dtype=np.int64, copy=True), return_counts=True)\n",
    "# count_bins = np.linspace(0, max(max(cluster_counts), max(vqvae_counts)), 50)\n",
    "# axes[4].hist(cluster_counts, bins=count_bins, alpha=0.5, label='Kelow Clusters', density=True)\n",
    "# axes[4].hist(vqvae_counts, bins=count_bins, alpha=0.5, label='VQVAE Clusters', density=True)\n",
    "# axes[4].set_xlabel('Cluster Size')\n",
    "# axes[4].set_ylabel('Frequency')\n",
    "# axes[4].set_title('Distribution of Cluster Sizes')\n",
    "# axes[4].legend()\n",
    "\n",
    "# plt.tight_layout()\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "62a12be9",
   "metadata": {},
   "source": [
    "## How pure are the VQVAE clusters compared to canonical clusters"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "b0997f80",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Numuber of unique canonical clusters: 134\n",
      "Num noise clusters: 64\n"
     ]
    }
   ],
   "source": [
    "print(\"Numuber of unique canonical clusters:\", results_df['canonical_cluster'].nunique())\n",
    "print(\"Num noise clusters:\", len(results_df[results_df['canonical_cluster'].str.endswith(\"-*\")]['canonical_cluster'].unique()))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "id": "dc32530d",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Purity of VQVAE clusters wrt canonical clusters (D=0.61): 0.869\n",
      "Purity of VQVAE clusters wrt canonical clusters (D=0.47): 0.873\n",
      "Purity of VQVAE clusters wrt canonical clusters (D=0.1): 0.888\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>loop_type</th>\n",
       "      <th>prop_correct (D=0.61)</th>\n",
       "      <th>prop_correct (D=0.47)</th>\n",
       "      <th>prop_correct (D=0.1)</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>H1</td>\n",
       "      <td>0.837381</td>\n",
       "      <td>0.855670</td>\n",
       "      <td>0.930470</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>H2</td>\n",
       "      <td>0.912534</td>\n",
       "      <td>0.912844</td>\n",
       "      <td>0.906780</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>H3</td>\n",
       "      <td>0.748031</td>\n",
       "      <td>0.705882</td>\n",
       "      <td>0.586207</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>H4</td>\n",
       "      <td>0.991254</td>\n",
       "      <td>0.997006</td>\n",
       "      <td>0.996575</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>L1</td>\n",
       "      <td>0.883316</td>\n",
       "      <td>0.883495</td>\n",
       "      <td>0.879643</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>5</th>\n",
       "      <td>L2</td>\n",
       "      <td>0.996933</td>\n",
       "      <td>0.972222</td>\n",
       "      <td>0.750000</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>6</th>\n",
       "      <td>L3</td>\n",
       "      <td>0.790373</td>\n",
       "      <td>0.796249</td>\n",
       "      <td>0.835928</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>7</th>\n",
       "      <td>L4</td>\n",
       "      <td>0.892857</td>\n",
       "      <td>0.892157</td>\n",
       "      <td>0.895833</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "  loop_type  prop_correct (D=0.61)  prop_correct (D=0.47)  \\\n",
       "0        H1               0.837381               0.855670   \n",
       "1        H2               0.912534               0.912844   \n",
       "2        H3               0.748031               0.705882   \n",
       "3        H4               0.991254               0.997006   \n",
       "4        L1               0.883316               0.883495   \n",
       "5        L2               0.996933               0.972222   \n",
       "6        L3               0.790373               0.796249   \n",
       "7        L4               0.892857               0.892157   \n",
       "\n",
       "   prop_correct (D=0.1)  \n",
       "0              0.930470  \n",
       "1              0.906780  \n",
       "2              0.586207  \n",
       "3              0.996575  \n",
       "4              0.879643  \n",
       "5              0.750000  \n",
       "6              0.835928  \n",
       "7              0.895833  "
      ]
     },
     "execution_count": 13,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "agg_results = None\n",
    "for dihedral_cutoff, canonical_cluster_key in [(0.61, 'canonical_cluster_ssc_comparison'), (0.47, 'canonical_cluster'), (0.1, 'canonical_cluster_strict')]:\n",
    "    vqvae_to_canonical = (\n",
    "    results_df\n",
    "        .groupby('quantized_index')\n",
    "        .agg(\n",
    "            cluster_size = ('quantized_index', 'size'),\n",
    "            canonical_cluster_nunique = (canonical_cluster_key, 'nunique'),\n",
    "            canonical_cluster_set = (canonical_cluster_key, Counter),\n",
    "        ).reset_index().sort_values('canonical_cluster_nunique', ascending=False)\n",
    "    )\n",
    "    vqvae_to_canonical['most_common_canonical_cluster'] = vqvae_to_canonical['canonical_cluster_set'].apply(lambda x: x.most_common(1)[0][0] if x else None)\n",
    "    vqvae_to_canonical['canonical_cluster_set'] = vqvae_to_canonical['canonical_cluster_set'].apply(lambda x: sorted(x.items(), key=lambda y: y[1], reverse=True))\n",
    "    vqvae_to_canonical_map = vqvae_to_canonical.set_index('quantized_index')['most_common_canonical_cluster'].to_dict()\n",
    "    \n",
    "    results_df[f'{canonical_cluster_key}_vqvae'] = results_df['quantized_index'].map(vqvae_to_canonical_map)\n",
    "    mask = ~results_df[canonical_cluster_key].str.endswith(\"-*\")\n",
    "    purity = (results_df[canonical_cluster_key][mask] == results_df[f'{canonical_cluster_key}_vqvae'][mask]).sum() / len(results_df[mask])\n",
    "    print(f\"Purity of VQVAE clusters wrt canonical clusters (D={dihedral_cutoff}): {purity:.3f}\")\n",
    "\n",
    "    results_df['correct_assignment'] = results_df[canonical_cluster_key] == results_df[f'{canonical_cluster_key}_vqvae']\n",
    "    agg_results_ = results_df[~results_df[canonical_cluster_key].str.endswith(\"-*\")].groupby(['loop_type']).agg(prop_correct=('correct_assignment', 'mean')).reset_index()\n",
    "    agg_results_.rename(columns={'prop_correct': f'prop_correct (D={dihedral_cutoff})'}, inplace=True)\n",
    "    if agg_results is None:\n",
    "        agg_results = agg_results_\n",
    "    else:\n",
    "        agg_results = agg_results.merge(agg_results_, on='loop_type')\n",
    "agg_results"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e0d77e80",
   "metadata": {},
   "source": [
    "### How pure are the clusters if we use dihedral only"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "id": "04816e77",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Purity of VQVAE clusters wrt canonical clusters (D=0.61): 0.858\n",
      "Purity of VQVAE clusters wrt canonical clusters (D=0.47): 0.859\n",
      "Purity of VQVAE clusters wrt canonical clusters (D=0.1): 0.874\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>loop_type</th>\n",
       "      <th>prop_correct (D=0.61)</th>\n",
       "      <th>prop_correct (D=0.47)</th>\n",
       "      <th>prop_correct (D=0.1)</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>H1</td>\n",
       "      <td>0.845829</td>\n",
       "      <td>0.847652</td>\n",
       "      <td>0.934560</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>H2</td>\n",
       "      <td>0.905320</td>\n",
       "      <td>0.909174</td>\n",
       "      <td>0.920097</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>H3</td>\n",
       "      <td>0.730315</td>\n",
       "      <td>0.712418</td>\n",
       "      <td>0.724138</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>H4</td>\n",
       "      <td>0.962099</td>\n",
       "      <td>0.946108</td>\n",
       "      <td>0.982877</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>L1</td>\n",
       "      <td>0.871034</td>\n",
       "      <td>0.851133</td>\n",
       "      <td>0.857355</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>5</th>\n",
       "      <td>L2</td>\n",
       "      <td>0.992331</td>\n",
       "      <td>0.994792</td>\n",
       "      <td>0.750000</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>6</th>\n",
       "      <td>L3</td>\n",
       "      <td>0.771739</td>\n",
       "      <td>0.773231</td>\n",
       "      <td>0.780838</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>7</th>\n",
       "      <td>L4</td>\n",
       "      <td>0.875000</td>\n",
       "      <td>0.872549</td>\n",
       "      <td>0.864583</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "  loop_type  prop_correct (D=0.61)  prop_correct (D=0.47)  \\\n",
       "0        H1               0.845829               0.847652   \n",
       "1        H2               0.905320               0.909174   \n",
       "2        H3               0.730315               0.712418   \n",
       "3        H4               0.962099               0.946108   \n",
       "4        L1               0.871034               0.851133   \n",
       "5        L2               0.992331               0.994792   \n",
       "6        L3               0.771739               0.773231   \n",
       "7        L4               0.875000               0.872549   \n",
       "\n",
       "   prop_correct (D=0.1)  \n",
       "0              0.934560  \n",
       "1              0.920097  \n",
       "2              0.724138  \n",
       "3              0.982877  \n",
       "4              0.857355  \n",
       "5              0.750000  \n",
       "6              0.780838  \n",
       "7              0.864583  "
      ]
     },
     "execution_count": 14,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "agg_results = None\n",
    "for dihedral_cutoff, canonical_cluster_key in [(0.61, 'canonical_cluster_ssc_comparison'), (0.47, 'canonical_cluster'), (0.1, 'canonical_cluster_strict')]:\n",
    "    vqvae_to_canonical = (\n",
    "    results_df\n",
    "        .groupby('quantized_index_dihedral_only')\n",
    "        .agg(\n",
    "            cluster_size = ('quantized_index_dihedral_only', 'size'),\n",
    "            canonical_cluster_nunique = (canonical_cluster_key, 'nunique'),\n",
    "            canonical_cluster_set = (canonical_cluster_key, Counter),\n",
    "        ).reset_index().sort_values('canonical_cluster_nunique', ascending=False)\n",
    "    )\n",
    "    vqvae_to_canonical['most_common_canonical_cluster'] = vqvae_to_canonical['canonical_cluster_set'].apply(lambda x: x.most_common(1)[0][0] if x else None)\n",
    "    vqvae_to_canonical['canonical_cluster_set'] = vqvae_to_canonical['canonical_cluster_set'].apply(lambda x: sorted(x.items(), key=lambda y: y[1], reverse=True))\n",
    "    vqvae_to_canonical_map = vqvae_to_canonical.set_index('quantized_index_dihedral_only')['most_common_canonical_cluster'].to_dict()\n",
    "    \n",
    "    results_df[f'{canonical_cluster_key}_vqvae'] = results_df['quantized_index_dihedral_only'].map(vqvae_to_canonical_map)\n",
    "    mask = ~results_df[canonical_cluster_key].str.endswith(\"-*\")\n",
    "    purity = (results_df[canonical_cluster_key][mask] == results_df[f'{canonical_cluster_key}_vqvae'][mask]).sum() / len(results_df[mask])\n",
    "    print(f\"Purity of VQVAE clusters wrt canonical clusters (D={dihedral_cutoff}): {purity:.3f}\")\n",
    "\n",
    "    results_df['correct_assignment'] = results_df[canonical_cluster_key] == results_df[f'{canonical_cluster_key}_vqvae']\n",
    "    agg_results_ = results_df[~results_df[canonical_cluster_key].str.endswith(\"-*\")].groupby(['loop_type']).agg(prop_correct=('correct_assignment', 'mean')).reset_index()\n",
    "    agg_results_.rename(columns={'prop_correct': f'prop_correct (D={dihedral_cutoff})'}, inplace=True)\n",
    "    if agg_results is None:\n",
    "        agg_results = agg_results_\n",
    "    else:\n",
    "        agg_results = agg_results.merge(agg_results_, on='loop_type')\n",
    "agg_results"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "bea31137",
   "metadata": {},
   "source": [
    "### How pure are the clusters if we use sequence only"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "id": "56fd1dae",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Purity of VQVAE clusters wrt canonical clusters (D=0.61): 0.855\n",
      "Purity of VQVAE clusters wrt canonical clusters (D=0.47): 0.865\n",
      "Purity of VQVAE clusters wrt canonical clusters (D=0.1): 0.862\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>loop_type</th>\n",
       "      <th>prop_correct (D=0.61)</th>\n",
       "      <th>prop_correct (D=0.47)</th>\n",
       "      <th>prop_correct (D=0.1)</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>H1</td>\n",
       "      <td>0.819430</td>\n",
       "      <td>0.838488</td>\n",
       "      <td>0.856851</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>H2</td>\n",
       "      <td>0.902615</td>\n",
       "      <td>0.908257</td>\n",
       "      <td>0.898305</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>H3</td>\n",
       "      <td>0.596457</td>\n",
       "      <td>0.552288</td>\n",
       "      <td>0.224138</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>H4</td>\n",
       "      <td>0.997085</td>\n",
       "      <td>1.000000</td>\n",
       "      <td>0.996575</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>L1</td>\n",
       "      <td>0.889458</td>\n",
       "      <td>0.889968</td>\n",
       "      <td>0.875186</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>5</th>\n",
       "      <td>L2</td>\n",
       "      <td>0.995399</td>\n",
       "      <td>0.994792</td>\n",
       "      <td>0.125000</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>6</th>\n",
       "      <td>L3</td>\n",
       "      <td>0.807453</td>\n",
       "      <td>0.804774</td>\n",
       "      <td>0.827545</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>7</th>\n",
       "      <td>L4</td>\n",
       "      <td>0.875000</td>\n",
       "      <td>0.872549</td>\n",
       "      <td>0.875000</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "  loop_type  prop_correct (D=0.61)  prop_correct (D=0.47)  \\\n",
       "0        H1               0.819430               0.838488   \n",
       "1        H2               0.902615               0.908257   \n",
       "2        H3               0.596457               0.552288   \n",
       "3        H4               0.997085               1.000000   \n",
       "4        L1               0.889458               0.889968   \n",
       "5        L2               0.995399               0.994792   \n",
       "6        L3               0.807453               0.804774   \n",
       "7        L4               0.875000               0.872549   \n",
       "\n",
       "   prop_correct (D=0.1)  \n",
       "0              0.856851  \n",
       "1              0.898305  \n",
       "2              0.224138  \n",
       "3              0.996575  \n",
       "4              0.875186  \n",
       "5              0.125000  \n",
       "6              0.827545  \n",
       "7              0.875000  "
      ]
     },
     "execution_count": 15,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "agg_results = None\n",
    "for dihedral_cutoff, canonical_cluster_key in [(0.61, 'canonical_cluster_ssc_comparison'), (0.47, 'canonical_cluster'), (0.1, 'canonical_cluster_strict')]:\n",
    "    vqvae_to_canonical = (\n",
    "    results_df\n",
    "        .groupby('quantized_index_sequence_only')\n",
    "        .agg(\n",
    "            cluster_size = ('quantized_index_sequence_only', 'size'),\n",
    "            canonical_cluster_nunique = (canonical_cluster_key, 'nunique'),\n",
    "            canonical_cluster_set = (canonical_cluster_key, Counter),\n",
    "        ).reset_index().sort_values('canonical_cluster_nunique', ascending=False)\n",
    "    )\n",
    "    vqvae_to_canonical['most_common_canonical_cluster'] = vqvae_to_canonical['canonical_cluster_set'].apply(lambda x: x.most_common(1)[0][0] if x else None)\n",
    "    vqvae_to_canonical['canonical_cluster_set'] = vqvae_to_canonical['canonical_cluster_set'].apply(lambda x: sorted(x.items(), key=lambda y: y[1], reverse=True))\n",
    "    vqvae_to_canonical_map = vqvae_to_canonical.set_index('quantized_index_sequence_only')['most_common_canonical_cluster'].to_dict()\n",
    "    \n",
    "    results_df[f'{canonical_cluster_key}_vqvae'] = results_df['quantized_index_sequence_only'].map(vqvae_to_canonical_map)\n",
    "    mask = ~results_df[canonical_cluster_key].str.endswith(\"-*\")\n",
    "    purity = (results_df[canonical_cluster_key][mask] == results_df[f'{canonical_cluster_key}_vqvae'][mask]).sum() / len(results_df[mask])\n",
    "    print(f\"Purity of VQVAE clusters wrt canonical clusters (D={dihedral_cutoff}): {purity:.3f}\")\n",
    "\n",
    "    results_df['correct_assignment'] = results_df[canonical_cluster_key] == results_df[f'{canonical_cluster_key}_vqvae']\n",
    "    agg_results_ = results_df[~results_df[canonical_cluster_key].str.endswith(\"-*\")].groupby(['loop_type']).agg(prop_correct=('correct_assignment', 'mean')).reset_index()\n",
    "    agg_results_.rename(columns={'prop_correct': f'prop_correct (D={dihedral_cutoff})'}, inplace=True)\n",
    "    if agg_results is None:\n",
    "        agg_results = agg_results_\n",
    "    else:\n",
    "        agg_results = agg_results.merge(agg_results_, on='loop_type')\n",
    "agg_results"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "945da014",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
