{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import pandas as pd\n",
    "import torch\n",
    "import lstnn.transformer_main as transformer_main\n",
    "import seaborn as sns\n",
    "import matplotlib.pyplot as plt\n",
    "import scipy.stats as stats\n",
    "plt.rcParams['font.sans-serif'] = \"Arial\"\n",
    "sns.set_style(\"ticks\")\n",
    "from lstnn.dataset import get_dataset\n",
    "import numpy as np\n",
    "import os\n",
    "%matplotlib inline\n",
    "%load_ext autoreload\n",
    "%autoreload 2\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Experimental parameters"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "model_label = 'Transformer'\n",
    "curriculum = 'All'\n",
    "seeds = [2235, 6312, 6068, 9742, 8880, 2197, 669, 6256, 3309, 2541, 8643, 7785, 195, 6914, 29]\n",
    "# seeds = [2235, 6312, 6068, 9742, 8880, 2197, 669, 6256, 3309] #, 2541, 8643, 7785, 195, 6914, 29]\n",
    "# seeds = [2235, 6312, 6068, 9742]\n",
    "#seeds = [6914, 29]\n",
    "device = 'mps'\n",
    "\n",
    "# model params\n",
    "nblocks = [4] #, 5]\n",
    "attnheads = [1] #, 4, 8]\n",
    "# wdecays = [0.0, 0.01, 0.05, 0.1, 0.2] #wdecay = 0.0\n",
    "wdecays = [0.0] #, 0.1] #wdecay = 0.0\n",
    "dropout = 0.0\n",
    "hidden_size = 160 #160\n",
    "learning_rate = 0.0001\n",
    "training_acc_cutoff = 0.0\n",
    "cutoff_length = 0  # how many epochs must the model sustain the accuracy cutoff?\n",
    "last_epoch = 4000\n",
    "checkpoint_freq = 200\n",
    "\n",
    "initializations = [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, \n",
    "                   1.0, 1.1, 1.2, 1.3, 1.4, 1.5, 1.6, 1.7, 1.8, 1.9, 2.0]\n",
    "outputdir = '../figures/manuscript_figures/perturbation_analysis/'\n",
    "if not os.path.exists(outputdir):\n",
    "    os.makedirs(outputdir)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Robustness to noise"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "inference for epoch 4000\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/Users/tito/mambaforge/envs/lstnn/lib/python3.9/site-packages/torch/utils/checkpoint.py:542: UserWarning: torch.utils.checkpoint.checkpoint_sequential: please pass in use_reentrant=True or use_reentrant=False explicitly. The default value of use_reentrant will be updated to be False in the future. To maintain current behavior, pass use_reentrant=True. It is recommended that you use use_reentrant=False. Refer to docs for more details on the differences between the two variants.\n",
      "  warnings.warn(\n",
      "/Users/tito/mambaforge/envs/lstnn/lib/python3.9/site-packages/torch/utils/checkpoint.py:90: UserWarning: None of the inputs have requires_grad=True. Gradients will be None\n",
      "  warnings.warn(\n",
      "/Users/tito/mambaforge/envs/lstnn/lib/python3.9/site-packages/torch/utils/checkpoint.py:542: UserWarning: torch.utils.checkpoint.checkpoint_sequential: please pass in use_reentrant=True or use_reentrant=False explicitly. The default value of use_reentrant will be updated to be False in the future. To maintain current behavior, pass use_reentrant=True. It is recommended that you use use_reentrant=False. Refer to docs for more details on the differences between the two variants.\n",
      "  warnings.warn(\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "inference for epoch 4000\n",
      "inference for epoch 4000\n",
      "inference for epoch 4000\n",
      "inference for epoch 4000\n",
      "inference for epoch 4000\n"
     ]
    }
   ],
   "source": [
    "\n",
    "# load the 108 fMRI trials\n",
    "validation_file = '../data/nn/puzzle_data_original.csv'\n",
    "dataloader = torch.utils.data.DataLoader(\n",
    "    get_dataset(validation_file), batch_size=108, shuffle=False)\n",
    "\n",
    "# ann_accuracy = np.zeros((108, len(seeds), nblocks, attnheads)) # may want to add epochs\n",
    "ann_accuracy = {}\n",
    "ann_accuracy['Accuracy'] = []\n",
    "ann_accuracy['Noise'] = []\n",
    "ann_accuracy['Epoch'] = []\n",
    "ann_accuracy['Seed'] = []\n",
    "ann_accuracy['Layers'] = []\n",
    "ann_accuracy['Heads'] = []\n",
    "ann_accuracy['Dropout'] = []\n",
    "ann_accuracy['Decay'] = []\n",
    "ann_accuracy['Puzzle'] = []\n",
    "ann_accuracy['PE'] = []\n",
    "ann_accuracy['Norm'] = []\n",
    "dropout = 0.0\n",
    "# for epoch in np.arange(0, last_epoch+1, checkpoint_freq):\n",
    "for noise_sd in [0.0, 0.02, 0.04, 0.06, 0.08, 0.1]: #[0.0, 0.05, 0.1, 0.15, 0.2, 0.25, 0.3, 0.35, 0.4]:\n",
    "    for epoch in [4000]:\n",
    "        print('inference for epoch', epoch)\n",
    "        for init in initializations:\n",
    "            pestr = 'learn-' + str(init)\n",
    "            petype = 'learn'\n",
    "            for wdecay in wdecays:\n",
    "                for layer in nblocks:\n",
    "                    for attnhead in [1]: #attnheads:\n",
    "                            resultdir = f\"../results/\"\n",
    "                            modelname = f\"model-{model_label}_\" \\\n",
    "                                        f\"pe-learn-{init}_\" \\\n",
    "                                        f\"nl-{layer}_\" \\\n",
    "                                        f\"do-{dropout}_\" \\\n",
    "                                        f\"wd-{wdecay}_\" \\\n",
    "                                        f\"at-{attnhead}_\" \\\n",
    "                                        f\"hs-{hidden_size}_\" \\\n",
    "                                        f\"curr-{curriculum}_\" \\\n",
    "                                        f\"lr-{learning_rate}_\" \\\n",
    "                                        f\"co-{training_acc_cutoff}_\" \\\n",
    "                                        f\"col-{cutoff_length}/\"\n",
    "                            for seed in seeds:\n",
    "                                try: \n",
    "                                    checkpoint = f\"s-{seed}_\" \\\n",
    "                                                f\"e-{epoch}\" \n",
    "                                    torch.manual_seed(seed)\n",
    "                                    model = transformer_main.Transformer(\n",
    "                                                nblocks=layer,\n",
    "                                                nhead=attnhead,\n",
    "                                                dropout=dropout,\n",
    "                                                embedding_dim=hidden_size,\n",
    "                                                positional_encoding=petype,\n",
    "                                                pe_init=init)\n",
    "                                    model = model.to(device=torch.device('mps'))\n",
    "                                    model.load_state_dict(torch.load(resultdir + modelname + checkpoint +'.pt',map_location=torch.device('mps') ))\n",
    "                                except:\n",
    "                                    continue\n",
    "\n",
    "                                with torch.no_grad():\n",
    "                                    for i, batch in enumerate(dataloader):\n",
    "\n",
    "                                        # get features\n",
    "                                        test_features, test_labels, index = batch[0], batch[1], batch[2]\n",
    "\n",
    "                                        # flatten to accommodate transformer\n",
    "                                        test_features = torch.flatten(test_features,start_dim=1,end_dim=2)\n",
    "                                        test_features = test_features.to(device)\n",
    "                                        test_labels = test_labels.to(device)\n",
    "                                        #model_history = tl.log_forward_pass(model, test_features, vis_opt='none')\n",
    "                                        #model_history = tl.log_forward_pass(model, test_features, vis_opt='rolled', vis_outpath='transformer.svg')\n",
    "\n",
    "                                        # add noise\n",
    "                                        test_features = test_features + torch.empty(test_features.size(),device='mps').normal_(mean=0,std=noise_sd)\n",
    "\n",
    "                                        # calculate norm of learned PEs\n",
    "                                        norm = []\n",
    "                                        for block in model.blocks:\n",
    "                                            norm.append(torch.norm(block.pe.positional_encoding).cpu().item())\n",
    "\n",
    "\n",
    "                                        # Compute prediction and loss\n",
    "                                        out = model(test_features)\n",
    "                                        accuracy = torch.argmax(out, dim=1) == torch.argmax(\n",
    "                                                        test_labels, dim=1)\n",
    "                                        accuracy = accuracy.cpu().numpy() * 1.0\n",
    "                                        ann_accuracy['Accuracy'].extend(accuracy)\n",
    "                                        ann_accuracy['Noise'].extend(np.repeat(noise_sd,len(accuracy)))\n",
    "                                        ann_accuracy['Puzzle'].extend(np.arange(len(accuracy)))\n",
    "                                        ann_accuracy['Seed'].extend(np.repeat(seed,len(accuracy)))\n",
    "                                        ann_accuracy['Epoch'].extend(np.repeat(epoch,len(accuracy)))\n",
    "                                        ann_accuracy['Layers'].extend(np.repeat(layer,len(accuracy)))\n",
    "                                        ann_accuracy['Heads'].extend(np.repeat(attnhead,len(accuracy)))\n",
    "                                        ann_accuracy['Dropout'].extend(np.repeat(dropout,len(accuracy)))\n",
    "                                        ann_accuracy['Decay'].extend(np.repeat(wdecay,len(accuracy)))\n",
    "                                        ann_accuracy['PE'].extend(np.repeat(pestr,len(accuracy)))\n",
    "                                        ann_accuracy['Norm'].extend(np.repeat(np.mean(norm),len(accuracy)))\n",
    "                                    # ann_accuracy[:, s] = accuracy.copy()\n",
    "ann_accuracy = pd.DataFrame(ann_accuracy)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Plot norm of learned PEs per init"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAG4AAAB3CAYAAADignWoAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjguMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8g+/7EAAAACXBIWXMAAA9hAAAPYQGoP6dpAAAQDElEQVR4nO2de0xTdxTHvy0dKAg4FZ9lvgIqoJmCVjcfgJnbmFFxImzQacEITjDifDGMbCDTLJuooHMBJwbnBKMOdUbnNFM3xSkhTCuSQDBAjYJBhaL0efaH6Z3lIW3pbblwPwmh9/be3+/cfHvO797ze1wBERF4OIfQ3gbwWAYvHEfhheMovHAchReOo/DCcRReOI7CC8dReOE4isjeBlhCcnIy7ty5A41Gg+rqaowdOxYAsHTpUkRGRppUxsKFC1FYWNhlWzIzMwEACQkJRvulUikePnwIZ2dnAIBWq8Wnn36KyMhI3LhxA3FxcXjrrbeMzpHJZFi0aJFpFROHqampoaCgILvasGfPHtqzZ0+b/VFRUVRUVMRs19XVUUBAAFVUVFBRURFFRUV1qV5OetzrCA4OxsSJE3Hv3j0cOnQIv/zyC65fv47GxkZ4eHhg586d8PDwwLhx41BeXo7MzEw8evQI1dXVUCgUmD17NlJSUgAABw8exOnTp6HVajF16lQkJSVBJBIhJycH+fn5GDBgANzc3DBp0qRO7fLw8MCoUaNQUVGB/v37d/k6e2QbN3PmTJw/fx4qlQoVFRU4evQozp07B7FYjNOnT7c5vqysDDk5OTh16hT++OMPlJeX49q1aygpKcGxY8dQWFgItVqN/Px83L59G8eOHcPJkyeRm5uLuro6k2ySy+WoqqqCr68vAODOnTtYuHCh0V9tba3J19jjPA4AJk+eDAAYOXIkNm3ahIKCAlRVVaG4uBgjRoxoc/yMGTPg6OgIR0dHjBw5Es+ePcPVq1dRWlqKxYsXAwBUKhVEIhFaWloQGBiIfv36AQDmzZsHvV7frh1btmyBs7Mz9Ho9+vbti7S0NIjFYigUCvj5+SEvL8/ia+yRwvXp0wcA8O+//+KLL75ATEwMPvjgAzg4OIDa6cVycnJiPgsEAhARdDodli1bhujoaABAU1MTBAIBCgoKjMoQiURQq9Xt2rFt2zZIJBJrXhpDjwyVBoqLiyGRSBAREYFRo0bh8uXL0Ol0Jp07ffp0/Prrr1AqldDpdEhMTMTx48cxY8YMXLp0CY2NjVCpVLhw4QLLV9E+PdLjDISEhGD16tWYN28enJyc4Ofnh5qaGpPODQ4ORnl5OcLDw6HT6TBt2jRERkZCJBJBJpNhyZIlcHd3x/Dhwy2yzdDGvUpgYCASExNNOl9A7cUOnm5Pjw6VPRleOI7CC8dReOE4Ci8cR+nRjwOvcuPGDfzwww9wdXVFZWUlhgwZgoyMDJSWliIjIwM6nQ6enp5ITU3FoEGDjHKe6enp+Prrr+Ht7Q25XA4fHx9IJBKcOHECT58+RVZWFry8vJCVlYXz589Dp9MhICAAqamp7F1Ql1LUHKKoqIjefvttUigUREQUFxdHmZmZNHPmTKquriYiouzsbEpISCAioqCgICooKCCil70Q3t7edPv2bdJqtTR37lz67rvviIho9+7dlJ6eTo2NjTR9+nTSarWk1Wpp8+bN9PDhQ9aup1eFSi8vL+aBecKECQCAiRMnwtPTEwAQHh6OoqIi5nhDzhN4md338/ODg4MDhg0bhhkzZgAAxGIxGhsb4erqCm9vb4SFhWHv3r2Ijo7GkCFDWLuWXiVc65zkxYsXIRAImH1EBI1Gw2wbcp4A4OjoaFSWg4NDm/Jzc3ORnJwMvV6PmJgY/PPPP9Y034heJVxrJk2ahNLSUiYNlp+fj2nTpllUVm1tLRYsWAAfHx+sXbsW7777LsrLy61prhG95uakPQYNGoTU1FTEx8dDq9Vi2LBh2LZtm0VlicVihISEIDQ0FM7Ozhg+fDhCQ0OtbPH/8LlKjtKrQyWX4YXjKLxwHIUXjqPwwnEUzglHRFAqle0O+ulNsC5cQkICamtrUVJSgrCwMEilUuaBNyMjAxEREcwAVFNobm6Gv78/mpub2TLZpuj1elRVVZl9HmvCqdVqrF69GqWlpQCA/fv3Izs7GykpKcjOzsaDBw9QXV2No0ePwsXFBSUlJR2Wo1Qqjf56ClVVVZgzZw7GjBmD7Oxss85lLXOiVquxfPlyHD9+HADQ0tKC/v37o3///rh//z7u3r2LKVOmAAACAgJQWlpqlNQ18OOPPyIrK4stM+0CEeHQoUNYs2YNmpqa4OzsDA8PD7PKMFs4jUaDsrIyqFQqZt/UqVPbHNevXz9MnTqVEe7V0b6Gdsowk6Vv374dhr7Y2FjIZDJmW6lUYs6cOeaa3W2or69HbGwsTp48CeBl70RaWhrmz59vVjlmCxcdHQ1XV1e4u7sz+9oTrjWvZuGFQiFcXFzw6NEjAMCLFy+YId2tMQwN7wmcOXMGMTExqKurg0gkwooVK7Bly5Z2h8V3htnCOTg4YN++fWZX1KdPHzQ0NKChoQEjRozAhAkT8NtvvyEqKgq3bt1CcHCw2WVyBaVSiXXr1jHt2OjRo5GWloawsDCLf5RmCzdr1izk5+djzJgxzD5TPO7zzz9HXFwcBAIBvv32W4jFYowcOZIZHu7v72+uKZzg+vXrkEqlqKyshEAgQHh4ODMMoiuY3TsQFxcHoVBoFCq3b9/eJSPMQalUwt/fH8XFxR2G1+6AWq1Gamoqtm/fDr1ej8GDByMlJQXR0dFGHbQWY+5Yh+XLl1tx5IT5NDU1kbe3NzU1NdnVjtchl8tpypQpBIAA0Pvvv083b94kvV5vtTrMDpWenp7Izc2Fj48Pc8NhSqjsDej1emRmZmLTpk1QqVRwc3PD5s2bER8fD1dXV6vWZdHjQHl5uVG3PC8cUFNTA5lMhosXLwIAJBIJ0tPTERQUBKGQhTyHuS66detWq7m7JXTHUPnzzz+Tu7s7ASAnJyfasGEDPX78mNU6zf4p1NXVWZRb64k0NDTgk08+QWRkJJ49ewYfHx8cOXIE27dvx8CBA1mt2+xQWVtbi5iYGGbKrWGYW2/jwoULkMlkUCgUcHBwgEwmQ0pKCsRisW0MsMRN9Xo91dfXk0ajsXIA6Bx7h8rnz5/TmjVrmDtGT09Pys3NJZVKZVM7zBauqKiI5s2bR0uXLqXg4GD6888/2bCrQ+wp3K1bt2j8+PGMaIsXL6a7d+/a3A4iC4SLiIhgGt76+nr6+OOPrW7U67CHcBqNhtLS0kgkEhEAGjhwIO3evZuam5ttZkNrzG7jhEIh0/AOGjTIaFh3T6SyshKfffYZrl27BgAICgpCeno6pk+fbpQ4tzVmCzd8+HDs2rULAQEBKC4uxrBhw9iwy+4QEQ4cOIC1a9eiubkZLi4uWL9+PdauXWuVJZ26jLkuqtFo6PDhw/TVV19RXl6ezRtlW4TKhw8f0oIFC5i2bPLkyXTmzBnS6XSs1WkuJnvczZs3mc/e3t7w8vKCQCBAaWlpj8qcnDp1CitWrEB9fT3eeOMNrFq1Cl9++SWrU6YswWThTpw40WbflStXoNPpjOaUcZWmpiasW7cOOTk5AICxY8cyfWYiUTecG2OJm9bV1VFsbCytWrWK6uvrrR0FXgsbofLvv/+mMWPGEAASCAQUGRlJFRUVViufDcwWrrCwkN577z0qLCxkw55OsaZwKpWKkpOTSSgUEgAaOnQo7d+/n1paWqxgKbuYLNzjx49p9erVFBsbS3V1dWza9FqsJdzdu3eN+sw+/PBDKi4utpKV7GNyD7hEIoGDg0O7I6y41AOu1+uxd+9ebNy4ES0tLXBzc0NSUhLi4+O7dY96a0xudQ2LRnMZhUKB6Oho/P777wBeLm2YlpaGuXPn2vVh2iIsddVdu3ZZz+/NwNJQmZ+fT2+++SbTZ7Zx40bW+8zYxOL73OLiYiv+fNjj6dOnSEhIwOHDhwG8XCYjLS0NixYtanflBK5gsXCGUcjdmUuXLmHZsmWora2FUCiETCbD1q1b26z3z0nMcc+qqiqqra012ieXy60aAjrDlFD54sULSkxMZO4YxWIx/fTTTzZPz7GJycJlZWVRREQEhYWFUUpKCpO3k0qlrBnXHp0JV1JSQr6+voxooaGhdPv2bZvaaAtMHnNy5coVHDlyBAUFBXB3d8eWLVsMHstKJDAXnU6HHTt2YNq0aZDL5RgwYAB27tyJw4cPw8/Pz97mWR2ThdPr9cyMm8TERAgEAmRkZHSL2+jKykrMnj0bSUlJ0Gg0mDNnDgoLC5GYmMiJttgiTHXNY8eOUUhICDU0NBARkVarpfXr15Ovry9b0aBdXg2Ver2eDhw4QC4uLgSAXFxcKCUlhZ48eWJTm+yBWTcnSqWyTZ/U+PHjrWpQZxiEu3//PoWGhjJt2ZQpU+j06dPdqs+MTbq8XuXixYstPlelUlFcXByFh4dTfn6+SecYhPvoo48IAIlEIoqPj6cHDx5YbAcX6fLY6K60cWfPnsWsWbNw5MgRnDt3zmiWa2cEBgZCIpHg4MGD2LlzZ48dQtERJj+AJyUltbtfoVBYXLlcLseSJUsgFArh7e2NyspK+Pj4GB2jVquN3l3T1NQEAIiKisLKlSshFAqhUqnMEr274+Li0qlDmCxcR0v4dWVpv+bmZqN54M+fP29zTEeT97k8D7wzTOn5MFk4SxfgfB3Ozs548eIFgI7ngbeevK/X6/Hs2TOIRCIEBgbi8uXLrHfHGBYMsEVdwEuP6wy7Dqbw9fXFzZs34eXlhbKysnZfCNTe5H03NzdmvZN+/frZrB/NlnV1hl2XhAoJCcG1a9ewZMkS5o1TPKZhV49zcnLC3r177WkCZ+HcImwGHB0dER8fb5M1UGxZl6nwazJzFM56XG+HF46j8MJxFM4Ip1arsWrVKkRERKCgoKDN94YFTa1JR2WyUZe5cEa4jhLSrRc0tQYdlclGXZbCGeHkcjn8/f2NEtLA/wuavvPOO1arq6My2ajLUjgjXEcJacOCptakozLZqMtSOCOcKQnp3gRnhDMkpIkIZWVlGD16tL1NsiucEa51QjovLw8VFRU2qTsnJ8dmdZkKn/LiKJzxOB5jeOE4Ci8cR+GF4yi8cByFF46j8MJxlG641pH1uXHjBtatW8e8naSlpQUymQwhISEIDg5u826b3bt3Y8CAAcx2YmIiMjIy2i37ypUrePz4MebPn4+zZ89i0aJFrF2HEXaduWAjioqKaNOmTcz2kydPKCgoiIiI+d9VampqKCoqyiplmUKv8LjWGN7ZZirBwcG4dOkSpFIpxo8fj3v37kGv12Pfvn24ePEiFAoF6uvrUVZWhpycHKxYsYJF61/Sa4S7evUqpFIpBAIB+vbti2+++Yb5TiqVMp8HDx6M77//vsNyJBIJkpOTkZycjL/++ovZv3LlSlRVVdlENKAXCTdr1izs2LGj3e/y8vJMLmfcuHEAgKFDh9p1hhB/V2kmHU1/EgqFRm+lZJte43Gv49VQCQAbNmzApEmTzCpj4MCBUCqVyMrKwuTJk1FWVsZq2OS7dTgKHyo5Ci8cR+GF4yi8cByFF46j8MJxFF44jsILx1F44TgKLxxH+Q8sECEYkZSITQAAAABJRU5ErkJggg==",
      "text/plain": [
       "<Figure size 120x130 with 1 Axes>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "tmpdf = ann_accuracy.groupby(['PE','Seed']).mean(numeric_only=True).reset_index()\n",
    "plt.figure(figsize=(1.2,1.3))\n",
    "# ax = sns.boxplot(data=tmpdf,x=\"PE\",hue=\"PE\",y=\"Norm\",fliersize=1,palette='rocket')\n",
    "ax = sns.lineplot(data=tmpdf,x=\"PE\",y=\"Norm\",color='k')\n",
    "# cbar = ax.collections[0].colorbar\n",
    "# # Label the colorbar\n",
    "# cbar.set_label('Accuracy', fontsize=8)  # Change 'Label' to your desired label text\n",
    "# # Change the size of tick labels\n",
    "# cbar.ax.tick_params(labelsize=8)\n",
    "# plt.xticks(np.arange(0, len(initializations)),initializations,rotation=90,fontsize=7);\n",
    "plt.xticks(np.arange(0,len(initializations),10),np.asarray(initializations)[::10],fontsize=6)\n",
    "plt.yticks(fontsize=6);\n",
    "plt.ylabel('L2-Norm', size=7,labelpad=1.5)\n",
    "plt.xlabel('PE init.',fontsize=7)\n",
    "plt.title('Trained PE\\nnorms',fontsize=8)\n",
    "sns.despine()\n",
    "plt.tight_layout()\n",
    "plt.savefig(f'{outputdir}norm_learnablePE.pdf',transparent=True,dpi=300)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "noise_levels = ann_accuracy.Noise.unique()\n",
    "# models = ann_accuracy.Model.unique()\n",
    "Layers = 4\n",
    "Heads = 1\n",
    "Decay = 0.0\n",
    "noise_performance = np.zeros((len(initializations),len(noise_levels)))\n",
    "\n",
    "i = 0\n",
    "for init in initializations:\n",
    "    pestr = 'learn-' + str(init)\n",
    "    j = 0\n",
    "    for noise in noise_levels:\n",
    "        tmpdf = ann_accuracy.loc[(ann_accuracy.PE==pestr) & (ann_accuracy.Noise==noise)]\n",
    "        noise_performance[i, j] = tmpdf.Accuracy.mean()\n",
    "        j += 1\n",
    "    i += 1"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAQkAAACMCAYAAABvYly7AAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjguMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8g+/7EAAAACXBIWXMAAA9hAAAPYQGoP6dpAAAh/klEQVR4nO3de1RU5foH8O8wXAIZFS31ICcJw0tSEhZ6oOJmqSSgIF5BEgklEVRWC/KWcklTjx4RU7RS1LwDlh4yzdLjBSXTBLmkoiigggrIVYYZ3t8f/pgcmBnejUMgPp+1WEv3vNfZ8MyevZ/9bhFjjIEQQtTQaesBEELaNwoShBCNKEgQQjSiIEEI0YiCBCFEIwoShBCNKEgQQjSiIEEI0YiCBCFEo+c+SBQUFMDKygoeHh4YM2YM3N3dMXnyZFy5ckVQO7GxsTh//rygOs7OzigoKOAun56ejpUrVwIAjh07hrVr1wrqr7UJnQ/Q8nkUFBTA2dlZcD1tOHfuHHx9fVW+1r9/f+52mvvdi4iIgKOjIzw8PJR+ysvLtTIPXrp/a2/tVI8ePfD9998r/r9r1y4sWrQIe/bs4W7jt99+w9ChQ1tjeArXrl3DgwcPAAAuLi5wcXFp1f7+Dh1lHi3V3O9eSEgIPD0922p4AChIqGRra4sVK1YAAPLz87FkyRKUlJRAX18f4eHhsLGxQUREBEpLS3Hr1i14eXnh8uXLWLhwIWJjY/HFF18gODgYQ4cORUFBAaZOnYpffvlFqU5oaCgAYP369cjJyYGenh4iIyMxYMAAXLlyBVFRUaipqUFJSQkCAgIwatQoxMbGorq6GnFxcTA1NUVaWhqWL1+O9PR0REdH49GjRzAxMUFkZCT69OkDX19fDB48GOfPn0dxcTGCg4Ob/MJVVlZi/vz5KCoqwr179zBs2DDExMQgLS0NGzZsgEQiQW5uLnr27Ik1a9aga9eu2LFjB3744QfU1NRAT08PK1euRN++fRVt+vr6IiAgAA4ODgAANzc3xMXF4eDBg/jpp58gl8vx1ltvITIyEklJSYp5xMXFNXmdR0lJCZYsWaI4igkODoazszOKioqwYMEClJeX4/79+3Bzc8PcuXORlJSE5ORklJWV4Z133kFpaSkkEgmysrJw584djB8/HjNnzlT73gBAaWkpAgICcPfuXbzxxhtYsmQJ9PX1FWOqqalBdHQ0srOzIZPJ4OvrC29vb0G/e5rcvHkT4eHhqKmpgVgsxsKFC2FjY8P1fgnGnnP5+fnMyclJ8X+5XM5WrVrF/P39GWOMTZ48mV2+fJkxxtjNmzeZs7Mzq6urY+Hh4SwsLExRz8fHh509e7bJv59sv3EdJycntn79esYYY8ePH2ceHh6MMcaio6PZqVOnGGOMFRQUMGtra8YYY4mJiSw8PFzp31KplDk5ObGLFy8yxhhLSUlhnp6einFERkYyxhjLzMxktra2TeZ/8OBBxRjq6urYBx98wDIyMtjZs2eZtbU1KywsZIwxNnPmTLZt2zZWUVHBfH19WU1NDWOMsXXr1in6cHJyYvn5+Sw5OZnNnTuXMcZYRkYGmzx5MisvL2fDhg1jMpmMyWQyFhERwe7evauYh7rXefZbWFgYO3r0KGOMsQcPHrDhw4ez+/fvs6+//prt37+fMcZYZWUls7GxYQ8ePGCJiYnM2dmZSaVSxX4JCgpicrmcFRUVscGDB7OHDx9qfG/eeOMNduPGDVZfX89CQ0PZ1q1bGWOM9evXjzHG2OrVq9mWLVsYY4xVVVWxMWPGsOzsbLVzYKzp7154eDhzcHBg7u7uip9PP/2UMcbYF198wRISEhS/O5s3b1b7Xj0tOpIAUFxcDA8PDwCAVCqFpaUlIiMjUVVVhUuXLmH+/PmKsnV1dbhz5w4A4M033xTcV+M6Xl5eAAAHBwd8+umnKC8vR3h4OE6cOIH4+Hjk5OSgurpabXt5eXmQSCSwtrYGAIwaNQqLFy9GRUWFol0AGDhwIMrKyprUHz16NC5cuICtW7ciNzcXJSUliv4sLS1hamqqqP/w4UMYGxtj1apVOHToEPLy8nDq1CkMHDhQqc0RI0Zg5cqVqKysRHJyMjw9PSGRSNCvXz94e3vD0dER/v7+6Nmzp6JOc69rcvLkSVy5cgXr1q0DAMhkMly/fh3+/v44c+YMvv76a1y9ehVSqRQ1NTUAACsrK+jp6SnaePfdd6Gjo4MePXrAxMQEFRUVGt+bIUOGwNzcHMDjI6WkpCT4+fkpjammpgbJyckAHh+xXblyBQMGDFAau7rfvQbqvm689957CAsLQ3p6Ot577z2150i0gYIEmn4vbFBRUQF9fX2l14qKivDSSy8BAAwNDdW2yf7/DnyZTKa0vXEdsVisVEdXVxehoaHo2rUrnJyc4OrqipSUFLX9yOVyiESiJn039GtgYAAATco0SEhIwJEjRzBx4kTY2dnh6tWrirE31G2ozxjD7du34ePjAz8/Pzg4OOCll15CdnZ2kzk6Ozvj8OHD+PXXXzFv3jwAwNatW3HhwgWcPHkS06dPx6pVq5TqqXrd1tZW7dwb1NfXIyEhASYmJgAe/+F169YNy5YtQ2FhIdzc3PD+++8jNTVVMbfG+0HVXDW9Nzo6f53zZ4wp7ceGMa1cuRJWVlYAgAcPHkAikTQZu7rfvebY29sjJSUFx48fR0pKCpKTk7FlyxbB7fB47q9uaCKRSGBubq74NDh//jw8PT2b/OEDj//Y5XI5AMDExAR//vknAODo0aMa+zh06JCiXN++fWFkZITU1FSEhIRg+PDhOHXqFIDHwUAsFjfp28LCAmVlZfjjjz8AACkpKejVq5fiD6Y5qampmDBhAtzc3CCVSpGTk4P6+nq15S9fvgxzc3P4+fnh9ddfx88//6yY95O8vLwQFxeHoUOHolOnTigoKIC7uztee+01zJkzB/b29or3CECzr2sybNgwfPfddwAeH1mNHj0aDx8+RGpqKvz9/TFy5Ejk5eWhuLhY49yEvDcXL15EYWEh6uvrceDAAdjb26scE2MMJSUlGDt2LHJzc7n7bs7nn3+Oo0ePwtPTE4sXL0ZWVpbW2m6MjiSasXLlSixZsgTffvstxGIx1q5dq3SCqoGjoyMWL16MZcuWISAgABEREUhOTsb777+vsf2bN2/Cw8MDRkZGWLZsGQBg9uzZ8PLygkQiwYABA2BmZob8/HxYW1tj/fr1+PLLL2FpaQkA0NfXx5o1axATE4OamhpIJBKsWbOGe35+fn5YtGgRNm7ciC5dusDGxgb5+fno06ePyvL29vbYtWsXnJ2dYWBggLffflvl5WJra2vo6OgoDpXNzMzg6uqKsWPHwsjICKamphg7diyOHDmi8fWioiIEBgZq/LRduHAhPv/8c7i5uYExhpiYGHTv3h0zZszAvHnz0KlTJ/zjH/+AlZUV8vPztfLevPrqq1i0aBGKioowdOhQjBs3TqlucHAwli5dCjc3N8hkMgQFBTX5WsYjNjYWCQkJStsiIyPh7++P8PBw7Ny5E2KxGEuWLBHcNi8RY7QyFdG+GzduIDg4GIcOHVL7VYfXZ599pgig5O9HRxJE67Zs2YJvvvkGq1ateuoAUV1d/VznUbQHLTqSkMlk0NWl+ELI80DQicvr169j3LhxGD58OIqKiuDp6Ynr16+31tgIIe2AoCARGRmJ+fPno2vXrujZsyc++ugjpRwCQkjHIyhIlJeXK6V+uru7K5JTCCEdk6Ag0alTJ9y5c0dxMiotLU0pCYUQ0vEIOnGZlZWFBQsWIC8vD//85z9RXl6O//znP4qUYEJIxyP46oZMJsONGzcgl8thYWGhMrGoLZVM+YirXNf3jLjKsfLapxhNUyJjvveLPWqa1alWPd8urK/ia1P+sGkGpSoizuPQO1nGXOWul3bhaxBAVrn6lPgnZT/kuwQr5XwPb1bzfb0+/WgfV7nqR7e4yrUGxuq4ynFdx/zss880vt7SRBepVIrQ0FCUlpbC09MT48ePV3p99uzZCA8Ph5mZWYvaJ4Q8Pa7PAltbW9ja2qKqqgr379/HsGHD8M4776C8vBxPk7CZkpKCd999Fzt37sThw4dRW/v4U1sqlWLWrFm4dOlSi9smhGgH15HE2LFjAQA7d+7E3r17FScuR40a1eTTX4jMzEyMGzcOOjo66NevH3Jzc/Haa69BKpXio48+QmJiYovbJoRoh+BLoKWlpYr/37t3D1VVVS3uvKqqCkZGj88NGBoaKu7VNzY2xttvv93idgkh2iMotzooKAju7u6wsbFBfX09/vjjDyxatKjFnRsZGSnyLGpqamBszHeCq4FUKoVUKm1x/4SQ5gkKEmPGjIGdnR0uXrwIkUiEpUuXonv37i3ufNCgQfjtt99gaWmJ7OxszJ07V1D9+Ph4xMXFKW1LfetfLR4PIaQpQUGi8R9kwzoCwcHBLerc1dUV8+bNQ1JSEjw9PbF9+3Y4Ojri1Vdf5ao/Y8YMTJs2TWmbdEbLxkIIUa3Ft3LW1dXh5MmTT5VIZWBggPXr16t9ffny5Rrr6+vrN8nTKGnxaAghqggKEo2PGD755BP4+/trdUCEkPblqda4rKqqUqwcTQjpmAQdSTg7OytyJBhjKCsrQ0BAQKsMrKUklnzJXSKzbnzlKjjvcpXxpTLDkO+GOFGdgLTsR3zptaIX+K4EiXS1m4reuSvfe9ipshN3m/qcH2+d9JovAwBiGV/6tpEO35+MgW7TlbFVUf+whPZDUJDYvn274t8ikQidO3cWfNmSEPJsEfR1IyYmBr1790bv3r1hamoKY2Nj+Pj4cNWVSqUICgrCxIkTsXfvXo3bf/vtN3h5eWHKlCkoLCwUMkRCiJZxBYng4GC4uLjgf//7n+IBry4uLnBwcOB+joG6+zRUbf/222/x1VdfYe7cuYrnKRBC2gbX143ly5ejrKwMUVFRWLx48V+VdXUVT7Nqjrr7NFRtHzRoECoqKpTStgkhbYMrSNy8eRODBg3C9OnTcfv2baXXbt26xXWfhbr7NFRtb1g/EwB27Nihtk1Vadm0ThYh2sUVJHbt2oXo6GjFA1mfJBKJsG3btmbbUHefhqrtW7duxeHDh1FeXo6lS5ciPj5eZZuq0rIvTxrGMyVCCCeuIBEdHQ1A+eqGUOru01C1vUuXLtDX10fnzp01PlFbVVo2Vs1q8RgJIU1xBQlfX1+NT2LiOZJQd59G4+0GBgaYNm0apkyZAh0dHcyZM0dtm6rSsvkyBgghvLjWuExLS9P4Os/j4f8udUv8uMrpvv1PvgbbKJkKrZBMxar4kqnqH2o3mao8ky/B7WrBi9xtXirlS1a6VsmXJMW7pOgtzvfwf9JDXOVKKzP5Om4FWl3j8skgkJmZibS0NOjo6MDOzk7xdGtCSMckKJlq06ZNmDNnDoqKilBYWIhPPvkE+/bxrQpMCHk2CUrL3r9/P5KTkxVXJmbNmoWJEyfC29u7VQbXEuJ/cOb/d+Ncvl1Pyw9G1hXzlasVsOKW+BFXMdELfDcy6OhyfnZI+b5iGb7Id4eC0V3+r1gviPmS+DpxPthaLOL7SmQk5tt/Yp2OczFe0JFEt27dlJ4mbmhoyH3vhpC07PT0dEyePBleXl44ffq0kCESQrRMUJAwNzfHhAkTsGnTJnz77bfw8/ODRCJBXFxck3yFxoSkZa9duxZxcXH45ptvmiRvEUL+XoKOpc3MzGBmZqbIdLS3t+euy5uWffnyZchkMnzxxRcoKSnB0qVLBU+KEKI9T7UylRC8admMMWRnZ2P16tV48OABVq1ahbVr16psU1VaNt3pQYh2CQoSW7duxVdffYWKigoAjxeeEYlEyM7ObraukLRsCwsLdO/eHd27d0dZWZnaNlWlZWfPcxEyJUJIMwSdk0hISEBSUhKys7ORnZ2NnJwcrgAB/JV+3XCk8Morr6jdLpfLUVJSgtu3b6NbN/UrSM2YMQO///670g8hRLsEBQlLS0v06NGjRR25urrizJkzGDduHD744ANs374d165da7LdwMAAc+fORWBgIEJDQxEYGKi2TX19fRgbGyv9EEK0S9DXDR8fH7i5uWHw4MEQP3G9mOep4pqWz2+83c7ODnZ2dkKGRghpJYKCRExMDNzc3NC7d+/WGg8hpJ0RFCS6dOnyVFc4CCHPHkFB4s0330RMTAycnJygp/dXim97egK4qEdnvoJd+O4i5E6jfsSZRq3Puca7zlM9EkU1KV/as4jzrlIYcqZ56/OlZRvq8d/obyjmS6M21uUrB/DdLSrWsGTCk3RF2k/LFnGOkYF3znwEBYmsrCwAQE5OjmIb78pUUqkUoaGhKC0thaenJ8aPH69xu0wmg6urK44cOSJkiIQQLWvxczeEaki/njhxIgICAuDh4QEDAwO12/fv34979+61uD9CiHYIChIXLlzA5s2bUV1dDcYY6uvrUVhYiF9//bXZukJWy7awsMC5c+dgZWXV4okRQrRD0BffxYsXY8SIEZDJZJgyZQp69eqF999/n6uukNWyd+zYgUmTJjXbplQqRWVlpdIPIUS7BAUJPT09jBkzBra2tujcuTO+/PJLnDp1iquukLTsrKwsriXx4uPjMWTIEKUfQoh2CQoSBgYGKCsrg7m5OdLT0yEWixV/4M3hTcu+ffs2bty4AV9fX2RnZ2PBggVq26S0bEJan6Ag4efnhzlz5sDZ2RlJSUlwdXWFjY0NV13etGxnZ2ckJydj+/btGDhwIGJiYtS2SWnZhLQ+rtWyn9Rw52d1dTXy8vIwYMAA6LTGNf0WYskRfAX7m/OVq+J8OLy28yRq+JakE1SWM08CpVV85XT4rttLM0q4yuVncOa4ALh4vztXuVvVfOfmK2V8c7n2kG/Jvp9rf+Eqd/dhKlc5QPt5ElpdLftJDc/fMDIywmuvvSa0OiHkGdN+DgEIIe2SlpeCbge6cK6WzbnqMfQ4vx7Uc35r62TI2R7fatCCiDi/Ehlp9zloIl2+w2QhDMV8h/3GnGn1VZzfxPTFfHMRibT/+avtdGtegmZSXV2NFStWYNasWSgvL8fy5cu5r24IWS37+++/x/jx4zFx4kTuRW0IIa1DUJD4/PPP0blzZ9y6dQv6+vqoqqpCRATfiUIhq2Xv3LkTu3btwqpVq7Bx40bhsyKEaI2gIHHt2jXMnDkTYrEYL7zwAqKionD9+nWuupmZmRgyZIhS+rW67Rs2bIBYLIZMJlO625QQ8vcTdE5CR0cHUqlUcYXj/v373HWFpGV369YNUqkUS5cuRVhYmNo2Va2WzXlGghDCSVCQ8Pf3x7Rp03Dv3j1ER0fj559/xuzZs7nqCknLrqurQ0hICNzc3DTe5KVqteycDT5CpkQIaYagIPHhhx9i4MCBSE1NhVwuR3x8PPr3789VtyH92tLSEtnZ2Zg7d67a7cuXL8d7770HT09PjW3OmDED06ZNU96YtkbIlAghzRB0TqKkpAR37tzBlClTUF1djfXr1+PatWtcdXnTsisqKrBv3z78+OOP8PX1RVRUlNo2KS2bkNYn6Eji008/hZ2dHXR0dHDkyBFMnToVS5YswY4dO5qtK2S17PT0dCHDIoS0IkFHEsXFxZg+fTqOHTsGd3d3jBkzRnECkhDSMQleT+LKlSs4duwYHB0dcePGDdTVaTc7jxDSvgj6ujFr1ixERERg/PjxMDc3h5ubm8ZLlG2CM6+C6fJNXaTtuzZb445ZPc7dyJvqzZuyzt0eXzEhdEV8Kcq82fJ19byrZfM1qCfS/qOreVO9GdNuSr+gIOHi4gIXl78eyHvw4MGn6lzoCtqEkL8fV2iaMWMGAMDZ2VkRKJ78aSkhqdqEkLbBdSTRcBnyaZbUV0XICtq0dgUhbYMrSDQ8SdzU1BS7du3C2bNnUVdXh6FDh8LHp+UZjkJStQkhbUPQOYkVK1bg5s2b8PLyAmMMSUlJKCgowMKFC1vUuZBUbVXo3g1CWp+gU+2nT59GXFwcXFxcMHz4cMTGxuL06dMt7px3Be2G7Y3RkvqEtD5BQUIul0Mm+2sJH5lMBjHv5TIVeFO1DQxUP3yVltQnpPUJ+rrh4eEBHx8fjB49GiKRCAcPHsTo0aNb3LmQVG1V9PX1oa+vr7StbRb4IqTjEhQkAgMDFXeB1tfX45NPPoGjo2MrDY0Q0h5wBYnbt28r/t23b1/07dtX6TVTU1Ptj4wQ0j4wDk5OTszZ2Zk5OTkpfpydnZmVlRUbMGAATxNtpra2lsXGxrLa2tpnutyzMEaa89OXa+u+VeEKEo1VVFSwhQsXMicnJ3bq1KkWdfx3qaioYP369WMVFRXPdLlnYYw056cv19Z9qyL4bqPTp0/D3d0dIpEIP/zwA+zt7VvjAIcQ0k5wn7isrKzEsmXLkJqaiqioKAoOhDwnuI4kTp48CXd3d4jFYhw8eJACBCHPEa4jiY8//hi6uro4deqUUoYl+/8njB87dqzVBvi09PX1ERwc3CSf4lkr9yyMkeb89OXaum9VRIyxZvOPCgsLNb7eu3fvFnVOCGn/uIIEIeT51QprqRFCOhIKEoQQjShIEEI0eq6DRHV1dZNFa4qLizXWycvLw82bN7naz8/P1/h6aWkpMjIy8PDhQ5WvN4ytrKwMly9fRmVlpcpy2dnZXONp7M6dO4qnuzdWW1uLrKwsXLx4Ebm5uajXsDK2VCpFTk4Ofv/9d/z5559N3lOh2vt+Afj2TVvuF23ukw534vLJm9Eae/JGtM2bN+P06dOor6/HoEGDEBYWBl1dXUydOhXbtm1TlPvvf/+LPXv2gDEGKysrFBQUQF9fH+bm5koPSz5w4IBSX4wxbN68GYGBgRgzZoxi+7Rp07BlyxYkJSUhOTkZ1tbWyMzMxMiRI5VWBY+KioKFhQU6deqEAwcO4PXXX0dOTg7efvttBAYGKvX11ltvwdHRESEhIXj55ZfVzj8pKQn79u2Dnp4e3nnnHVy6dAnGxsYwMTFBRESEUrlffvkFr7zyCi5cuICXX34ZlZWV+PDDDzFy5EilNvft24fDhw+jX79+MDQ0RE1NDbKzszFixAhMmjRJUa6j7BeAf9+01X7h3SfcWpTM3Y7Nnj2bOTg4sIiIiCY/T5o4caLi37t372aBgYGstraW+fj4KJWbMGECY+xx/ruTk5Ni+5QpU5TKrVu3jr377rts9erVLDk5mSUlJbGRI0ey5ORkpXK+vr6MMcZ8fHxYXV0dY4wxuVzOJk2apFRu8uTJivJyuVyx3dvbu8mcfXx82NWrV9m8efPY9OnT2Z49e9i1a9dYZWWlyrnU1tYyFxcXtXN5cixyuZwFBQWx+vp6RX11ZZ80btw4pf93lP3CGP++aav9wrtPeAlaT+JZEBsbi+DgYHh5eeGtt95SW05PTw/nz5/HkCFDMGHCBBgaGmLmzJlNDjEbPi08PDywe/duAMCZM2eg1+ghQMHBwZgyZQq+/vprXLp0CQEBAUhOTlb6tAIeHxZv2LABMpkMGRkZePPNN3HmzBnoNnpYUM+ePfHDDz/A1tYWiYmJGDp0KNLS0tCtWzeV83n11Vfx73//G9XV1Thx4gS+++47FBYWIj4+XlHG2NgY8fHxkEqlEIvF+Pnnn2FkZNQkycbIyAiJiYmwsbHB2bNnYWhoiMzMTJX9du3aFbt378aQIUMUa5OmpaWhZ8+eSuU6yn4BhO2bttgvvPuEV4f7ugEAdXV1kEql6NRJ/bK4d+/eRUJCAoKCgtC5c2cAwKVLl7BhwwZs3LhRUa6qqgqHDh3ChAkTFNt27NiBkSNH4sUXX1Tbdnx8PDIyMrB//36l1yorK5Gbm4urV6/CzMwMgwcPRmxsLKZNm6ZYlRx4/N1z7969OH/+PMrLyyGRSGBlZYVJkyZBIpEotbl161Z89NFHzb4vtbW1OHnyJCwsLKCnp4cdO3bAyMgIPj4+6N69u9IY9+7di1u3bqFv377w9vZGeno6Xn75ZfTq1UupTalUih9//BGZmZmKVc4HDRoEV1fXJr/k7WW/XL58Gfv27VN6jXe/NLyPPPumrfaLkH3CpUXHH8+ga9euPVflnte+ac5PX66xDnl1Q9WZ3RMnTnTYcm3dtyrHjx/XajnevrXdL2+5tuxbVTlVV4gOHz7M1V5jHe6chKYzux2xXFv3re6qxahRo7RarvFVlb+r38bl2rJv3nLqrhCdO3cOs2bNUtmGJh0uSCQnJ2Pnzp1Ntnt7eytd/uko5dq67+XLlyM9PR3/+te/mpRftmxZhyv3LIzxl19+wa5duwAAe/bswaxZs7Bu3TqwFp5+7HBBgvfMbkcp19Z981616CjlnoUx8l4h4tXhrm7wntntKOXaum+A76pFRyrX3sfIe4WIV4cLEoQQ7eqQVzcIIdpDQYIQohEFCUKIRhQkCCEaUZDoQAoKCmBlZQUPDw+MGTMG7u7umDx5Mq5cuQIAiIiIgKOjIzw8PJR+ysvLldrJyMjAggULNPa1a9cuxbX4zz77TLFY8scff4yioiK19c6dOwdfX18AwIIFC5CRkSF4nunp6Vi5ciUA4NixY1i7dq3gNogALUrmJu1Sfn6+0m3TjDG2c+dONn78eMYYY+Hh4SwxMVHr/To5ObH8/HyusmfPnm1y27dQiYmJLDw8/KnaIPw6XDIVUWZra4sVK1YIqnPu3DnExcVh+/bt8PX1xeDBg3H+/HkUFxcjODgYnp6eWLduHQBAV1cXxcXFCAwMxPbt2+Ht7Y1t27aha9eumD9/PoqKinDv3j0MGzYMMTExSv34+voiODgYf/75JxITEwE8ztG4fv06fvrpJ0ilUkRFRaGmpgYlJSUICAjAqFGjEBsbi+rqasTFxcHU1BRpaWmKbMTo6Gg8evQIJiYmiIyMRJ8+fdTOgfChINGB1dfX48CBA7CxsVFsi42NRUJCguL//fv3bzaI1NTUYPfu3cjKysK0adOU/sCCgoKwb98+bNq0Sem25uPHj2PAgAGIjY2FTCbDhx9+qHZNiqlTp2Lq1KkAgLCwMIwYMQLm5uaIiYnBzJkzYW9vj8LCQowePRqTJ09GSEgI0tLSEBwcjKSkJACPk4zmzJmD1atXw9raGj/++CPmzZunCD6a5kA0oyDRwRQXF8PDwwPA409lS0tLREZGKl4PCQkR/Afi4OAAABg4cCDKysq46owePRoXLlzA1q1bkZubi5KSElRXV2uss2HDBlRVVSE0NBQAEB4ejhMnTiA+Ph45OTka6+fl5UEikcDa2hrA45ueFi9ejIqKihbPgTxGQaKD6dGjB77//nuttmlgYAAAEIlE3HUSEhJw5MgRTJw4EXZ2drh69arGG4yOHj2KQ4cOYc+ePYp+QkND0bVrVzg5OcHV1RUpKSlq68vl8ibjY4xBJpO1eA7kMbq6QZ6aWCyGXC5X2paamooJEybAzc1NsTaFupWdc3JyEBMTg7i4OBgbGyu1ERISguHDh+PUqVMAHgcDsVis+ONvYGFhgbKyMvzxxx8AgJSUFPTq1QsmJiZanOnziY4knjONz0kAQGRkJAYPHtziNl1cXBAYGIhNmzYptvn5+WHRokXYuHEjunTpAhsbG+Tn56NPnz5N6q9YsQIymQxhYWGKYDNv3jzMnj0bXl5ekEgkGDBgAMzMzJCfnw9ra2usX78eX375JSwtLQE8fijumjVrEBMTg5qaGkgkEqxZs6bFcyJ/oRu8CCEa0dcNQohGFCQIIRpRkCCEaERBghCiEQUJQohGFCQIIRpRkCCEaERBghCiEQUJQohGFCQIIRpRkCCEaPR/k7qZ+27/GVoAAAAASUVORK5CYII=",
      "text/plain": [
       "<Figure size 275x150 with 1 Axes>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "plt.figure(figsize=(2.75,1.5))\n",
    "ax = sns.heatmap(noise_performance.T, cmap='magma',square=False,cbar=False,cbar_kws={'fraction':0.046},vmin=0.4,vmax=1)\n",
    "ax.invert_yaxis()\n",
    "# cbar = ax.collections[0].colorbar\n",
    "# Label the colorbar\n",
    "# cbar.set_label('Accuracy', fontsize=8)  # Change 'Label' to your desired label text\n",
    "# Change the size of tick labels\n",
    "# cbar.ax.tick_params(labelsize=6)\n",
    "plt.xticks(np.arange(0.5, len(initializations)+0.5),initializations,rotation=90,fontsize=6);\n",
    "plt.yticks(np.arange(0.5,len(noise_levels)+0.5),noise_levels,fontsize=6,rotation=0)\n",
    "plt.xlabel('PE initialization',fontsize=8)\n",
    "plt.ylabel('Noise amplitude',fontsize=8)\n",
    "plt.title('Perturbation analysis, learnable PEs',fontsize=8)\n",
    "plt.tight_layout()\n",
    "plt.savefig(f'{outputdir}perturbation_learnablePE_{1}head.pdf',transparent=True,dpi=300)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAR8AAACVCAYAAABhC/5HAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjguMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8g+/7EAAAACXBIWXMAAA9hAAAPYQGoP6dpAAAwVklEQVR4nO2deVgT1/rHv+ygoID7guK+UzfArVrBgqIIKopXQQGtRUStK1pxKW7UpValbvjUSt1vESxqf1eoC9aKuKFVURQEAygo+xIISc7vD27mEjIhGSQk4Pk8D8/DZN6ceWcy8z3veecsWoQQAgqFQqlntNXtAIVC+TSh4kOhUNQCFR8KhaIWqPhQKBS1QMWHQqGoBSo+FApFLVDxoVAoaoGKD4VCUQtUfCgUilposOKTnp6O/v37w8XFBS4uLnB2doadnR327dun8Ht2dnYfffyioiIsWrToo8vhyrVr13Ds2DGVHuOrr75CVlbWR5fz/PlzbN++vQ48ql/279+P/fv3QyQSYdGiRSgpKanRvq7uqdpw584deHp6su7r1auX0uVUfZ5cXV0xefJkzJo1C0lJSQCANWvW4IsvvmCeN8lfYWFhrX3XrfU3NYDWrVvjwoULzHZWVhYcHR3h5OSE7t27q/TYBQUFSExMVOkx2Hjy5InKjxEaGlon5WzevBn79++vk7LUgY6ODtzc3BASEoKAgAB1u6Nyqj9Pp0+fxvr163H27FkAwJIlSzB16tQ6O16DFp/qvH//HoQQGBsbAwCOHDmCyMhI6OjoYOTIkVi1ahUAoLy8HN988w1SUlJgYWGBbdu2oXnz5rCzs0NYWBg6duyIO3fuICQkBL/++ivOnTuHEydOQCQSoVu3btixYweCgoKQnZ0NX19fBAYGws/PD/369cPTp09hYGCAH374ARYWFnj27Bm2bduGkpISmJiYYOPGjejWrRtrmaWlpVixYgXy8vIgEomwdOlSjBs3jjm/Fy9e4MyZMwCAtm3bYtKkSdiwYQMSExOhpaWFefPmwdXVVeqa3LlzBwcPHoSJiQmSk5PRpk0b7NmzB6amprhx4wb27NkDkUgECwsLBAUFoWXLlsx1EIlECAgIAJ/Ph46ODgIDAzF48GC551SVuLg4mJubw9zcHABw4sQJ/P777+Dz+dDT08POnTuRmpqK06dP4+jRowCA8PBw3L9/H1u2bMHu3bvx999/QygUYvz48Vi0aBHu3LmDHTt2gBACS0tLBAQEYN26dSgsLMSHDx/g7OyMZcuWoaKiAhs3bsT9+/fRunVraGtrw8/PD7a2tjh27BiioqIgFAphbW2NtWvXQldXF0ePHsXZs2dhbm6OZs2awcrKCgAwevRobNmyBQsXLkSzZs0U3oO5ubnYtGkT0tPTAQD+/v6ws7NDVlYWq6/nz59HREQE8vPzMWrUKOTl5cHExATPnj3D27dvMWPGDPj6+qK4uBjffvstsrKy8P79ewwbNgxbt24FAOTl5WH+/Pl49+4drKyssGnTJujr6zM+8fl8bNmyBYmJiRAKhfD09MT06dMVnouNjQ127Nih0C4tLY31PlEIaaDweDzSr18/MnnyZOLo6EhsbGyIj48P+euvvwghhNy4cYNMmzaNlJaWkoqKCuLr60tOnDhBeDwe6dWrF7l79y4hhJDg4GCydetWQgghY8eOJTwejxBCSFxcHPHw8CCEEGJjY0MKCgoY+6dPnxIej0fGjh3L+NKrVy/y+PFjQgghmzdvJtu3bycCgYC4uLiQ9PR0Qggh9+7dI1OnTpVb5vHjx8m2bdsIIYQ8ffqUBAcHy5z3vn37yL59+wghhOzYsYN89913hBBCcnJyiJ2dHUlMTJSyj4uLIwMHDiQZGRmEEEJ8fX1JWFgY+fDhAxk1ahR58+YNIYSQ0NBQsnjxYqnrsG3bNnL8+HFCCCHXr18noaGhNZ5TVbZs2ULCwsIIIYQUFRURT09PwufzCSGE7N+/nwQFBRGBQEBGjhxJcnJyCCGEeHt7k7i4OHLu3DnmNxEIBGT+/Pnk2rVrJC4ujgwaNIjk5+cTQgg5evQo+e233wghhBQXF5PBgweTnJwcEhYWRhYvXkxEIhHh8Xhk4MCBJC4ujty6dYssXryYCIVCIhaLSWBgIDlx4gR5/PgxcXBwIEVFRaS0tJRMnjyZucaEELJo0SJy5coVmXOUUPVeWLFiBYmOjmZ+k3HjxpEPHz7I9TU8PJzY2dkRgUBACCEkICCALFy4kIhEIpKVlUU+++wzUlBQQKKioshPP/1ECCGkoqKCODg4kH/++YfExcURKysr8vr1ayIWi8nSpUvJL7/8QgghpGfPnoQQQn744Qdy7NgxQgghJSUlxNXVVeY+qXoOhBAiEonIrl27iI+PD+PXmDFjyOTJk5m/VatWEUII632iDA068pGEiWKxGN9//z1evnyJYcOGAQBu376NSZMmwcjICAAwbdo0REZGYsyYMejcuTOGDh0KAHBxccGaNWtqPI69vT1mzJgBOzs7ODo6om/fvkzNJqFFixYYMGAAAKBPnz64d+8eXr9+jdTUVPj5+TF2ubm5EAgErGXq6OjAx8cHmZmZ+PzzzxXmlOLi4rBlyxYAgLm5Oezt7REfH4/evXtL2fXo0QPt27dnfCsoKMA///yDAQMGwMLCAgDg7u6OI0eOSH1v9OjRWLFiBR4/fozRo0fD09OzxnOqWtumpaXB1tYWAGBsbIxdu3bh4sWLSE1NxV9//YU+ffpAT08Pjo6O+M9//gMHBwekpaXBxsYGJ0+exLNnz3Dnzh0AlTV3UlISPvvsM3Tt2hXNmzcHAPj4+ODvv//G0aNH8fLlSwgEAvD5fNy6dQvTp0+HtrY2OnbsiOHDhwMAbt68iUePHjFNh/Lycujq6qKsrAxffPEFEzE7ODhALBYz59K+fXukpaXV+FtIuHnzJpKSkpjmplAoREpKilxfAaB///7Q09Njyvj888+hra2N1q1bw8zMDEVFRZg0aRIePHiAX375BcnJycjNzUVpaSkAYMiQIbC0tAQAODs74/z585g7d66UT3w+HxEREQCA4uJiJCUlydwn2dnZcHFxAQAIBAL06NEDQUFBzH55zS62+0QZGrT4SNDW1saqVavg6uqK0NBQ+Pr6QiwWQ0tLS8pOKBQy9hIIIdDR0ZHarmoLANu2bcOzZ88QGxuLVatWwd/fH0OGDJEq28DAgPlfS0sLhBCIxWJYWFgw7WhCCLKysqCvr89apouLC/7zn/8gNjYW165dw88//4zLly9L+VuV6udICJHyuybfRCKRzHcrKiqkvjdy5EhcvnwZ169fx+XLlxEREYGAgAC551QVbW1t5oHKzMyEh4cH5s6dizFjxqBVq1ZMvszFxQU7d+6ESCSCk5MTtLS0IBKJsGrVKjg6OgKobFYYGhri8ePHTGUCANu3b0dGRgacnZ3x5Zdf4vbt2zK/p+ScAUAkEmHu3Lnw8fEBUPnSQEtLC+fOnWN+dwDQ1dWFQCBgtnV0dGTKlIdYLMbx48dhZmYGoPKBNjc3l+srAKlzAth/r+PHj+PKlSuYOXMmRowYgZcvXzLfr+l+lvi0c+dO9O/fHwCQk5MDExMTGd+r53yUhe0+UealSIN921UdXV1drF69GocPH0ZWVhaGDRuGqKgo8Pl8CIVChIeHw9raGgCQmprKJG7Dw8MxcuRIAICpqSlevHgBAIiOjgZQWeuOGzcObdu2ha+vL1xcXJCYmAhdXV3WB70qXbt2RUFBAVODR0VFwdfXV26ZR44cQWhoKJycnLBp0ybk5uaiuLhYqkwdHR3muMOGDcO5c+cAVEYfMTExTESniM8++wyPHj0Cj8cDAJw9exY2NjZSNhs3bkR0dDSmTp2KDRs24NmzZ3LPqTqdO3dmosMnT57A0tISc+fOxYABAxATEwORSAQAsLKyQl5eHv79738zta7kvCoqKsDn8+Hl5YVbt27JHOP27dvw8fHB+PHjkZqaiuzsbIjFYowYMQKXLl1ihDE+Ph5aWloYNmwYIiMjUVxcDJFIhGXLliE8PBzDhw/H1atXUVhYiPLycua3l5CZmYlOnTopdV2HDRuGkydPAqi8zyZNmoSCggK5virL7du34e7uDmdnZwgEAjx//pz5/sOHD5GRkQGxWIzIyEjmfq7uEyEEubm5mDJlCpKTk5U+tiLY7hNlaBSRj4TRo0dj0KBB+PHHH7F9+3YkJibCzc0NQqEQI0aMgKenJ7Kzs2FhYYHDhw8jNTUVPXr0wLJlywAAS5cuxebNm3HgwAGMHj0aQGWt5OvrCw8PDxgaGqJ58+YIDg5Gy5Yt0aFDB8yaNUtuUk5fXx979+7Ftm3bUFZWhiZNmmDXrl1yy9TT08PKlSvh7OwMXV1dLF68WCbJaWtri1WrVsHMzAyLFi3Cpk2bMGnSJIhEInz11VdMolQRLVu2RFBQEPz9/SEUCtGuXTumCSfBx8cHAQEBOHXqFHR0dJhEJts5VcfOzg4nT57E7NmzMXLkSJw+fRp2dnYwMDCAtbU18woXACZOnIgrV64wbyhnzpyJtLQ0uLq6QigUYuLEiRg3bhwjeBK+/vprLF++HE2bNkW7du3Qv39/8Hg8zJgxA8+fP4ezszNatWqFdu3awdDQEDY2Nnjx4gXc3d0hEolgY2OD2bNnQ1dXF97e3nBzc0Pz5s2ZJipQGS09ffoUwcHByMrKwoIFC2qMDgIDA7Fx40Y4OzuDEIKtW7eiRYsWcn1Vlrlz52L9+vU4dOgQmjdvjsGDB4PH46Fz587o3r071q9fj6ysLNja2sLNzU3qu/7+/vjuu+/g7OwMoVCIhQsXok+fPkofW8K+fftw/Phxqc+CgoJY7xOlUCozRKHUgn/961/kw4cP9X7c69evk6tXrxJCCMnPzydjx44lubm5tSorOjqafP/998z2mjVr6sRHCiGNptlF0Ty+/fZbHDx4sN6P27VrVxw5cgQuLi7w8PCAv78/k4Phgkgkwm+//cYk10tLS2Fvb1/X7n6yaBFC53CmUCj1D418KBSKWqDiQ6FQ1AIVHwqFohbqXXxiYmKwdu1amc/37NmD6dOnw9PTE2/evFG6PEIIiouLQVNXFErDol7FZ9euXdi1a5eMUDx58gSJiYn497//jZUrV2Lnzp1Kl1lSUoIhQ4YonPaAQqFoFvXayXDAgAEYNWoUIiMjpT5/8OABRowYAaCy521NPSQFAoFU1/fqPYApFErDoF7Fx9HRUaaXKlApIG3atGG2a2pCHT58GCEhISrxj0Kh1B8aMbzC2NhYqtkkbyAlUNml3tvbm9kuLi7GmDFjanXc1NRU5OfnS31mamrKjBCm1Ez160evHYULGiE+AwcOxIEDBzBnzhwkJCTUOAuhvr6+zAjq2pCTk4PBgwfLDO7T0dFBUlISWrRoIfOdxv6wcRFjtutX07Xjys8//4z9+/dj8eLFzCh0SuNCreITHByMKVOmwMrKCr169cKMGTOgpaX1UfP+KvMA8Xg85OTk4MSJEyguLkZaWhq2bt2KdevWoV+/fuDxeCgtLWXmugG4P2xcHmRVRmDKCiZXMW7RogUePHiA+Ph4LFiwAEeOHIGNjU2dCA9QOY/y69evsX//frWJT2OvbNRNgx9eUVxcjCFDhuD+/fsoLy9Hjx49anyAeDwehgwZCoGgvMZy9fUNcP/+PQCVDyZQObXC06dPpYRKMgK6RYsWsLCwAI/HQ3JyMqZOnSrjh7a2Ns6fP49u3bpxtq0KF0Gpfj0UCWZVMenZs6fCBy4hIQFffPEFrl+/joEDB8q14/ogqzvy4XrtKNzRiGZXXZCeno6ysjKcOHFCSiA6d+4MY2NjJprJyclRKDwAIBCU49mzZ5jjOQflLPaS+XMlGOgbIOpiFCY7T0ZZeRlrmWKxGK6urjA0MMTvUb8rbXv33l0AlQ9Efn6+jFixCVXV6K66YPJ4PPB4PCnBlAhsdfLz85GQkMDY1gZVN9FUgaojO0ojEp8xY8agvFxaJGQEwsAAUVFR0Nc3UCryAcAqPGyUC8qRmpqKsvIytDNqDwMd9rxUuUiAt/xMxra9YRvoa7PbCsQCZJZl1SiCwP+EykDfAPf+G61ZD7WWEbbq16MmEVywYIGMbVURlCCZl6fq/DzVhao2D7KyzS6uTVYuEZilpSVj27NnT9rkqmMajfgIBAJowwhaYJ/ukkCE8nI+9PX1cf/+PaWaO5KHzNygFXS19NiKBQAISQVyy98z22/5mQr9bd68OXR0dJBZVvP6WJIpMZURwXJBOeNzWXkZrJt1h4lOE1bbIlEp7ha+YkRwdLOeMNVpymqbLypBbGESnj17Bq85c8FnidaqipWRgSHi792tMaqSRFSArFilpqbC1dUVZ8+ehaurKxISElhFgmueqiFGYI2ZRiM+ACAGXyk7CwsLWFhY4MGDBwprTUMDQylhkYehgSEzsbkykU+bNm0QHR2NV69eMfuqJr47d+4MAOjevTvT9aC/cVc01TViLbdEyMeT4hSpz+4WvmK1ZSO2MEmhTUFBAfjlZfiX2SC01jVmtckWFuN03kNGcGysrcEvqzmqMjI0RPzdSrGqLhB79uzBnj17WEVCElHl5+cjKSlJYVRFm1KaRaMSH0WRT3VxUhRGW1hY4O69uwpzLaampswNbGhgqDDyMTQwZGr7wYMHyzQFOnfuDBsbG8Y/Ho8HQwNDGXGRV67kwVcm8pEw2WwQWuqxC8qHimL8nveQ2W6ta4yO+qY1+gJURhr8sjLMNx+IdnLKfltRjKO5CcjJyYGFhQWrQEgS32wiUf03VNQ8aqhNKXUn4FVBoxEffX19lJfXHPkYGBhwruUkURIAmUiJrSkgEavqb8Ykie/27dtLNTPYmgILFiyQqumriqCEqjV9z549AUg3XwwNDBVGPlWjtZZ6xmir31ypa5JdIX9IS037lKW6QNT0Fk1TUHWHVU3oelDXNBrxuXHjBsr+G96zPZiAbG5BgrK1iqWlpUJbCwsLNGnSBHZ2doygSBK9ipoOs2bNQmZmJtq3b4/Lly9L2VUVwaqwPZxVxUpR9wCJoH0QFsk97+r7Tuc/lGPJztHcBE72DY366LC6ePFi5r5rLHASHx6PV+vXraqmY8eOzKJvEpStNbnUKsrYVhWUqihqOqxcuZK5wT62xpQngmwPhKGBIX5XIBBVo6R/mQ5CaznNqOyKYhlxUqbZBUAqOc32Fg2QX4GoE665p9okvn18fBpNxCOBk/j4+fnB1NQUU6dOxfjx42UWO2uocKlVhgwZ8t+OikNqtKuNeCh7g6Wmpko9nPJqTTYRrC6A1Zt0NUWNTLNPei1GaWraVwNZWVmYMH48p+S0JsEl90QT35VwEp+oqCg8ffoUERER+Omnn2Bra4spU6YovVCdpsKlVrl//z6EQiHu37+vYq/YqV5rVs8PVUcZEWRr0smLGo0MDHE6r+Zml9F/E99ApVgoanYZGRoCAPhlZfhXy+5orcdeqWVX8HH6wytGBLn2N9IkNCnxra5kNuecT79+/WBpaYk+ffogJCQEcXFxaNKkCTZs2MCsCNrQ4HLx1d32ViaaqU5d3VwWFhaI55j4jr/LLao6/UFx9wB5URJQrb+RhkZJmgaXtENdChUn8YmNjcWFCxdw584d2NnZ4ccff8Rnn32G169fY+7cuYiNjf0oZ+qC6k0SQHEyj8vF14S2N9dasi7flHBJfMuzZ7OViI8ykU9BQQH4ZWXw7tgNbQ3Zbd+V8XEsPZl5ha8qqnei1IQIjKtAcKlQ6/Je4iQ+oaGhcHNzw9atW2H431AZALp06aL2BxJgb5IAipN56o5mVE1DOL8WLVrAyNBQYeRjZPi/xPex9Lpbb7w28Hg81k6UgHojMK4CwaVCrct7iZP4hISE4Ny5czA0NERmZiZ+/fVX+Pv7o2nTpvDy8vpoZz4Wrm+ZJGhCNKNKapvIBupvGgkLCwvOTbS13fujkxH7kJA3/BJsf/VEpT5LOlFunWCLrubNWG1Scgux7o87Ko/AqqLKyqYunxVO4rNq1SoMHjwYQOXYJBMTEwQEBGjUtKYNpceqplHbqFFZlHlDJ2mipaamyny/qr1EfDoZNUUPY/aHvj7pat4MfdpwX45ZVTSUypST+Lx79w6+vr4AgKZNm8LPzw+urq6q8ItSz9QmalT2lT+XN3SqFkFlULa/UX3yyQ+v0NXVxYsXL9CrVy8AQHJyMvT05I/2/tRpaDcMl6iRi6BweUNX26ZzTXBJCsvL47D1N/rl+PFa+VMbPvnhFWvXrsX8+fPRqlUrAJXTInBZY+tTozHeMBK4vvLnImx12XTm+lpeksfZ/Pln6NKcvVf264JirL/5CAUFBXXmpyIawksDrnASH2tra1y9ehVJSUnQ0dFB165d62Qy98ZKY7xhqtIQ8muS1/IbB/aDpQl7cjq1qATfJTyVio66NDdG7xbKDbStDxpKHocLnMQnJSUFp06dQmlpKQghEIvF4PF4OHXqlKr8a9A0xhumwaKCISGUj4Pz266xY8fi/v37mDJlCmJiYqReg1Iomsp3D5+q2wVKNTiJT0VFBfz9/VFeXo6+fftixowZmDZtmqp8o1DqjI2D+sHSWE6zq7ikXsRJE3tDqxNO4mNkZASBQABLS0skJiZi6NChzBw6FEp984ZfovQ+S+Om6NVcfX2CsrKyMGHCePD5ChLfRoaIj/80xqNxEh9nZ2csWLAAO3fuhLu7O65fv87MNUyh1BeSoRiKejBXHYqhSl7nFircV1BQAD6/DMEz7NC1lSmrbcr7fKw5d7Vee0OrE07iY2VlBVdXVxgbG+PkyZN48uQJRo4cqSrfKBRWqg/FAOSPrpe3Hlld8u0fd5S27drKFH07tFKhN6pFbaPa16xZg8uXLwMA2rVrh3bt2n3UwSmU2qLs6Pr6EJ9tE2zRRc7Yrte5hZzESdNR26j2rl27Yu/evRg0aJDULIYNdR4fCkUerwvkT4RffV8XFY/t0qQ149U2qr2goAD37t3DvXv3mM+0tLQQFhb20Y5QKJrE+puP1O0CAM1b6FBto9p//fXXOjkohaLpKDO8oj5Q5XzPql7uRxGcxMfT0xNaWrLdQWnkQ9F0Uovkv5Zn26dJwytUMd9zbZb7qWs4iU/Vdp5QKER0dDTMzDRnHhMKpTqS1/LfJdTcidDI0LDe3o5pAlyX+1EFnMTHxsZGanvEiBFwc3PDkiVL6tQpCqWu4PJaXrJW/KcC16Wm6xpO4pOZ+b81yAkhePHiRb1OK6AsDW0eHYpq4TrpvSpJeZ9Xq33Kou48Dhc4iY+Hhwfzv5aWFszMzBAYGFjnTn0sjXkeHUr9wOVVOxfWnLtW6+8qQhPyOFzgJD5Xr14Fn89nxniVlJQonfMRi8VYt24dXr9+DWNjY+zYsQPm5ubM/nnz5qG8vBxaWlro1KkTs755bWjs8+h8CqgrepXkiBS9zart0I3gGWPRtRX7M5PyPu+jxIntzVjPnj0/ahZIVcJJfC5cuICjR48iKioKb9++hZeXF9auXQsHBweF342OjoaBgQHOnDmDy5cv48iRI1izZg2zPycnB5GRkZxPgA06j07DR13Ra21W0eBC11ZmKh1eUf3NWH03K7nASXyOHj3K9PXp3Lkzzp8/Dy8vL6XE58GDBxg1ahQA4PPPP8eRI0eYfRkZGSgsLMS8efMgEAiwatUqWFlZsZYjEAggEAiY7eLi2ofAFM1FndEr14UOVYWyE9k31EGonOfzMTU1ZbbNzMxACFHqu8XFxTA2ruy01bRpU5SU/K9vhVgshpeXF2bPng0ej4evv/4a//d//8fap+jw4cMatVQPRTVwWWtM8roYqHkVDYm9srbqhMfjwcbGWmYKDpmJ7BvwFBycxGfYsGFYvnw5nJ2doaWlhUuXLjHreCnC2NiYEZySkhKYmJgw+9q2bYvp06dDR0cHlpaWMDExQV5enlROSMLXX38Nb29vZru4uBhjxozhchqURgJbgpXrsjyamozNyckBn1+GHd7O6NquJatNytsPWH0sqsFOwcFJfNatW4dTp07h7Nmz0NXVha2tLWbOnKnUdwcOHIhbt27B3t4esbGxGDRoELPvxo0biIqKwt69e/Hu3Tvw+Xy5iWx9fX06aT0FAPdldriuuKEJdG3XEv06tVW3GyqBk/jw+XwIBAIcOnSIWS5ZIBAotXaXg4MDYmNjMXPmTOjp6WHPnj0IDg7GlClTYGdnx+zT0tLCli1bWJtcFEp1uDaZNLGJVd9oynSunMRn5cqVtV4uWUdHB9u3b5f6rOrbrqCgIC6uUCiUWsDj8WBrY4NSPl9mX9V8UhMjI9yJj1epANHlkimUT4icnByU8vn4IcAX3Tt1YLV59SYDy78/pPJc0kctl/zq1Su6XDLlkyelhjmcq+9LeZ8v37aGfXVN904d0L+HZb0dj41aL5espaWFvLw87NixQ1W+USgajaQ39DoF06QaGRrC0tISRkaGWHPuas22RtKj61PefZBrW3Wfsn2CPoa6nlGxVsslJyQkIDExEREREfDz80N8fHytHaBQGipcR8zHx3MfXb/65yiFfmRlZcFpwgSZPE71PkFNjIxw7JdfOJ2jRNTy8/MxdepUqW4N2traOH/+PLp161ar5hkn8Xnx4gXOnDmD33//HeXl5Vi5ciWmT5/O+aAUSmOBy4j52oyu3+HjjK5t5fTzefcBq3+OQkFBAUr5fOzyn4VuHdqw2iZnZGFlyClOs1DIEzUJYrEYrq6utU5OKyU+kZGROHPmDNLT0+Ho6IjQ0FCsXLkSXl5enA5GoVC40bWt8v18unVog35dO9bZsSWitjcoAN27dGK1efX6DZZu+L5WyWmlxGft2rUYP348tm7dim7dugEA7YdDodQDKW9ryPnUsE8RyW8yld9X06P+ETKglPj88ccfCA8Ph7e3N8zMzDBx4kSIRKLaH5VC+cRRNMasRYsWMDIyxOpjNed8jIxqN7XHsu8PKm27dP33nMtXBqXEx9LSEitWrMCyZctw48YNnD9/Hjk5OZg3bx5mzZoFe3t7lThHoagTLoNQuQxwVWaMWfUEdV1P7bEnYCG6dWrPui/5TaaUOO3dHIDulnKaXalvai1OnBLO2traGDt2LMaOHYvc3FxcuHABe/fupeJDaXRwGYTKdYCrsmPMuE7tkZyRJfd8qu/r1qm90v18ult2woDePZSy5QIn8amKubk5vL29pUaYUyiNBS6DULkOcAVUM8ZsZcipOi9TldRafCiUxg4XgdCEAavfzBiPjq1lp6EBgPTsXPx47v+Y7VdvMuSWU9O+uoSKD4XSwGnRogWaGBlJiQsbTYyMYGlpiSZGRlj+/SGFtpJE9qvXb+Ta1bRPEVR8KBQNR1Hi28LCAnfi45VKTle3lWcvaS42MTLC0g01J5SbGBnVaugGFR8KRYNRNvEtSU6npqbKlMEmVsr2tFZWqFQ+vIJCodQvXBLfbEIFfNy6XVWFqrqw1evAUgqFUv8o+4DX5q2bsqhi/msqPhRKI4JLJMKlY6Qq5r/WIsqufaOhFBcXY8iQIbh//z6zNA+FQqmZnJwc9OjRQ61LK9PIh0L5BFFlE01ZGrz4SAI3unIphcKNli1bomVL2bmCPuZZatq0qdIzXjR48ZEsREgXDqRQ1A+X9EeDz/mIxWJkZ2dLKa5kFdMbN24odSG42KvKVlP8aIg+Uz80x+dPKvLR1tZG27bsM70ZGxtzSkJzsVeVrab40RB9pn7U3lbVZbOh/VHfplAolFpCxYdCoaiFRik++vr68Pf3h76+fp3bq8pWU/xoiD5TPzTTZ0U0+IQzhUJpmDTKyIdCoWg+VHwoFIpaoOJDoVDUAhUfCoWiFqj4UCgUtUDFpwrKvPgrKyuDQCBQukw+n6+UXV5entJlAkBubq5S/nJFIBAo7TOF8jF8kuJz9epVODg4wMnJCZcvX2Y+nzt3roxtRkYGAgICsGXLFty+fRsuLi6YPHky/vzzTxnbzMxMmT9vb2+8fftWxnb9+vUAgMePH2P8+PFYsGABnJ2d8ejRI1afw8PDERISgsePH8PJyQm+vr6YMGECYmNjZWxtbGwQExOj1LVIS0vDihUrsHHjRjx8+BBTpkyBq6sroqLYl+nl8Xjw9fWFo6MjrKys4OnpidWrVyM7O1up46kaVVQggGoqEVVVIEDdVSI8Hg9LlizBuHHjMG7cODg4OGDhwoWsc0VzhjQCfHx8iKenp9Sfh4cH8fT0ZLV3c3MjeXl5JC8vj/j4+JBz584RQgjx8PCQsfXw8CDx8fEkIiKCDB06lOTm5pLS0lIyc+ZMGVt7e3sydOhQ5vgeHh7MdnUkn82ZM4ekpqYSQgh5+/YtmTVrFqvP06ZNI+Xl5cTT05Okp6cTQgjJy8sjU6dOZbUNCAggCxcuJE+ePGEtr+r5/f333+TKlSvE1taWvH//npSVlRF3d3dWex8fH5KSkkIIIeTRo0dk9+7dJDExkcybN4/V/s2bN2Tx4sXE3t6e2Nvbky+//JL4+vqS169f1+iXIv7880/y5ZdfkgkTJpBLly4xn7Nd6/T0dLJ69WqyefNm8vfffxMHBwfi6OhIYmJiWMvOyMiQ+XN3dyeZmZkytoGBgYSQymvh6OhI3NzcyKRJk0hCQoKM7W+//Ub2799PHj16RCZMmECmT59OHB0dyY0bN1j9sLa2JtHR0Updj9TUVLJ8+XKyYcMG8uDBA+Lk5EQcHBzI77//LmP75s0b8vXXXxMHBwcyYMAA4uHhQVatWkWysrJkbD09PcmjR4+kPnv48KHc+4MLDX5gKQAsXboU69atw08//QQdHR2F9rq6ujA1NQUAhISEwNvbGy1btmQdjVtRUQFra2tYW1vj4cOHMDMzA1A5oLU6ly5dwoEDB5Ceno7Vq1ejTZs28PT0RFhYmIwt+W+Np6enh06dKtfBbtu2LUQikVyfdXR0YGpqilatWgEATExMWG2bNGmC4OBgPHjwACEhIcjOzsbw4cNhYWEBd3d3KVuBQIDhw4dDKBTC2NiYmd9F3sjkwsJCdOnSBQDQr18/7Ny5E8uXL5c7B8y6deuwcuVKWFlZMZ8lJCRgzZo1OHPmjJTtvHnzUFFRIfUZIQRaWloy1/DgwYM4d+4cAGDFihUoKSnB9OnTWSOJNWvWYMmSJcjIyMCSJUtw5coVGBoawsfHh3Wp7zlz5qCgoAB9+vRhyktOTkZAQICMH2lpaQCA3bt34/Dhw+jcuTPevXuHFStW4OTJk1K2p0+fxqlTpzB//nyEhoaiQ4cOyM/Px7x58zB69GgZPzp16oSYmBicP38eixYtQr9+/WRsJAQGBsLPzw/FxcVYuHAhLl68CBMTE8ydOxfOzs5Stps2bUJgYCC6dOmCx48fIyYmBk5OTvj2229x9OhRKduysjKp3w4ABg4cWCcRW6MQHysrK7i7u+PVq1ews7NTaN+nTx+sXr0aGzZsgLGxMX766Sf4+PiwNh2srKzg7++Pffv24bvvvgMA7Nq1ixGMqhgYGGDZsmVISUnBxo0bYWNjI/dH0tfXx5QpU1BYWIiTJ0/Czc0Na9asQa9evVjtp0+fDi8vL3Tp0gWzZ8+GjY0N4uLiMHHiRBlbyTEHDx6MgwcPoqCgAHfv3mUelKoMGDAAc+fOhZ6eHtq3b4/AwECYmJigW7durH4MGDAAS5YswYgRIxAbGwtbW1tERESwTkoFcLt5uVQiqqpAANVVIlwqEEB1lQiXCmT48OHw9fXFiBEjYGJigpKSEty6dQvDhg2T67eyfLLDK/766y/Y2NgwY1T4fD5OnjyJ+fPny9gmJCRIrWcUGxuLESNGQFe3Zu2OjIzEzZs3sXv3brk2mZmZEIvFaNu2La5fvw57e3u5UcebN29w+/Zt5ObmwtTUFIMGDULv3r1l7C5evIhJkybV6FtVXrx4gQ4dOoAQgsjISJiYmMDZ2VmuAFy7dg3Jycno3bs3Ro0aheTkZHTs2BEGBgYytnv27MGLFy9kbt6ePXti2bJlMvYnTpxA+/btFVYiQUFBKC4uZiqQnJwcpgK5ffu2lO22bduQmZmJffv2MYKza9cu5OTkYPv27XKPkZKSgh07dsDGxgZXr17FiRMnZGzmzZuH3NxcFBYWwtvbm6lEmjdvzlRWEsLDwxEZGYkuXbogMTFRqgLx8fGRKdvT0xO//vors121Epk3b56U7ZYtW/Dy5Uvo6elBIBCgU6dOMDExQVFREbZs2SJz7T58+MBUIH379kWHDh3w559/IiQkRMaPp0+f4sGDByguLoaxsTEGDhyIAQMGyL1uSvPRDTcNJikpSWX2mmDbUPx48uQJCQsLIwcOHCBhYWHk8ePHnMqWx82bN0l5eTmzXVpaSkJDQ1ltHz58KLV948YNUlFRodRxIiIiyPLly2u0ycjIIDwej1RUVJDo6GgiFotZ7dLS0siZM2fIgQMHyKlTp0hiYqLcMqOiopTyT8Lz589JUVERKSwsJGFhYSQiIoIIhUJW26tXr5LQ0FBy8+ZNQgghr169ImVlZUofi+v9wUajFp/g4GCV2WuCbUP1gxDNEM3GKPT1Zcv192ajUeR88vPzcfjwYcTFxaG0tBTGxsYYNGgQ/Pz8PtpeE2wbqh81cf78eQQEBDQY20/Vj8ePH2Pbtm3Q19fHN998g8GDBwMAa/cRzny0fGkAX331Fbl06RIpKioiYrGYFBUVkYsXL5I5c+Z8tL0m2DZUPygNH3d3d5KSkkJevnxJpk6dSq5du0YIYe/SwJVGEfkUFRXBycmJ2TY2NsbEiRNZE4Rc7TXBtqH6AQBr165l/RyATLJXE2ypH9K2urq6zJux0NBQeHl5wdzcXO73udAoxKdLly4ICgrCyJEjpd6oSC7ax9hrgm1D9QMA7O3t8cMPP2DTpk2s+zXNlvohTcuWLbF3714sWLAA5ubm2LdvH/z8/FBaWqr4BBTQKF61E0IQHR0t8zrQwcGBtS8HF3tNsG2ofkjYvXs3rKys8OWXX7Lu1zRb6sf/EAgEiIiIgLOzM5o0aQIAyMrKwqFDh7Bx40alzkEejUJ8KBRKw+OTHFhKoVDUDxUfCoWiFqj4UCgUtUDFh0KhqAUqPhQKRS1Q8aFQKGqBig+FQlELVHwaEenp6ejfvz9cXFzg6uqKyZMnY9asWUhKSpLaV/XvyJEjUmX8888/WLduXY3HqWpTVFSERYsWKfXdO3fuwNPTU+njsMHleFwQiUTw9/eXOyOjhMzMzBqHJ1CUp1EMr6D8j9atW+PChQvM9unTp7F+/Xrs3r1bZh8bAwYMUDhRVFWbgoICJCYmKv1dLsdho7bHU8SZM2dga2sLY2PjGu3at28PU1NTXL9+HV988UWdHPtThUY+jRwbGxskJSUpbV81Orlz5w68vLywePFiODk5wdvbG/n5+VI2QUFByM7Ohq+vL/O5UChEYGAg3N3dMW7cOPj5+cmspCCxDQsLY6KwCRMmoFevXkhNTZVbBtvxAODIkSNwcnKCs7MzgoODmWlM5Z1DVQghCAsLk5rrODs7G0uWLIGLiwuGDBmCXr16oVevXnj48CFcXFxw+PBhzr8FRRoqPo0YsViMyMhIZg6W7OxsmWbXgwcPaizj0aNHWLt2LS5fvgxDQ0OZJXU2bNiA1q1b49ChQ8xnDx8+hI6ODs6ePYvo6GiUlJSwLvEDVE7WfuHCBVy4cAF9+/bFwoULYWlpKbcMtuPFxsbiypUrCA8PR0REBNLS0qQmp1d0Ds+fP0fTpk2ZOaEBYMmSJRgyZAguXLiAmJgYGBsb448//mCmrk1JSeG81hpFGtrsamRIBAaoHBTYo0cPBAUFgRCiVLOrOj169ED79u0BVE68X1BQoPA71tbWMDExwcmTJ5GcnIyUlBSUlpZKPdzVOXjwIEpKSrB06dIay2Dj9u3bmDRpEoyMjAAA06ZNQ2RkJGbPnq3UOaSmpqJdu3bMdkJCAoqKiph13MzMzNChQwcpsWnXrh14PB4zGT2FO1R8GhnyBCY9Pb1W5VWdFF5LS0upJVNiYmKwd+9eeHt7Y9q0acjPz6/xe9HR0bh48SLOnj3LTJ7PpQyxWCwz6b5QKFT6HLS1taGnp8dsJyYmok+fPsx2UVER3r59KzVZv46OjlLLNFHkQ5tdlI9CV1dX6kEHgLi4ODg5OWHq1Klo1qwZ4uPj5a5H9vz5c2zduhUhISFSyV55ZbAdb9iwYYiKigKfz4dQKER4eDisra2VPofOnTtLibOZmRlevHgBgUCAiooKBAUFwd3dHU2bNmVssrKy0LFjR6WPQZGFRj6fEFWbZBL69u1b4/IximjVqhU6dOiAWbNmMU2m6dOnY/ny5bhw4QKaNGmCwYMHIz09nXWtsx07dkAoFGLFihWMQC1fvlxuGWzHGzt2LBITE+Hm5gahUIgRI0YwiWhl6N27N4qKilBYWIhmzZrBzs4ON27cgLOzM5o2bQp7e3v4+voy9klJSejSpQuaN29e6+tGofP5UCgAgJMnT6KiogJeXl4Kbbdu3YpRo0ZhzJgxqnesEUObXRQKAHd3d8TFxSnsZJiRkYHc3FwqPHUAjXwoFIpaoJEPhUJRC1R8KBSKWqDiQ6FQ1AIVHwqFohao+FAoFLVAxYdCoagFKj4UCkUtUPGhUChqgYoPhUJRC1R8KBSKWvh/pXq34p9/Up4AAAAASUVORK5CYII=",
      "text/plain": [
       "<Figure size 300x160 with 1 Axes>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "tmpdf = ann_accuracy.groupby(['PE','Seed']).mean(numeric_only=True).reset_index()\n",
    "plt.figure(figsize=(3,1.6))\n",
    "ax = sns.boxplot(data=tmpdf,x=\"PE\",hue=\"PE\",y=\"Accuracy\",fliersize=1,palette='rocket')\n",
    "# cbar = ax.collections[0].colorbar\n",
    "# # Label the colorbar\n",
    "# cbar.set_label('Accuracy', fontsize=8)  # Change 'Label' to your desired label text\n",
    "# # Change the size of tick labels\n",
    "# cbar.ax.tick_params(labelsize=8)\n",
    "plt.xticks(np.arange(0, len(initializations)),initializations,rotation=90,fontsize=7);\n",
    "plt.yticks(fontsize=7);\n",
    "plt.ylabel('Accuracy',fontsize=8)\n",
    "plt.xlabel(r'PE initialization ($\\sigma$)',fontsize=8)\n",
    "plt.title('Robustness to noise (averaged), learnable PEs',fontsize=8)\n",
    "sns.despine()\n",
    "plt.tight_layout()\n",
    "plt.savefig(f'{outputdir}perturbation_learnablePE_avg_across_noise.pdf',transparent=True,dpi=300)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Perturbation analysis on common PE"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "inference for epoch 4000\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/Users/tito/mambaforge/envs/lstnn/lib/python3.9/site-packages/torch/utils/checkpoint.py:542: UserWarning: torch.utils.checkpoint.checkpoint_sequential: please pass in use_reentrant=True or use_reentrant=False explicitly. The default value of use_reentrant will be updated to be False in the future. To maintain current behavior, pass use_reentrant=True. It is recommended that you use use_reentrant=False. Refer to docs for more details on the differences between the two variants.\n",
      "  warnings.warn(\n",
      "/Users/tito/mambaforge/envs/lstnn/lib/python3.9/site-packages/torch/utils/checkpoint.py:90: UserWarning: None of the inputs have requires_grad=True. Gradients will be None\n",
      "  warnings.warn(\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "inference for epoch 4000\n",
      "inference for epoch 4000\n",
      "inference for epoch 4000\n",
      "inference for epoch 4000\n",
      "inference for epoch 4000\n"
     ]
    }
   ],
   "source": [
    "\n",
    "# load the 108 fMRI trials\n",
    "validation_file = '../data/nn/puzzle_data_original.csv'\n",
    "dataloader = torch.utils.data.DataLoader(\n",
    "    get_dataset(validation_file), batch_size=108, shuffle=False)\n",
    "\n",
    "\n",
    "positional_encodings = {\n",
    "      'pe-1dpe_':'absolute',\n",
    "      'pe-2dpe_':'absolute2d',\n",
    "      'pe-shaw_':'relative',\n",
    "      'pe-rope2_':'rope2',\n",
    "      'pe-learn-0.2_':'learn',\n",
    "      'pe-rndpe_':'rndpe',\n",
    "      'pe-cnope_':'cnope',\n",
    "      'pe-nope_':'nope'\n",
    "}\n",
    "pe_labels = {\n",
    "      'pe-1dpe_':'1d-fixed',\n",
    "      'pe-2dpe_':'2d-fixed',\n",
    "      'pe-shaw_':'relative',\n",
    "      'pe-rope2_':'rope',\n",
    "      'pe-learn-0.2_':'learn-0.2',\n",
    "      'pe-rndpe_':'random',\n",
    "      'pe-cnope_':'c-nope',\n",
    "      'pe-nope_':'nope'\n",
    "}\n",
    "\n",
    "# ann_accuracy = np.zeros((108, len(seeds), nblocks, attnheads)) # may want to add epochs\n",
    "ann_accuracy2 = {}\n",
    "ann_accuracy2['Accuracy'] = []\n",
    "ann_accuracy2['Noise'] = []\n",
    "ann_accuracy2['Epoch'] = []\n",
    "ann_accuracy2['Seed'] = []\n",
    "ann_accuracy2['Layers'] = []\n",
    "ann_accuracy2['Heads'] = []\n",
    "ann_accuracy2['Dropout'] = []\n",
    "ann_accuracy2['Decay'] = []\n",
    "ann_accuracy2['Puzzle'] = []\n",
    "ann_accuracy2['PE'] = []\n",
    "ann_accuracy2['Model'] = []\n",
    "dropout = 0.0\n",
    "# for epoch in np.arange(0, last_epoch+1, checkpoint_freq):\n",
    "for noise_sd in [0.0, 0.02, 0.04, 0.06, 0.08, 0.1]: #[0.0, 0.05, 0.1, 0.15, 0.2, 0.25, 0.3, 0.35, 0.4]:\n",
    "    for epoch in [4000]:\n",
    "        print('inference for epoch', epoch)\n",
    "        for pe in positional_encodings:\n",
    "            petype = positional_encodings[pe]\n",
    "            pestr = pe_labels[pe]\n",
    "            if pe_labels[pe] == 'learn-0.2':\n",
    "                pe_init = 0.2\n",
    "            else:\n",
    "                pe_init = 1.0\n",
    "            for wdecay in wdecays:\n",
    "                for layer in nblocks:\n",
    "                    for attnhead in [1]: #attnheads:\n",
    "                            resultdir = f\"../results/\"\n",
    "                            modelname = f\"model-{model_label}_\" \\\n",
    "                                        f\"{pe}\" \\\n",
    "                                        f\"nl-{layer}_\" \\\n",
    "                                        f\"do-{dropout}_\" \\\n",
    "                                        f\"wd-{wdecay}_\" \\\n",
    "                                        f\"at-{attnhead}_\" \\\n",
    "                                        f\"hs-{hidden_size}_\" \\\n",
    "                                        f\"curr-{curriculum}_\" \\\n",
    "                                        f\"lr-{learning_rate}_\" \\\n",
    "                                        f\"co-{training_acc_cutoff}_\" \\\n",
    "                                        f\"col-{cutoff_length}/\"\n",
    "                            model_str = pestr + '-' + str(attnhead) + 'H'\n",
    "                            # _df = df_good_models.loc[(df_good_models.epoch==4000) & (df_good_models.pe==pestr) & (df_good_models.heads==attnhead)]\n",
    "                            # if len(_df)<1: continue\n",
    "                            for seed in seeds:\n",
    "                                try: \n",
    "                                    checkpoint = f\"s-{seed}_\" \\\n",
    "                                                f\"e-{epoch}\" \n",
    "                                    torch.manual_seed(seed)\n",
    "                                    model = transformer_main.Transformer(\n",
    "                                                nblocks=layer,\n",
    "                                                nhead=attnhead,\n",
    "                                                dropout=dropout,\n",
    "                                                embedding_dim=hidden_size,\n",
    "                                                positional_encoding=petype,\n",
    "                                                pe_init=pe_init)\n",
    "                                    model = model.to(device=torch.device('mps'))\n",
    "                                    model.load_state_dict(torch.load(resultdir + modelname + checkpoint +'.pt',map_location=torch.device('mps') ))\n",
    "                                except:\n",
    "                                    continue\n",
    "\n",
    "                                with torch.no_grad():\n",
    "                                    for i, batch in enumerate(dataloader):\n",
    "\n",
    "                                        # get features\n",
    "                                        test_features, test_labels, index = batch[0], batch[1], batch[2]\n",
    "\n",
    "                                        # flatten to accommodate transformer\n",
    "                                        test_features = torch.flatten(test_features,start_dim=1,end_dim=2)\n",
    "                                        test_features = test_features.to(device)\n",
    "                                        test_labels = test_labels.to(device)\n",
    "                                        #model_history = tl.log_forward_pass(model, test_features, vis_opt='none')\n",
    "                                        #model_history = tl.log_forward_pass(model, test_features, vis_opt='rolled', vis_outpath='transformer.svg')\n",
    "\n",
    "                                        # add noise\n",
    "                                        test_features = test_features + torch.empty(test_features.size(),device='mps').normal_(mean=0,std=noise_sd)\n",
    "\n",
    "                                        # Compute prediction and loss\n",
    "                                        out = model(test_features)\n",
    "                                        accuracy = torch.argmax(out, dim=1) == torch.argmax(\n",
    "                                                        test_labels, dim=1)\n",
    "                                        accuracy = accuracy.cpu().numpy() * 1.0\n",
    "                                        ann_accuracy2['Accuracy'].extend(accuracy)\n",
    "                                        ann_accuracy2['Model'].extend(np.repeat(model_str,len(accuracy)))\n",
    "                                        ann_accuracy2['Noise'].extend(np.repeat(noise_sd,len(accuracy)))\n",
    "                                        ann_accuracy2['Puzzle'].extend(np.arange(len(accuracy)))\n",
    "                                        ann_accuracy2['Seed'].extend(np.repeat(seed,len(accuracy)))\n",
    "                                        ann_accuracy2['Epoch'].extend(np.repeat(epoch,len(accuracy)))\n",
    "                                        ann_accuracy2['Layers'].extend(np.repeat(layer,len(accuracy)))\n",
    "                                        ann_accuracy2['Heads'].extend(np.repeat(attnhead,len(accuracy)))\n",
    "                                        ann_accuracy2['Dropout'].extend(np.repeat(dropout,len(accuracy)))\n",
    "                                        ann_accuracy2['Decay'].extend(np.repeat(wdecay,len(accuracy)))\n",
    "                                        ann_accuracy2['PE'].extend(np.repeat(pestr,len(accuracy)))\n",
    "                                    # ann_accuracy[:, s] = accuracy.copy()\n",
    "ann_accuracy2 = pd.DataFrame(ann_accuracy2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 161,
   "metadata": {},
   "outputs": [],
   "source": [
    "tmpdf = ann_accuracy2.groupby(['PE','Seed']).mean(numeric_only=True).reset_index()\n",
    "tmpdf2 = tmpdf.groupby('PE').mean(numeric_only=True).reset_index()\n",
    "df_validation_order = tmpdf2.sort_values('Accuracy',ascending=False)\n",
    "\n",
    "noise_levels = ann_accuracy2.Noise.unique()\n",
    "# models = ann_accuracy.Model.unique()\n",
    "Layers = 4\n",
    "Heads = 1\n",
    "Decay = 0.0\n",
    "noise_performance = np.zeros((len(pe_labels),len(noise_levels)))\n",
    "\n",
    "i = 0\n",
    "for pe in df_validation_order.PE.values:\n",
    "    pestr = pe \n",
    "    j = 0\n",
    "    for noise in noise_levels:\n",
    "        tmpdf = ann_accuracy2.loc[(ann_accuracy2.PE==pestr) & (ann_accuracy2.Noise==noise)]\n",
    "        noise_performance[i, j] = tmpdf.Accuracy.mean()\n",
    "        j += 1\n",
    "    i += 1"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 173,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAIEAAACXCAYAAAAyLzjWAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjguMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8g+/7EAAAACXBIWXMAAA9hAAAPYQGoP6dpAAAYnklEQVR4nO2df1hN2f7H37tySCkhfadCclVjognFiNEgXDI4Sj+mTI3oiTKE25UpLvnZnVDmmqjMrS5dox+uMc+UMGFm3AaD0U2/3IxqEioqOj/X9w9PW8fZp/apOHVbr+dZz3P23p+91jrnvPf6+VlrM4QQAkqvRkvTGaBoHioCChUBhYqAAioCCqgIKKAioICKgAIqAgoAna6IpLGxEfv27UN+fj60tLTQv39/bNiwARMmTOiK6F871tbWsLGxYY91dHQQFhYGR0dHxMXFIS0tDUOGDFG45/PPP8cf/vCHN53V1wPpJDKZjHh4eJB9+/YRiURCCCHk5s2bZNKkSaSqqqqz0b8RrKysFI7z8vKIk5MTIYSQ2NhYEhsbq4lsvTE6XRJcuXIF1dXVWLNmDbS0XtQu48aNQ3R0NHt8+PBhZGVlQVtbG05OTti4cSN+//13BAUFwcrKCgUFBRgzZgwmTZqEjIwM1NfX4+DBgxg9ejRmzJiBefPm4ccff4RMJkNoaCiSkpLw3//+F3/605/g6uqKx48fY/PmzaisrISOjg7WrVuH999/H3FxcXjw4AF+++03VFZW4v3338eWLVva/U6Ojo54+PAh6urq2rSrra3F+vXrUVdXB5lMhk8//RSzZs3q7E/6xul0m6CwsBA2NjbsH97CtGnTYGJigosXLyInJwfp6enIzMzEvXv3kJaWBgAoLi6Gv78/zpw5g1u3bqGiogL//Oc/MX/+fHz99ddsXIMHD0ZGRgbefvttxMfHIzExEXv37kVCQgIAYPv27Zg4cSJOnz6N2NhYhIeH49GjR2z+EhIS8K9//Qu5ubkoKipq9ztlZGTAwsICRkZGAIC0tDQsXLiQDX5+fgCAb775BlZWVsjKysKePXtw7dq1zv6cGqHTJYGWlhb69u2r8vpPP/0EV1dX6OrqAgCWLFmCrKwsTJ8+HcbGxrC1tQUAvPXWW3jvvfcAAObm5sjPz2fjcHZ2BgCYmZnBxMQEOjo6MDc3x9OnTwG8KI22bdsGABg2bBjs7Oxw8+ZNAMB7770HgUAAgUCAESNG4MmTJ5z5XLhwIQBAIpHA1NQUBw4cYK95enoiJCRE6Z5Jkybhk08+QVVVFaZNm4bVq1e3/4N1QzotgrFjxyI1NRWEEDAMw56PjY2Fra0t5HK5wnkAkEqlAACBQKBwXltbmzONPn36vMywjnKW5XK5wjEhhE2jtUAZhgFRMXN+6tQpzvNtYW1tjezsbFy8eBEXLlxAUlISvv32W6VSsbvT6dxOmDABxsbG2L9/P/vD5+fnIy0tDVZWVpg8eTJOnz6N58+fQyqVIj09HQ4ODp3OeGsmT56MEydOAADu37+Pa9eu4d133+3SNLg4fPgwjhw5gnnz5mHr1q2ora1FY2Pja0+3q+l0ScAwDA4dOoTdu3djwYIF0NHRgYGBAQ4dOgRzc3OYm5ujsLAQbm5ukEqlmDJlCnx9fVFTU9MV+QcAbN68GZGRkezTvH37dpiYmHRZ/GlpacjNzVU4FxwcDDc3N2zYsIH93iEhITAwMOiydN8UDFFVPlJ6DT2r8qK8FqgIKFQEFCoCCqgIejQhISGoqKhgj8ViMYKCguDp6cl2mflARdADEYvFWL16NTsq2sK3336LadOm4dixY/juu+8gEol4xUdF0E0Qi8VobGxUCGKxWKWtn58fpkyZonC+oKAAEyZMgJaWFqysrFBWVsYrbd6DRbLHmbzsmJNn+dn14R4ifpVn+fW87LR1eZnh7A8jeNl9+PMefhHyhBCJ0jmZ/Hv2c3z8rzh48KDC9eDgYM45C319fTg4OCA9PV3hfFNTE/r37w8A0NXVxbNnz3jlrUucSigdpFVxHRgYCH9/f4XLr86ttEf//v3x/PlzAMDz58+hr6/P6z5aHWiS5mY2CAQC6OvrKwR1RfDOO+/g559/BiEEhYWFGDlyJK/7qAg0CCMSsaEzJCQkoLS0lHW+cXNzw+zZs9uc4m8NrQ40iai5U7fv3r0bABAQEMCe++KLL9SOh4pAk4iVG4uagIpAg3S2GugqqAg0iYh7HOBNQ0WgSZppSUChJQGl54mAw8uXC2b4YH7xiaW8zPqN4jf0Ka3h94OOGVzLy+6NwPM3eN3QkkCTiGgXkdJMRdDrIaKX1QHTht3rhopAkzTTNkGvh1ARUIhIpuksAKAi0ChE3D0Wf1F/Ag0iF8nZwBdVHsXNzc1Yvnw5li5dipMnT6qVDyoCDSJvfhn4osqj+NKlS5g4cSLS0tKoCHoSctHLwBdVHsWWlpaQSCSQSqUK+znwgbYJNIhM9HJ0QCwWK7mYt+yw0hpVHsV9+vTBmTNnkJWVhaVLl6qVj64XwbC3+NnV1vMy07LkV18ytb/zsntrVAMvuzeBVPyyII6Pj+flcq7Kozg1NRWhoaFwcXFBcHAwKisrYWZmxisftCTQIDLpSxHwdTlv8SgePXo0CgsLsW7dOgAvSgU9PT1oaWlBX1+f95oDgLYJNIpErM0Gvi7nr3oUp6SkoLS0FL6+vkhMTIS3tzcGDhyI0aNH884HLQk0iETKbxVWa/r27avSo/jo0aMdygcVgQaRSNQXweuAikCDiGVUBL0ecQeqg9cBFYEGEcmpCHo9Unn36JxREWgQkYyKoNcj7nElAc9Nm4khv21dmabn/NL9v0H87LT5DRv3HclvufabQNTjREDpcnpeSUDpckQyTfoYv4SKQINICBVBr0cspyLo9YioCChi/v6lrxUqAg3SkWUHYrEYn376Kerq6iAUCllXsubmZmzatAnV1dWwsbHh9eq/FrpHH6WXIpYzbOCLKm/jEydOYPr06Th+/DgsLS3R3MzfhZmKQIOI5YQNfFHlbXz16lVUVFTA19cXenp66NevH+84qQg0iET+MvDd4FqVt/HTp08xdOhQHD16FKdOnUJtLf/NOLq+TdCH51asWjyLQL3+Hc8LV7LGel0aX2cQt2oTdNbb2MDAAI6OjtDR0cHYsWNRUVGBQYP4DbnTkkCDtK4OAgMDce3aNYUQGBiodI+q/YtbzgPAnTt3eO9rDFARaBSR7GXorLexl5cXcnNz4ebmBgcHBwwYMIB3PmgXUYOI5eoPFLTlbRwfH9+hfFARaBCxrHssTaci0CAS0j2GDKkINIhYTncq6fWICBVBr0cEuo9hr0fE0F3Oez0ipnOvv+kqul4EOjxX1WjztOPp5czwHYbuRkhBS4Jejwj8N5J4nVARaBAJeK69eM1QEWgQiZyWBL0eiZyWBL0eKgIKpLLuIQLqT6BBZHIxG/iiam/jFuLi4pCRkaFWPqgINIhM3swGvqjyNgaA2tpaTmG0BxWBBpHLn7OBL6q8jYEXb09ftGiR2vmgbQINQlpVA53d2/j3339HU1MT7OzsOpCRDiISiUhsbCwRiUTUrgN2rxIbG0usrKwUQmxsrJLd9u3bSVFRESGEkF27dpHCwkJCCCFbtmwh9+/fJ+np6SQ9PV2ttDssgoaGBmJlZUUaGhqoXQfsXkUkEpGGhgaFwCWkjIwMkpqaSuRyOVm2bBlpbm4mhBAiFAqJj48PmTNnDpkzZw4pLy/nnTatDroJXEU/F/PmzUNoaCgyMjIgFAqRkpICZ2dnpKenAwDbMxgxYgTvtKkIehhteRsDgFAoVDtO2jugdFwEAoEAwcHB7RZh1K77wxBCuofzO0Vj0OqAQkVAoSKggIqAAioC1NfXo722cXl5OfLy8lBVVdWmbWNjI+7cuYPGxsauzuZrhfdg0YwZM8AwDJqbm/Hs2TOYm5ujsrISQ4YMQU5ODmt369Yt7Ny5EwKBAGvXrsX48eMBAGvXrsX+/fuV4r179y7OnTunMCUaHBysZPfgwQPs3bsXdXV1cHFxgY2NDezt7dX5rgrk5+djx44dEIvF+OMf/wgTExN4eHgo2SUnJ+PSpUt4+PAh3NzcUFJSgr/85S9Kdt999x0SExPR3NyM+fPnQyqVcn6PbolaA9yEkLVr15KamhpCCCG1tbUkJCRE4bqHhwe5e/cuKSkpIUKhkFy4cIEQQoiPjw9nfPPnzydHjx4lGRkZbOBixYoV5Pr168THx4fcu3ePuLu7K9l88MEHZMaMGWTKlCnk3XffJa6ursTe3p64uLgo2Xp7e5OGhgbi4+NDRCIRWbx4MWe63t7eCvlfsmQJp52XlxeRSCSsnar4uiNqDxtXVFTA2NgYAGBkZISqqiqF6zo6OuxWKUeOHIGfnx8GDRoEhuFeHGJmZgY/P7920xWJRLC3twfDMBg+fDj69lXesv78+fMAgHXr1iE8PBzGxsaoq6tTuaefvr4+GIaBQCCAnh73XkYymQwSiQQMw4AQovI9xHK5HDo6Ouz3VGf3ME2jtggcHR2xfPly2Nra4ubNm5g7d67C9SFDhuDAgQNYuXIlBg0ahNjYWKxatUrlGztdXFywfv16hT12uIrRAQMGICMjAyKRCLm5uW1ux9KeUAHAzs4OYWFhqKmpwc6dO2Ftbc0Zl6+vL9zc3PDo0SN4eXlxVhkAMHfuXHz88ceoqKhAcHAwnJ2dVeavu9GhEcOSkhKUlpZixIgRGDNmjMI1sViMzMxMLFiwgHV+ePDgAb788kvOJ9LNzQ2urq4wNDRkzy1evFjJ7unTp4iPj0dJSQksLS1ZkXERHR2NO3fusEKdOnUqAgIClOzy8vLY+GbMmKHy+9bX1+P+/fswNzeHkZGRSrvS0lKUlpbCwsICNjY2Ku26G2qLoKMNtLq6Os4fMCAgAAkJCe3ev2fPHgiFQt6vf21LqADw8OFDJCUloaysDKNGjUJgYCAGDhyoZHf16lXs3bsXtbW1MDIywpYtW2Bra6tkV1ZWhgMHDrDxbdy4EcOGDeOVV42jbiOCTwONiw0bNnCeX7NmDQkKCiJxcXFs4OLs2bMkODiYeHl5kWPHjrXptFFdXU1CQ0OJv78/OXbsGLl+/bqSjY+PD8nMzCRlZWUkPT2dBAQEcMa1ePFi8ttvvxFCCCkvLydLly7ltHN3dyf5+flEJBKRn376SWVDuDui9jgBnwYaF9HR0ZznnZ2d4eLiAnNzc5iZmal83fusWbMQFxeHuLg4XLlyBVOnTlWZVkREBHx8fCCRSODk5IRdu3Yp2WhpaWHRokWwtLSEUChU6KK2xsDAgH2iR4wYobLB179/fzg4OEAgEGDy5MnQ5rvquhugdsOwvQaaWCzGiRMn8O9//xtNTU3Q19fH+PHj4eXlxSkYV1dXnDhxAsXFxbCwsMBHH33Eme7t27dx6tQpXLt2DU5OTsjKylKZRz5CNTU1xVdffQUHBwcUFBSgX79+7GaQDg4OrJ2BgQFWrVoFBwcH3L59GyKRiN15tHUDdvDgwYiKioKjoyMKCgpACGHz2BEP4DeJ2iLYuXMn4uPjYWhoiKtXr2LHjh0K1//85z/j7bffxurVq6Gnp4empiZcvHgR69atw9/+9jel+CIjIzF06FDMnDkTP//8MzZt2oTPP/9cye7LL7+Em5sbNm3aBK129izg25MoKipCUVERgBd/YotrVmsRfPDBBwAAhmHaLH1aejfFxcXo06cPHBwcUFFR0WY+uw3q1h95eXkKx1999ZXC8UcffcR5n6enJ+f5V+vOV+0uXbpECHnhYJmZmakQVPHkyROyd+9esmLFCrJr1y5SW1vLaVdUVETOnDlDbt++rTIusVhMUlNTSWRkJElKSmrTi/jChQskPj6eZGdnq7TpjqhdEqSkpODKlSvw9fVFRESEUh1uaGiIxMREODk5YcCAAWhqasIPP/yAIUOGcMYnFotRX1+PgQMH4smTJ0pj8w8fPgQAVFZW8s7jjRs3sHHjRvb473//Oz7++GMFm+TkZJw9exZ2dnZITU3F9OnTOfcS5ltSRUdHo7q6GhMnTkR2djby8/Px2Wef8c6zRumIckJDQ8mYMWNIamqq0rXnz5+TpKQkEhwcTPz8/EhISAhJTEwkTU1NnHFdvnyZzJ49myxdupTMnj2bffJfZffu3QrHW7ZsUZm/gIAAsmfPHlJVVUWWL19OIiMjlWw8PDyITCYjhBAilUpVDge3V1K14OXlpXDMt9fUHVC7JIiIiMCTJ0+QkJCAmJgYSKVShaesX79+8Pf3h7+/PwDgH//4h8rGHgA4OTkhOzub7Ye/Orz89ddf49ChQ3j8+DE7UUUIabMPfuTIEaxfvx6zZs1CeHg4Z/oymYxtW2hra6tszbdXUrUgkUggFoshEAg431PQnVFbBKNGjcL27dsBvPiDVXX9WsjOzub8E3x9fVXOJyQnJ7Of3d3d4e7ujsOHD2PlypW88tieUAFg9uzZ8PX1hb29PW7cuAEXFxfOuNasWQMPDw8MHDgQ9fX1iIiI4LRbtmwZlixZAisrKxQXF2PFihW88tod4C2Cy5cvY+rUqTA0NFTonr3zzjtt3qfqydm9ezcAICYmBvPmzcP48eNx+/ZtZGdnc9o7ODggMjISEsmLDSBramqQmJjIactHqOfPn8e2bdtQWlqK+fPnq5w7aF1StfUSCQMDAxgbG+PcuXMYNmwYTp48iQ8//FClfXeCtwh+/fVXTJ06Va0GGgAcPnyY83xLg7KmpgYzZ84EAEybNk2l/Y4dOxAYGIhvvvkGY8eO5Sxy1REqwzA4dOgQRo4cibKyMpw9e5Zz4urkyZNISkpSGEw6d+6ckl10dDS2bdsGExMTzvx3Z3iL4Pvvv0dQUBDu3bvXbhUAAAcPHkRqair09PRACAHDMJw/np6eHuLj42Fra4tffvmFnf17FUNDQ7i4uODChQsICAjgrGLUEaq7u3u7NsCLnsXRo0fb/XNNTU1ZB5qeBm8RmJqawtnZGXV1dbh+/ToAtPnn5ubm4uLFi+0uwoiJiUFmZibOnj0LCwsLlQLT0dHBjRs3IBKJcPnyZTx69EjJRh2hcs1UcmFubs7r6dbV1UVwcLDC7GFP8SziLYJ9+/YBeDFiGB4e3q798OHDecUrk8lgZGQEOzs7MAyD06dPcw6zbt26FXfv3sWqVauwf/9+zh9YXaHyzZ9QKFRoM3DNRfQk/4FX6fAKpIiICLbxxYWnpycqKythYWHxIiGGUWj1t+Dr6wsLCwuFp631H8zlENKCqakp53m+QuVDfn6+0jlHR8cuibu70GER+Pr6IiUlReX1goICpfl5rhnCZcuWcYqjdTotrl2tUSWq1rQnVHXoyri6Gx1emt7e+veoqCgcP3683XhsbGxw/vx5WFtbs+MGrZ/w1kJraGhAZWUlzM3N2fcBtkV5eXm7Nny5d+9el8XV3eAtgleL5VWrVrHnuIplhmEQGhqKkSNHsn8uVz1eWFiIwsJChfu4nvCOuHSrs1GDKsrLy3Hr1i0MHz4c0dHR8PDw4N3e6Snwrg58fX0BAI8fP8azZ89gY2ODkpISGBkZ4eTJk0r2mZmZigkxDK959ZqaGgwdOlTpvLe3N5KTk+Hv74+UlBQIhUKl/fo60n5oD09PT4SFhcHe3h6//PILYmJi2qwGeyK8S4KWLx4UFISYmBjo6upCLBYrvca1BWdnZ/zwww+QSqUghKCmpobTbv/+/UhPT2d37zI3N8fp06eV7Pi4dIeFhQHgL1S+tPhQdmaxS3dG7TbBgwcPWE8dLS0tdqr3VdasWQNra2sUFha26R7+448/4vz584iKisLKlSuxefNmTrs5c+a069KtrlD5MHToUBw8eBD29va4deuWyinxnozaIli8eDEWLVqE0aNHo7i4WOWkjra2Nj777DOEh4djx44d8Pb25rQbMGAA+vTpA5FIBDMzMzQ3c+/umZOTg6ioKBQXF2PkyJFtunTzFSof9uzZg+PHjyMnJweWlpacYwQ9Hd4iaL3GMDIyEgKBAKampoiKisKCBQuU7AkhqK+vx7Nnz9Dc3IyGhgbOeCUSCZKTk2FkZITQ0FCV07AMw+CLL75gx/pzc3NVNgwXLVqEhQsXdsmMnq6uLj755JMO398T4C2CnTt3YteuXZDJZAgLC0NISAjGjRuHx48fc9oHBQXhzJkz+PDDDzFz5ky4urpy2kkkEgiFQujq6iIvLw/jxo3jtOM71g+8GHsQCoW4e/cuzMzMMHjwYN739kZ4i0DdNYaTJ0+Gra0tKioqkJOTo7JfzzAMIiMj2a7kf/7zH84nnM9Y/6ZNm1Re+18sxrsK3iJQd40h3369Ok94e8ycORMxMTHYunVrl8XZK+DrhyYSiUhaWpqCr2B1dTXZunUrp72mlmr/9a9/JTk5OW8krf8VeJcEAoFAaUWuiYmJymXfmlqqvX79+jeSzv8Sr227mrlz58LPz69HLtXubXT53sYtS7SAF95Aurq6IIT0OA/c3kSXi6D1dLGZmRktAXoAdFtbCt3CjkJFQAEVAQVUBBRQEVBARUABFQEFVAQUAP8Pn24mfzy6kaAAAAAASUVORK5CYII=",
      "text/plain": [
       "<Figure size 140x161 with 2 Axes>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "plt.figure(figsize=(1.4,1.61))\n",
    "ax = sns.heatmap(noise_performance.T, cmap='magma',square=False,cbar_kws={'fraction':0.046},vmin=0.4, vmax=1)\n",
    "ax.invert_yaxis()\n",
    "cbar = ax.collections[0].colorbar\n",
    "# Label the colorbar\n",
    "# cbar.set_label('Accuracy', fontsize=6,labelpad=2)  # Change 'Label' to your desired label text\n",
    "# Change the size of tick labels\n",
    "cbar.ax.tick_params(labelsize=6)\n",
    "plt.xticks(np.arange(0.5, len(pe_labels)+0.5),df_validation_order.PE.values,rotation=90,fontsize=7);\n",
    "plt.yticks(np.arange(0.5,len(noise_levels)+0.5),noise_levels,rotation=0,fontsize=6)\n",
    "plt.yticks([])\n",
    "plt.xlabel('',fontsize=8)\n",
    "# plt.ylabel('Noise amplitude',fontsize=8)\n",
    "plt.title('Common PEs',fontsize=8)\n",
    "plt.tight_layout()\n",
    "plt.savefig(f'{outputdir}perturbation_commonPEs_{1}head.pdf',transparent=True,dpi=300)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 198,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAKoAAACVCAYAAADIWmH6AAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjguMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8g+/7EAAAACXBIWXMAAA9hAAAPYQGoP6dpAAAvjklEQVR4nO2deTyV6f//X4djVwYRoUwqbdO0E1oZpqmZqSnUtEjrNKUmKpmS6tPCFE3SMgppWkwbqmlliogWWdKKkp0s4eRwOOf9+6Of++tEQkqn7ufj0SPnnGu77/t1v6/tfV0Xh4gILCwfOVJtXQAWlqbACpVFImCFyiIRsEJlkQhYobJIBKxQWSQCVqgsEgErVBaJgBUqi0TACpVFIuB+iEx4PB62b9+OmzdvQkpKCoqKili+fDkGDRr0IbJ/ZwwNDdGzZ0/mM5fLhbOzM4YOHYqdO3ciKCgIHTp0EIvj6emJbt26feiitgnl5eVYtWoVdu3a1eQ4p06dws2bN+Hu7t6k8O9dqCKRCHPnzoWxsTGCg4PB5XKRlJSE+fPnIzg4GNra2u+7CK1CaGgo83dkZCQcHR0RFRUFAJgyZQocHBzaqmhtTmlpKR48ePBe83jvQo2NjUVeXh6WLFkCKalXLY1+/fph69atzGdfX1+EhIRAWloapqamWLFiBXJzc7Fw4UL06NED9+7dQ+/evWFkZIRTp07hxYsX8PHxQffu3TFmzBh89913uH79OoRCIRwdHeHv74+nT59i5cqVGD9+PIqKirB69WpkZ2eDy+Vi2bJlGDFiBHbu3In8/HxkZGQgOzsbI0aMgJub21uvaejQoXj+/DlKSkoaDVdcXAwnJyeUlJRAKBRi6dKlsLCweGP4nTt3Ijs7mynPb7/9hhs3biA+Ph6GhobYsWMHOBwOAgICcObMGdTU1GDIkCFwcXEBl8vF9u3bERMTg7KyMmhoaMDLywsaGhoYNmwYvv/+e9y+fRsCgQDu7u7o27evWN7h4eHw8fFBdXU1NDU1sXXrVqirqyM4OBh+fn7gcDjo06cPXF1doaSkBCMjI/Tr1w95eXlQU1NDQUEBfvnlF+zduxdnz56Fv78/hEIhunXrhvXr10NZWRkhISHYs2cPlJWVoaOjA0VFxbfeawZ6z+zfv58WLFjwxt8jIiJo0qRJVFFRQdXV1fTLL7/QoUOHKDMzk3r06EF3796lmpoaMjc3p23bthER0Y4dO2jTpk1ERDR69Gjy9/cnIiJnZ2eaOnUqVVdXU0xMDP34449ERLR06VLat28fERFlZGSQqakpPX/+nLy9vemnn36iqqoq4vF4ZGZmRg8fPqxXxh49eoh9Pnz4MFlaWhIRkbe3N5mYmNAPP/zA/LOzsyMiosDAQNq8eTMREd27d4/c3d0bvVfe3t40ceJEqq6uptjYWOrZsyelpKRQdXU1WVhY0IMHDyg6OpocHByopqaGRCIRrVmzhg4dOkTp6en066+/klAoJCKiVatWkZ+fH1P+CxcuEBFRQEAALV68uF7eEyZMoAcPHhARka+vL0VGRtKjR4/I3NycioqKiIho/fr1zDX06NGDoqOjiYgoMzOTRo8eTUREqampNHXqVKqsrCQiot27d5O7uzvl5eWRiYkJFRQUUE1NDc2ZM4ecnZ0bvR91ee8WVUpKCnJycm/8PSYmBuPHj4eCggIAYNKkSQgJCcHIkSOhoaHBvPna2toYNmwYAEBXVxc3b95k0hg1ahQAQEdHBx07dgSXy4Wuri7KysoAvLLqGzZsAADo6enh66+/RmJiIgBg2LBhkJWVhaysLLp06YLS0tIGy/njjz8CAKqrq9GpUyfs2LGD+e1NVb+RkRFmz56NnJwcDB8+HIsWLXrr/TIzMwOXy4WOjg40NDSYdq6WlhZKS0tx7do1JCYm4qeffgIAVFVVgcvlYtq0aXB2dsaxY8fw9OlTxMXFQUdHp9496tmzJ8LDw+vla2FhgV9++QXm5uYYM2YMTE1NcfjwYYwePRpqamoAABsbG7i4uDBxBgwYUC+dmJgYpKenw8bGBgBQU1MDPT09xMfHY8CAAdDQ0AAAjB8/HrGxsW+9H7W8d6F+9dVXOHToEIgIHA6H+d7b2xt9+/aFSCQS+x54dXEAICsrK/a9tLR0g3nIyMgwf3O59S9JJBKJfSYiJo+6LxGHwwG9wT23bhu1qRgaGuLixYuIjIzElStX4O/vj3PnzjFNnoZ427UIhULY2dlh9uzZAF51ZDgcDpKSkuDk5IQ5c+bg22+/hbS0tNi11F7n6/e6lkWLFmHs2LGIiIjA1q1bkZSUBCUlJbHwde8bAMa4vF6+sWPHwtXVFQBQUVEBgUCAGzduiJWnoWtrjPc+PDVo0CBoaGjgzz//ZC7y5s2bCAoKQo8ePWBsbIwzZ86Az+ejpqYGJ0+exJAhQ1q1DMbGxjh27BgAIDMzE3Fxcejfv3+r5tEQvr6+2LdvH7777jusW7cOxcXF4PF475SmsbExQkJCwOPxIBQKsWzZMpw8eRJxcXEwMjLClClToK+vj4iICAiFwianO378eBAR7O3tMWvWLNy/fx9Dhw5FWFgYiouLAQDHjh1r8NlwuVzm2RoZGeHy5ct4/vw5AGDLli3YvXs3Bg0ahISEBOTm5kIkEuHcuXPNuu73blE5HA727NkDd3d3fP/99+ByuWjfvj327NkDXV1d6Orq4sGDB5g8eTJqampgYmKCGTNmoKCgoNXKsHr1aqxdu5axiv/73//QsWPHVks/KCgIYWFhYt8tXrwYkydPxvLly5nrdnBwQPv27XH06FEUFBRg6dKlzc5rzJgxePToEWxtbSEUCjF06FBMmzYNRUVFWLRoESwtLSEnJ4e+ffsiMzOzyek6OTnht99+g4yMDOTl5bFu3Tr06NEDCxcuhJ2dHaqrq9G7d2+sX7++XlwNDQ3o6Ojg559/xpEjR+Dg4AB7e3sQEQwMDLBq1SooKSnB1dUV9vb2UFRUbPbQHYfeVNexvDeKi4vh7++P5cuXt3VRJAZWqG3ArVu3oKurKzFjyB8DrFBZJAJ2rp9FImCFyiIRtLlQiQg8Hu+N45csLMBHINSXL19i0KBBePnyZVsXheUjptlCDQsLE5tGq2X79u2wtrbGjBkzkJGR0SqFY2GppVkD/tu2bUNYWFi9WZ3k5GQ8ePAAx48fR2JiIrZu3YqdO3c2mIZAIIBAIGA+152pycrKYmZBcnNzG5zFUVZWhra2NtTU1KCrq9uc4rNIMM0S6ldffQUzMzOEhISIfX/nzh2YmJgAAL7++mvcv3//jWn89ddf8PHxqfd9dnY2LL/5BlV1RNwYcrKyiL5+nRXrZ0KzhGplZYUbN27U+57H44lNSTbWMVqwYAHs7e3F4o4cORIAUNOMuenmhP3YefbsWYNeWyoqKujSpUsblOjjo1Xm+pWVlcU6Q415B9W61L2Ojo4O/v33X6SlpYHH42HVqlUNCp7D4cDd3R1ff/11g9ZU0h56UVERhg0bVs/DC3jlLZaUlAR1dfU2KNnHRasItX///ti9ezdmzpyJhISEFq8V6t+/P9P+HTVqVLMF964P/UOKvG57PCAgADweDxkZGfDw8ICzszM6d+4MZWVlZGdng8/nf/ZNnHcSqru7OyZOnIh+/frB0NAQNjY24HA42LJlyzsXrDnCSEhIQFpaGgDAzc0NFRUVyMvLQ2BgIOzs7KClpQVFRUVcvXoVAGBgYMC8ELWCKS0tha2tbYMil5KSwj///AMVFZVW6cRlZWXB1NQUVVVVDf7u4eEh9llOTg7R0dGftVjbfK6fx+Nh0KBBiIuLg7KycrPjZ2VlwcjIqFm+l9LS0kxbe7iZGfiVlU2OqyAvj2tRUe8kmqSkJFhaWkKJowppTuO2Qkg1eEkluHTpEvr169fiPCWdD7Jc+n3DleY2S6hc6VeXXVxcDH5lJeb1/RLaSvW91V8n9yUf+5Kfori4GLq6us0aTgNQzxpLc7jgcuq31+vxBlMiae3xd0Hihaqrq4vo69H1BHPt2jWEhYXBwsICw4cPb1AwtXH2JT9tdr5ZWVkttsa1CKn6rXHeFOZz64RJvFABMCsFADDV49atW1FYWIjY2Fj8+eefjcZvrkUF/s8aLzQfjE5q7VDC44NfXVMvjoIMF6rKCsgpLsee8NvMywEAL+nFG63l21BXV0dMTAxKS0uRkpKCRYsWYdeuXejevTtUVFQ+KZECn4hQG2LhwoXYs2cPFi5c+MYwampqUJCXb5ZFVZCXh5qaGiO4PeG3W1xGJc4XkObINBpGSNWvBN0Ar1fv3bt3/2TbsZ+sUJuCrq4urkVFiVk5APUsVF3qtjPl5eRQ+Yaee0PIy8mJiVyaI/NObdTPiU9WqHv27EF6ejr27NkDOzu7N4ar22x4ncYslK6uLqKio1s0tPX6i9Ec6nbgaklJSRH7vy6fik/EJyvUplT9r1Pbi379wb+pF11X5LXtxddpzR54VlYWhg83A5/fcAeuoQ0uFBTkce3auw2nfQx89EINDAxkBNeYZXwdOzu7ZoVvqBdd++Cb0ov+EMNBxcXF4PMr4bFkJrrqar01/JOsPDh7H2SG0ySZj16oTa3C35W6vejXaU4vuqUvFkvjfPRCbUkV3lJawyp+iBfL2fvge0n3Y+ajF2pzq/C25kO8WM2t+j8FPnqhShof4sXqqquF3l313mseHxusUNsQIdWfyWpKmCdZeU1Kv6nhJAFWqG2Ampoa5OTk8LKqpEmD+XL/f6IAeDXc1JzqXEFBnokrybBCbQN0dXURHR3d5BmxuoP21661fCZNkmGF2ka8y4xYS+JJOm2+AQULS1NghcoiEbBCZZEIWKGySARsZ0qCaa63lyTDClVCeVdvL0mDFaqE0lreXpICK1QJ5lOr3huD7UyxSASsUFkkAlaoLBIBK1QWiYAVKotEwAqVRSJghcoiEbBCZZEIWKGySASsUFkkgiZPoYpEIqxevRpPnz6FsrIy/vjjD7FFY3PmzEFVVRU4HA46d+6MTZs2vZcCs3yeNFmoly9fhpycHIKCgnDu3Dn4+vpi1apVzO9FRUX1DkpriMZO7vtc+Zzc9VpKk4V6584dmJmZAQCGDx8OX19f5rfs7GyUlZVhzpw5EAgEWLFixRsXmb3p5L7Plc/NXa+lNFmoPB6PObVESUlJ7AA0kUiEWbNmYdq0acjMzMSCBQtw4cIFcDiceuk0dnLf58jn5q7XUprcmap7Ot/Lly/Rrl075jctLS1YW1tDWloa+vr6aNeuHUpKShpMR1ZWFsrKymL/Pne6dOmCfv36IT4+HvPnz0d8fDz69evHVvt1aLJQ+/fvj+joaABAZGQkBgwYwPwWERHBtFfz8vLA5/OhqqraykX99Km7EyCLOE0WqqWlJfh8PqZMmYKgoCAsXLgQ7u7uePToEcaMGQMVFRVMmTIFy5Ytw8aNGxus9lkaZ+HChdDX1/8gW2xKGhJ/ch/L5wE74M8iEbBCZZEIWKGySASsUFkkAlaonzmBgYEwNjZGYGDgB43bXNh1/Z8ptf4FO3bsQE5ODnbs2IEBAwY0yb/gXeK2FHZ46jOkqKgIX331VYuOUH+XuO8CK9TPiISEBKSlpQEACgsLUVFRgby8PAQGBsLOzg5aWlpQVFREhw4dAAAGBgbo378/APEzWHNzc8Hj8ZCRkQEPDw84Ozujc+fOUFZWhra2NoDW35KdFepnQlZWFoyMjCAUCpscR1paGjdu3AAAmJmYolLQjJO0ZeWQnvGs2eV8E2xn6jOCK928Lklt+OLi4maJFECzw7+1LK2aGstHi66uLqKvt+zYduCVhWyuRW1NWKF+RjR0bPvZs2dx4sQJTJ48GePHj2+w566rq4uoFoi8NWHbqJ85xsbGSE9Ph76+PmJjY5sUp3Z46nXe5/AUa1E/c1pyyHBbOHS3uUUtLy/H4MGDERERwVrUTxAlJaVW8U1uc4tau7zlc1439SnTWk26NreoIpEIBQUFb3zzahf/NdfitjSeJMb9mMv7yVhUKSkpaGlpvTVcSxcCvssCQkmLK2nlbQ7sgD+LRMAKlUUi+OiFKisri8WLF0NWVvaDxJPEuJJW3pbQ5p0pFpam8NFbVBYWgBUqi4TACpVFImCFyiIRSJxQG3IvkyTYvmvLaDOhEhHKysrw4sWLJsepqqrCyZMnkZSU1Cr5t0X81to87kO/sM1ZwvI+aBOhCgQCnD9/Hvb29nB0dER4ePhb41RVVeHMmTMICgpCYWFhi/KtrKzEs2ev1vFwOJxmi62l8R8/fozt27dj1apVWLlyJe7fv9+8gr+GQCDA6dOnce/evUbDVVdXo6SkBPn5+e+c39mzZ/Hw4cN3SuedoDZAKBTSw4cPKSEhgVJSUuj777+nxMTEN4bn8/l0+vRpmj59OsXExIil01QEAgEFBATQ2LFj6dixY8z3IpGoyfEPHDhA3333XbPjV1VVUVpaGg0dOpR+/vlnys/Pb3K5G0rr0KFD9PXXX5OVlZXY/ahLWloabdmyhaZNm0Y///wzXb58+Z3yGzBgAM2dO5du3LjR4rK/C23ilCIlJQUDAwNwua+y79OnD7Kyshrc97+yshKXLl1CcHAwqqurIRQKERISggkTJkBKSgpE9NbqVCAQIDQ0FIcOHYKNjQ1Onz6N9u3bw8rKirGMjaUhEAgQEhKCgwcPwtbWFufOnUOHDh0wevToRuOLRCJISUlBVlYWe/fuRe/evfHHH39AQ0MDAJpU9obKERoain379oHP5+PQoUPo0KEDunXrxoR7/vw5/Pz80LNnT/z444+QkZHBsmXL0K1bN+jr6zcrv9DQUJw9exarV69G//794eDggHXr1mHo0KFNTqc1+GBCFQgEeP78OYBX4pOVlUWHDh1w/vx5AGDWkguFQkhLSwMA+Hw+rly5gqCgICQkJEBfXx/t27fH3r178eLFC8yaNeutD7qqqgqnT59GSEgIXFxcYG5ujsLCQhw5cgQCgQDff/99o2KrbXKEhobCy8sLhoaGKCkpwaVLl1BZWYmxY8c2GL9WpADg5OSEFy9eMCKtqqqCnJxcs0RaVVWFEydO4OLFi3Bzc4NIJMKpU6cgEAjg4eGBn3/+GaNHjwYA3Lx5E3w+H+bm5ujUqRMAMNvVN1Wotf2BS5cuYe7cuTA3NwcA6OvrIyUl5dMUqkAggJubG4KDgzFu3DgkJSVBVlYWKioqUFZWhqGhIbMYTFpammn7Xbx4EUeOHEFKSgosLS0hEolw48YN+Pn54bfffsOECRPwxRdfvDFfIsL58+dx8OBBeHh4oHPnziguLsa4ceOYk13y8vIwb968BkVTN/7GjRshLS2NH3/8Eerq6li6dCl27tyJ/Pz8Bl+YWpE6OzujrKwMmzdvRl5eHg4cOICSkhL07dsXP//8M3Jzc1FdXY3OnTs3eg9LSkpw9uxZLFq0CO3atcOiRYswb9489OvXDy9fvoSTkxM6d+4MAwMDZGdno0ePHujUqRPzwsyaNQs9evRgruttL0lJSQnOnDmDWbNmwdzcHEKhEKdOnYK2tjZEIhESEhKgoKAAQ0PDRtNpLT6IUGVlZTFv3jwkJCRg1KhR8PT0hEAgQEREBJ4+fYozZ87g4cOHMDU1xaxZs/Dvv/+iZ8+eGDx4MFRVVREdHY3Zs2dDS0sLc+fORUpKCoqKiurd7NcfAIfDwZAhQ6ClpQUVFRXY2NigU6dOGDJkCCwsLPDXX39h7ty5mDhxImPR61IbX1VVFf369cOiRYtgb2+Pa9euITU1Ff7+/nB1dQWfz4eCggLi4uKgqqqKrl27AgDWrl2Lx48fw9fXF4mJiYiOjgaPx8PEiRPh5eUFIkJxcTGKioqwZMmSRlduamlpYefOnejQoQMOHjwIU1NTpjYAXu1qUntm18CBA+Hi4oI+ffogOTkZioqKsLKygpKSEiorKyEvL//WZ6alpQUfHx+oq6tDKBTizJkziIqKgpaWFgICAjBixAhcv34da9aswYgRI94ugnfkg/X6u3btCh8fHwQEBCAkJASysrJITk5GeXk5nJyc8OeffyI8PBze3t6Ijo5GREQEtLS0YGJigrS0NISFhQEAdu3aBR6PB1NTU0hJSeH27duIiYkB8EpYrw/b6OjowNjYGPfu3cPQoUPh4eGBfv364cSJE7C2toaUlBTU1dXFRhKoTm9eR0cHI0eOREVFBWpqamBmZgZvb2/8999/mDBhAmpqaiAnJ4fExEScPn0aBw8eZA58W7BgAfbv3w8NDQ2EhYVBU1MTrq6uMDExwezZs3Hp0iWUlpbCzMysScuLa8MUFxejXbt2jEjDwsJQXl6Ojh07IiEhAR07dsSaNWsQGRmJsrIy5Ofnw8HBAVlZWYxI58+fj9u3bzeaX+2BISdPnkRycjKeP3+O9PR0bNiwAevWrcPmzZuxa9cu5OXlvffx4Q/amTIwMIC7uzt27NgBoVCIq1evYu3atRg0aBCAVzfm4cOHGDp0KCZMmMB0tlatWoUlS5bg3r17ePnyJXbt2gWBQIAHDx7g1q1biIiIwJMnTzBt2jSmyn2dHj16YP369TA2Nsbw4cNx7NgxqKmpYevWrXjy5AkOHToES0tLDBs2rME2Z3V1NQQCAZKSkqCrqwtPT0/s378fffr0QVJSEs6cOQMZGRnMmzcPsrKyICLo6OgAeHVSTHp6OtasWQNlZWXk5OTg6tWraNeuHUaOHInhw4cDeHuVXHttw4YNw/r162FkZITY2FiUlZVh1KhRyMrKgr+/P4RCIebNm4eVK1dCRkYGAHDgwAFcvXoV06dPx6RJk5CZmYnw8HAIhUIYGRk1mt+AAQNQVlaGO3fuwNraGmZmZhCJROjatSt0dHSgqqr63g8XaRM3v+LiYsTGxuLatWvYsmULACA8PBzHjx+HqakpfvzxR7Rv317swWVmZiInJwdEBEVFRQQGBsLKygr6+vro1KkTpkyZghUrVjS6SDA1NRVeXl7MjnO7du1CcXExDh8+jDt37qBbt24wMzPDtGnTGoz/6NEjLFmyBKqqqujbty9WrlyJmJgY3LhxA9XV1Vi8eDFUVFSY8LXlLysrw6xZs7BkyRJ07doVly5dQlFREUaMGIFhw4YBeNWJ5HA4b3zRXicsLAznz59HRUUFfvjhB1RXVyMrKwsxMTH49ttvERQUBGdnZ5iamkIoFDIvva2tLZSVlbF161ZkZ2czy6W/+uqrRvOrqqqCu7s7xo0bh8GDBwMAfH19ceXKFRw9ehTe3t4wMTFhfmtt2mR4Sk1NDQMGDMC2bdsQGhqKZ8+egc/nY/DgwbCxsYGcnBxqamrA5XIhEAggKysLHR0d6OnpgcfjYebMmfjmm2/w4MED3L9/H19//TW6du0qdpogUN9CdevWDS4uLti7dy9sbW1RXl6Os2fPQklJCYcPH4ampibGjh2Lzp07M1auLoaGhjh06BAKCwuhqKiIe/fu4dq1axCJRFiyZAkuXryIjIwMfPnll5g0aRJjmdu3b4/NmzfD09MTeXl56NWrF4YNG4bevXsDeGWtay0f8OosVGlpaaat+zpEBAsLC4waNQpcLhcPHjzA0aNHoaqqCk9PT2hqakJKSgoXLlyAkZERI9IxY8bAwMAA+/btAwAoKChg8ODBqKp6+1Y9cnJyMDU1xfr16zF27FhkZWUhIiICp0+fRlBQEPbv3w8ejwcul8vsANiatNkUqra2Nvbt24fIyEiEh4dDW1sb3bt3Z7Y2rLUste09f39/xMbGMuOShoaGWLp0Kezt7bF69Wo8fvwYFhYWuHr1Km7evAmg4dkjPT09/P7776ioqMDRo0chEomwYMECaGpqorCwEAoKCvXKWrfdq6GhgV69euHly5c4deoUOBwOfvvtN3zxxRdQU1ODlZUVQkJCmF2Ya1+Unj17YtOmTRg+fDi+//57mJiYwMnJCbGxsZCRkQERISgoCE5OTpg+fTrs7OyQkJDQ4L2rTZPL5SIhIQHHjx/HjRs3MGPGDGhqaqK4uBh3796Fnp4eZGRkwOFwcPv2bXTp0oURaXFxMUJDQ1FQUAA9PT0AjU8L174cv//+O4RCIbp06YLz58/j33//RVBQEDZs2AAzMzMsW7YMt27dauTJt5APMq3QCAUFBXTo0CGqrKykx48fk7m5Od2+fZv5fcGCBdS/f39auHAhWVtb07Nnz+ju3bs0adIk8vb2JltbW7K0tKSqqipKTEyk2bNn0+TJk+nQoUNMGg3NHhUWFpK7uzvl5OQQEVF6ejq5ubmRh4cHPX/+nCorK+n58+eNlnvbtm2UkpJCa9asoW+++Yb8/Pzo2bNnlJKSQhMmTKDMzEwiIiovL6fCwkIiIuLxeEwaFy5cIFtbW0pLSyMior/++osMDQ3Jy8uLKisr6Z9//nnjzBMRUWJiIq1Zs4ZWrVpFU6dOpYyMDCoqKqLg4GDasGEDXb58mXJycig2NpYePXokdu1Hjx6ladOmUXJyMkVFRZGTkxM5OzvTtWvXGn1edTlw4ACZm5tTQUEB852XlxedPXu2yWk0lY9iKUrdwfErV64gODgY3t7eAF6dZF1SUoKjR4+ic+fO8PT0hI2NDdTU1ODo6IjS0lKcPn0aSUlJCA4OhpSUFCwsLBAZGYnx48c3Wg3VTi48ffoUBw4cgIaGBvT09CAUCnHy5EnIyspi1KhRsLOzQ0REBHR1dWFgYCCWRnx8PDZv3ow9e/YgISEBMTExqKysRGlpKZYtW4aioiLExMSgffv2sLa2ZpYV1+adnJyMXr164dSpU1i3bh08PDxQVlYGKysr8Pl8PHz4EBYWFg2Wv7CwEH5+fnB0dMTt27fh7u4OBQUF9OzZE0pKSlBXV0doaCi++eYbXLx4EY6OjjAzM8Px48dx4sQJbNmyBU+fPoWbmxtcXFwgIyODvXv3wsXFBSYmJo0+sx07duD48eOwsrKCq6srgFe137fffgsXFxd88803TXr2TabVpd9CRCIRY/kEAgFVVlbSyJEjyc3NjUQiEcXHx1NCQoLYPHlOTg6lpqbS1atXydXVlfbu3UvZ2dlERPTo0SOytbWlmzdvMuFf9w0QiUQkEAho06ZN5ObmRjdv3iRfX19yd3enf/75h4qKimjatGm0e/duWr16Ne3cuZOqqqrE0sjMzCQLCwuKioqi7OxsCg0NpR9++IHOnTtHRUVFZGlpSRMnTqQHDx4QEVFGRka96z116hT16tWLrl+/TkRE4eHhNGTIEMYK14ZviOrqaubvjIwMSktLIz6fT+fPn6e1a9fSmDFj6Pr165SXl0dTp06lzMxMio+Pp8jISMrJySFjY2MxP4vg4GD666+/Gn1WO3fuJHNzc7p06RLNnTuXoqOj6fHjx/Tdd9+Rh4cHlZaW0t9//00BAQGN1gjN4aMRKtH/PYzy8nKysLCg5cuXE9Ergc2ePZtmz57NhK0VXXl5Obm7u9P27dspLy+PiF49sEWLFpG1tTVt2LBBzImkIQoLC+nRo0eUm5tL1tbWFBYWxvzm6upKc+fOpd27dzNV+evcu3ePFi1aRPb29jR58mQ6deoU8fl82rZtGzk6OpKjoyOlp6cTn8+nHTt2iDUpTp06RYaGhswLdfXqVbK0tKSgoCDKycmhEydONPm+1VJTU0OzZs2io0ePUmFhIU2ePJm8vLzojz/+EGt63L17l1xdXcXiLly4kLy8vJj72BAZGRmUmppKRESXL1+madOm0S+//EK7d++m9PR0srS0pOnTp9PVq1dp0qRJYvezpXx0QhWJRDRu3DhavHgxERFVVlbSggULaP78+UREtHXrVjp//rxYvOLiYsbSZmZm0rx588jT05Pi4uIoNzeXRo8eTeHh4W/NPy4ujn755Rfm8/Xr12nBggXk4+NDRUVF9cpal8LCQvruu+/o4MGDxOfzadOmTeTq6koZGRlUUVFBGzdupJSUFKqsrBSzyqmpqRQXF0dERBEREWRubk5HjhwhIqLo6GhycHBoNN+G4PP5tHz5csbTKSsri6ZPn05z584Vs9L379+ncePGUUJCAhEReXh4kIODAxUVFdHff/9NDg4OTPu5IWrLwuPxqKqqikQiEdnY2JC7uzvNmjWLLl++TLm5ubRq1ap6NVFzafMtfepS25v18/NDx44dUVVVhWXLlqGmpgb79+8H8KpdVl5eLhbviy++AIfDQU5ODlasWIGhQ4fCxsYG2trakJKSwsCBA1FSUvLW/Dt16oS0tDRcuHABhYWFyMvLQ9++fWFhYYHw8HC8fPkS6urqDTqyqKur4++//4aamhr8/PzA4XAwdepU6Orq4sCBAygrK0N8fDxKSkowZMgQJs/aNm90dDQcHR3h7OwMa2trAK82GGvXrh1evnyJ/Px8dO3atUkD6/Ly8rCyssL69esxYsQI9O3bF56enuDz+WInlvTq1QuOjo5YvXo1VFVVIRKJ4OzsjCNHjiA+Ph7S0tJwc3ODnZ1dg+3k2rIoKSkBAEpLS6Guro4VK1YgOzsbK1asgIqKCrhc7jtPCHyUS1E0NTUBACtXrgSfz2dEmp+fj+fPn0NTUxMXL14Umzqt/d/U1BSTJ0+Gjo4OpKSk4OPjg0uXLsHY2Pit+WppaWH37t24cOECjhw5AmVlZfTp0wfh4eH4559/oKqqCj8/PwQFBYnlW4uamhoEAgGuX7+ODh06oFevXti/fz/u37+P8ePHw9DQEMHBwcjNza2Xt66uLjw8PPDtt98iKysLR48exX///YeCggJMnToV69evx507d5p0/+j/DyW5urpCUVERT58+BZfLRZcuXcSGoIgIY8aMwZ49e7B9+3YcPnwYiYmJePHiBebNmwdfX1/8/vvv2L9/f4Mb975ORUUFCgoKEBcXBz09PXh6eoKIMGHCBLFx4pbwUfT630R+fj6UlJRw6dIl5OTk4NGjR7h8+TLGjRuHzMxM9OrVC0uXLmXmwIlIbBbGx8cHgYGBOHnyJDp37sxMHtQdZWiI58+fIzQ0FFOmTIG/vz+ys7OxcuVKqKurIzk5GV5eXti2bRtUVFQYb6+6ok1PT8ejR4+QnJyMwsJCWFlZYdSoUdiyZQvy8/Ph5eXVYP4VFRWYOHEiNDU1oampiZs3b2L9+vUwMDBAly5dIBKJIBKJmOt7Pd+WUNet0t/fH+fOncP48eNhY2MDRUVFJCQkYM2aNQgMDGzS+VFXrlzBpk2b0LNnT3Tp0gW//vorOBwOFBUV36mcH1XV/zodO3Zk5qS7dOmCkSNHQigUwtPTkwlTVlaGzMxM6OnpgcPhMDfd19cXPj4+uHz5MjIzM1FQUIDBgwczYq2oqEBkZCQqKirQuXNnsak/DQ0N2NvbQ0pKihFL7UOSkZGBsrIyYz2lpaXriUVfXx/t2rVDfHw8hg8fDhMTE8TFxeHp06dwd3eHlJQUrl27Bnl5ebFmgKKiIk6cOIF27dqBz+djxowZGDRoEIRCISIjIxEWFgYZGRno6+tjxowZrTK/Xnu/vLy8UFBQADU1NXz55ZdQVFREamoqjhw5Amtra6ipqYHP56OsrAwdO3ZsMC0iwujRo6GtrY2ioiKoqKgwzYJ35aMWKvBqJmnXrl0AgKNHjzLW5Nq1a7h48SKKiopARBgyZAjmzJnDPDxzc3NYWVlBT08Pjx8/xtatW7FixQqYmJigtLQU9vb2MDAwgIaGBvz8/ODs7CzmriYtLY3KykrIycnh8ePHMDAwgEAgwNq1axnBT5kyBb/++ivTfqtrqdXV1eHg4MDMpBUVFWHgwIFITk7G3bt3ceLECUydOhV9+/ZlZsOIiBlnjY+Ph4qKClRUVHD48GFcuHABpaWlOH78OObMmQMFBQVMnjy5Ve5xSUkJXr58ibFjx4KIsHHjRiQmJuLMmTOYMWMGZGRksHHjRsjKyiItLQ1TpkzBmDFj6qVTdxaultaw+oAECLUWkUiE/Px89O/fHzweD1euXEFSUhIGDx6MtWvXYsaMGejSpQsjmtpOSnV1NczNzRnPp549e2LChAkYM2YM1q1bBwD48ssvERUVVc+vUl5eHitWrICLiwvCw8Px5MkTGBoawsnJCQsXLgSfz8edO3dQUlLCuAzWpdaaPHnyBJ6ennjx4gXGjRsHdXV1BAcHIzU1FampqSguLsbIkSPFpnz19fWhr6/PODBbW1vj4sWLCA4Oxq+//oqIiAgArSMEVVVVLFu2DFwuF/Ly8lBWVkZeXh5cXV3Rrl07/PXXX8jPz4exsTHc3d1hZ2eHjh07ok+fPm9Nu9W8qt5pzOAD8/TpU0pMTKTz58+TtbU1ZWdn07x588jb25v8/PzqLWCrHT6p/b+0tJRsbW3Jzc1NLNy+ffto586db8w3MzOToqKi6MGDB1RdXU2zZs2iVatWUW5uLuXl5ZGZmRndunWLCf/6xEJubi45OjrSzp07qaamhoiIAgICyNTUlIKDg2nmzJkUEBDQYN61Q0u1zJ8/n3766SemvE1dnNhUXi97aGgorV27lhITE8na2poiIiLI09OzxYsFW8pH2et/E/r6+szSi4EDB6JTp07w9vZGamoqTp48CQUFBYhEIlRUVAAQHw0AgKSkJOjo6GDt2rVMmmFhYdi1a1ejU4a6urowNTWFrq4uJk6cCG1tbWzZsgVaWlp49uwZNDQ0kJ2djaioKACoZ1m1tLSwceNGLF68GNLS0ggICMDBgwfRu3dvZGVlITAwEFeuXGHOKa2LvLw8Xrx4gfPnz6O4uBjbt2+HkZERunXrhvDwcOzbtw+nT58G8GoKs6io6B3u8P+Vnf6/Zc/Pz4esrCz69esHT09PeHt748qVK4wjy4dCooRay6BBg/Dff//Bz88PUVFR2LJlC2xtbSEnJ4d58+bBxcUFBw4cAPBqTX3tWnwejwcFBQXmYYSFhWHLli3YvHkzBg4cCJFIBHo1CdJgviUlJTAyMsLmzZsBvPKhvXTpEuNAvH79ehw/frzBuHW9sgoLC+Hk5ARfX1/ExcXhzz//hJycXIObPKirq2Pbtm0IDAzEkiVLcOzYMfzwww/IzMzEjh07IC8vD19fXxw7dgwxMTHYtm0bMjMzW3xva6l9uQcMGIDIyEjcunULubm52L17N6ZPn96k7exbk496eKoxUlNT4ePjAy6XiwkTJqCsrAxRUVFQV1eHra0tli9fjh9++AH5+fnIzc2Fm5sbszbJwsICUlJS8Pf3x+bNm2FpaSmWdnl5OeTl5VFSUsKM6b7OhQsXcOvWLXTo0AETJ06ElpYW9u/fj8rKSvz6669vHP4SiURwdXWFtrY2Fi9ejOrqaixduhSVlZXw9/d/4/UWFxczQ2tnz57FzZs3sXjxYnTr1g3Jycnw8fGBqqoqNDU18dtvv7Wqx31YWBhOnTqFsrIyrFu3Dl27dm2yg3er8UEbGq1MrUMGn8+niRMnirn27d69m+bMmUMuLi6UnJzMfP/48WPy8PAgX19funPnDhGJt/N4PB4zfTht2jS6evVqvXx5PB65u7uTp6cn418QGxtLFhYWTXJxS01NpW+//Za8vb1p3bp19OLFC0pPT2/SNfN4PLGNOHg8HgUEBNCkSZPI19e3Xrv8Xak7TUr0ymGoLZCYXn9D1I4BZmdnQ01NjVlCkp6ejrS0NKirq8Pe3h7du3dn4nTv3h0rV64US6eu9ZGTk4OZmRn69u0LKSkprFixAhoaGow3PvCqNz9//nxUV1czA/Nr167FzJkzMW7cuLeW28DAAD4+Prh16xYqKyuZYaimICMjAwUFBWbPrvj4eGRmZsLS0hLz5s0D0HpDQkD9adLa4cEPjUQLtfYmfvHFF8jMzMTFixchLy+Phw8fgsvlYs6cOWIirUtNTQ1Onz6NDh06iA1LcblcsU0aBg4ciMzMTDGhAv+3QvP27dtwcHDAkiVLmBelKUIxMDCo59vaFGRlZeHo6IgtW7YgJCQE1dXVGDBgwHsRaUO870V8b8yXSDLbqK/z+PFj7Nu3DykpKejVqxdmzpyJXr16NRonMzMT5eXl0NPTQ0VFBcrKyiAUChkLFxwcjDNnzmDDhg3M5g2vk5GRgZiYGNja2r6Py3oj+fn5+PPPP6Gjo4PFixcDeP8ibUs+GaECQEFBAXbv3o2pU6c2eQePmpoaLFiwANevX4e5uTnu3r0LeXl5qKuro6SkBEuXLsXIkSMbXEv1Oh9aKCUlJYxl/5RFCnxiQgXAzOU3hydPnmDp0qVwcHCApaUls5pVRkbmgxxN86586iIFPkGhtpTU1FT8/vvvsLGxEZtD/xxEIAlI5ID/+6Bbt2743//+hwsXLuDJkyfM96xIPw5Yi/oaL168gJSUFNq3b9/WRWGpAytUFomArfpZJAJWqCwSAStUFomAFSqLRMAKlUUiYIXKIhGwQmWRCFihskgErFBZJIL/B/THIX9bZauuAAAAAElFTkSuQmCC",
      "text/plain": [
       "<Figure size 175x160 with 1 Axes>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "tmpdf = ann_accuracy2.groupby(['PE','Seed']).mean(numeric_only=True).reset_index()\n",
    "tmpdf2 = tmpdf.groupby('PE').mean(numeric_only=True).reset_index()\n",
    "df_validation_order = tmpdf2.sort_values('Accuracy',ascending=False)\n",
    "plt.figure(figsize=(1.75,1.6))\n",
    "ax = sns.boxplot(data=tmpdf,x=\"PE\",hue=\"PE\",y=\"Accuracy\",fliersize=1,palette='rocket',\n",
    "                 order=df_validation_order.PE.values)\n",
    "# cbar = ax.collections[0].colorbar\n",
    "# # Label the colorbar\n",
    "# cbar.set_label('Accuracy', fontsize=8)  # Change 'Label' to your desired label text\n",
    "# # Change the size of tick labels\n",
    "# cbar.ax.tick_params(labelsize=8)\n",
    "# plt.xticks(np.arange(0.5, len(pe_labels)+0.5),pe_labels.values(),rotation=0,fontsize=6);\n",
    "plt.xticks(rotation=-45,fontsize=8);\n",
    "plt.yticks(fontsize=7);\n",
    "plt.xlabel('',fontsize=10)\n",
    "plt.ylabel('',fontsize=8)\n",
    "plt.title('Common PEs, mean sorted',fontsize=8)\n",
    "sns.despine()\n",
    "plt.tight_layout()\n",
    "plt.savefig(f'{outputdir}perturbation_performance_avg_across_nosie.pdf',transparent=True,dpi=300)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "LSTANN",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.18"
  },
  "orig_nbformat": 4
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
