{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Warning: Square Attack has decreased the robust accuracy of 2.50%. This might indicate that the robustness evaluation using AutoAttack is unreliable. Consider running Square Attack with more iterations and restarts or an adaptive attack. See flags_doc.md for details.\n",
      "Testing, epoch  1126 : done with batch  1  out of  50\n",
      "Warning: Square Attack has decreased the robust accuracy of 1.50%. This might indicate that the robustness evaluation using AutoAttack is unreliable. Consider running Square Attack with more iterations and restarts or an adaptive attack. See flags_doc.md for details.\n",
      "Testing, epoch  1126 : done with batch  2  out of  50\n",
      "Testing, epoch  1126 : done with batch  3  out of  50\n",
      "Warning: Square Attack has decreased the robust accuracy of 1.00%. This might indicate that the robustness evaluation using AutoAttack is unreliable. Consider running Square Attack with more iterations and restarts or an adaptive attack. See flags_doc.md for details.\n",
      "Testing, epoch  1126 : done with batch  4  out of  50\n",
      "Testing, epoch  1126 : done with batch  5  out of  50\n",
      "GPU memory allocated in GB: 0.011449344\n",
      "{'clean': 0.844, 'clean_bool_track_correct_preds': [tensor([ True, False,  True,  True,  True,  True,  True,  True, False,  True,\n",
      "         True,  True,  True,  True,  True,  True,  True,  True, False,  True,\n",
      "         True,  True, False,  True,  True,  True,  True,  True,  True,  True,\n",
      "         True,  True,  True, False,  True,  True,  True,  True,  True,  True,\n",
      "         True,  True,  True,  True,  True,  True,  True,  True,  True,  True,\n",
      "         True,  True,  True, False,  True,  True,  True,  True,  True, False,\n",
      "         True, False,  True, False, False,  True, False,  True,  True,  True,\n",
      "         True,  True,  True,  True,  True,  True,  True, False,  True,  True,\n",
      "         True,  True,  True,  True,  True,  True,  True,  True,  True,  True,\n",
      "         True,  True,  True,  True,  True, False,  True,  True,  True,  True,\n",
      "         True,  True,  True,  True,  True,  True,  True,  True,  True,  True,\n",
      "         True, False,  True,  True, False,  True,  True,  True,  True, False,\n",
      "         True,  True,  True,  True, False,  True,  True,  True,  True,  True,\n",
      "         True,  True,  True,  True,  True,  True,  True,  True,  True, False,\n",
      "         True,  True,  True,  True, False,  True,  True,  True,  True, False,\n",
      "         True,  True,  True,  True,  True,  True,  True,  True, False,  True,\n",
      "         True,  True,  True,  True,  True,  True,  True, False,  True,  True,\n",
      "         True, False, False,  True,  True,  True,  True,  True,  True,  True,\n",
      "         True, False,  True,  True,  True,  True,  True,  True,  True,  True,\n",
      "         True,  True,  True,  True,  True, False,  True,  True,  True,  True]), tensor([False,  True,  True,  True,  True,  True,  True,  True,  True,  True,\n",
      "         True, False,  True,  True,  True,  True,  True, False,  True, False,\n",
      "         True,  True,  True,  True,  True,  True,  True,  True,  True,  True,\n",
      "         True,  True, False, False,  True,  True,  True,  True,  True,  True,\n",
      "         True, False, False, False, False,  True,  True, False,  True,  True,\n",
      "         True, False,  True,  True,  True,  True,  True, False,  True, False,\n",
      "         True, False,  True,  True,  True,  True,  True,  True, False,  True,\n",
      "        False,  True,  True,  True,  True,  True,  True,  True,  True,  True,\n",
      "         True,  True,  True,  True,  True,  True,  True,  True,  True,  True,\n",
      "        False,  True,  True,  True,  True,  True,  True,  True,  True,  True,\n",
      "        False, False,  True,  True,  True,  True,  True,  True,  True,  True,\n",
      "         True,  True,  True, False,  True,  True,  True,  True, False,  True,\n",
      "        False, False,  True,  True,  True, False,  True,  True, False,  True,\n",
      "         True,  True,  True,  True,  True,  True,  True,  True, False,  True,\n",
      "        False,  True,  True,  True,  True,  True,  True,  True,  True, False,\n",
      "         True,  True, False,  True,  True,  True,  True,  True,  True,  True,\n",
      "         True,  True,  True, False,  True,  True, False,  True,  True,  True,\n",
      "         True,  True,  True,  True,  True,  True,  True,  True,  True,  True,\n",
      "         True, False,  True,  True,  True,  True,  True,  True,  True,  True,\n",
      "         True,  True,  True,  True,  True,  True,  True,  True,  True,  True]), tensor([ True,  True,  True, False,  True, False, False,  True,  True,  True,\n",
      "         True,  True, False,  True,  True,  True,  True,  True,  True,  True,\n",
      "         True,  True,  True,  True,  True,  True,  True,  True, False,  True,\n",
      "         True,  True, False,  True,  True,  True,  True,  True,  True,  True,\n",
      "         True,  True,  True,  True,  True,  True,  True,  True, False, False,\n",
      "         True,  True, False,  True,  True,  True,  True,  True,  True,  True,\n",
      "        False,  True,  True,  True, False,  True, False, False, False,  True,\n",
      "         True,  True,  True,  True,  True,  True,  True,  True, False, False,\n",
      "         True,  True,  True, False,  True,  True, False,  True,  True,  True,\n",
      "         True,  True,  True,  True,  True, False,  True,  True,  True,  True,\n",
      "         True,  True, False,  True,  True,  True,  True, False,  True,  True,\n",
      "        False,  True,  True,  True, False,  True,  True,  True,  True,  True,\n",
      "         True,  True,  True, False, False,  True,  True,  True, False,  True,\n",
      "         True, False, False,  True,  True,  True, False,  True, False,  True,\n",
      "         True,  True, False, False,  True,  True,  True,  True, False,  True,\n",
      "         True, False,  True,  True,  True,  True,  True,  True,  True,  True,\n",
      "         True,  True,  True,  True,  True,  True,  True,  True,  True, False,\n",
      "         True,  True,  True,  True,  True,  True,  True,  True, False,  True,\n",
      "         True,  True, False,  True,  True,  True,  True,  True,  True,  True,\n",
      "         True,  True,  True,  True,  True, False,  True,  True,  True,  True]), tensor([ True, False,  True, False,  True, False, False,  True,  True,  True,\n",
      "         True,  True, False, False,  True,  True,  True,  True,  True,  True,\n",
      "         True,  True,  True,  True,  True,  True,  True,  True,  True,  True,\n",
      "         True,  True,  True,  True,  True,  True,  True, False,  True,  True,\n",
      "         True,  True,  True,  True,  True, False,  True,  True,  True,  True,\n",
      "         True,  True,  True, False,  True,  True,  True,  True,  True, False,\n",
      "         True,  True,  True,  True,  True,  True,  True, False,  True, False,\n",
      "         True,  True,  True,  True,  True,  True,  True,  True,  True,  True,\n",
      "         True,  True,  True,  True,  True, False,  True,  True,  True,  True,\n",
      "         True, False, False,  True, False,  True,  True,  True,  True, False,\n",
      "         True,  True,  True,  True,  True,  True,  True,  True,  True,  True,\n",
      "         True,  True,  True,  True, False,  True,  True,  True,  True,  True,\n",
      "        False,  True,  True, False,  True,  True,  True,  True, False,  True,\n",
      "         True,  True,  True,  True,  True,  True,  True,  True,  True,  True,\n",
      "        False, False,  True,  True,  True,  True,  True,  True,  True,  True,\n",
      "         True,  True, False,  True,  True,  True,  True,  True,  True,  True,\n",
      "        False,  True,  True,  True,  True,  True, False,  True,  True,  True,\n",
      "         True,  True,  True,  True,  True,  True,  True,  True,  True,  True,\n",
      "         True,  True,  True,  True,  True,  True,  True, False,  True,  True,\n",
      "         True, False,  True,  True,  True, False,  True, False,  True,  True]), tensor([False, False,  True,  True,  True,  True,  True,  True,  True,  True,\n",
      "         True,  True,  True,  True,  True,  True,  True, False,  True,  True,\n",
      "         True,  True,  True,  True,  True,  True,  True,  True,  True, False,\n",
      "        False,  True,  True,  True,  True,  True,  True, False,  True, False,\n",
      "         True,  True,  True,  True, False,  True, False,  True,  True,  True,\n",
      "         True,  True,  True,  True,  True,  True,  True, False,  True,  True,\n",
      "         True,  True,  True,  True,  True,  True,  True,  True,  True,  True,\n",
      "         True,  True,  True,  True,  True,  True,  True, False,  True,  True,\n",
      "         True, False,  True,  True,  True,  True,  True,  True,  True,  True,\n",
      "         True,  True,  True,  True,  True,  True,  True, False, False,  True,\n",
      "        False,  True,  True,  True,  True,  True,  True,  True,  True,  True,\n",
      "         True,  True,  True,  True,  True,  True,  True,  True,  True,  True,\n",
      "         True,  True,  True,  True, False,  True,  True,  True,  True,  True,\n",
      "        False,  True,  True,  True,  True,  True,  True,  True, False,  True,\n",
      "         True,  True,  True,  True, False,  True,  True, False,  True,  True,\n",
      "        False, False,  True,  True,  True,  True, False,  True, False, False,\n",
      "        False,  True,  True,  True,  True, False,  True,  True,  True,  True,\n",
      "         True,  True,  True,  True,  True,  True,  True,  True,  True,  True,\n",
      "         True,  True, False,  True,  True,  True,  True,  True,  True,  True,\n",
      "         True,  True,  True,  True,  True,  True,  True,  True, False,  True])], 'PGD_L1_std': 0.44, 'PGD_L1_std_bool_track_correct_preds': [tensor([ True, False,  True,  True,  True, False, False,  True, False,  True,\n",
      "        False, False, False, False,  True, False, False,  True, False,  True,\n",
      "        False, False, False, False,  True, False,  True,  True, False, False,\n",
      "         True,  True,  True, False, False,  True,  True,  True, False,  True,\n",
      "        False,  True, False, False, False, False,  True,  True,  True,  True,\n",
      "        False, False,  True, False, False,  True, False,  True, False, False,\n",
      "         True, False, False, False, False,  True, False,  True, False, False,\n",
      "         True, False,  True,  True,  True, False,  True, False, False, False,\n",
      "         True,  True,  True,  True, False, False,  True,  True, False,  True,\n",
      "         True, False, False,  True,  True, False,  True, False,  True, False,\n",
      "        False,  True,  True,  True, False, False, False, False,  True, False,\n",
      "        False, False, False, False, False, False, False,  True, False, False,\n",
      "        False, False,  True, False, False,  True, False,  True,  True,  True,\n",
      "        False,  True, False, False, False,  True, False,  True,  True, False,\n",
      "        False,  True, False, False, False,  True,  True,  True, False, False,\n",
      "        False, False, False, False, False, False, False, False, False, False,\n",
      "        False, False,  True, False, False, False,  True, False,  True,  True,\n",
      "        False, False, False, False,  True,  True,  True, False,  True,  True,\n",
      "         True, False,  True, False, False,  True, False,  True, False, False,\n",
      "         True, False,  True,  True, False, False,  True, False,  True,  True]), tensor([False,  True,  True, False,  True, False,  True, False, False, False,\n",
      "        False, False,  True,  True,  True,  True,  True, False, False, False,\n",
      "        False, False,  True,  True,  True, False,  True,  True,  True,  True,\n",
      "         True,  True, False, False,  True,  True,  True, False,  True,  True,\n",
      "         True, False, False, False, False, False,  True, False,  True,  True,\n",
      "        False, False, False, False,  True,  True,  True, False,  True, False,\n",
      "         True, False,  True,  True, False,  True,  True,  True, False, False,\n",
      "        False,  True,  True,  True,  True,  True,  True,  True,  True,  True,\n",
      "        False,  True, False, False,  True,  True,  True, False,  True, False,\n",
      "        False, False, False, False,  True,  True,  True, False,  True,  True,\n",
      "        False, False,  True,  True,  True, False, False, False, False,  True,\n",
      "         True, False, False, False,  True, False, False,  True, False, False,\n",
      "        False, False, False,  True, False, False,  True, False, False,  True,\n",
      "        False, False,  True,  True,  True, False, False,  True, False, False,\n",
      "        False,  True,  True,  True, False,  True,  True,  True,  True, False,\n",
      "         True,  True, False, False,  True, False, False, False, False,  True,\n",
      "         True, False, False, False, False,  True, False, False, False, False,\n",
      "         True, False, False,  True, False,  True,  True,  True,  True,  True,\n",
      "        False, False,  True,  True,  True,  True,  True, False,  True, False,\n",
      "         True,  True, False,  True, False, False, False,  True,  True,  True]), tensor([ True, False, False, False, False, False, False,  True,  True,  True,\n",
      "         True, False, False,  True,  True, False,  True, False,  True, False,\n",
      "        False,  True, False,  True, False,  True, False, False, False, False,\n",
      "        False, False, False, False,  True, False, False,  True,  True, False,\n",
      "         True, False,  True,  True,  True, False, False, False, False, False,\n",
      "         True, False, False, False,  True,  True, False, False, False,  True,\n",
      "        False, False, False,  True, False,  True, False, False, False, False,\n",
      "         True, False, False,  True,  True,  True,  True,  True, False, False,\n",
      "         True, False,  True, False, False,  True, False, False, False,  True,\n",
      "        False,  True,  True,  True,  True, False,  True, False, False, False,\n",
      "         True,  True, False,  True, False,  True,  True, False, False, False,\n",
      "        False, False,  True, False, False, False, False, False, False, False,\n",
      "        False, False, False, False, False,  True,  True, False, False,  True,\n",
      "        False, False, False,  True, False,  True, False,  True, False,  True,\n",
      "        False,  True, False, False,  True, False,  True, False, False,  True,\n",
      "        False, False, False, False,  True, False,  True, False,  True, False,\n",
      "        False, False, False, False, False,  True, False,  True, False, False,\n",
      "        False, False, False,  True,  True,  True,  True, False, False,  True,\n",
      "         True,  True, False, False,  True,  True,  True,  True, False,  True,\n",
      "         True,  True,  True,  True,  True, False, False, False,  True, False]), tensor([ True, False, False, False, False, False, False, False,  True,  True,\n",
      "        False, False, False, False, False,  True, False, False, False, False,\n",
      "         True, False, False,  True,  True, False,  True, False,  True, False,\n",
      "         True, False, False,  True, False, False,  True, False, False, False,\n",
      "         True, False, False, False, False, False, False,  True, False,  True,\n",
      "        False,  True,  True, False, False,  True,  True, False, False, False,\n",
      "         True, False, False,  True, False, False,  True, False,  True, False,\n",
      "         True, False,  True,  True, False,  True,  True,  True, False,  True,\n",
      "        False,  True,  True, False, False, False,  True,  True, False,  True,\n",
      "         True, False, False,  True, False,  True,  True,  True,  True, False,\n",
      "         True, False,  True, False, False,  True,  True, False, False, False,\n",
      "         True, False,  True, False, False, False,  True, False,  True, False,\n",
      "        False, False, False, False, False,  True,  True,  True, False, False,\n",
      "         True,  True,  True,  True,  True,  True,  True, False, False,  True,\n",
      "        False,  True,  True,  True,  True,  True, False,  True, False, False,\n",
      "        False, False, False,  True,  True,  True,  True, False, False,  True,\n",
      "        False,  True,  True, False, False,  True, False,  True,  True, False,\n",
      "         True,  True,  True, False,  True, False,  True,  True, False, False,\n",
      "        False, False, False,  True, False, False,  True, False,  True,  True,\n",
      "        False, False,  True, False, False, False, False, False,  True,  True]), tensor([False, False, False, False, False,  True, False, False, False,  True,\n",
      "         True,  True, False, False,  True,  True, False, False,  True,  True,\n",
      "        False, False,  True,  True, False,  True,  True, False, False, False,\n",
      "        False,  True, False, False, False,  True,  True, False, False, False,\n",
      "         True,  True,  True, False, False, False, False,  True,  True, False,\n",
      "         True,  True,  True,  True, False, False, False, False,  True,  True,\n",
      "        False,  True, False, False,  True, False, False, False,  True,  True,\n",
      "        False, False, False,  True, False,  True, False, False, False,  True,\n",
      "         True, False, False, False, False,  True,  True,  True, False,  True,\n",
      "        False, False,  True, False, False,  True, False, False, False,  True,\n",
      "        False,  True,  True, False, False,  True, False,  True,  True,  True,\n",
      "         True,  True, False,  True,  True,  True,  True,  True, False, False,\n",
      "         True,  True, False,  True, False, False, False,  True,  True,  True,\n",
      "        False,  True, False,  True, False, False, False, False, False, False,\n",
      "        False,  True, False, False, False,  True,  True, False, False,  True,\n",
      "        False, False,  True, False, False, False, False, False, False, False,\n",
      "        False,  True, False,  True,  True, False,  True,  True,  True, False,\n",
      "        False,  True, False, False,  True, False, False, False, False,  True,\n",
      "         True,  True, False,  True,  True,  True,  True,  True, False, False,\n",
      "        False, False, False, False,  True, False,  True, False, False, False])], 'PGD_L2_std': 0.1, 'PGD_L2_std_bool_track_correct_preds': [tensor([ True, False,  True, False, False, False, False, False, False, False,\n",
      "        False, False, False, False, False, False, False, False, False, False,\n",
      "        False, False, False, False, False, False,  True, False, False, False,\n",
      "        False, False, False, False, False, False,  True,  True, False,  True,\n",
      "        False,  True, False, False, False, False, False, False,  True, False,\n",
      "        False, False,  True, False, False, False, False, False, False, False,\n",
      "        False, False, False, False, False, False, False, False, False, False,\n",
      "        False, False, False, False,  True, False, False, False, False, False,\n",
      "        False, False, False,  True, False, False, False, False, False, False,\n",
      "        False, False, False, False, False, False, False, False, False, False,\n",
      "        False, False, False,  True, False, False, False, False, False, False,\n",
      "        False, False, False, False, False, False, False, False, False, False,\n",
      "        False, False, False, False, False, False, False, False,  True, False,\n",
      "        False, False, False, False, False,  True, False, False, False, False,\n",
      "        False,  True, False,  True, False, False, False, False, False, False,\n",
      "        False, False, False, False, False, False, False, False, False, False,\n",
      "        False, False,  True, False, False, False, False, False,  True, False,\n",
      "        False, False, False, False, False,  True,  True, False,  True, False,\n",
      "        False, False, False, False, False, False, False, False, False, False,\n",
      "        False, False, False, False, False, False,  True, False, False, False]), tensor([False, False, False, False, False, False, False, False, False,  True,\n",
      "        False, False, False, False, False, False, False, False, False, False,\n",
      "        False, False, False, False, False, False, False, False, False,  True,\n",
      "        False, False, False, False,  True, False, False, False, False, False,\n",
      "        False, False, False, False, False, False, False, False, False, False,\n",
      "        False, False, False, False, False,  True, False, False, False, False,\n",
      "        False, False,  True, False, False,  True, False, False, False, False,\n",
      "        False, False, False, False, False, False,  True, False, False, False,\n",
      "        False, False, False, False, False, False, False, False,  True, False,\n",
      "        False, False, False, False, False, False, False, False, False, False,\n",
      "        False, False,  True, False, False, False, False, False, False, False,\n",
      "        False, False, False, False, False, False, False, False, False, False,\n",
      "        False, False, False, False, False, False, False, False, False,  True,\n",
      "        False, False,  True, False, False, False, False, False, False, False,\n",
      "        False, False,  True, False, False,  True, False, False, False, False,\n",
      "        False, False, False, False, False, False, False, False, False, False,\n",
      "        False, False, False, False, False, False, False, False, False, False,\n",
      "         True, False, False, False, False, False, False, False,  True, False,\n",
      "        False, False, False,  True, False,  True, False, False, False, False,\n",
      "        False, False, False, False, False, False, False, False, False, False]), tensor([False, False, False, False, False, False, False, False, False, False,\n",
      "         True, False, False, False, False, False,  True, False, False, False,\n",
      "        False,  True, False, False, False, False, False, False, False, False,\n",
      "         True, False, False, False, False, False, False, False,  True, False,\n",
      "        False, False, False, False, False, False, False, False, False, False,\n",
      "        False, False, False, False, False, False, False, False, False, False,\n",
      "        False, False, False, False, False, False, False, False, False, False,\n",
      "        False, False, False, False,  True, False,  True, False, False, False,\n",
      "        False, False, False, False, False, False, False, False, False, False,\n",
      "        False, False, False,  True,  True, False, False, False, False, False,\n",
      "        False, False, False, False, False, False, False, False, False, False,\n",
      "        False, False, False, False, False, False, False, False, False, False,\n",
      "        False, False, False, False, False, False, False, False, False,  True,\n",
      "        False, False, False, False, False, False, False,  True, False, False,\n",
      "        False, False, False, False,  True, False, False, False, False, False,\n",
      "        False, False, False, False, False, False, False, False, False, False,\n",
      "        False, False, False, False, False, False, False, False, False, False,\n",
      "        False, False, False, False, False, False, False, False, False,  True,\n",
      "        False, False, False, False, False, False, False, False, False, False,\n",
      "        False, False, False, False, False, False, False, False, False, False]), tensor([False, False, False, False, False, False, False, False, False, False,\n",
      "        False, False, False, False, False, False, False, False, False, False,\n",
      "        False, False, False, False, False, False, False, False, False, False,\n",
      "        False, False, False, False, False, False,  True, False, False, False,\n",
      "        False, False, False, False, False, False, False, False, False, False,\n",
      "        False,  True,  True, False, False,  True, False, False, False, False,\n",
      "        False, False, False,  True, False, False, False, False, False, False,\n",
      "        False, False, False, False, False,  True,  True, False, False,  True,\n",
      "        False, False, False, False, False, False, False,  True, False, False,\n",
      "        False, False, False,  True, False, False, False, False, False, False,\n",
      "         True, False, False, False, False, False,  True, False, False, False,\n",
      "        False, False, False, False, False, False,  True, False, False, False,\n",
      "        False, False, False, False, False,  True, False, False, False, False,\n",
      "        False, False, False, False, False,  True, False, False, False, False,\n",
      "        False, False, False, False, False, False, False, False, False,  True,\n",
      "        False, False, False, False, False, False,  True, False, False, False,\n",
      "        False,  True,  True, False, False, False, False, False,  True, False,\n",
      "         True, False, False, False, False, False, False,  True, False, False,\n",
      "        False, False, False, False, False, False, False, False, False, False,\n",
      "         True, False,  True, False, False, False, False, False, False,  True]), tensor([False, False, False, False, False,  True, False, False, False,  True,\n",
      "        False,  True, False, False, False, False, False, False,  True,  True,\n",
      "        False, False,  True, False,  True,  True,  True, False, False, False,\n",
      "        False, False, False, False, False, False, False, False, False, False,\n",
      "        False, False,  True, False, False, False, False, False, False, False,\n",
      "         True, False,  True, False, False, False, False, False, False, False,\n",
      "        False, False, False, False,  True, False, False, False, False, False,\n",
      "        False, False, False, False, False, False, False, False, False,  True,\n",
      "        False, False, False, False, False, False, False, False, False, False,\n",
      "        False, False, False, False, False, False, False, False, False, False,\n",
      "        False, False, False, False, False, False, False, False, False, False,\n",
      "        False,  True, False, False, False, False, False,  True, False, False,\n",
      "         True, False, False, False, False, False, False,  True, False, False,\n",
      "        False, False, False, False, False, False, False, False, False, False,\n",
      "        False,  True, False, False, False, False,  True, False, False, False,\n",
      "        False, False, False, False, False, False, False, False, False, False,\n",
      "        False,  True, False, False, False, False, False, False, False, False,\n",
      "        False, False, False, False, False, False, False, False, False,  True,\n",
      "        False, False, False, False,  True, False, False, False, False, False,\n",
      "        False, False, False, False, False, False, False, False, False, False])], 'PGD_Linf_std': 0.592, 'PGD_Linf_std_bool_track_correct_preds': [tensor([ True, False,  True,  True,  True,  True, False,  True, False,  True,\n",
      "         True, False,  True,  True,  True,  True,  True,  True, False,  True,\n",
      "        False,  True, False, False,  True,  True,  True,  True,  True,  True,\n",
      "        False,  True,  True, False, False,  True,  True,  True, False,  True,\n",
      "         True,  True,  True, False, False, False,  True,  True,  True,  True,\n",
      "         True, False, False, False,  True,  True,  True,  True,  True, False,\n",
      "         True, False, False, False, False, False, False,  True, False,  True,\n",
      "         True,  True,  True, False,  True,  True,  True, False, False, False,\n",
      "         True,  True,  True,  True, False,  True,  True, False,  True,  True,\n",
      "         True,  True, False, False,  True, False,  True, False,  True,  True,\n",
      "         True,  True,  True,  True,  True, False,  True, False,  True, False,\n",
      "        False, False, False,  True, False, False,  True,  True,  True, False,\n",
      "        False, False,  True, False, False,  True,  True,  True,  True,  True,\n",
      "         True,  True, False, False,  True,  True,  True, False,  True, False,\n",
      "         True,  True, False,  True, False,  True, False,  True,  True, False,\n",
      "        False, False, False,  True,  True, False,  True,  True, False, False,\n",
      "         True, False,  True,  True, False,  True,  True, False,  True,  True,\n",
      "         True, False, False, False,  True, False,  True,  True,  True,  True,\n",
      "         True, False,  True,  True,  True, False, False,  True,  True,  True,\n",
      "         True,  True,  True, False,  True, False,  True,  True,  True,  True]), tensor([False,  True,  True,  True,  True,  True,  True, False,  True,  True,\n",
      "         True, False,  True,  True,  True,  True,  True, False, False, False,\n",
      "         True,  True, False,  True,  True,  True, False, False,  True,  True,\n",
      "         True,  True, False, False,  True,  True,  True,  True,  True,  True,\n",
      "        False, False, False, False, False, False,  True, False,  True,  True,\n",
      "         True, False,  True,  True,  True,  True,  True, False,  True, False,\n",
      "        False, False,  True,  True, False,  True,  True,  True, False, False,\n",
      "        False,  True,  True,  True,  True,  True,  True,  True,  True,  True,\n",
      "         True,  True, False, False,  True,  True,  True,  True,  True, False,\n",
      "        False, False,  True,  True,  True,  True,  True,  True, False, False,\n",
      "        False, False,  True, False,  True,  True,  True, False,  True, False,\n",
      "         True,  True,  True, False,  True,  True,  True,  True, False, False,\n",
      "        False, False,  True, False,  True, False,  True,  True, False,  True,\n",
      "         True, False,  True, False,  True, False, False,  True, False,  True,\n",
      "        False,  True,  True,  True, False,  True,  True, False,  True, False,\n",
      "         True,  True, False, False,  True, False,  True, False, False,  True,\n",
      "         True,  True, False, False, False,  True, False, False,  True,  True,\n",
      "         True,  True,  True,  True, False,  True,  True,  True,  True,  True,\n",
      "         True, False,  True,  True,  True,  True, False, False,  True,  True,\n",
      "         True,  True,  True,  True, False, False, False, False,  True,  True]), tensor([ True,  True, False, False,  True, False, False,  True,  True,  True,\n",
      "         True,  True, False,  True,  True, False,  True, False,  True,  True,\n",
      "        False,  True,  True,  True,  True,  True, False,  True, False,  True,\n",
      "         True, False, False, False,  True, False, False,  True,  True,  True,\n",
      "         True,  True,  True, False, False, False,  True, False, False, False,\n",
      "         True, False, False,  True,  True,  True,  True, False,  True,  True,\n",
      "        False, False,  True,  True, False,  True, False, False, False, False,\n",
      "         True,  True, False,  True,  True,  True,  True, False, False, False,\n",
      "         True,  True,  True, False,  True,  True, False, False, False,  True,\n",
      "        False,  True,  True,  True,  True, False,  True, False, False, False,\n",
      "         True,  True, False,  True, False, False,  True, False, False,  True,\n",
      "        False, False,  True, False, False, False, False,  True, False,  True,\n",
      "        False, False,  True, False, False,  True,  True, False, False,  True,\n",
      "        False, False, False,  True,  True,  True, False,  True, False,  True,\n",
      "         True,  True, False, False,  True, False,  True,  True, False,  True,\n",
      "        False, False,  True, False, False,  True,  True, False,  True,  True,\n",
      "         True,  True, False, False,  True,  True, False,  True,  True, False,\n",
      "        False, False, False, False,  True,  True, False, False, False,  True,\n",
      "         True,  True, False, False,  True,  True,  True,  True,  True,  True,\n",
      "         True,  True,  True,  True,  True, False, False,  True,  True,  True]), tensor([ True, False,  True, False, False, False, False,  True,  True,  True,\n",
      "        False, False, False, False, False,  True,  True, False, False, False,\n",
      "         True,  True,  True,  True, False, False,  True, False, False, False,\n",
      "         True, False,  True,  True,  True,  True,  True, False,  True,  True,\n",
      "         True, False,  True,  True,  True, False, False,  True,  True,  True,\n",
      "         True,  True,  True, False, False,  True,  True,  True, False, False,\n",
      "         True, False,  True,  True, False,  True, False, False, False, False,\n",
      "        False, False,  True, False, False,  True,  True, False,  True,  True,\n",
      "         True,  True,  True, False, False, False,  True,  True, False, False,\n",
      "         True, False, False,  True, False,  True,  True,  True,  True, False,\n",
      "         True,  True,  True, False,  True,  True, False, False, False,  True,\n",
      "         True, False,  True,  True, False, False,  True, False,  True,  True,\n",
      "        False,  True,  True, False,  True,  True, False, False, False,  True,\n",
      "         True,  True,  True,  True, False,  True,  True, False, False, False,\n",
      "        False, False,  True,  True,  True,  True, False,  True, False,  True,\n",
      "         True,  True, False, False,  True,  True,  True,  True, False,  True,\n",
      "        False,  True,  True,  True,  True,  True, False,  True,  True,  True,\n",
      "         True,  True,  True,  True, False, False,  True,  True, False,  True,\n",
      "        False, False, False,  True,  True,  True,  True, False,  True, False,\n",
      "         True, False,  True, False,  True, False,  True, False,  True,  True]), tensor([False, False,  True, False,  True,  True,  True,  True,  True,  True,\n",
      "        False,  True,  True,  True,  True,  True, False, False,  True,  True,\n",
      "         True,  True,  True,  True,  True,  True,  True, False, False, False,\n",
      "        False,  True, False, False,  True,  True, False, False,  True, False,\n",
      "         True, False,  True,  True, False, False, False, False,  True, False,\n",
      "         True,  True,  True, False,  True,  True, False, False,  True,  True,\n",
      "         True, False, False, False,  True, False, False, False, False,  True,\n",
      "        False,  True,  True,  True, False,  True, False, False,  True,  True,\n",
      "         True, False, False, False,  True,  True,  True,  True,  True,  True,\n",
      "         True, False,  True,  True, False,  True, False, False, False,  True,\n",
      "        False,  True, False,  True, False,  True,  True,  True,  True,  True,\n",
      "         True,  True,  True,  True,  True, False,  True,  True,  True, False,\n",
      "         True,  True,  True,  True, False,  True, False,  True,  True,  True,\n",
      "        False,  True, False,  True, False, False,  True,  True, False, False,\n",
      "        False,  True,  True, False, False,  True,  True, False,  True,  True,\n",
      "        False, False,  True,  True, False, False, False, False, False, False,\n",
      "        False,  True, False,  True,  True, False,  True,  True,  True, False,\n",
      "        False,  True,  True, False,  True, False,  True, False,  True,  True,\n",
      "         True,  True, False,  True,  True,  True,  True,  True,  True, False,\n",
      "        False, False, False,  True,  True, False,  True,  True, False,  True])], 'Deepfool_base': 0.781, 'Deepfool_base_bool_track_correct_preds': [tensor([ True, False,  True,  True,  True,  True,  True,  True, False,  True,\n",
      "         True,  True,  True,  True,  True,  True,  True,  True, False,  True,\n",
      "         True,  True, False,  True,  True,  True,  True,  True,  True,  True,\n",
      "         True,  True,  True, False,  True,  True,  True,  True, False,  True,\n",
      "         True,  True,  True, False, False, False,  True,  True,  True,  True,\n",
      "         True,  True,  True, False,  True,  True,  True,  True,  True, False,\n",
      "         True, False, False, False, False,  True, False,  True,  True,  True,\n",
      "         True,  True,  True,  True,  True,  True,  True, False,  True,  True,\n",
      "         True,  True,  True,  True, False,  True,  True,  True,  True,  True,\n",
      "         True,  True, False,  True,  True, False,  True,  True,  True,  True,\n",
      "         True,  True,  True,  True,  True,  True,  True, False,  True, False,\n",
      "        False, False,  True,  True, False, False,  True,  True,  True, False,\n",
      "         True,  True,  True,  True, False,  True,  True,  True,  True,  True,\n",
      "         True,  True,  True, False,  True,  True,  True,  True,  True, False,\n",
      "         True,  True,  True,  True, False,  True,  True,  True,  True, False,\n",
      "         True, False,  True,  True,  True, False,  True,  True, False,  True,\n",
      "         True, False,  True,  True,  True,  True,  True, False,  True,  True,\n",
      "         True, False, False, False,  True,  True,  True,  True,  True,  True,\n",
      "         True, False,  True,  True,  True,  True,  True,  True,  True,  True,\n",
      "         True,  True,  True,  True,  True, False,  True,  True,  True,  True]), tensor([False,  True,  True,  True,  True,  True,  True, False,  True,  True,\n",
      "         True, False,  True,  True,  True,  True,  True, False,  True, False,\n",
      "         True,  True,  True,  True,  True,  True,  True,  True,  True,  True,\n",
      "         True,  True, False, False,  True,  True,  True,  True,  True,  True,\n",
      "         True, False, False, False, False,  True,  True, False,  True,  True,\n",
      "         True, False,  True,  True,  True,  True,  True, False,  True, False,\n",
      "         True, False,  True,  True, False,  True,  True,  True, False,  True,\n",
      "        False,  True,  True,  True,  True,  True,  True,  True,  True,  True,\n",
      "         True,  True, False,  True,  True,  True,  True,  True,  True,  True,\n",
      "        False,  True,  True,  True,  True,  True,  True,  True,  True,  True,\n",
      "        False, False,  True,  True,  True,  True,  True,  True,  True,  True,\n",
      "         True,  True,  True, False,  True,  True,  True,  True, False,  True,\n",
      "        False, False,  True,  True,  True, False,  True,  True, False,  True,\n",
      "         True, False,  True,  True,  True,  True,  True,  True, False,  True,\n",
      "        False,  True,  True,  True,  True,  True,  True,  True,  True, False,\n",
      "         True,  True, False, False,  True,  True,  True,  True,  True,  True,\n",
      "         True,  True,  True, False,  True,  True, False,  True,  True,  True,\n",
      "         True,  True,  True,  True,  True,  True,  True,  True,  True,  True,\n",
      "         True, False,  True,  True,  True,  True,  True, False,  True,  True,\n",
      "         True,  True,  True,  True,  True,  True,  True,  True,  True,  True]), tensor([ True,  True, False, False,  True, False, False,  True,  True,  True,\n",
      "         True,  True, False,  True,  True, False,  True,  True,  True,  True,\n",
      "        False,  True,  True,  True,  True,  True, False,  True, False,  True,\n",
      "         True,  True, False,  True,  True,  True,  True,  True,  True,  True,\n",
      "         True,  True,  True,  True,  True, False,  True,  True, False, False,\n",
      "         True, False, False,  True,  True,  True,  True, False,  True,  True,\n",
      "        False,  True,  True,  True, False,  True, False, False, False, False,\n",
      "         True,  True,  True,  True,  True,  True,  True,  True, False, False,\n",
      "         True,  True,  True, False,  True,  True, False,  True, False,  True,\n",
      "        False,  True,  True,  True,  True, False,  True,  True,  True,  True,\n",
      "         True,  True, False,  True,  True,  True,  True, False, False,  True,\n",
      "        False, False,  True, False, False,  True,  True,  True,  True,  True,\n",
      "         True,  True,  True, False, False,  True,  True, False, False,  True,\n",
      "         True, False, False,  True,  True,  True, False,  True, False,  True,\n",
      "         True,  True, False, False,  True,  True,  True,  True, False,  True,\n",
      "        False, False,  True, False,  True,  True,  True,  True,  True,  True,\n",
      "         True,  True,  True,  True,  True,  True,  True,  True,  True, False,\n",
      "        False,  True, False,  True,  True,  True,  True,  True, False,  True,\n",
      "         True,  True, False,  True,  True,  True,  True,  True,  True,  True,\n",
      "         True,  True,  True,  True,  True, False,  True,  True,  True,  True]), tensor([ True, False,  True, False,  True, False, False,  True,  True,  True,\n",
      "         True,  True, False, False,  True,  True,  True,  True,  True, False,\n",
      "         True,  True,  True,  True, False,  True,  True, False,  True, False,\n",
      "         True, False,  True,  True,  True,  True,  True, False,  True,  True,\n",
      "         True,  True,  True,  True,  True, False,  True,  True,  True,  True,\n",
      "         True,  True,  True, False,  True,  True,  True,  True, False, False,\n",
      "         True,  True,  True,  True,  True,  True,  True, False,  True, False,\n",
      "         True,  True,  True,  True,  True,  True,  True,  True,  True,  True,\n",
      "         True,  True,  True,  True,  True, False,  True,  True, False,  True,\n",
      "         True, False, False,  True, False,  True,  True,  True,  True, False,\n",
      "         True,  True,  True,  True,  True,  True,  True,  True,  True,  True,\n",
      "         True, False,  True,  True, False,  True,  True,  True,  True,  True,\n",
      "        False,  True,  True, False,  True,  True,  True,  True, False,  True,\n",
      "         True,  True,  True,  True,  True,  True,  True,  True,  True,  True,\n",
      "        False, False,  True,  True,  True,  True,  True,  True,  True,  True,\n",
      "         True,  True, False,  True,  True,  True,  True,  True,  True,  True,\n",
      "        False,  True,  True,  True,  True,  True, False,  True,  True,  True,\n",
      "         True,  True,  True,  True,  True,  True,  True,  True,  True,  True,\n",
      "        False,  True,  True,  True,  True,  True,  True, False,  True,  True,\n",
      "         True, False,  True,  True,  True, False,  True, False,  True,  True]), tensor([False, False,  True,  True,  True,  True,  True,  True,  True,  True,\n",
      "         True,  True,  True,  True,  True,  True, False, False,  True,  True,\n",
      "         True,  True,  True,  True,  True,  True,  True, False,  True, False,\n",
      "        False,  True, False,  True,  True,  True,  True, False,  True, False,\n",
      "         True,  True,  True,  True, False,  True, False, False,  True,  True,\n",
      "         True,  True,  True,  True,  True,  True,  True, False,  True,  True,\n",
      "         True,  True,  True,  True,  True, False, False,  True,  True,  True,\n",
      "        False,  True,  True,  True,  True,  True,  True, False,  True,  True,\n",
      "         True, False, False,  True,  True,  True,  True,  True,  True,  True,\n",
      "         True,  True,  True,  True,  True,  True, False, False, False,  True,\n",
      "        False,  True,  True,  True,  True,  True,  True,  True,  True,  True,\n",
      "         True,  True,  True,  True,  True,  True,  True,  True,  True,  True,\n",
      "         True,  True,  True,  True, False,  True, False,  True,  True,  True,\n",
      "        False,  True, False,  True, False,  True,  True,  True, False,  True,\n",
      "         True,  True,  True,  True, False,  True,  True, False,  True,  True,\n",
      "        False, False,  True,  True,  True,  True, False,  True, False, False,\n",
      "        False,  True,  True,  True,  True, False,  True,  True,  True,  True,\n",
      "         True,  True,  True,  True,  True, False,  True,  True,  True,  True,\n",
      "         True,  True, False,  True,  True,  True,  True,  True,  True,  True,\n",
      "         True,  True,  True,  True,  True, False,  True,  True, False,  True])], 'CW_base': 0.623, 'CW_base_bool_track_correct_preds': [tensor([ True, False,  True,  True,  True,  True,  True,  True, False, False,\n",
      "         True, False,  True,  True,  True,  True, False,  True, False,  True,\n",
      "        False,  True, False,  True, False,  True,  True,  True,  True,  True,\n",
      "        False,  True,  True, False, False,  True,  True,  True, False,  True,\n",
      "         True,  True,  True,  True, False, False,  True, False,  True,  True,\n",
      "         True, False,  True, False,  True,  True,  True,  True, False, False,\n",
      "         True, False, False, False, False,  True, False,  True, False,  True,\n",
      "         True,  True,  True,  True,  True,  True, False, False, False, False,\n",
      "         True,  True,  True,  True, False,  True,  True, False,  True,  True,\n",
      "         True,  True, False, False,  True, False,  True, False,  True,  True,\n",
      "         True,  True,  True,  True,  True, False,  True, False,  True, False,\n",
      "        False, False, False,  True, False, False,  True,  True,  True, False,\n",
      "        False, False,  True, False, False, False, False,  True,  True,  True,\n",
      "         True,  True, False, False,  True,  True,  True,  True,  True, False,\n",
      "         True,  True,  True,  True, False,  True,  True,  True,  True, False,\n",
      "        False, False,  True,  True,  True, False,  True,  True, False, False,\n",
      "         True, False,  True,  True, False,  True,  True, False,  True,  True,\n",
      "         True, False, False, False,  True,  True,  True,  True,  True,  True,\n",
      "         True, False,  True,  True,  True,  True,  True,  True,  True, False,\n",
      "         True,  True,  True,  True,  True, False,  True,  True,  True,  True]), tensor([False,  True,  True,  True,  True,  True,  True, False,  True,  True,\n",
      "         True, False,  True,  True,  True,  True,  True, False, False, False,\n",
      "         True,  True,  True,  True,  True,  True, False, False,  True,  True,\n",
      "         True,  True, False, False,  True,  True,  True,  True,  True,  True,\n",
      "         True, False, False, False, False, False,  True, False,  True, False,\n",
      "         True, False,  True,  True,  True,  True,  True, False,  True, False,\n",
      "        False, False,  True,  True, False,  True,  True,  True, False, False,\n",
      "        False,  True,  True,  True,  True,  True,  True,  True,  True,  True,\n",
      "        False,  True, False, False,  True,  True,  True,  True,  True,  True,\n",
      "        False,  True,  True,  True,  True,  True,  True,  True, False,  True,\n",
      "        False, False,  True,  True,  True,  True,  True, False,  True, False,\n",
      "         True,  True,  True, False,  True,  True,  True,  True, False,  True,\n",
      "        False, False, False,  True,  True, False,  True,  True, False,  True,\n",
      "         True, False,  True,  True,  True,  True, False,  True, False,  True,\n",
      "        False,  True,  True, False, False,  True,  True, False,  True, False,\n",
      "         True, False, False, False,  True,  True,  True, False, False,  True,\n",
      "        False,  True, False, False, False, False, False, False,  True,  True,\n",
      "         True,  True,  True,  True,  True,  True,  True,  True,  True,  True,\n",
      "         True, False, False,  True,  True,  True,  True, False,  True,  True,\n",
      "         True,  True,  True,  True, False,  True, False,  True,  True,  True]), tensor([ True,  True, False, False,  True, False, False,  True,  True,  True,\n",
      "         True, False, False,  True,  True, False,  True,  True,  True,  True,\n",
      "        False,  True, False,  True,  True,  True, False,  True, False,  True,\n",
      "         True, False, False,  True,  True,  True, False,  True,  True,  True,\n",
      "         True,  True, False, False,  True, False,  True,  True, False, False,\n",
      "         True, False, False,  True,  True,  True,  True, False,  True,  True,\n",
      "        False, False,  True,  True, False,  True, False, False, False, False,\n",
      "         True,  True,  True,  True,  True,  True,  True,  True, False, False,\n",
      "         True,  True,  True, False,  True,  True, False,  True, False,  True,\n",
      "        False,  True,  True,  True,  True, False, False,  True, False,  True,\n",
      "         True,  True, False,  True, False, False,  True, False, False,  True,\n",
      "        False, False,  True, False, False, False, False,  True,  True,  True,\n",
      "        False, False,  True, False, False,  True,  True, False, False,  True,\n",
      "        False, False, False,  True,  True,  True, False,  True, False,  True,\n",
      "         True,  True, False, False,  True,  True,  True,  True, False,  True,\n",
      "        False, False, False, False,  True,  True,  True, False,  True,  True,\n",
      "         True,  True, False, False, False,  True,  True,  True,  True, False,\n",
      "        False,  True, False, False,  True,  True,  True, False, False,  True,\n",
      "         True,  True, False, False,  True,  True, False,  True,  True,  True,\n",
      "         True,  True,  True,  True,  True, False,  True,  True,  True,  True]), tensor([ True, False,  True, False, False, False, False,  True,  True,  True,\n",
      "        False, False, False, False,  True,  True, False,  True,  True, False,\n",
      "        False,  True,  True,  True, False,  True,  True, False, False, False,\n",
      "        False, False,  True,  True,  True,  True,  True, False, False,  True,\n",
      "         True, False,  True,  True,  True, False, False,  True,  True,  True,\n",
      "         True,  True,  True, False,  True,  True,  True,  True, False, False,\n",
      "         True, False, False,  True, False,  True,  True, False, False, False,\n",
      "        False, False,  True,  True, False,  True,  True, False,  True,  True,\n",
      "         True,  True,  True, False, False, False,  True,  True, False, False,\n",
      "         True, False, False,  True, False,  True,  True,  True,  True, False,\n",
      "         True,  True,  True, False,  True,  True,  True, False, False,  True,\n",
      "         True, False,  True,  True, False,  True,  True, False,  True,  True,\n",
      "        False,  True,  True, False,  True,  True,  True,  True, False,  True,\n",
      "         True,  True,  True,  True,  True,  True,  True, False, False, False,\n",
      "        False, False,  True,  True,  True,  True, False,  True,  True,  True,\n",
      "         True, False, False, False,  True,  True,  True,  True, False,  True,\n",
      "        False,  True,  True,  True, False,  True, False,  True,  True,  True,\n",
      "         True,  True,  True,  True,  True,  True,  True,  True,  True, False,\n",
      "        False, False, False,  True,  True,  True, False, False, False,  True,\n",
      "         True, False,  True,  True, False, False,  True, False,  True,  True]), tensor([False, False,  True, False, False,  True, False,  True,  True,  True,\n",
      "         True,  True,  True,  True,  True,  True, False, False,  True,  True,\n",
      "         True,  True,  True, False,  True,  True,  True, False, False, False,\n",
      "        False,  True, False, False,  True,  True,  True, False,  True, False,\n",
      "         True,  True,  True,  True, False, False, False,  True,  True,  True,\n",
      "         True,  True,  True, False,  True, False, False, False,  True, False,\n",
      "         True, False, False, False,  True, False, False, False, False, False,\n",
      "        False,  True,  True,  True, False,  True, False, False,  True,  True,\n",
      "         True, False, False,  True,  True,  True,  True,  True,  True,  True,\n",
      "         True, False,  True, False,  True,  True, False, False, False,  True,\n",
      "        False,  True,  True,  True, False, False,  True,  True,  True,  True,\n",
      "         True,  True,  True, False,  True, False,  True,  True,  True, False,\n",
      "         True,  True,  True,  True, False, False, False,  True,  True,  True,\n",
      "        False,  True, False,  True, False, False,  True,  True, False,  True,\n",
      "        False,  True,  True, False, False,  True,  True, False,  True,  True,\n",
      "        False, False,  True,  True, False, False, False,  True, False, False,\n",
      "        False,  True,  True, False,  True, False,  True,  True,  True, False,\n",
      "         True,  True,  True, False,  True, False,  True, False,  True,  True,\n",
      "         True,  True, False, False,  True,  True,  True,  True,  True,  True,\n",
      "        False,  True, False,  True,  True, False,  True,  True, False,  True])], 'PGD_Linf_mod': 0.051, 'PGD_Linf_mod_bool_track_correct_preds': [tensor([False, False, False,  True, False, False, False, False, False, False,\n",
      "        False, False, False,  True, False, False, False, False, False, False,\n",
      "        False, False, False, False, False, False, False, False,  True, False,\n",
      "        False, False, False, False, False, False, False, False, False, False,\n",
      "        False, False, False, False, False, False, False, False, False, False,\n",
      "        False, False, False, False, False, False, False, False, False, False,\n",
      "        False, False, False, False, False, False, False, False, False, False,\n",
      "        False,  True, False, False, False, False, False, False, False, False,\n",
      "        False, False,  True, False, False, False, False, False, False, False,\n",
      "        False, False, False, False, False, False, False, False, False, False,\n",
      "        False,  True, False, False, False, False, False, False, False, False,\n",
      "        False, False, False, False, False, False, False, False, False, False,\n",
      "        False, False, False, False, False, False, False, False,  True, False,\n",
      "        False, False, False, False, False, False,  True, False, False, False,\n",
      "        False, False, False, False, False, False, False, False,  True, False,\n",
      "        False, False, False, False, False, False, False, False, False, False,\n",
      "        False, False, False, False, False, False, False, False, False, False,\n",
      "        False, False, False, False,  True, False, False, False, False, False,\n",
      "        False, False, False,  True, False, False, False, False,  True, False,\n",
      "        False, False, False, False, False, False, False, False, False,  True]), tensor([False, False, False, False, False, False, False, False, False, False,\n",
      "        False, False, False, False, False, False, False, False, False, False,\n",
      "        False, False, False, False, False, False, False, False, False, False,\n",
      "        False, False, False, False, False, False,  True, False, False, False,\n",
      "        False, False, False, False, False, False, False, False, False, False,\n",
      "        False, False, False, False, False, False, False, False,  True, False,\n",
      "        False, False, False, False, False, False, False, False, False, False,\n",
      "        False,  True, False, False, False, False, False,  True, False, False,\n",
      "        False, False, False, False, False, False, False, False, False, False,\n",
      "        False, False, False, False,  True, False,  True, False, False, False,\n",
      "        False, False, False, False, False,  True, False, False, False, False,\n",
      "        False, False,  True, False, False, False, False, False, False, False,\n",
      "        False, False, False, False, False, False, False,  True, False, False,\n",
      "        False, False, False, False, False, False, False, False, False, False,\n",
      "        False, False, False, False, False, False, False, False, False, False,\n",
      "        False, False, False, False, False, False, False, False, False, False,\n",
      "        False, False, False, False, False, False, False, False, False, False,\n",
      "        False, False,  True, False, False, False, False, False, False, False,\n",
      "        False, False, False, False, False, False, False, False, False, False,\n",
      "         True, False, False, False, False, False, False, False, False, False]), tensor([False, False, False, False, False, False, False,  True, False, False,\n",
      "        False, False, False, False, False, False, False, False, False, False,\n",
      "        False, False, False, False, False, False, False, False, False, False,\n",
      "        False, False, False, False, False, False, False, False, False, False,\n",
      "         True,  True, False, False, False, False, False, False, False, False,\n",
      "        False, False, False, False, False, False, False, False,  True, False,\n",
      "        False, False, False, False, False, False, False, False, False, False,\n",
      "        False, False, False, False,  True, False, False, False, False, False,\n",
      "        False, False, False, False, False, False, False, False, False, False,\n",
      "        False, False, False, False, False, False, False, False, False, False,\n",
      "        False, False, False, False, False, False, False, False, False, False,\n",
      "        False, False, False, False, False, False, False, False, False, False,\n",
      "        False, False, False, False, False,  True, False, False, False, False,\n",
      "        False, False, False, False, False, False, False, False, False,  True,\n",
      "        False, False, False, False, False, False, False, False, False, False,\n",
      "        False, False, False, False, False, False, False, False, False, False,\n",
      "        False, False, False, False, False, False, False, False, False, False,\n",
      "        False, False, False, False, False, False, False, False, False, False,\n",
      "        False,  True, False, False, False, False, False, False, False, False,\n",
      "        False, False,  True, False, False, False, False, False, False, False]), tensor([False, False, False, False, False, False, False, False, False, False,\n",
      "        False, False, False, False, False, False, False, False, False, False,\n",
      "        False, False, False, False, False, False, False, False, False, False,\n",
      "        False, False, False, False, False, False, False, False, False, False,\n",
      "        False, False, False, False,  True, False, False, False, False, False,\n",
      "        False, False, False, False, False,  True,  True, False, False, False,\n",
      "        False, False, False, False, False, False, False, False, False, False,\n",
      "        False, False, False, False, False, False, False, False, False, False,\n",
      "        False, False, False, False, False, False, False, False, False, False,\n",
      "        False, False, False,  True, False, False, False, False, False, False,\n",
      "        False, False, False, False, False, False, False, False, False,  True,\n",
      "        False, False, False, False, False, False, False, False, False, False,\n",
      "        False, False, False, False, False, False, False, False, False, False,\n",
      "        False, False, False, False, False, False, False, False, False, False,\n",
      "        False, False,  True, False, False, False, False, False, False, False,\n",
      "        False, False, False, False, False, False, False, False, False, False,\n",
      "        False, False, False, False, False, False, False, False, False, False,\n",
      "        False, False, False, False, False, False, False, False, False, False,\n",
      "        False, False, False, False, False, False, False, False, False, False,\n",
      "        False, False, False, False, False, False,  True, False, False,  True]), tensor([False, False, False, False, False,  True, False, False, False, False,\n",
      "        False, False, False, False, False, False, False, False, False, False,\n",
      "        False, False, False, False, False, False, False, False, False, False,\n",
      "        False, False, False, False, False, False, False, False, False, False,\n",
      "        False, False, False, False, False, False, False, False, False, False,\n",
      "        False, False, False, False, False, False, False, False,  True, False,\n",
      "        False, False, False, False,  True, False, False, False, False, False,\n",
      "        False,  True, False,  True, False, False, False, False, False, False,\n",
      "        False, False, False, False, False, False, False, False, False, False,\n",
      "        False, False, False, False, False, False, False, False, False,  True,\n",
      "        False, False, False, False, False, False, False, False, False, False,\n",
      "        False, False, False, False, False, False, False, False, False, False,\n",
      "        False,  True, False, False, False, False, False, False, False, False,\n",
      "        False, False, False, False, False, False, False, False, False, False,\n",
      "        False, False, False, False, False, False, False, False, False, False,\n",
      "        False, False, False, False, False, False, False, False, False, False,\n",
      "        False, False, False, False, False, False, False, False, False, False,\n",
      "        False, False,  True, False, False, False, False, False, False, False,\n",
      "        False, False, False, False, False, False, False, False, False, False,\n",
      "        False, False, False,  True, False, False, False,  True, False, False])], 'Deepfool_mod': 0.194, 'Deepfool_mod_bool_track_correct_preds': [tensor([False, False, False,  True,  True, False, False, False, False, False,\n",
      "        False, False,  True,  True,  True, False, False, False, False, False,\n",
      "        False,  True, False, False, False, False, False, False,  True, False,\n",
      "        False, False, False, False, False,  True, False, False, False,  True,\n",
      "         True, False, False, False, False, False, False, False,  True, False,\n",
      "        False, False,  True, False, False,  True,  True, False,  True, False,\n",
      "        False, False, False, False, False, False, False, False, False, False,\n",
      "        False,  True, False, False,  True, False, False, False, False, False,\n",
      "        False, False,  True, False, False, False,  True, False, False,  True,\n",
      "        False, False, False, False,  True, False, False, False, False,  True,\n",
      "        False,  True, False, False, False, False,  True, False, False, False,\n",
      "        False, False, False, False, False, False, False, False, False, False,\n",
      "        False, False, False, False, False, False, False, False,  True, False,\n",
      "         True, False, False, False, False,  True,  True, False, False, False,\n",
      "        False,  True, False, False, False, False, False, False,  True, False,\n",
      "        False, False, False,  True, False, False,  True,  True, False, False,\n",
      "        False, False, False, False, False, False, False, False,  True, False,\n",
      "        False, False, False, False,  True, False, False, False, False, False,\n",
      "        False, False, False,  True, False, False, False, False,  True, False,\n",
      "        False, False, False, False,  True, False, False, False, False,  True]), tensor([False, False, False, False,  True, False,  True, False, False, False,\n",
      "        False, False,  True, False,  True,  True, False, False, False, False,\n",
      "         True, False, False,  True, False, False, False, False, False, False,\n",
      "        False, False, False, False, False, False,  True, False,  True, False,\n",
      "        False, False, False, False, False, False,  True, False,  True, False,\n",
      "        False, False, False, False, False, False,  True, False,  True, False,\n",
      "        False, False, False, False, False,  True, False, False, False, False,\n",
      "        False,  True, False, False, False,  True,  True,  True,  True, False,\n",
      "        False, False, False, False, False, False, False, False,  True, False,\n",
      "        False, False, False,  True,  True, False,  True, False, False, False,\n",
      "        False, False,  True, False, False,  True,  True, False, False, False,\n",
      "        False,  True,  True, False, False, False, False, False, False, False,\n",
      "        False, False, False, False, False, False, False,  True, False,  True,\n",
      "        False, False,  True, False,  True, False, False, False, False,  True,\n",
      "        False, False,  True,  True, False,  True, False, False, False, False,\n",
      "        False, False, False, False, False, False, False, False, False,  True,\n",
      "        False,  True, False, False, False, False, False, False, False, False,\n",
      "        False, False,  True,  True, False,  True, False, False, False, False,\n",
      "         True, False, False, False,  True, False, False, False, False, False,\n",
      "         True, False, False, False, False,  True, False, False, False, False]), tensor([ True, False, False, False,  True, False, False,  True, False, False,\n",
      "        False, False, False, False, False, False, False, False,  True, False,\n",
      "        False, False, False,  True,  True, False, False, False, False, False,\n",
      "        False, False, False,  True, False, False, False, False, False, False,\n",
      "         True,  True, False, False, False, False, False, False, False, False,\n",
      "        False, False, False, False, False, False, False, False,  True, False,\n",
      "        False, False, False, False, False, False, False, False, False, False,\n",
      "        False, False, False, False,  True, False,  True, False, False, False,\n",
      "        False, False, False, False, False,  True, False, False, False,  True,\n",
      "        False,  True, False, False, False, False, False, False, False, False,\n",
      "         True, False, False,  True, False, False, False, False, False,  True,\n",
      "        False, False,  True, False, False, False, False, False, False, False,\n",
      "        False, False, False, False, False,  True, False, False, False,  True,\n",
      "        False, False, False,  True,  True, False, False,  True, False,  True,\n",
      "         True, False, False, False, False,  True, False,  True, False, False,\n",
      "        False, False,  True, False, False, False, False, False, False, False,\n",
      "        False, False, False, False, False, False, False, False, False, False,\n",
      "        False, False, False, False,  True, False, False, False, False, False,\n",
      "        False,  True, False, False, False, False, False,  True, False, False,\n",
      "         True, False,  True, False,  True, False,  True, False, False, False]), tensor([False, False,  True, False, False, False, False, False, False,  True,\n",
      "        False, False, False, False, False, False, False, False, False, False,\n",
      "        False, False, False, False, False, False, False, False, False, False,\n",
      "        False, False,  True,  True, False, False, False, False, False, False,\n",
      "        False, False,  True,  True,  True, False, False, False, False, False,\n",
      "        False,  True,  True, False, False,  True,  True, False, False, False,\n",
      "         True, False, False, False, False, False, False, False, False, False,\n",
      "        False, False, False, False, False,  True, False, False,  True,  True,\n",
      "        False, False, False, False, False, False, False, False, False, False,\n",
      "         True, False, False, False, False, False, False, False, False, False,\n",
      "        False,  True, False, False,  True, False, False, False, False,  True,\n",
      "        False, False, False, False, False, False,  True, False,  True, False,\n",
      "        False, False,  True, False,  True, False, False, False, False, False,\n",
      "        False, False, False, False, False,  True, False, False, False, False,\n",
      "        False, False,  True,  True,  True, False, False, False, False, False,\n",
      "        False,  True, False, False,  True, False, False, False, False, False,\n",
      "        False,  True, False,  True,  True,  True, False, False, False, False,\n",
      "        False, False, False, False, False, False, False, False, False, False,\n",
      "        False, False, False, False, False, False, False, False, False, False,\n",
      "        False, False,  True, False, False, False,  True, False, False,  True]), tensor([False, False, False, False, False,  True,  True, False, False, False,\n",
      "        False, False, False,  True, False, False, False, False,  True, False,\n",
      "         True,  True, False,  True, False, False, False, False, False, False,\n",
      "        False, False, False, False, False,  True, False, False, False, False,\n",
      "        False, False, False, False, False, False, False, False, False, False,\n",
      "         True,  True, False, False,  True, False, False, False,  True, False,\n",
      "        False, False, False, False,  True, False, False, False, False, False,\n",
      "        False,  True,  True,  True, False,  True, False, False,  True,  True,\n",
      "        False, False, False, False, False, False, False, False,  True, False,\n",
      "         True, False, False, False, False,  True, False, False, False,  True,\n",
      "        False, False, False,  True, False, False,  True, False, False, False,\n",
      "        False, False,  True, False,  True, False, False, False, False, False,\n",
      "        False,  True, False, False, False, False, False, False, False, False,\n",
      "        False, False, False, False, False, False, False, False, False, False,\n",
      "        False, False, False, False, False,  True, False, False, False, False,\n",
      "        False, False, False, False, False, False, False, False, False, False,\n",
      "        False,  True, False, False, False, False, False, False, False, False,\n",
      "        False,  True,  True, False, False, False, False, False, False,  True,\n",
      "        False, False, False, False,  True, False,  True, False,  True, False,\n",
      "        False, False, False,  True, False, False, False,  True, False, False])], 'CW_mod': 0.302, 'CW_mod_bool_track_correct_preds': [tensor([ True, False,  True, False, False, False,  True, False, False, False,\n",
      "        False, False, False, False,  True,  True, False,  True, False, False,\n",
      "        False, False, False,  True, False,  True,  True, False, False,  True,\n",
      "        False, False, False, False, False,  True,  True,  True, False, False,\n",
      "        False,  True,  True, False, False, False,  True, False, False, False,\n",
      "        False, False, False, False, False, False, False,  True, False, False,\n",
      "         True, False, False, False, False,  True, False,  True, False, False,\n",
      "         True,  True,  True, False, False, False, False, False, False, False,\n",
      "        False,  True,  True,  True, False, False, False, False,  True, False,\n",
      "        False, False, False, False, False, False, False, False,  True, False,\n",
      "        False,  True, False, False,  True, False, False, False, False, False,\n",
      "        False, False, False, False, False, False, False, False, False, False,\n",
      "        False, False,  True, False, False, False,  True,  True,  True, False,\n",
      "        False, False, False, False,  True,  True, False, False, False, False,\n",
      "         True,  True,  True,  True, False,  True, False, False, False, False,\n",
      "        False, False, False,  True, False, False,  True,  True, False, False,\n",
      "         True, False, False, False, False,  True, False, False,  True, False,\n",
      "        False, False, False, False,  True, False,  True, False,  True, False,\n",
      "        False, False,  True,  True, False, False, False,  True, False,  True,\n",
      "        False,  True, False,  True,  True, False,  True,  True, False,  True]), tensor([False, False, False, False,  True, False, False, False, False, False,\n",
      "        False, False,  True,  True, False,  True,  True, False, False, False,\n",
      "        False, False, False,  True, False,  True, False, False,  True,  True,\n",
      "         True,  True, False, False,  True, False,  True, False,  True, False,\n",
      "         True, False, False, False, False, False,  True, False,  True, False,\n",
      "         True, False, False, False,  True, False, False, False,  True, False,\n",
      "        False, False,  True,  True, False, False, False,  True, False, False,\n",
      "        False,  True, False,  True, False, False,  True, False,  True, False,\n",
      "        False,  True, False, False,  True,  True, False,  True,  True, False,\n",
      "        False, False, False, False, False,  True, False,  True, False, False,\n",
      "        False, False,  True,  True, False,  True, False, False, False, False,\n",
      "         True, False, False, False,  True, False, False, False, False, False,\n",
      "        False, False, False, False, False, False, False,  True, False,  True,\n",
      "         True, False,  True,  True,  True, False, False,  True, False, False,\n",
      "        False,  True,  True, False, False,  True, False, False,  True, False,\n",
      "        False, False, False, False,  True, False,  True, False, False, False,\n",
      "        False, False, False, False, False, False, False, False,  True, False,\n",
      "         True, False, False, False,  True,  True, False, False,  True, False,\n",
      "         True, False, False,  True, False,  True, False, False, False, False,\n",
      "        False, False, False, False, False,  True, False, False, False,  True]), tensor([ True, False, False, False, False, False, False,  True, False, False,\n",
      "         True, False, False, False, False, False,  True, False,  True, False,\n",
      "        False,  True, False,  True,  True,  True, False, False, False,  True,\n",
      "         True, False, False,  True, False, False, False, False,  True,  True,\n",
      "         True,  True, False, False,  True, False, False, False, False, False,\n",
      "        False, False, False,  True,  True, False, False, False, False,  True,\n",
      "        False, False,  True, False, False,  True, False, False, False, False,\n",
      "        False,  True,  True, False,  True,  True,  True,  True, False, False,\n",
      "        False,  True,  True, False, False,  True, False, False, False,  True,\n",
      "        False, False,  True, False,  True, False, False, False, False, False,\n",
      "         True, False, False,  True, False, False, False, False, False,  True,\n",
      "        False, False, False, False, False, False, False,  True,  True,  True,\n",
      "        False, False, False, False, False,  True, False, False, False,  True,\n",
      "        False, False, False,  True, False,  True, False,  True, False,  True,\n",
      "         True,  True, False, False, False, False,  True,  True, False, False,\n",
      "        False, False, False, False, False, False, False, False, False,  True,\n",
      "         True,  True, False, False, False, False, False,  True, False, False,\n",
      "        False, False, False, False,  True, False, False, False, False,  True,\n",
      "        False,  True, False, False, False,  True, False,  True,  True, False,\n",
      "        False, False,  True, False,  True, False,  True, False, False, False]), tensor([ True, False, False, False, False, False, False, False,  True,  True,\n",
      "        False, False, False, False,  True,  True, False, False,  True, False,\n",
      "        False, False, False, False, False, False,  True, False, False, False,\n",
      "        False, False, False,  True, False,  True,  True, False, False, False,\n",
      "        False, False,  True, False,  True, False, False, False, False,  True,\n",
      "         True,  True,  True, False, False,  True, False,  True, False, False,\n",
      "        False, False, False,  True, False, False, False, False, False, False,\n",
      "        False, False,  True, False, False,  True,  True, False, False,  True,\n",
      "        False,  True, False, False, False, False, False,  True, False, False,\n",
      "        False, False, False,  True, False, False, False,  True, False, False,\n",
      "        False, False, False, False, False,  True, False, False, False, False,\n",
      "         True, False, False,  True, False, False,  True, False,  True, False,\n",
      "        False,  True,  True, False, False,  True,  True, False, False, False,\n",
      "         True, False,  True, False, False,  True,  True, False, False, False,\n",
      "        False, False,  True, False,  True, False, False,  True, False,  True,\n",
      "        False, False, False, False,  True, False,  True, False, False, False,\n",
      "        False,  True,  True, False, False, False, False, False,  True, False,\n",
      "         True,  True,  True,  True,  True,  True,  True,  True,  True, False,\n",
      "        False, False, False,  True,  True,  True, False, False, False,  True,\n",
      "         True, False,  True, False, False, False,  True, False,  True,  True]), tensor([False, False, False, False, False, False, False, False,  True,  True,\n",
      "        False,  True,  True, False,  True,  True, False, False,  True, False,\n",
      "        False,  True,  True, False,  True,  True,  True, False, False, False,\n",
      "        False, False, False, False, False,  True,  True, False, False, False,\n",
      "        False, False,  True, False, False, False, False, False, False,  True,\n",
      "        False, False,  True, False, False, False, False, False,  True, False,\n",
      "        False, False, False, False, False, False, False, False, False, False,\n",
      "        False, False,  True,  True, False, False, False, False,  True, False,\n",
      "        False, False, False,  True, False, False,  True,  True,  True,  True,\n",
      "        False, False, False, False, False, False, False, False, False, False,\n",
      "        False, False, False, False, False, False, False, False,  True,  True,\n",
      "        False,  True, False, False, False, False, False,  True, False, False,\n",
      "        False,  True,  True,  True, False, False, False,  True,  True,  True,\n",
      "        False, False, False, False, False, False,  True,  True, False, False,\n",
      "        False,  True, False, False, False, False,  True, False, False, False,\n",
      "        False, False, False,  True, False, False, False, False, False, False,\n",
      "        False, False,  True, False,  True, False, False,  True,  True, False,\n",
      "         True, False, False, False, False, False, False, False,  True, False,\n",
      "         True,  True, False, False,  True, False, False, False, False, False,\n",
      "        False, False, False, False, False, False, False, False, False, False])], 'Autoattack': 0.55, 'Autoattack_bool_track_correct_preds': [tensor([ True, False,  True,  True,  True,  True, False,  True, False, False,\n",
      "         True, False,  True,  True,  True,  True,  True,  True, False,  True,\n",
      "        False,  True, False, False,  True,  True,  True,  True,  True,  True,\n",
      "        False,  True,  True, False, False,  True,  True,  True, False,  True,\n",
      "         True,  True,  True, False, False, False,  True, False,  True,  True,\n",
      "        False, False, False, False,  True,  True,  True,  True,  True, False,\n",
      "         True, False, False, False, False, False, False,  True, False,  True,\n",
      "         True,  True, False, False,  True,  True, False, False, False, False,\n",
      "        False,  True,  True,  True, False,  True,  True, False,  True,  True,\n",
      "         True,  True, False, False,  True, False,  True, False,  True,  True,\n",
      "         True,  True,  True,  True,  True, False,  True, False, False, False,\n",
      "        False, False, False,  True, False, False,  True,  True,  True, False,\n",
      "        False, False,  True, False, False,  True, False,  True,  True,  True,\n",
      "         True,  True, False, False,  True,  True,  True, False,  True, False,\n",
      "         True,  True, False,  True, False,  True, False,  True,  True, False,\n",
      "        False, False, False, False,  True, False,  True,  True, False, False,\n",
      "         True, False,  True,  True, False,  True,  True, False,  True,  True,\n",
      "         True, False, False, False,  True, False,  True,  True,  True,  True,\n",
      "         True, False,  True,  True, False, False, False,  True,  True,  True,\n",
      "         True,  True,  True, False,  True, False,  True,  True,  True,  True]), tensor([False, False,  True,  True,  True, False,  True, False,  True,  True,\n",
      "         True, False, False,  True,  True,  True,  True, False, False, False,\n",
      "         True,  True, False,  True,  True, False, False, False,  True,  True,\n",
      "         True,  True, False, False,  True,  True,  True,  True,  True,  True,\n",
      "        False, False, False, False, False, False,  True, False,  True, False,\n",
      "         True, False,  True,  True,  True,  True,  True, False,  True, False,\n",
      "        False, False,  True,  True, False,  True,  True,  True, False, False,\n",
      "        False,  True,  True,  True,  True,  True,  True,  True,  True,  True,\n",
      "         True,  True, False, False, False,  True,  True,  True,  True, False,\n",
      "        False, False,  True,  True,  True, False,  True,  True, False, False,\n",
      "        False, False,  True, False, False,  True,  True, False,  True, False,\n",
      "         True,  True,  True, False,  True,  True,  True,  True, False, False,\n",
      "        False, False, False, False,  True, False,  True,  True, False,  True,\n",
      "         True, False,  True, False,  True, False, False,  True, False,  True,\n",
      "        False,  True,  True,  True, False,  True,  True, False,  True, False,\n",
      "         True,  True, False, False,  True, False,  True, False, False,  True,\n",
      "         True,  True, False, False, False,  True, False, False,  True,  True,\n",
      "         True, False,  True,  True, False,  True,  True,  True,  True,  True,\n",
      "         True, False,  True,  True,  True,  True, False, False,  True,  True,\n",
      "         True,  True,  True,  True, False, False, False, False, False,  True]), tensor([ True, False, False, False,  True, False, False,  True, False,  True,\n",
      "         True,  True, False,  True, False, False,  True, False,  True,  True,\n",
      "        False,  True,  True,  True,  True,  True, False,  True, False,  True,\n",
      "         True, False, False, False,  True, False, False,  True,  True,  True,\n",
      "         True,  True,  True, False, False, False, False, False, False, False,\n",
      "        False, False, False,  True,  True,  True,  True, False,  True,  True,\n",
      "        False, False,  True,  True, False,  True, False, False, False, False,\n",
      "         True, False, False,  True,  True,  True,  True, False, False, False,\n",
      "         True,  True,  True, False,  True,  True, False, False, False,  True,\n",
      "        False,  True,  True,  True,  True, False,  True, False, False, False,\n",
      "         True,  True, False,  True, False, False,  True, False, False,  True,\n",
      "        False, False,  True, False, False, False, False,  True, False, False,\n",
      "        False, False, False, False, False,  True,  True, False, False,  True,\n",
      "        False, False, False,  True,  True, False, False,  True, False,  True,\n",
      "         True,  True, False, False,  True, False,  True,  True, False, False,\n",
      "        False, False,  True, False, False,  True,  True, False,  True,  True,\n",
      "         True,  True, False, False,  True,  True, False,  True,  True, False,\n",
      "        False, False, False, False,  True,  True, False, False, False,  True,\n",
      "         True,  True, False, False,  True,  True,  True,  True,  True,  True,\n",
      "         True,  True,  True,  True,  True, False, False,  True,  True,  True]), tensor([ True, False,  True, False, False, False, False,  True,  True,  True,\n",
      "        False, False, False, False, False,  True, False, False, False, False,\n",
      "         True,  True,  True,  True, False, False,  True, False, False, False,\n",
      "         True, False,  True,  True,  True,  True,  True, False,  True,  True,\n",
      "         True, False,  True,  True,  True, False, False,  True,  True,  True,\n",
      "         True,  True,  True, False, False,  True,  True,  True, False, False,\n",
      "         True, False,  True,  True, False,  True, False, False, False, False,\n",
      "        False, False,  True, False, False,  True,  True, False,  True,  True,\n",
      "         True,  True,  True, False, False, False,  True,  True, False, False,\n",
      "         True, False, False,  True, False,  True,  True,  True,  True, False,\n",
      "         True,  True,  True, False,  True,  True, False, False, False,  True,\n",
      "         True, False, False,  True, False, False,  True, False,  True, False,\n",
      "        False,  True,  True, False,  True,  True, False, False, False,  True,\n",
      "         True,  True,  True,  True, False,  True,  True, False, False, False,\n",
      "        False, False,  True,  True,  True,  True, False,  True, False,  True,\n",
      "         True,  True, False, False,  True,  True,  True,  True, False,  True,\n",
      "        False,  True,  True,  True,  True,  True, False,  True,  True,  True,\n",
      "         True, False,  True,  True, False, False,  True,  True, False,  True,\n",
      "        False, False, False,  True,  True,  True, False, False,  True, False,\n",
      "         True, False,  True, False,  True, False,  True, False,  True,  True]), tensor([False, False, False, False,  True,  True,  True,  True,  True,  True,\n",
      "        False,  True,  True,  True,  True,  True, False, False,  True,  True,\n",
      "         True,  True,  True,  True,  True,  True,  True, False, False, False,\n",
      "        False,  True, False, False,  True,  True, False, False,  True, False,\n",
      "         True, False,  True,  True, False, False, False, False,  True, False,\n",
      "         True,  True,  True, False,  True, False, False, False,  True, False,\n",
      "         True, False, False, False,  True, False, False, False, False, False,\n",
      "        False,  True,  True,  True, False,  True, False, False,  True,  True,\n",
      "         True, False, False, False,  True,  True,  True,  True,  True,  True,\n",
      "        False, False,  True,  True, False,  True, False, False, False,  True,\n",
      "        False,  True, False,  True, False,  True,  True,  True,  True,  True,\n",
      "         True,  True,  True,  True,  True, False,  True,  True,  True, False,\n",
      "         True,  True,  True,  True, False,  True, False,  True,  True,  True,\n",
      "        False,  True, False,  True, False, False,  True,  True, False, False,\n",
      "        False,  True,  True, False, False,  True,  True, False,  True,  True,\n",
      "        False, False,  True,  True, False, False, False, False, False, False,\n",
      "        False,  True, False, False,  True, False,  True,  True,  True, False,\n",
      "        False,  True,  True, False,  True, False,  True, False,  True,  True,\n",
      "         True,  True, False,  True,  True,  True,  True,  True,  True, False,\n",
      "        False, False, False,  True,  True, False,  True,  True, False,  True])], 'worst_no_skipped': 0.027, 'worst_with_skipped': 0.004, 'worst_seen': 0.592, 'worst_unseen': 0.004, 'worst_unseen_no_skipped': 0.027, 'worst_always_unseen': 0.028, 'worst_always_unseen_no_skipped': 0.093, 'num_test_samples': 1000, 'num_attack_restarts': 10, 'model_name': 'model_PGD_Linf_1125.pt', 'seen_attacks': ['clean', 'PGD_Linf_std'], 'unseen_attacks': ['PGD_L1_std', 'PGD_L2_std', 'Deepfool_base', 'CW_base', 'PGD_Linf_mod', 'Deepfool_mod', 'CW_mod', 'Autoattack'], 'always_unseen_attacks': ['PGD_Linf_mod', 'Deepfool_mod', 'CW_mod', 'Autoattack'], 'skipped_domains_worst_case': ['PGD_Linf_mod']}\n"
     ]
    }
   ],
   "source": [
    "# clear memory\n",
    "from IPython import get_ipython\n",
    "get_ipython().magic('reset -sf') \n",
    "\n",
    "import numpy as np\n",
    "import torch\n",
    "import time\n",
    "timer = 0\n",
    "\n",
    "from torchvision import datasets\n",
    "import torchvision.transforms as transforms\n",
    "from torch.utils.data.sampler import SubsetRandomSampler\n",
    "\n",
    "import torch.nn as nn\n",
    "import torch.nn.functional as F\n",
    "\n",
    "import advertorch.attacks as attacks\n",
    "from attacks.deepfool import DeepfoolLinfAttack\n",
    "import torch.nn as nn\n",
    "from autoattack import AutoAttack\n",
    "\n",
    "from advertorch.context import ctx_noparamgrad_and_eval\n",
    "from torch.utils.tensorboard import SummaryWriter\n",
    "\n",
    "\n",
    "\n",
    "import os, random\n",
    "\n",
    "\n",
    "# import argparse\n",
    "\n",
    "# argument_parser = argparse.ArgumentParser()\n",
    "\n",
    "# argument_parser.add_argument(\"--lr_init\", type=float, help=\"Initial learning rate value, default=0.01. CAREFUL: this will be divided by beta, since the ERM term is multiplied by beta in the objective.\")\n",
    "\n",
    "# parsed_args = argument_parser.parse_args()\n",
    "\n",
    "\n",
    "# Make sure validation splits are the same at all time (e.g. even after loading)\n",
    "seed = 0\n",
    "\n",
    "def seed_init_fn(seed=seed):\n",
    "   np.random.seed(seed)\n",
    "   random.seed(seed)\n",
    "   torch.manual_seed(seed)\n",
    "   return\n",
    "\n",
    "device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n",
    "\n",
    "num_workers = 0\n",
    "# Make sure test_data is a multiple of batch_size_test\n",
    "batch_size_train_and_valid = 128\n",
    "batch_size_test = 200\n",
    "\n",
    "# proportion of full training set used for validation\n",
    "valid_size = 0.2\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "transform = transforms.ToTensor()\n",
    "train_and_valid_data = datasets.MNIST(root = 'data', train = True, download = True, transform = transform)\n",
    "test_data = datasets.MNIST(root = 'data', train = False, download = True, transform = transform)\n",
    "\n",
    "num_valid_samples = int(np.floor(valid_size * len(train_and_valid_data)))\n",
    "num_train_samples = len(train_and_valid_data) - num_valid_samples\n",
    "train_data, valid_data = torch.utils.data.random_split(train_and_valid_data, [num_train_samples, num_valid_samples], generator=torch.Generator().manual_seed(seed))\n",
    "\n",
    "train_loader = torch.utils.data.DataLoader(train_data, batch_size = batch_size_train_and_valid)\n",
    "valid_loader = torch.utils.data.DataLoader(valid_data, batch_size = batch_size_train_and_valid)\n",
    "test_loader = torch.utils.data.DataLoader(test_data, batch_size = batch_size_test, worker_init_fn=seed_init_fn)\n",
    "\n",
    "\n",
    "class Net(nn.Module):\n",
    "    def __init__(self):\n",
    "        super(Net,self).__init__()\n",
    "        \n",
    "        self.fc1 = nn.Linear(28*28, 512)\n",
    "        self.fc2 = nn.Linear(512, 512)\n",
    "        self.fc3 = nn.Linear(512, 10)\n",
    "        \n",
    "    def forward(self,x):\n",
    "        # vectorise input\n",
    "        x = x.view(-1,28*28)\n",
    "        # Hidden layer 1 + relu\n",
    "        x = F.relu(self.fc1(x))\n",
    "        # Hidden layer 2 + relu\n",
    "        x = F.relu(self.fc2(x))\n",
    "        # Output layer\n",
    "        x = self.fc3(x)\n",
    "        return x\n",
    "\n",
    "\n",
    "model = Net()\n",
    "# model.to(device)\n",
    "\n",
    "\n",
    "model.load_state_dict(torch.load('model_no_dropout.pt'))\n",
    "model.to(device)\n",
    "\n",
    "\n",
    "# if str(device) == \"cuda\" and torch.cuda.device_count() > 1:\n",
    "#     print(\"Using DataParallel\")\n",
    "#     model = torch.nn.DataParallel(model)\n",
    "# model.to(device)\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "# divided by 10 eps, eps_iter and CW's lr, added as input binary_search_steps to CW attacks\n",
    "\n",
    "\n",
    "adversary_PGD_Linf_std = attacks.LinfPGDAttack(\n",
    "    model, loss_fn=nn.CrossEntropyLoss(reduction=\"sum\"), eps=0.3,\n",
    "    nb_iter=40, eps_iter=0.01, rand_init=True, clip_min=0.0,\n",
    "    clip_max=1.0, targeted=False)\n",
    "\n",
    "adversary_CW = attacks.CarliniWagnerL2Attack(\n",
    "    model, num_classes=10, max_iterations=20, learning_rate=0.1,\n",
    "    binary_search_steps=5, clip_min=0.0, clip_max=1.0)\n",
    "\n",
    "adversary_deepfool = DeepfoolLinfAttack(\n",
    "        model, num_classes=10, nb_iter=30, eps=0.11, clip_min=0.0, clip_max=1.0)\n",
    "\n",
    "# Unseen attacks used for validation, has bigger learning rate and number of iterations\n",
    "adversary_CW_unseen = attacks.CarliniWagnerL2Attack(\n",
    "    model, num_classes=10, max_iterations=30, learning_rate=0.12,\n",
    "    binary_search_steps=7, clip_min=0.0, clip_max=1.0)\n",
    "\n",
    "adversary_PGD_Linf_unseen = attacks.LinfPGDAttack(\n",
    "    model, loss_fn=nn.CrossEntropyLoss(reduction=\"sum\"), eps=0.4,\n",
    "    nb_iter=40, eps_iter=0.033, rand_init=True, clip_min=0.0,\n",
    "    clip_max=1.0, targeted=False)\n",
    "\n",
    "adversary_deepfool_unseen = DeepfoolLinfAttack(\n",
    "        model, num_classes=10, nb_iter=50, eps=0.4, clip_min=0.0, clip_max=1.0)\n",
    "\n",
    "adversary_autoattack_unseen = AutoAttack(model, norm='Linf', eps=.3, \n",
    "        version='standard', seed=None, verbose=False)\n",
    "\n",
    "adversary_PGD_L2_std = attacks.L2PGDAttack(\n",
    "    model, loss_fn=nn.CrossEntropyLoss(reduction=\"sum\"), eps=2.,\n",
    "    nb_iter=40, eps_iter=0.1, rand_init=True, clip_min=0.0,\n",
    "    clip_max=1.0, targeted=False)\n",
    "\n",
    "adversary_PGD_L1_std = attacks.L1PGDAttack(\n",
    "    model, loss_fn=nn.CrossEntropyLoss(reduction=\"sum\"), eps=10.,\n",
    "    nb_iter=40, eps_iter=0.5, rand_init=True, clip_min=0.0,\n",
    "    clip_max=1.0, targeted=False)\n",
    "\n",
    "def get_fb_attack(attack_name):\n",
    "    if attack_name == 'PA_L1':\n",
    "        fb_attack = fb.attacks.PointwiseAttack()\n",
    "        fb_attack._distance = fb.distances.l1\n",
    "        metric = 'L1'\n",
    "    elif attack_name == 'PA_L2':\n",
    "        fb_attack = fb.attacks.PointwiseAttack()\n",
    "        fb_attack._distance = fb.distances.l2\n",
    "        metric = 'L2'\n",
    "    elif attack_name == 'BA_L2':\n",
    "        fb_attack = fb.attacks.BoundaryAttack(steps=5000)\n",
    "        metric = 'L2'\n",
    "    elif attack_name == \"VAT\":\n",
    "        fb_attack = fb.attacks.VirtualAdversarialAttack(steps=1000)\n",
    "        metric = 'L2'\n",
    "    elif attack_name == 'InvL2':\n",
    "        fb_attack = fb.attacks.InversionAttack(distance=fb.distances.l2)\n",
    "        metric = 'L2'\n",
    "    elif attack_name == 'LinContL2':\n",
    "        fb_attack = fb.attacks.LinearSearchContrastReductionAttack(distance=fb.distances.l2)\n",
    "        metric = 'L2'\n",
    "    else:\n",
    "        raise ValueError(\"Invalid fb attack:\", attack_name)\n",
    "    return fb_attack,  metric\n",
    "\n",
    "def generate_domains(domain_name, data, label, batch_size=batch_size_test, bool_correct_preds_per_domain={}):\n",
    "    if len(bool_correct_preds_per_domain) == 0:\n",
    "        mask = torch.ones_like(label)\n",
    "    else:\n",
    "        mask = bool_correct_preds_per_domain[domain_name]\n",
    "    masked_data = data[mask, :, :, :]\n",
    "    masked_label = label[mask]\n",
    "\n",
    "    # All the data might have been masked. In that case return None.\n",
    "    if len(masked_data) == 0:\n",
    "        return None\n",
    "\n",
    "    if domain_name == 'clean':\n",
    "        return masked_data\n",
    "    if domain_name == 'PGD_L1_std':\n",
    "        return adversary_PGD_L1_std.perturb(masked_data, masked_label)\n",
    "    if domain_name == 'PGD_L2_std':\n",
    "        return adversary_PGD_L2_std.perturb(masked_data, masked_label)\n",
    "    if domain_name == 'PGD_Linf_std':\n",
    "        return adversary_PGD_Linf_std.perturb(masked_data, masked_label)\n",
    "    if domain_name == 'Deepfool_base':\n",
    "        return adversary_deepfool.perturb(masked_data, masked_label)\n",
    "    if domain_name == \"CW_base\":\n",
    "        return adversary_CW.perturb(masked_data, masked_label)\n",
    "    if domain_name == 'PGD_Linf_mod':\n",
    "        return adversary_PGD_Linf_unseen.perturb(masked_data, masked_label)\n",
    "    if domain_name == 'Deepfool_mod':\n",
    "        return adversary_deepfool_unseen.perturb(masked_data, masked_label)\n",
    "    if domain_name == 'CW_mod':\n",
    "        return adversary_CW_unseen.perturb(masked_data, masked_label)\n",
    "    if domain_name == \"Autoattack\":\n",
    "        return adversary_autoattack_unseen.run_standard_evaluation(masked_data, masked_label, bs=len(masked_label))\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "def loss_helper(model, data_all_domains, label_all_domains, num_domains, num_correct_per_domain, tensor_list_losses_epoch):\n",
    "    list_losses = []\n",
    "    \n",
    "    for domain in range(0, num_domains):\n",
    "        preds = model(data_all_domains[domain])\n",
    "        list_losses.append(F.cross_entropy(preds, label_all_domains[domain]))\n",
    "        num_correct_per_domain[domain] += ((torch.argmax(preds, dim=1) == label_all_domains[domain]).sum().item())\n",
    "    \n",
    "    # Some spaghetti going on here between torch and lists types, as evidenced by how the loss_helper() is called in compute_loss()\n",
    "    tensor_list_losses = torch.stack(list_losses)\n",
    "    \n",
    "    ERM_term = torch.sum(tensor_list_losses) / num_domains\n",
    "    REx_variance_term = torch.var(tensor_list_losses)\n",
    "    \n",
    "    tensor_list_losses_epoch += tensor_list_losses\n",
    "    \n",
    "    return ERM_term, REx_variance_term\n",
    "\n",
    "def REx_loss(ERM_term, REx_variance_term, beta):\n",
    "    return beta * REx_variance_term + ERM_term\n",
    "\n",
    " \n",
    "def compute_loss(is_REx, beta, loss_terms, model, list_data_all_domains, list_label_all_domains, num_domains, \n",
    "                 num_train_correct_preds_per_domain, tensor_list_losses_epoch_train):\n",
    "    if is_REx:\n",
    "        ERM_term, REx_variance_term = loss_helper(model, list_data_all_domains, list_label_all_domains, num_domains, num_train_correct_preds_per_domain, tensor_list_losses_epoch_train)\n",
    "        loss_terms_temp = [ERM_term.item(), REx_variance_term.item()]\n",
    "        loss_terms += np.array(loss_terms_temp)\n",
    "        loss = REx_loss(ERM_term, REx_variance_term, beta)\n",
    "    else:\n",
    "        ERM_term, _ = loss_helper(model, list_data_all_domains, list_label_all_domains, num_domains, num_train_correct_preds_per_domain, tensor_list_losses_epoch_train)\n",
    "        loss_terms += np.array([ERM_term.item()])\n",
    "        loss = ERM_term\n",
    "    return loss\n",
    "\n",
    "\n",
    "# Keep track across restarts of which samples were still correctly predicted, for each attack\n",
    "def track_correct_pred_per_domain(model, data_all_domains, labels, domains, bool_correct_per_domain):\n",
    "    for domain in domains:\n",
    "        # Case when the mask filtered all data\n",
    "        if data_all_domains[domain] == None:\n",
    "            continue\n",
    "\n",
    "        preds = model(data_all_domains[domain])\n",
    "        # bool_correct_per_domain[domain] = torch.logical_and(bool_correct_per_domain[domain], (torch.argmax(preds, dim=1) == label_all_domains[domain]))\n",
    "\n",
    "        # Array sizes of preds and bool_correct are different because of the mask when generating the domains, so handling it manually. Maybe\n",
    "        # there is/will be a native method to handle this but gotta go fast.\n",
    "        mask = bool_correct_per_domain[domain]\n",
    "        are_preds_right = (torch.argmax(preds, dim=1) == labels[mask])\n",
    "        i = 0\n",
    "        for k in range(len(bool_correct_per_domain[domain])):\n",
    "            if bool_correct_per_domain[domain][k]:\n",
    "                bool_correct_per_domain[domain][k] = are_preds_right[i]\n",
    "                i += 1\n",
    "    return\n",
    "\n",
    "# Compute the number of correct predictions against each attack after all the restarts\n",
    "def update_num_correct_pred_per_domain(num_correct_per_domain, bool_correct_per_domain, domains):\n",
    "    for domain in domains:\n",
    "        num_correct_per_domain[domain] += bool_correct_per_domain[domain].sum().item()\n",
    "    return\n",
    "\n",
    "# Compute the number of correct predictions if the attacker was using an ensemble of all attacks. Skip the attacks in skipped_domains_worst_case from calculation.\n",
    "def get_num_correct_worst_case(bool_correct_per_domain, domains, skipped_domains_worst_case=[]):\n",
    "    # TODO WARNING\n",
    "    # TODO WARNING\n",
    "    if len(domains) == 0:\n",
    "        raise ValueError(\"No domain has been defined !\")\n",
    "    \n",
    "    bool_correct_worst_case = torch.ones_like(bool_correct_per_domain[domains[0]], dtype=torch.bool)\n",
    "    for domain in domains:\n",
    "        if domain in skipped_domains_worst_case:\n",
    "            continue\n",
    "        bool_correct_worst_case = torch.logical_and(bool_correct_worst_case, bool_correct_per_domain[domain])\n",
    "\n",
    "    return bool_correct_worst_case.sum().item()\n",
    "\n",
    "# Get which attacks were seen based on model filename\n",
    "def get_seen_attacks(model_name):\n",
    "    split_model_name = model_name.split('_')\n",
    "    seen_attacks = []\n",
    "    if \"MSD\" in split_model_name:\n",
    "        if \"ERM\" in split_model_name:\n",
    "            seen_attacks = ['PGD_L1_std', 'PGD_L2_std', 'PGD_Linf_std']\n",
    "        else:\n",
    "            seen_attacks = ['clean', 'PGD_L1_std', 'PGD_L2_std', 'PGD_Linf_std']\n",
    "    if \"PGDs\" in split_model_name:\n",
    "        seen_attacks = ['clean', 'PGD_L1_std', 'PGD_L2_std', 'PGD_Linf_std']\n",
    "    if \"std\" in split_model_name:\n",
    "        seen_attacks = ['clean', 'PGD_Linf_std', 'Deepfool_base', 'CW_base']\n",
    "    if \"clean\" in split_model_name:\n",
    "        seen_attacks = ['clean']\n",
    "    if \"L1\" in split_model_name:\n",
    "        seen_attacks = ['clean', 'PGD_L1_std']\n",
    "    if \"L2\" in split_model_name:\n",
    "        seen_attacks = ['clean', 'PGD_L2_std']\n",
    "    if \"Linf\" in split_model_name:\n",
    "        seen_attacks = ['clean', 'PGD_Linf_std']\n",
    "    return seen_attacks\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "    \n",
    "\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "resume = True\n",
    "# If you do not want restarts, set to 1 and not 0 as it's the number of times an adv is computed per sample\n",
    "num_attack_restarts = 10\n",
    "\n",
    "WORKING_DIR = \"results/MNIST/\"\n",
    "TRAINED_MODEL_PATH = WORKING_DIR + \"models/\"\n",
    "for root, dirs, files in os.walk(TRAINED_MODEL_PATH):\n",
    "    model_filenames = files\n",
    "    model_paths = [TRAINED_MODEL_PATH + file for file in files]\n",
    "\n",
    "\n",
    "\n",
    "# if resume:\n",
    "#     # checkpoint = torch.load(\"experiments/MNIST/MLP/pretrained_hard_PGD/REx_waterfall_lr_init_0.01/model_AIT_REx_3040.pt\")\n",
    "#     checkpoint = torch.load(\"model_MNIST_std_REx_840.pt\")\n",
    "#     # checkpoint = torch.load(\"model_MNIST_MSD_250.pt\")\n",
    "#     starting_epoch = checkpoint['epoch']\n",
    "#     model.load_state_dict(checkpoint['current_model'])\n",
    "#     model.to(device)\n",
    "\n",
    "\n",
    "\n",
    "        \n",
    "\n",
    "# TRAINED_MODEL_PATH = \"experiments/MNIST/MLP/test/\"\n",
    "# writer = SummaryWriter(TRAINED_MODEL_PATH)\n",
    "\n",
    "fb_attacks = []\n",
    "domains = ['clean', 'PGD_L1_std', 'PGD_L2_std', 'PGD_Linf_std', 'Deepfool_base', 'CW_base',\n",
    "                'PGD_Linf_mod', 'Deepfool_mod', 'CW_mod', 'Autoattack']\n",
    "skipped_domains_worst_case = ['PGD_Linf_mod']\n",
    "# includes foolbox attacks\n",
    "all_domains = domains + fb_attacks\n",
    "\n",
    "num_test_batches = len(test_loader)\n",
    "# Number of non foolbox domains\n",
    "num_domains = 0\n",
    "# Number of foolbox domains\n",
    "num_fb_domains = len(fb_attacks)\n",
    "    \n",
    "    \n",
    "######################    \n",
    "# test the model #\n",
    "######################\n",
    "model.eval()\n",
    "if len(fb_attacks) > 0:\n",
    "        # We do not use foolbox in our evaluations. All foolbox code is just here for future-proofing.\n",
    "        # As such, we do not the code to require foolbox and put the import in a conditional statement.\n",
    "        import foolbox as fb\n",
    "        fmodel = fb.PyTorchModel(model, bounds=(0, 1), device=device)\n",
    "\n",
    "for model_num, model_path in enumerate(model_paths):\n",
    "    # checkpoint = torch.load(\"experiments/MNIST/MLP/pretrained_hard_PGD/REx_waterfall_lr_init_0.01/model_AIT_REx_3040.pt\")\n",
    "    checkpoint = torch.load(model_path)\n",
    "    # checkpoint = torch.load(\"model_MNIST_MSD_250.pt\")\n",
    "    starting_epoch = checkpoint['epoch']\n",
    "    model.load_state_dict(checkpoint['current_model'])\n",
    "    model.to(device)\n",
    "\n",
    "    seen_attacks = get_seen_attacks(model_filenames[model_num])\n",
    "    unseen_attacks = [attack for attack in all_domains if attack not in seen_attacks]\n",
    "    always_unseen_attacks = ['PGD_Linf_mod', 'Deepfool_mod', 'CW_mod', 'Autoattack']\n",
    "\n",
    "    # number of correct predictions on each domain\n",
    "    num_test_correct_preds_per_domain = {}\n",
    "    results = {}\n",
    "    for domain in all_domains:\n",
    "        results[domain] = 0\n",
    "        results[domain + \"_bool_track_correct_preds\"] = []\n",
    "        num_test_correct_preds_per_domain[domain] = 0\n",
    "        \n",
    "    # number of correct predictions against ensemble of all attacks, first excludes the skipped domains in worst case calculation, second doesn't\n",
    "    num_test_correct_preds_per_domain['worst_no_skipped'] = 0\n",
    "    num_test_correct_preds_per_domain['worst_with_skipped'] = 0\n",
    "    # number of correct preds against worst ensemble of seen or unseen\n",
    "    num_test_correct_preds_per_domain['worst_seen'] = 0\n",
    "    num_test_correct_preds_per_domain['worst_unseen'] = 0\n",
    "    num_test_correct_preds_per_domain['worst_unseen_no_skipped'] = 0\n",
    "    num_test_correct_preds_per_domain['worst_always_unseen'] = 0\n",
    "    num_test_correct_preds_per_domain['worst_always_unseen_no_skipped'] = 0\n",
    "\n",
    "    which_batch_test = 1\n",
    "\n",
    "    for _, (data, label) in enumerate(test_loader):\n",
    "        data, label = data.to(device), label.to(device)\n",
    "\n",
    "        # Keeps track for each sample and each domain of whether one restart succeeded in fooling the network by using logical and\n",
    "        # on (label == prediction) and bool_track_correct_pred each iteration. fb trackers are appended later in the code\n",
    "        bool_track_correct_pred_per_domain = {}\n",
    "        for domain in all_domains:\n",
    "            bool_track_correct_pred_per_domain[domain] = torch.ones_like(label, dtype=torch.bool)\n",
    "\n",
    "\n",
    "        for i_restarts in range(0, num_attack_restarts):\n",
    "            with ctx_noparamgrad_and_eval(model):\n",
    "                # Clean data is a domain.\n",
    "                data_all_domains = {}\n",
    "                for domain in domains:\n",
    "                    data_all_domains[domain] = generate_domains(domain, data, label, batch_size=batch_size_test, bool_correct_preds_per_domain=bool_track_correct_pred_per_domain)\n",
    "\n",
    "\n",
    "                # num_domains = len(data_all_domains)\n",
    "                # # Initialise count of correct predictions per domain. This array tracks both non fb AND fb domains\n",
    "                # if len(num_test_correct_preds_per_domain) == 0:\n",
    "                #     num_test_correct_preds_per_domain = np.zeros(num_domains + num_fb_domains)\n",
    "\n",
    "\n",
    "\n",
    "                # if len(bool_track_correct_pred_per_domain) == 0:\n",
    "                #     bool_track_correct_pred_per_domain = [torch.ones_like(label)] * num_domains\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "            with torch.no_grad():\n",
    "                track_correct_pred_per_domain(model, data_all_domains, label, domains, bool_track_correct_pred_per_domain)\n",
    "\n",
    "        # Out of the block that is restarted due to historically testing Boundary attack here, which doesn't require restarts\n",
    "        for fb_attack_name in fb_attacks:\n",
    "            # Only notified on the first minibatch to avoid spamming\n",
    "            if which_batch_test == 1:\n",
    "                print(\"Using Foolbox attack \", fb_attack_name)\n",
    "            fb_attack, metric = get_fb_attack(fb_attack_name)\n",
    "            if metric == 'L0' or metric == 'L1':\n",
    "                epsilon = 10.\n",
    "            elif metric == 'L2':\n",
    "                epsilon = 2\n",
    "            _, temp_adv, bool_track_preds_temp = fb_attack(fmodel, data, label, epsilons=epsilon)\n",
    "            # invert the bool because foolbox reports the attack's successes as True and we track the model's successes against adv\n",
    "            bool_track_correct_pred_per_domain[fb_attack_name] = (~bool_track_preds_temp)\n",
    "\n",
    "\n",
    "            # # Measure distance between adv example and clean sample with the same norm as the attack, compute the median distance over minibatch\n",
    "            # temp = temp_adv-data\n",
    "            # temp = torch.reshape(temp, (100, -1))\n",
    "            # print(torch.linalg.norm(temp, dim=1, ord=int(metric[-1])).median())\n",
    "\n",
    "\n",
    "        with torch.no_grad():\n",
    "            update_num_correct_pred_per_domain(num_test_correct_preds_per_domain, bool_track_correct_pred_per_domain, all_domains)\n",
    "            num_test_correct_preds_per_domain['worst_no_skipped'] += get_num_correct_worst_case(bool_track_correct_pred_per_domain, all_domains, skipped_domains_worst_case)\n",
    "            num_test_correct_preds_per_domain['worst_with_skipped'] += get_num_correct_worst_case(bool_track_correct_pred_per_domain, all_domains)\n",
    "            num_test_correct_preds_per_domain['worst_seen'] += get_num_correct_worst_case(bool_track_correct_pred_per_domain, seen_attacks)\n",
    "            num_test_correct_preds_per_domain['worst_unseen'] += get_num_correct_worst_case(bool_track_correct_pred_per_domain, unseen_attacks)\n",
    "            num_test_correct_preds_per_domain['worst_unseen_no_skipped'] += get_num_correct_worst_case(bool_track_correct_pred_per_domain, unseen_attacks, skipped_domains_worst_case)\n",
    "            num_test_correct_preds_per_domain['worst_always_unseen'] += get_num_correct_worst_case(bool_track_correct_pred_per_domain, always_unseen_attacks)\n",
    "            num_test_correct_preds_per_domain['worst_always_unseen_no_skipped'] += get_num_correct_worst_case(bool_track_correct_pred_per_domain, always_unseen_attacks, skipped_domains_worst_case)\n",
    "\n",
    "\n",
    "        # Keep track of bool array to avoid having to redo the very costly perturbation with all attacks in case further metrics are needed\n",
    "        for domain in all_domains:\n",
    "            results[domain + \"_bool_track_correct_preds\"].append(bool_track_correct_pred_per_domain[domain].to('cpu'))\n",
    "\n",
    "        # Debugging\n",
    "        print(\"Testing, epoch \", starting_epoch, \": done with batch \", which_batch_test, \" out of \", num_test_batches)\n",
    "        if which_batch_test % 5 == 0:\n",
    "            # print(\"Testing, epoch \", starting_epoch, \": done with batch \", which_batch_test, \" out of \", num_test_batches)\n",
    "            print(\"GPU memory allocated in GB:\", torch.cuda.memory_allocated()/10**9)\n",
    "            # Only compute on the first 10 minibatches = 1000 test samples with the default test minibatches of 100\n",
    "            break\n",
    "        which_batch_test += 1\n",
    "\n",
    "\n",
    "\n",
    "    # calculate accuracies\n",
    "    for keys, _ in num_test_correct_preds_per_domain.items():\n",
    "        results[keys] = num_test_correct_preds_per_domain[keys] / (which_batch_test * batch_size_test) #len(test_loader.sampler)\n",
    "    results['num_test_samples'] = (which_batch_test * batch_size_test)\n",
    "    results['num_attack_restarts'] = num_attack_restarts\n",
    "    results['model_name'] = model_filenames[model_num]\n",
    "    results['seen_attacks'] = seen_attacks\n",
    "    results['unseen_attacks'] = unseen_attacks\n",
    "    results['always_unseen_attacks'] = always_unseen_attacks\n",
    "    results['skipped_domains_worst_case'] = skipped_domains_worst_case\n",
    "    print(results)\n",
    "\n",
    "    working_dir_of_save = WORKING_DIR + \"test_accs/\"\n",
    "    if not os.path.exists(working_dir_of_save):\n",
    "        os.mkdir(WORKING_DIR + \"test_accs/\")\n",
    "    np.save(working_dir_of_save + model_filenames[model_num], results)\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "# writer.add_scalar('Test_accuracy_clean', test_acc_per_domain[0], starting_epoch)\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "# writer.close()\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Testing, epoch  656 : done with batch  1  out of  50\n",
      "Testing, epoch  656 : done with batch  2  out of  50\n",
      "Testing, epoch  656 : done with batch  3  out of  50\n",
      "Testing, epoch  656 : done with batch  4  out of  50\n",
      "Testing, epoch  656 : done with batch  5  out of  50\n",
      "GPU memory allocated in GB: 0.026421248\n",
      "{'clean': 0.884, 'PGD_L1_std_eps_20.0': 0.812, 'PGD_L1_std_eps_40.0': 0.835, 'PGD_L1_std_eps_60.0': 0.809, 'PGD_L1_std_eps_80.0': 0.818, 'PGD_L1_std_eps_100.0': 0.818, 'PGD_L2_std_eps_1.0': 0.632, 'PGD_L2_std_eps_2.0': 0.497, 'PGD_L2_std_eps_3.0': 0.305, 'PGD_L2_std_eps_4.0': 0.07, 'PGD_L2_std_eps_5.0': 0.008, 'PGD_Linf_std_eps_0.1': 0.68, 'PGD_Linf_std_eps_0.2': 0.449, 'PGD_Linf_std_eps_0.3': 0.082, 'PGD_Linf_std_eps_0.4': 0.002, 'PGD_Linf_std_eps_0.5': 0.0, 'num_test_samples': 1000, 'num_attack_restarts': 10, 'model_name': 'model_MSD_ERM_655.pt'}\n",
      "Testing, epoch  656 : done with batch  1  out of  50\n",
      "Testing, epoch  656 : done with batch  2  out of  50\n",
      "Testing, epoch  656 : done with batch  3  out of  50\n",
      "Testing, epoch  656 : done with batch  4  out of  50\n",
      "Testing, epoch  656 : done with batch  5  out of  50\n",
      "GPU memory allocated in GB: 0.026421248\n",
      "{'clean': 0.902, 'PGD_L1_std_eps_20.0': 0.864, 'PGD_L1_std_eps_40.0': 0.866, 'PGD_L1_std_eps_60.0': 0.851, 'PGD_L1_std_eps_80.0': 0.828, 'PGD_L1_std_eps_100.0': 0.809, 'PGD_L2_std_eps_1.0': 0.696, 'PGD_L2_std_eps_2.0': 0.524, 'PGD_L2_std_eps_3.0': 0.157, 'PGD_L2_std_eps_4.0': 0.004, 'PGD_L2_std_eps_5.0': 0.0, 'PGD_Linf_std_eps_0.1': 0.821, 'PGD_Linf_std_eps_0.2': 0.763, 'PGD_Linf_std_eps_0.3': 0.604, 'PGD_Linf_std_eps_0.4': 0.009, 'PGD_Linf_std_eps_0.5': 0.0, 'num_test_samples': 1000, 'num_attack_restarts': 10, 'model_name': 'model_MSD_REx_655.pt'}\n",
      "Testing, epoch  1106 : done with batch  1  out of  50\n",
      "Testing, epoch  1106 : done with batch  2  out of  50\n",
      "Testing, epoch  1106 : done with batch  3  out of  50\n",
      "Testing, epoch  1106 : done with batch  4  out of  50\n",
      "Testing, epoch  1106 : done with batch  5  out of  50\n",
      "GPU memory allocated in GB: 0.026421248\n",
      "{'clean': 0.988, 'PGD_L1_std_eps_20.0': 0.944, 'PGD_L1_std_eps_40.0': 0.937, 'PGD_L1_std_eps_60.0': 0.922, 'PGD_L1_std_eps_80.0': 0.909, 'PGD_L1_std_eps_100.0': 0.909, 'PGD_L2_std_eps_1.0': 0.907, 'PGD_L2_std_eps_2.0': 0.597, 'PGD_L2_std_eps_3.0': 0.078, 'PGD_L2_std_eps_4.0': 0.003, 'PGD_L2_std_eps_5.0': 0.0, 'PGD_Linf_std_eps_0.1': 0.925, 'PGD_Linf_std_eps_0.2': 0.801, 'PGD_Linf_std_eps_0.3': 0.431, 'PGD_Linf_std_eps_0.4': 0.005, 'PGD_Linf_std_eps_0.5': 0.0, 'num_test_samples': 1000, 'num_attack_restarts': 10, 'model_name': 'model_PGDs_ERM_1105.pt'}\n",
      "Testing, epoch  1106 : done with batch  1  out of  50\n",
      "Testing, epoch  1106 : done with batch  2  out of  50\n",
      "Testing, epoch  1106 : done with batch  3  out of  50\n",
      "Testing, epoch  1106 : done with batch  4  out of  50\n",
      "Testing, epoch  1106 : done with batch  5  out of  50\n",
      "GPU memory allocated in GB: 0.026421248\n",
      "{'clean': 0.873, 'PGD_L1_std_eps_20.0': 0.83, 'PGD_L1_std_eps_40.0': 0.813, 'PGD_L1_std_eps_60.0': 0.816, 'PGD_L1_std_eps_80.0': 0.815, 'PGD_L1_std_eps_100.0': 0.767, 'PGD_L2_std_eps_1.0': 0.695, 'PGD_L2_std_eps_2.0': 0.482, 'PGD_L2_std_eps_3.0': 0.151, 'PGD_L2_std_eps_4.0': 0.004, 'PGD_L2_std_eps_5.0': 0.0, 'PGD_Linf_std_eps_0.1': 0.775, 'PGD_Linf_std_eps_0.2': 0.697, 'PGD_Linf_std_eps_0.3': 0.594, 'PGD_Linf_std_eps_0.4': 0.003, 'PGD_Linf_std_eps_0.5': 0.0, 'num_test_samples': 1000, 'num_attack_restarts': 10, 'model_name': 'model_PGDs_REx_1105.pt'}\n",
      "Testing, epoch  931 : done with batch  1  out of  50\n",
      "Testing, epoch  931 : done with batch  2  out of  50\n",
      "Testing, epoch  931 : done with batch  3  out of  50\n",
      "Testing, epoch  931 : done with batch  4  out of  50\n",
      "Testing, epoch  931 : done with batch  5  out of  50\n",
      "GPU memory allocated in GB: 0.026421248\n",
      "{'clean': 0.989, 'PGD_L1_std_eps_20.0': 0.897, 'PGD_L1_std_eps_40.0': 0.858, 'PGD_L1_std_eps_60.0': 0.835, 'PGD_L1_std_eps_80.0': 0.783, 'PGD_L1_std_eps_100.0': 0.755, 'PGD_L2_std_eps_1.0': 0.876, 'PGD_L2_std_eps_2.0': 0.342, 'PGD_L2_std_eps_3.0': 0.015, 'PGD_L2_std_eps_4.0': 0.0, 'PGD_L2_std_eps_5.0': 0.0, 'PGD_Linf_std_eps_0.1': 0.937, 'PGD_Linf_std_eps_0.2': 0.82, 'PGD_Linf_std_eps_0.3': 0.536, 'PGD_Linf_std_eps_0.4': 0.004, 'PGD_Linf_std_eps_0.5': 0.0, 'num_test_samples': 1000, 'num_attack_restarts': 10, 'model_name': 'model_std_ERM_930.pt'}\n",
      "Testing, epoch  1126 : done with batch  1  out of  50\n",
      "Testing, epoch  1126 : done with batch  2  out of  50\n",
      "Testing, epoch  1126 : done with batch  3  out of  50\n",
      "Testing, epoch  1126 : done with batch  4  out of  50\n",
      "Testing, epoch  1126 : done with batch  5  out of  50\n",
      "GPU memory allocated in GB: 0.026421248\n",
      "{'clean': 0.9, 'PGD_L1_std_eps_20.0': 0.691, 'PGD_L1_std_eps_40.0': 0.661, 'PGD_L1_std_eps_60.0': 0.59, 'PGD_L1_std_eps_80.0': 0.547, 'PGD_L1_std_eps_100.0': 0.402, 'PGD_L2_std_eps_1.0': 0.619, 'PGD_L2_std_eps_2.0': 0.186, 'PGD_L2_std_eps_3.0': 0.028, 'PGD_L2_std_eps_4.0': 0.005, 'PGD_L2_std_eps_5.0': 0.0, 'PGD_Linf_std_eps_0.1': 0.824, 'PGD_Linf_std_eps_0.2': 0.745, 'PGD_Linf_std_eps_0.3': 0.66, 'PGD_Linf_std_eps_0.4': 0.037, 'PGD_Linf_std_eps_0.5': 0.0, 'num_test_samples': 1000, 'num_attack_restarts': 10, 'model_name': 'model_std_REx_1125.pt'}\n"
     ]
    },
    {
     "ename": "",
     "evalue": "",
     "output_type": "error",
     "traceback": [
      "\u001b[1;31mnotebook controller is DISPOSED. \n",
      "View Jupyter <a href='command:jupyter.viewOutput'>log</a> for further details."
     ]
    }
   ],
   "source": [
    "# clear memory\n",
    "from IPython import get_ipython\n",
    "get_ipython().magic('reset -sf') \n",
    "\n",
    "import numpy as np\n",
    "import torch\n",
    "import time\n",
    "timer = 0\n",
    "\n",
    "from torchvision import datasets\n",
    "import torchvision.transforms as transforms\n",
    "from torch.utils.data.sampler import SubsetRandomSampler\n",
    "\n",
    "import torch.nn as nn\n",
    "import torch.nn.functional as F\n",
    "\n",
    "import advertorch.attacks as attacks\n",
    "from attacks.deepfool import DeepfoolLinfAttack\n",
    "import torch.nn as nn\n",
    "from autoattack import AutoAttack\n",
    "\n",
    "from advertorch.context import ctx_noparamgrad_and_eval\n",
    "from torch.utils.tensorboard import SummaryWriter\n",
    "\n",
    "\n",
    "\n",
    "import os, random\n",
    "\n",
    "\n",
    "# import argparse\n",
    "\n",
    "# argument_parser = argparse.ArgumentParser()\n",
    "\n",
    "# argument_parser.add_argument(\"--lr_init\", type=float, help=\"Initial learning rate value, default=0.01. CAREFUL: this will be divided by beta, since the ERM term is multiplied by beta in the objective.\")\n",
    "\n",
    "# parsed_args = argument_parser.parse_args()\n",
    "\n",
    "\n",
    "# Make sure validation splits are the same at all time (e.g. even after loading)\n",
    "seed = 0\n",
    "\n",
    "def seed_init_fn(seed=seed):\n",
    "   np.random.seed(seed)\n",
    "   random.seed(seed)\n",
    "   torch.manual_seed(seed)\n",
    "   return\n",
    "\n",
    "device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n",
    "\n",
    "num_workers = 0\n",
    "# Make sure test_data is a multiple of batch_size_test\n",
    "batch_size_train_and_valid = 128\n",
    "batch_size_test = 200\n",
    "\n",
    "# proportion of full training set used for validation\n",
    "valid_size = 0.2\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "transform = transforms.ToTensor()\n",
    "train_and_valid_data = datasets.MNIST(root = 'data', train = True, download = True, transform = transform)\n",
    "test_data = datasets.MNIST(root = 'data', train = False, download = True, transform = transform)\n",
    "\n",
    "num_valid_samples = int(np.floor(valid_size * len(train_and_valid_data)))\n",
    "num_train_samples = len(train_and_valid_data) - num_valid_samples\n",
    "train_data, valid_data = torch.utils.data.random_split(train_and_valid_data, [num_train_samples, num_valid_samples], generator=torch.Generator().manual_seed(seed))\n",
    "\n",
    "train_loader = torch.utils.data.DataLoader(train_data, batch_size = batch_size_train_and_valid)\n",
    "valid_loader = torch.utils.data.DataLoader(valid_data, batch_size = batch_size_train_and_valid)\n",
    "test_loader = torch.utils.data.DataLoader(test_data, batch_size = batch_size_test, worker_init_fn=seed_init_fn)\n",
    "\n",
    "\n",
    "class Net(nn.Module):\n",
    "    def __init__(self):\n",
    "        super(Net,self).__init__()\n",
    "        \n",
    "        self.fc1 = nn.Linear(28*28, 512)\n",
    "        self.fc2 = nn.Linear(512, 512)\n",
    "        self.fc3 = nn.Linear(512, 10)\n",
    "        \n",
    "    def forward(self,x):\n",
    "        # vectorise input\n",
    "        x = x.view(-1,28*28)\n",
    "        # Hidden layer 1 + relu\n",
    "        x = F.relu(self.fc1(x))\n",
    "        # Hidden layer 2 + relu\n",
    "        x = F.relu(self.fc2(x))\n",
    "        # Output layer\n",
    "        x = self.fc3(x)\n",
    "        return x\n",
    "\n",
    "\n",
    "model = Net()\n",
    "# model.to(device)\n",
    "\n",
    "\n",
    "model.load_state_dict(torch.load('model_no_dropout.pt'))\n",
    "model.to(device)\n",
    "\n",
    "\n",
    "# if str(device) == \"cuda\" and torch.cuda.device_count() > 1:\n",
    "#     print(\"Using DataParallel\")\n",
    "#     model = torch.nn.DataParallel(model)\n",
    "# model.to(device)\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "def eval_PGD_Linf_increasing_eps(list_data_all_domains, epsilons):\n",
    "    for epsilon in epsilons:\n",
    "        adversary_PGD_attempt = attacks.LinfPGDAttack(\n",
    "            model, loss_fn=nn.CrossEntropyLoss(reduction=\"sum\"), eps=epsilon,\n",
    "            nb_iter=200, eps_iter=0.05, rand_init=True, clip_min=0.0,\n",
    "            clip_max=1.0, targeted=False)\n",
    "        list_data_all_domains.append(adversary_PGD_attempt.perturb(data, label))\n",
    "    return\n",
    "\n",
    "def eval_PGD_L2_increasing_eps(list_data_all_domains, epsilons):\n",
    "    for epsilon in epsilons:\n",
    "        adversary_PGD_attempt = attacks.L2PGDAttack(\n",
    "            model, loss_fn=nn.CrossEntropyLoss(reduction=\"sum\"), eps=epsilon,\n",
    "            nb_iter=200, eps_iter=0.1, rand_init=True, clip_min=0.0,\n",
    "            clip_max=1.0, targeted=False)\n",
    "        list_data_all_domains.append(adversary_PGD_attempt.perturb(data, label))\n",
    "    return\n",
    "\n",
    "def eval_PGD_L1_increasing_eps(list_data_all_domains, epsilons):\n",
    "    for epsilon in epsilons:\n",
    "        adversary_PGD_attempt = attacks.L1PGDAttack(\n",
    "            model, loss_fn=nn.CrossEntropyLoss(reduction=\"sum\"), eps=epsilon,\n",
    "            nb_iter=200, eps_iter=0.5, rand_init=True, clip_min=0.0,\n",
    "            clip_max=1.0, targeted=False)\n",
    "        list_data_all_domains.append(adversary_PGD_attempt.perturb(data, label))\n",
    "    return\n",
    "\n",
    "\n",
    "# Keep track across restarts of which samples were still correctly predicted, for each attack\n",
    "def track_correct_pred_per_domain(model, data_all_domains, label_all_domains, num_domains, bool_correct_per_domain):\n",
    "    for domain in range(0, num_domains):\n",
    "        preds = model(data_all_domains[domain])\n",
    "        # print((torch.argmax(preds, dim=1) == label_all_domains[domain]))\n",
    "        bool_correct_per_domain[domain] = torch.logical_and(bool_correct_per_domain[domain], (torch.argmax(preds, dim=1) == label_all_domains[domain]))\n",
    "    return\n",
    "\n",
    "# Compute the number of correct predictions against each attack after all the restarts\n",
    "def update_num_correct_pred_per_domain(num_correct_per_domain, bool_correct_per_domain, num_domains, only_update_fb=False):\n",
    "    start_for_loop = 0\n",
    "    # Avoids reaccumulating the first few entries\n",
    "    if only_update_fb:\n",
    "        start_for_loop = num_domains\n",
    "    for domain in range(start_for_loop, len(bool_correct_per_domain)):\n",
    "        num_correct_per_domain[domain] += bool_correct_per_domain[domain].sum().item()\n",
    "    return\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "WORKING_DIR = \"results/MNIST/\"\n",
    "TRAINED_MODEL_PATH = WORKING_DIR + \"models/\"\n",
    "for root, dirs, files in os.walk(TRAINED_MODEL_PATH):\n",
    "    model_filenames = files\n",
    "    model_paths = [TRAINED_MODEL_PATH + file for file in files]\n",
    "\n",
    "\n",
    "\n",
    "num_test_batches = len(test_loader)\n",
    "# Number of non foolbox domains\n",
    "num_domains = 0\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "num_attack_restarts = 10\n",
    "epsilons_Linf = [0.1, 0.2, 0.3, 0.4, 0.5]\n",
    "epsilons_L2 = [1.0, 2.0, 3.0, 4.0, 5.0]\n",
    "epsilons_L1 = [20.0, 40.0, 60.0, 80.0, 100.0]\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "base_domains = ['clean', 'PGD_L1_std', 'PGD_L2_std', 'PGD_Linf_std']\n",
    "domains = ['clean']\n",
    "for i, epsilons in enumerate([epsilons_L1, epsilons_L2, epsilons_Linf]):\n",
    "    for epsilon in epsilons:\n",
    "        domains.append(base_domains[i+1] + \"_eps_\" + str(epsilon))\n",
    "\n",
    "\n",
    "\n",
    "######################    \n",
    "# test the model #\n",
    "######################\n",
    "model.eval()\n",
    "for model_num, model_path in enumerate(model_paths):\n",
    "    checkpoint = torch.load(model_path)\n",
    "    starting_epoch = checkpoint['epoch']\n",
    "    model.load_state_dict(checkpoint['current_model'])\n",
    "    model.to(device)\n",
    "\n",
    "    results = {}\n",
    "    for domain in domains:\n",
    "        results[domain] = 0\n",
    "    # number of correct predictions on each domain\n",
    "    num_test_correct_preds_per_domain = []\n",
    "    # number of correct predictions against ensemble of all attacks\n",
    "\n",
    "\n",
    "\n",
    "    which_batch_test = 1\n",
    "\n",
    "    for _, (data, label) in enumerate(test_loader):\n",
    "        data, label = data.to(device), label.to(device)\n",
    "\n",
    "        # Keeps track for each sample and each domain of whether one restart succeeded in fooling the network by using logical and\n",
    "        # on (label == prediction) and bool_track_correct_pred each iteration. fb trackers are appended later in the code\n",
    "        bool_track_correct_pred_per_domain = [torch.ones_like(label)] * num_domains\n",
    "\n",
    "        for i_restarts in range(0, num_attack_restarts):\n",
    "            with ctx_noparamgrad_and_eval(model):\n",
    "                # Clean data is a domain\n",
    "                list_data_all_domains = [data]\n",
    "                # # Eval at multiple eps for each norm to test for gradient masking\n",
    "                eval_PGD_L1_increasing_eps(list_data_all_domains, epsilons=epsilons_L1)\n",
    "                eval_PGD_L2_increasing_eps(list_data_all_domains, epsilons=epsilons_L2)\n",
    "                eval_PGD_Linf_increasing_eps(list_data_all_domains, epsilons=epsilons_Linf)\n",
    "\n",
    "                num_domains = len(list_data_all_domains)\n",
    "                # Initialise count of correct predictions per domain. This array tracks both non fb AND fb domains\n",
    "                if len(num_test_correct_preds_per_domain) == 0:\n",
    "                    num_test_correct_preds_per_domain = np.zeros(num_domains)\n",
    "\n",
    "                list_label_all_domains = [label] * num_domains\n",
    "\n",
    "                if len(bool_track_correct_pred_per_domain) == 0:\n",
    "                    bool_track_correct_pred_per_domain = [torch.ones_like(label)] * num_domains\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "            with torch.no_grad():\n",
    "                track_correct_pred_per_domain(model, list_data_all_domains, list_label_all_domains, num_domains, bool_track_correct_pred_per_domain)\n",
    "\n",
    "        with torch.no_grad():\n",
    "            update_num_correct_pred_per_domain(num_test_correct_preds_per_domain, bool_track_correct_pred_per_domain, num_domains)\n",
    "\n",
    "\n",
    "        # Debugging\n",
    "        print(\"Testing, epoch \", starting_epoch, \": done with batch \", which_batch_test, \" out of \", num_test_batches)\n",
    "        if which_batch_test % 5 == 0:\n",
    "            # print(\"Testing, epoch \", starting_epoch, \": done with batch \", which_batch_test, \" out of \", num_test_batches)\n",
    "            print(\"GPU memory allocated in GB:\", torch.cuda.memory_allocated()/10**9)\n",
    "            break\n",
    "        which_batch_test += 1\n",
    "\n",
    "\n",
    "\n",
    "    # calculate average loss over an epoch\n",
    "    test_acc_per_domain = num_test_correct_preds_per_domain / (which_batch_test * batch_size_test) #len(test_loader.sampler)\n",
    "    results['num_test_samples'] = (which_batch_test * batch_size_test)\n",
    "    results['num_attack_restarts'] = num_attack_restarts\n",
    "    results['model_name'] = model_filenames[model_num]\n",
    "    for i, domain in enumerate(domains):\n",
    "        results[domain] = test_acc_per_domain[i]\n",
    "    print(results)\n",
    "\n",
    "    working_dir_of_save = WORKING_DIR + \"increasing_eps/\"\n",
    "    if not os.path.exists(working_dir_of_save):\n",
    "        os.mkdir(WORKING_DIR + \"increasing_eps/\")\n",
    "    np.save(working_dir_of_save + model_filenames[model_num], results)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "model_MSD_ERM_655.pt\n",
      "{'clean': 0.884, 'PGD_L1_std': 0.822, 'PGD_L2_std': 0.611, 'PGD_Linf_std': 0.193, 'Deepfool_base': 0.567, 'CW_base': 0.771, 'PGD_Linf_mod': 0.002, 'Deepfool_mod': 0.158, 'CW_mod': 0.402, 'Autoattack': 0.015, 'worst_no_skipped': 0.006, 'worst_with_skipped': 0.002, 'worst_seen': 0.193, 'worst_unseen': 0.002, 'worst_unseen_no_skipped': 0.006, 'worst_always_unseen': 0.002, 'worst_always_unseen_no_skipped': 0.006, 'num_test_samples': 1000, 'num_attack_restarts': 10, 'model_name': 'model_MSD_ERM_655.pt', 'seen_attacks': ['PGD_L1_std', 'PGD_L2_std', 'PGD_Linf_std'], 'unseen_attacks': ['clean', 'Deepfool_base', 'CW_base', 'PGD_Linf_mod', 'Deepfool_mod', 'CW_mod', 'Autoattack'], 'always_unseen_attacks': ['PGD_Linf_mod', 'Deepfool_mod', 'CW_mod', 'Autoattack'], 'skipped_domains_worst_case': ['PGD_Linf_mod']} \n",
      "\n",
      "\n",
      "model_MSD_REx_655.pt\n",
      "{'clean': 0.902, 'PGD_L1_std': 0.868, 'PGD_L2_std': 0.718, 'PGD_Linf_std': 0.674, 'Deepfool_base': 0.824, 'CW_base': 0.473, 'PGD_Linf_mod': 0.01, 'Deepfool_mod': 0.199, 'CW_mod': 0.129, 'Autoattack': 0.312, 'worst_no_skipped': 0.039, 'worst_with_skipped': 0.004, 'worst_seen': 0.601, 'worst_unseen': 0.004, 'worst_unseen_no_skipped': 0.039, 'worst_always_unseen': 0.004, 'worst_always_unseen_no_skipped': 0.039, 'num_test_samples': 1000, 'num_attack_restarts': 10, 'model_name': 'model_MSD_REx_655.pt', 'seen_attacks': ['clean', 'PGD_L1_std', 'PGD_L2_std', 'PGD_Linf_std'], 'unseen_attacks': ['Deepfool_base', 'CW_base', 'PGD_Linf_mod', 'Deepfool_mod', 'CW_mod', 'Autoattack'], 'always_unseen_attacks': ['PGD_Linf_mod', 'Deepfool_mod', 'CW_mod', 'Autoattack'], 'skipped_domains_worst_case': ['PGD_Linf_mod']} \n",
      "\n",
      "\n",
      "model_PGD_L1_95.pt\n",
      "{'clean': 0.985, 'PGD_L1_std': 0.968, 'PGD_L2_std': 0.177, 'PGD_Linf_std': 0.0, 'Deepfool_base': 0.057, 'CW_base': 0.069, 'PGD_Linf_mod': 0.0, 'Deepfool_mod': 0.0, 'CW_mod': 0.028, 'Autoattack': 0.0, 'worst_no_skipped': 0.0, 'worst_with_skipped': 0.0, 'worst_seen': 0.968, 'worst_unseen': 0.0, 'worst_unseen_no_skipped': 0.0, 'worst_always_unseen': 0.0, 'worst_always_unseen_no_skipped': 0.0, 'num_test_samples': 1000, 'num_attack_restarts': 10, 'model_name': 'model_PGD_L1_95.pt', 'seen_attacks': ['clean', 'PGD_L1_std'], 'unseen_attacks': ['PGD_L2_std', 'PGD_Linf_std', 'Deepfool_base', 'CW_base', 'PGD_Linf_mod', 'Deepfool_mod', 'CW_mod', 'Autoattack'], 'always_unseen_attacks': ['PGD_Linf_mod', 'Deepfool_mod', 'CW_mod', 'Autoattack'], 'skipped_domains_worst_case': ['PGD_Linf_mod']} \n",
      "\n",
      "\n",
      "model_PGD_L2_75.pt\n",
      "{'clean': 0.983, 'PGD_L1_std': 0.968, 'PGD_L2_std': 0.635, 'PGD_Linf_std': 0.022, 'Deepfool_base': 0.859, 'CW_base': 0.565, 'PGD_Linf_mod': 0.0, 'Deepfool_mod': 0.0, 'CW_mod': 0.16, 'Autoattack': 0.001, 'worst_no_skipped': 0.0, 'worst_with_skipped': 0.0, 'worst_seen': 0.635, 'worst_unseen': 0.0, 'worst_unseen_no_skipped': 0.0, 'worst_always_unseen': 0.0, 'worst_always_unseen_no_skipped': 0.0, 'num_test_samples': 1000, 'num_attack_restarts': 10, 'model_name': 'model_PGD_L2_75.pt', 'seen_attacks': ['clean', 'PGD_L2_std'], 'unseen_attacks': ['PGD_L1_std', 'PGD_Linf_std', 'Deepfool_base', 'CW_base', 'PGD_Linf_mod', 'Deepfool_mod', 'CW_mod', 'Autoattack'], 'always_unseen_attacks': ['PGD_Linf_mod', 'Deepfool_mod', 'CW_mod', 'Autoattack'], 'skipped_domains_worst_case': ['PGD_Linf_mod']} \n",
      "\n",
      "\n",
      "model_PGDs_ERM_1105.pt\n",
      "{'clean': 0.988, 'PGD_L1_std': 0.956, 'PGD_L2_std': 0.683, 'PGD_Linf_std': 0.58, 'Deepfool_base': 0.923, 'CW_base': 0.599, 'PGD_Linf_mod': 0.009, 'Deepfool_mod': 0.037, 'CW_mod': 0.164, 'Autoattack': 0.349, 'worst_no_skipped': 0.012, 'worst_with_skipped': 0.003, 'worst_seen': 0.555, 'worst_unseen': 0.003, 'worst_unseen_no_skipped': 0.012, 'worst_always_unseen': 0.003, 'worst_always_unseen_no_skipped': 0.012, 'num_test_samples': 1000, 'num_attack_restarts': 10, 'model_name': 'model_PGDs_ERM_1105.pt', 'seen_attacks': ['clean', 'PGD_L1_std', 'PGD_L2_std', 'PGD_Linf_std'], 'unseen_attacks': ['Deepfool_base', 'CW_base', 'PGD_Linf_mod', 'Deepfool_mod', 'CW_mod', 'Autoattack'], 'always_unseen_attacks': ['PGD_Linf_mod', 'Deepfool_mod', 'CW_mod', 'Autoattack'], 'skipped_domains_worst_case': ['PGD_Linf_mod']} \n",
      "\n",
      "\n",
      "model_PGDs_REx_1105.pt\n",
      "{'clean': 0.873, 'PGD_L1_std': 0.825, 'PGD_L2_std': 0.728, 'PGD_Linf_std': 0.708, 'Deepfool_base': 0.809, 'CW_base': 0.414, 'PGD_Linf_mod': 0.007, 'Deepfool_mod': 0.584, 'CW_mod': 0.121, 'Autoattack': 0.406, 'worst_no_skipped': 0.081, 'worst_with_skipped': 0.001, 'worst_seen': 0.645, 'worst_unseen': 0.001, 'worst_unseen_no_skipped': 0.081, 'worst_always_unseen': 0.001, 'worst_always_unseen_no_skipped': 0.081, 'num_test_samples': 1000, 'num_attack_restarts': 10, 'model_name': 'model_PGDs_REx_1105.pt', 'seen_attacks': ['clean', 'PGD_L1_std', 'PGD_L2_std', 'PGD_Linf_std'], 'unseen_attacks': ['Deepfool_base', 'CW_base', 'PGD_Linf_mod', 'Deepfool_mod', 'CW_mod', 'Autoattack'], 'always_unseen_attacks': ['PGD_Linf_mod', 'Deepfool_mod', 'CW_mod', 'Autoattack'], 'skipped_domains_worst_case': ['PGD_Linf_mod']} \n",
      "\n",
      "\n",
      "model_std_ERM_930.pt\n",
      "{'clean': 0.989, 'PGD_L1_std': 0.91, 'PGD_L2_std': 0.535, 'PGD_Linf_std': 0.672, 'Deepfool_base': 0.931, 'CW_base': 0.692, 'PGD_Linf_mod': 0.004, 'Deepfool_mod': 0.085, 'CW_mod': 0.234, 'Autoattack': 0.434, 'worst_no_skipped': 0.034, 'worst_with_skipped': 0.002, 'worst_seen': 0.628, 'worst_unseen': 0.002, 'worst_unseen_no_skipped': 0.034, 'worst_always_unseen': 0.002, 'worst_always_unseen_no_skipped': 0.034, 'num_test_samples': 1000, 'num_attack_restarts': 10, 'model_name': 'model_std_ERM_930.pt', 'seen_attacks': ['clean', 'PGD_Linf_std', 'Deepfool_base', 'CW_base'], 'unseen_attacks': ['PGD_L1_std', 'PGD_L2_std', 'PGD_Linf_mod', 'Deepfool_mod', 'CW_mod', 'Autoattack'], 'always_unseen_attacks': ['PGD_Linf_mod', 'Deepfool_mod', 'CW_mod', 'Autoattack'], 'skipped_domains_worst_case': ['PGD_Linf_mod']} \n",
      "\n",
      "\n",
      "model_std_REx_1125.pt\n",
      "{'clean': 0.9, 'PGD_L1_std': 0.726, 'PGD_L2_std': 0.44, 'PGD_Linf_std': 0.701, 'Deepfool_base': 0.846, 'CW_base': 0.683, 'PGD_Linf_mod': 0.04, 'Deepfool_mod': 0.648, 'CW_mod': 0.421, 'Autoattack': 0.588, 'worst_no_skipped': 0.259, 'worst_with_skipped': 0.024, 'worst_seen': 0.634, 'worst_unseen': 0.024, 'worst_unseen_no_skipped': 0.259, 'worst_always_unseen': 0.026, 'worst_always_unseen_no_skipped': 0.346, 'num_test_samples': 1000, 'num_attack_restarts': 10, 'model_name': 'model_std_REx_1125.pt', 'seen_attacks': ['clean', 'PGD_Linf_std', 'Deepfool_base', 'CW_base'], 'unseen_attacks': ['PGD_L1_std', 'PGD_L2_std', 'PGD_Linf_mod', 'Deepfool_mod', 'CW_mod', 'Autoattack'], 'always_unseen_attacks': ['PGD_Linf_mod', 'Deepfool_mod', 'CW_mod', 'Autoattack'], 'skipped_domains_worst_case': ['PGD_Linf_mod']} \n",
      "\n",
      "\n",
      "model_clean_0.pt\n",
      "{'clean': 0.981, 'PGD_L1_std': 0.955, 'PGD_L2_std': 0.018, 'PGD_Linf_std': 0.0, 'Deepfool_base': 0.033, 'CW_base': 0.044, 'PGD_Linf_mod': 0.0, 'Deepfool_mod': 0.0, 'CW_mod': 0.023, 'Autoattack': 0.0, 'worst_no_skipped': 0.0, 'worst_with_skipped': 0.0, 'worst_seen': 0.981, 'worst_unseen': 0.0, 'worst_unseen_no_skipped': 0.0, 'worst_always_unseen': 0.0, 'worst_always_unseen_no_skipped': 0.0, 'num_test_samples': 1000, 'num_attack_restarts': 10, 'model_name': 'model_clean_0.pt', 'seen_attacks': ['clean'], 'unseen_attacks': ['PGD_L1_std', 'PGD_L2_std', 'PGD_Linf_std', 'Deepfool_base', 'CW_base', 'PGD_Linf_mod', 'Deepfool_mod', 'CW_mod', 'Autoattack'], 'always_unseen_attacks': ['PGD_Linf_mod', 'Deepfool_mod', 'CW_mod', 'Autoattack'], 'skipped_domains_worst_case': ['PGD_Linf_mod']} \n",
      "\n",
      "\n",
      "model_PGD_Linf_3000.pt\n",
      "{'clean': 0.912, 'PGD_L1_std': 0.546, 'PGD_L2_std': 0.205, 'PGD_Linf_std': 0.647, 'Deepfool_base': 0.845, 'CW_base': 0.681, 'PGD_Linf_mod': 0.034, 'Deepfool_mod': 0.163, 'CW_mod': 0.386, 'Autoattack': 0.612, 'worst_no_skipped': 0.035, 'worst_with_skipped': 0.008, 'worst_seen': 0.647, 'worst_unseen': 0.008, 'worst_unseen_no_skipped': 0.035, 'worst_always_unseen': 0.018, 'worst_always_unseen_no_skipped': 0.079, 'num_test_samples': 1000, 'num_attack_restarts': 10, 'model_name': 'model_PGD_Linf_3000.pt', 'seen_attacks': ['clean', 'PGD_Linf_std'], 'unseen_attacks': ['PGD_L1_std', 'PGD_L2_std', 'Deepfool_base', 'CW_base', 'PGD_Linf_mod', 'Deepfool_mod', 'CW_mod', 'Autoattack'], 'always_unseen_attacks': ['PGD_Linf_mod', 'Deepfool_mod', 'CW_mod', 'Autoattack'], 'skipped_domains_worst_case': ['PGD_Linf_mod']} \n",
      "\n",
      "\n",
      "model_std_ERM_1125.pt\n",
      "{'clean': 0.99, 'PGD_L1_std': 0.903, 'PGD_L2_std': 0.536, 'PGD_Linf_std': 0.677, 'Deepfool_base': 0.929, 'CW_base': 0.688, 'PGD_Linf_mod': 0.006, 'Deepfool_mod': 0.071, 'CW_mod': 0.232, 'Autoattack': 0.423, 'worst_no_skipped': 0.034, 'worst_with_skipped': 0.003, 'worst_seen': 0.632, 'worst_unseen': 0.003, 'worst_unseen_no_skipped': 0.034, 'worst_always_unseen': 0.003, 'worst_always_unseen_no_skipped': 0.034, 'num_test_samples': 1000, 'num_attack_restarts': 10, 'model_name': 'model_std_ERM_1125.pt', 'seen_attacks': ['clean', 'PGD_Linf_std', 'Deepfool_base', 'CW_base'], 'unseen_attacks': ['PGD_L1_std', 'PGD_L2_std', 'PGD_Linf_mod', 'Deepfool_mod', 'CW_mod', 'Autoattack'], 'always_unseen_attacks': ['PGD_Linf_mod', 'Deepfool_mod', 'CW_mod', 'Autoattack'], 'skipped_domains_worst_case': ['PGD_Linf_mod']} \n",
      "\n",
      "\n",
      "model_PGD_Linf_1125.pt\n",
      "{'clean': 0.844, 'PGD_L1_std': 0.44, 'PGD_L2_std': 0.1, 'PGD_Linf_std': 0.592, 'Deepfool_base': 0.781, 'CW_base': 0.623, 'PGD_Linf_mod': 0.051, 'Deepfool_mod': 0.194, 'CW_mod': 0.302, 'Autoattack': 0.55, 'worst_no_skipped': 0.027, 'worst_with_skipped': 0.004, 'worst_seen': 0.592, 'worst_unseen': 0.004, 'worst_unseen_no_skipped': 0.027, 'worst_always_unseen': 0.028, 'worst_always_unseen_no_skipped': 0.093, 'num_test_samples': 1000, 'num_attack_restarts': 10, 'model_name': 'model_PGD_Linf_1125.pt', 'seen_attacks': ['clean', 'PGD_Linf_std'], 'unseen_attacks': ['PGD_L1_std', 'PGD_L2_std', 'Deepfool_base', 'CW_base', 'PGD_Linf_mod', 'Deepfool_mod', 'CW_mod', 'Autoattack'], 'always_unseen_attacks': ['PGD_Linf_mod', 'Deepfool_mod', 'CW_mod', 'Autoattack'], 'skipped_domains_worst_case': ['PGD_Linf_mod']} \n",
      "\n",
      "\n"
     ]
    }
   ],
   "source": [
    "# mask = torch.BoolTensor([False, False, False])\n",
    "# a = torch.rand(3,6,2)\n",
    "# print(a)\n",
    "# b = len(a[mask, :, :])\n",
    "# if b == 0:\n",
    "#     c = None\n",
    "# print(c == None)\n",
    "\n",
    "# for domain in domains:\n",
    "    # print(domain, data_all_domains[domain].shape)\n",
    "import os\n",
    "import numpy as np\n",
    "import torch\n",
    "WORKING_DIR = \"results/MNIST/\"\n",
    "RESULTS_PATH = WORKING_DIR + \"test_accs/\"\n",
    "for root, dirs, files in os.walk(RESULTS_PATH):\n",
    "    model_filenames = files\n",
    "    model_paths = [RESULTS_PATH + file for file in files]\n",
    "for path in model_paths:\n",
    "    temp = np.load(path, allow_pickle = True).item()\n",
    "    results = {}\n",
    "    for k, v in temp.items():\n",
    "        if \"bool\" in k.split('_'):\n",
    "            continue\n",
    "        results[k] = v\n",
    "    print(results['model_name'])\n",
    "    print(results, '\\n\\n')\n",
    "# PATH = \"results/CIFAR10/test_accs/model_PGDs_REx_370.pt.npy\"\n",
    "# for domain in all_domains:\n",
    "#     for k in range(len(results[domain + '_bool_track_correct_preds'])):\n",
    "#         results[domain + '_bool_track_correct_preds'][k] = results[domain + '_bool_track_correct_preds'][k].to('cpu')\n",
    "# results = np.save(PATH, results)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3.7.4 ('py3.8')",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.4"
  },
  "orig_nbformat": 4,
  "vscode": {
   "interpreter": {
    "hash": "65ece6dca1d30e560a7eedc0faf594e5b278fbb4bf81aedd4a08d0f646ac509d"
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
