{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "a266e990",
   "metadata": {},
   "source": [
    "# Low Rank Training and evaluation script\n",
    "\n",
    "1) Install requirements (Linux and Mac)\n",
    "\n",
    "Open console and type\n",
    "\n",
    "```bash\n",
    "python3 -m venv venv\n",
    "source venv/bin/activate\n",
    "pip install -r requirements.txt\n",
    "```\n",
    "\n",
    "Then select the appropriate jupyter kernel\n",
    "\n",
    "Or install directly in jupyter\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "75d1c685",
   "metadata": {},
   "outputs": [],
   "source": [
    "!pip install --quiet torch torchvision matplotlib pandas tqdm wandb tensorly tensorly-torch tntorch transformers\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "4d407869",
   "metadata": {},
   "source": [
    "2) Download the data (Cifar10 test case)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "7130b629",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|████████████████████████████████████████████████████████████████████████████████████████████████| 170M/170M [01:32<00:00, 1.84MB/s]\n",
      "Processing ../../Cifar10_train.txt: 100%|███████████████████████████████████████████████████████| 50000/50000 [00:18<00:00, 2739.37it/s]\n",
      "Processing ../../Cifar10_test.txt: 100%|████████████████████████████████████████████████████████| 10000/10000 [00:03<00:00, 2716.86it/s]\n"
     ]
    }
   ],
   "source": [
    "from dataset.data_adversarial_rs.create_cifar10 import save_cifar10_images\n",
    "\n",
    "save_cifar10_images(\"dataset/data_adversarial_rs/Cifar10\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e88251d1",
   "metadata": {},
   "source": [
    "2) Train the network\n",
    "\n",
    "You can directly attack with pre-trained networks using run_train=false below"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "dd14165b",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "wandb: WARNING Using legacy-service, which is deprecated. If this is unintentional, you can fix it by ensuring you do not call `wandb.require('legacy-service')` and do not set the WANDB_X_REQUIRE_LEGACY_SERVICE environment variable.\n",
      "wandb: ERROR Find detailed error logs at: /home/8v5/Desktop/adversarial_rs_low_rank/last_test/supplementary_material/wandb/debug-cli.8v5.log\n",
      "Error: api_key not configured (no-tty). call wandb login [your_api_key]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "iteration 1\n",
      "Network: vgg16\n",
      "DataParallel(\n",
      "  (module): VGG(\n",
      "    (features): Sequential(\n",
      "      (0): CustomConv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
      "      (1): ReLU(inplace=True)\n",
      "      (2): Conv2dLowRankLayerAugBUG(\n",
      "        rank=[32, 32, 3, 3]\n",
      "        (Us): ParameterList(\n",
      "            (0): Parameter containing: [torch.float32 of size 64x32 (cuda:0)]\n",
      "            (1): Parameter containing: [torch.float32 of size 64x32 (cuda:0)]\n",
      "            (2): Parameter containing: [torch.float32 of size 3x3 (cuda:0)]\n",
      "            (3): Parameter containing: [torch.float32 of size 3x3 (cuda:0)]\n",
      "        )\n",
      "      )\n",
      "      (3): ReLU(inplace=True)\n",
      "      (4): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)\n",
      "      (5): Conv2dLowRankLayerAugBUG(\n",
      "        rank=[64, 32, 3, 3]\n",
      "        (Us): ParameterList(\n",
      "            (0): Parameter containing: [torch.float32 of size 128x64 (cuda:0)]\n",
      "            (1): Parameter containing: [torch.float32 of size 64x32 (cuda:0)]\n",
      "            (2): Parameter containing: [torch.float32 of size 3x3 (cuda:0)]\n",
      "            (3): Parameter containing: [torch.float32 of size 3x3 (cuda:0)]\n",
      "        )\n",
      "      )\n",
      "      (6): ReLU(inplace=True)\n",
      "      (7): Conv2dLowRankLayerAugBUG(\n",
      "        rank=[64, 64, 3, 3]\n",
      "        (Us): ParameterList(\n",
      "            (0): Parameter containing: [torch.float32 of size 128x64 (cuda:0)]\n",
      "            (1): Parameter containing: [torch.float32 of size 128x64 (cuda:0)]\n",
      "            (2): Parameter containing: [torch.float32 of size 3x3 (cuda:0)]\n",
      "            (3): Parameter containing: [torch.float32 of size 3x3 (cuda:0)]\n",
      "        )\n",
      "      )\n",
      "      (8): ReLU(inplace=True)\n",
      "      (9): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)\n",
      "      (10): Conv2dLowRankLayerAugBUG(\n",
      "        rank=[128, 64, 3, 3]\n",
      "        (Us): ParameterList(\n",
      "            (0): Parameter containing: [torch.float32 of size 256x128 (cuda:0)]\n",
      "            (1): Parameter containing: [torch.float32 of size 128x64 (cuda:0)]\n",
      "            (2): Parameter containing: [torch.float32 of size 3x3 (cuda:0)]\n",
      "            (3): Parameter containing: [torch.float32 of size 3x3 (cuda:0)]\n",
      "        )\n",
      "      )\n",
      "      (11): ReLU(inplace=True)\n",
      "      (12): Conv2dLowRankLayerAugBUG(\n",
      "        rank=[128, 128, 3, 3]\n",
      "        (Us): ParameterList(\n",
      "            (0): Parameter containing: [torch.float32 of size 256x128 (cuda:0)]\n",
      "            (1): Parameter containing: [torch.float32 of size 256x128 (cuda:0)]\n",
      "            (2): Parameter containing: [torch.float32 of size 3x3 (cuda:0)]\n",
      "            (3): Parameter containing: [torch.float32 of size 3x3 (cuda:0)]\n",
      "        )\n",
      "      )\n",
      "      (13): ReLU(inplace=True)\n",
      "      (14): Conv2dLowRankLayerAugBUG(\n",
      "        rank=[128, 128, 3, 3]\n",
      "        (Us): ParameterList(\n",
      "            (0): Parameter containing: [torch.float32 of size 256x128 (cuda:0)]\n",
      "            (1): Parameter containing: [torch.float32 of size 256x128 (cuda:0)]\n",
      "            (2): Parameter containing: [torch.float32 of size 3x3 (cuda:0)]\n",
      "            (3): Parameter containing: [torch.float32 of size 3x3 (cuda:0)]\n",
      "        )\n",
      "      )\n",
      "      (15): ReLU(inplace=True)\n",
      "      (16): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)\n",
      "      (17): Conv2dLowRankLayerAugBUG(\n",
      "        rank=[150, 128, 3, 3]\n",
      "        (Us): ParameterList(\n",
      "            (0): Parameter containing: [torch.float32 of size 512x150 (cuda:0)]\n",
      "            (1): Parameter containing: [torch.float32 of size 256x128 (cuda:0)]\n",
      "            (2): Parameter containing: [torch.float32 of size 3x3 (cuda:0)]\n",
      "            (3): Parameter containing: [torch.float32 of size 3x3 (cuda:0)]\n",
      "        )\n",
      "      )\n",
      "      (18): ReLU(inplace=True)\n",
      "      (19): Conv2dLowRankLayerAugBUG(\n",
      "        rank=[150, 150, 3, 3]\n",
      "        (Us): ParameterList(\n",
      "            (0): Parameter containing: [torch.float32 of size 512x150 (cuda:0)]\n",
      "            (1): Parameter containing: [torch.float32 of size 512x150 (cuda:0)]\n",
      "            (2): Parameter containing: [torch.float32 of size 3x3 (cuda:0)]\n",
      "            (3): Parameter containing: [torch.float32 of size 3x3 (cuda:0)]\n",
      "        )\n",
      "      )\n",
      "      (20): ReLU(inplace=True)\n",
      "      (21): Conv2dLowRankLayerAugBUG(\n",
      "        rank=[150, 150, 3, 3]\n",
      "        (Us): ParameterList(\n",
      "            (0): Parameter containing: [torch.float32 of size 512x150 (cuda:0)]\n",
      "            (1): Parameter containing: [torch.float32 of size 512x150 (cuda:0)]\n",
      "            (2): Parameter containing: [torch.float32 of size 3x3 (cuda:0)]\n",
      "            (3): Parameter containing: [torch.float32 of size 3x3 (cuda:0)]\n",
      "        )\n",
      "      )\n",
      "      (22): ReLU(inplace=True)\n",
      "      (23): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)\n",
      "      (24): Conv2dLowRankLayerAugBUG(\n",
      "        rank=[150, 150, 3, 3]\n",
      "        (Us): ParameterList(\n",
      "            (0): Parameter containing: [torch.float32 of size 512x150 (cuda:0)]\n",
      "            (1): Parameter containing: [torch.float32 of size 512x150 (cuda:0)]\n",
      "            (2): Parameter containing: [torch.float32 of size 3x3 (cuda:0)]\n",
      "            (3): Parameter containing: [torch.float32 of size 3x3 (cuda:0)]\n",
      "        )\n",
      "      )\n",
      "      (25): ReLU(inplace=True)\n",
      "      (26): Conv2dLowRankLayerAugBUG(\n",
      "        rank=[150, 150, 3, 3]\n",
      "        (Us): ParameterList(\n",
      "            (0): Parameter containing: [torch.float32 of size 512x150 (cuda:0)]\n",
      "            (1): Parameter containing: [torch.float32 of size 512x150 (cuda:0)]\n",
      "            (2): Parameter containing: [torch.float32 of size 3x3 (cuda:0)]\n",
      "            (3): Parameter containing: [torch.float32 of size 3x3 (cuda:0)]\n",
      "        )\n",
      "      )\n",
      "      (27): ReLU(inplace=True)\n",
      "      (28): Conv2dLowRankLayerAugBUG(\n",
      "        rank=[150, 150, 3, 3]\n",
      "        (Us): ParameterList(\n",
      "            (0): Parameter containing: [torch.float32 of size 512x150 (cuda:0)]\n",
      "            (1): Parameter containing: [torch.float32 of size 512x150 (cuda:0)]\n",
      "            (2): Parameter containing: [torch.float32 of size 3x3 (cuda:0)]\n",
      "            (3): Parameter containing: [torch.float32 of size 3x3 (cuda:0)]\n",
      "        )\n",
      "      )\n",
      "      (29): ReLU(inplace=True)\n",
      "      (30): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)\n",
      "    )\n",
      "    (avgpool): AdaptiveAvgPool2d(output_size=(7, 7))\n",
      "    (classifier): Sequential(\n",
      "      (0): LowRankLayerAugBUG(rmax=150, rank=150)\n",
      "      (1): ReLU(inplace=True)\n",
      "      (2): Dropout(p=0.5, inplace=False)\n",
      "      (3): LowRankLayerAugBUG(rmax=150, rank=150)\n",
      "      (4): ReLU(inplace=True)\n",
      "      (5): Dropout(p=0.5, inplace=False)\n",
      "      (6): CustomLinearLayer(in_features=4096, out_features=10, bias=True)\n",
      "    )\n",
      "  )\n",
      ")\n",
      "Start low-rank training\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/8v5/Desktop/adversarial_rs_low_rank/venv/lib/python3.10/site-packages/torch/optim/lr_scheduler.py:182: UserWarning: Detected call of `lr_scheduler.step()` before `optimizer.step()`. In PyTorch 1.1.0 and later, you should call them in the opposite order: `optimizer.step()` before `lr_scheduler.step()`.  Failure to do this will result in PyTorch skipping the first value of the learning rate schedule. See more details at https://pytorch.org/docs/stable/optim.html#how-to-adjust-learning-rate\n",
      "  warnings.warn(\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Step 99; Epoch 1/30:  100/391 Time: 0.08 cls_loss = 1.618 acc = 0.405 cr = 94.14\n",
      "\n",
      "Step 199; Epoch 1/30:  200/391 Time: 0.08 cls_loss = 1.023 acc = 0.645 cr = 94.34\n",
      "\n",
      "Step 299; Epoch 1/30:  300/391 Time: 0.08 cls_loss = 0.870 acc = 0.700 cr = 94.43\n",
      "\n",
      "Step 399; Epoch 2/30:  9/391 Time: 0.08 cls_loss = 0.778 acc = 0.732 cr = 93.75\n",
      "\n",
      "Step 499; Epoch 2/30:  109/391 Time: 0.08 cls_loss = 0.721 acc = 0.750 cr = 93.75\n",
      "\n",
      "Step 599; Epoch 2/30:  209/391 Time: 0.08 cls_loss = 0.708 acc = 0.756 cr = 93.75\n",
      "\n",
      "Step 699; Epoch 2/30:  309/391 Time: 0.07 cls_loss = 0.675 acc = 0.769 cr = 93.75\n",
      "\n",
      "Step 799; Epoch 3/30:  18/391 Time: 0.07 cls_loss = 0.661 acc = 0.773 cr = 93.75\n",
      "\n",
      "Step 899; Epoch 3/30:  118/391 Time: 0.08 cls_loss = 0.613 acc = 0.796 cr = 93.75\n",
      "\n",
      "Step 999; Epoch 3/30:  218/391 Time: 0.08 cls_loss = 0.608 acc = 0.794 cr = 93.75\n",
      "\n",
      "Step 1099; Epoch 3/30:  318/391 Time: 0.08 cls_loss = 0.578 acc = 0.804 cr = 93.75\n",
      "\n",
      "Step 1199; Epoch 4/30:  27/391 Time: 0.08 cls_loss = 0.553 acc = 0.810 cr = 93.75\n",
      "\n",
      "Step 1299; Epoch 4/30:  127/391 Time: 0.08 cls_loss = 0.531 acc = 0.824 cr = 93.75\n",
      "\n",
      "Step 1399; Epoch 4/30:  227/391 Time: 0.08 cls_loss = 0.530 acc = 0.817 cr = 93.75\n",
      "\n",
      "Step 1499; Epoch 4/30:  327/391 Time: 0.08 cls_loss = 0.519 acc = 0.823 cr = 93.75\n",
      "\n",
      "Step 1599; Epoch 5/30:  36/391 Time: 0.08 cls_loss = 0.501 acc = 0.830 cr = 93.75\n",
      "\n",
      "Step 1699; Epoch 5/30:  136/391 Time: 0.08 cls_loss = 0.485 acc = 0.832 cr = 93.75\n",
      "\n",
      "Step 1799; Epoch 5/30:  236/391 Time: 0.08 cls_loss = 0.483 acc = 0.834 cr = 93.75\n",
      "\n",
      "Step 1899; Epoch 5/30:  336/391 Time: 0.08 cls_loss = 0.478 acc = 0.836 cr = 93.75\n",
      "\n",
      "Step 1999; Epoch 6/30:  45/391 Time: 0.08 cls_loss = 0.472 acc = 0.841 cr = 93.75\n",
      "\n",
      "Step 2099; Epoch 6/30:  145/391 Time: 0.08 cls_loss = 0.450 acc = 0.849 cr = 93.75\n",
      "\n",
      "Step 2199; Epoch 6/30:  245/391 Time: 0.08 cls_loss = 0.454 acc = 0.844 cr = 93.75\n",
      "\n",
      "Step 2299; Epoch 6/30:  345/391 Time: 0.08 cls_loss = 0.449 acc = 0.844 cr = 93.75\n",
      "\n",
      "Step 2399; Epoch 7/30:  54/391 Time: 0.08 cls_loss = 0.432 acc = 0.852 cr = 93.75\n",
      "\n",
      "Step 2499; Epoch 7/30:  154/391 Time: 0.08 cls_loss = 0.427 acc = 0.855 cr = 93.75\n",
      "\n",
      "Step 2599; Epoch 7/30:  254/391 Time: 0.08 cls_loss = 0.402 acc = 0.860 cr = 93.75\n",
      "\n",
      "Step 2699; Epoch 7/30:  354/391 Time: 0.08 cls_loss = 0.413 acc = 0.861 cr = 93.75\n",
      "\n",
      "Step 2799; Epoch 8/30:  63/391 Time: 0.08 cls_loss = 0.393 acc = 0.867 cr = 93.75\n",
      "\n",
      "Step 2899; Epoch 8/30:  163/391 Time: 0.08 cls_loss = 0.384 acc = 0.865 cr = 93.75\n",
      "\n",
      "Step 2999; Epoch 8/30:  263/391 Time: 0.08 cls_loss = 0.397 acc = 0.866 cr = 93.75\n",
      "\n",
      "Step 3099; Epoch 8/30:  363/391 Time: 0.08 cls_loss = 0.388 acc = 0.867 cr = 93.75\n",
      "\n",
      "Step 3199; Epoch 9/30:  72/391 Time: 0.08 cls_loss = 0.359 acc = 0.879 cr = 93.75\n",
      "\n",
      "Step 3299; Epoch 9/30:  172/391 Time: 0.08 cls_loss = 0.359 acc = 0.876 cr = 93.75\n",
      "\n",
      "Step 3399; Epoch 9/30:  272/391 Time: 0.08 cls_loss = 0.365 acc = 0.874 cr = 93.75\n",
      "\n",
      "Step 3499; Epoch 9/30:  372/391 Time: 0.08 cls_loss = 0.358 acc = 0.879 cr = 93.75\n",
      "\n",
      "Step 3599; Epoch 10/30:  81/391 Time: 0.08 cls_loss = 0.342 acc = 0.882 cr = 93.75\n",
      "\n",
      "Step 3699; Epoch 10/30:  181/391 Time: 0.08 cls_loss = 0.349 acc = 0.882 cr = 93.75\n",
      "\n",
      "Step 3799; Epoch 10/30:  281/391 Time: 0.08 cls_loss = 0.328 acc = 0.886 cr = 93.75\n",
      "\n",
      "Step 3899; Epoch 10/30:  381/391 Time: 0.08 cls_loss = 0.341 acc = 0.883 cr = 93.75\n",
      "\n",
      "Step 3999; Epoch 11/30:  90/391 Time: 0.08 cls_loss = 0.326 acc = 0.887 cr = 94.56\n",
      "\n",
      "Step 4099; Epoch 11/30:  190/391 Time: 0.08 cls_loss = 0.314 acc = 0.891 cr = 94.56\n",
      "\n",
      "Step 4199; Epoch 11/30:  290/391 Time: 0.08 cls_loss = 0.319 acc = 0.890 cr = 94.56\n",
      "\n",
      "Step 4299; Epoch 11/30:  390/391 Time: 0.08 cls_loss = 0.312 acc = 0.895 cr = 94.56\n",
      "\n",
      "Step 4399; Epoch 12/30:  99/391 Time: 0.08 cls_loss = 0.299 acc = 0.894 cr = 93.75\n",
      "\n",
      "Step 4499; Epoch 12/30:  199/391 Time: 0.08 cls_loss = 0.297 acc = 0.900 cr = 93.75\n",
      "\n",
      "Step 4599; Epoch 12/30:  299/391 Time: 0.08 cls_loss = 0.307 acc = 0.894 cr = 93.75\n",
      "\n",
      "Step 4699; Epoch 13/30:  8/391 Time: 0.08 cls_loss = 0.299 acc = 0.897 cr = 93.75\n",
      "\n",
      "Step 4799; Epoch 13/30:  108/391 Time: 0.08 cls_loss = 0.289 acc = 0.900 cr = 93.75\n",
      "\n",
      "Step 4899; Epoch 13/30:  208/391 Time: 0.08 cls_loss = 0.289 acc = 0.904 cr = 93.75\n",
      "\n",
      "Step 4999; Epoch 13/30:  308/391 Time: 0.08 cls_loss = 0.290 acc = 0.902 cr = 93.75\n",
      "\n",
      "Step 5099; Epoch 14/30:  17/391 Time: 0.08 cls_loss = 0.284 acc = 0.905 cr = 93.75\n",
      "\n",
      "Step 5199; Epoch 14/30:  117/391 Time: 0.08 cls_loss = 0.277 acc = 0.903 cr = 93.75\n",
      "\n",
      "Step 5299; Epoch 14/30:  217/391 Time: 0.08 cls_loss = 0.269 acc = 0.908 cr = 93.75\n",
      "\n",
      "Step 5399; Epoch 14/30:  317/391 Time: 0.08 cls_loss = 0.279 acc = 0.902 cr = 93.75\n",
      "\n",
      "Step 5499; Epoch 15/30:  26/391 Time: 0.08 cls_loss = 0.258 acc = 0.912 cr = 93.75\n",
      "\n",
      "Step 5599; Epoch 15/30:  126/391 Time: 0.07 cls_loss = 0.264 acc = 0.911 cr = 93.75\n",
      "\n",
      "Step 5699; Epoch 15/30:  226/391 Time: 0.07 cls_loss = 0.250 acc = 0.913 cr = 93.75\n",
      "\n",
      "Step 5799; Epoch 15/30:  326/391 Time: 0.07 cls_loss = 0.263 acc = 0.910 cr = 93.75\n",
      "\n",
      "Step 5899; Epoch 16/30:  35/391 Time: 0.08 cls_loss = 0.257 acc = 0.913 cr = 93.75\n",
      "\n",
      "Step 5999; Epoch 16/30:  135/391 Time: 0.08 cls_loss = 0.254 acc = 0.911 cr = 93.75\n",
      "\n",
      "Step 6099; Epoch 16/30:  235/391 Time: 0.08 cls_loss = 0.237 acc = 0.917 cr = 93.75\n",
      "\n",
      "Step 6199; Epoch 16/30:  335/391 Time: 0.08 cls_loss = 0.251 acc = 0.914 cr = 93.75\n",
      "\n",
      "Step 6299; Epoch 17/30:  44/391 Time: 0.08 cls_loss = 0.241 acc = 0.917 cr = 93.75\n",
      "\n",
      "Step 6399; Epoch 17/30:  144/391 Time: 0.08 cls_loss = 0.242 acc = 0.916 cr = 93.75\n",
      "\n",
      "Step 6499; Epoch 17/30:  244/391 Time: 0.08 cls_loss = 0.231 acc = 0.920 cr = 93.75\n",
      "\n",
      "Step 6599; Epoch 17/30:  344/391 Time: 0.08 cls_loss = 0.246 acc = 0.915 cr = 93.75\n",
      "\n",
      "Step 6699; Epoch 18/30:  53/391 Time: 0.08 cls_loss = 0.224 acc = 0.922 cr = 93.75\n",
      "\n",
      "Step 6799; Epoch 18/30:  153/391 Time: 0.08 cls_loss = 0.219 acc = 0.924 cr = 93.75\n",
      "\n",
      "Step 6899; Epoch 18/30:  253/391 Time: 0.08 cls_loss = 0.234 acc = 0.920 cr = 93.75\n",
      "\n",
      "Step 6999; Epoch 18/30:  353/391 Time: 0.08 cls_loss = 0.224 acc = 0.921 cr = 93.75\n",
      "\n",
      "Step 7099; Epoch 19/30:  62/391 Time: 0.08 cls_loss = 0.211 acc = 0.928 cr = 93.75\n",
      "\n",
      "Step 7199; Epoch 19/30:  162/391 Time: 0.07 cls_loss = 0.221 acc = 0.924 cr = 93.75\n",
      "\n",
      "Step 7299; Epoch 19/30:  262/391 Time: 0.08 cls_loss = 0.224 acc = 0.921 cr = 93.75\n",
      "\n",
      "Step 7399; Epoch 19/30:  362/391 Time: 0.08 cls_loss = 0.224 acc = 0.919 cr = 93.75\n",
      "\n",
      "Step 7499; Epoch 20/30:  71/391 Time: 0.08 cls_loss = 0.212 acc = 0.927 cr = 93.75\n",
      "\n",
      "Step 7599; Epoch 20/30:  171/391 Time: 0.08 cls_loss = 0.211 acc = 0.927 cr = 93.75\n",
      "\n",
      "Step 7699; Epoch 20/30:  271/391 Time: 0.08 cls_loss = 0.218 acc = 0.923 cr = 93.75\n",
      "\n",
      "Step 7799; Epoch 20/30:  371/391 Time: 0.08 cls_loss = 0.214 acc = 0.924 cr = 93.75\n",
      "\n",
      "Step 7899; Epoch 21/30:  80/391 Time: 0.08 cls_loss = 0.213 acc = 0.925 cr = 94.58\n",
      "\n",
      "Step 7999; Epoch 21/30:  180/391 Time: 0.08 cls_loss = 0.207 acc = 0.931 cr = 94.58\n",
      "\n",
      "Step 8099; Epoch 21/30:  280/391 Time: 0.08 cls_loss = 0.192 acc = 0.932 cr = 94.58\n",
      "\n",
      "Step 8199; Epoch 21/30:  380/391 Time: 0.08 cls_loss = 0.209 acc = 0.928 cr = 94.58\n",
      "\n",
      "Step 8299; Epoch 22/30:  89/391 Time: 0.08 cls_loss = 0.195 acc = 0.934 cr = 93.75\n",
      "\n",
      "Step 8399; Epoch 22/30:  189/391 Time: 0.07 cls_loss = 0.202 acc = 0.930 cr = 93.75\n",
      "\n",
      "Step 8499; Epoch 22/30:  289/391 Time: 0.07 cls_loss = 0.197 acc = 0.930 cr = 93.75\n",
      "\n",
      "Step 8599; Epoch 22/30:  389/391 Time: 0.07 cls_loss = 0.200 acc = 0.930 cr = 93.75\n",
      "\n",
      "Step 8699; Epoch 23/30:  98/391 Time: 0.08 cls_loss = 0.189 acc = 0.936 cr = 93.75\n",
      "\n",
      "Step 8799; Epoch 23/30:  198/391 Time: 0.08 cls_loss = 0.191 acc = 0.933 cr = 93.75\n",
      "\n",
      "Step 8899; Epoch 23/30:  298/391 Time: 0.08 cls_loss = 0.206 acc = 0.928 cr = 93.75\n",
      "\n",
      "Step 8999; Epoch 24/30:  7/391 Time: 0.08 cls_loss = 0.192 acc = 0.933 cr = 93.75\n",
      "\n",
      "Step 9099; Epoch 24/30:  107/391 Time: 0.08 cls_loss = 0.184 acc = 0.935 cr = 93.75\n",
      "\n",
      "Step 9199; Epoch 24/30:  207/391 Time: 0.08 cls_loss = 0.182 acc = 0.936 cr = 93.75\n",
      "\n",
      "Step 9299; Epoch 24/30:  307/391 Time: 0.08 cls_loss = 0.194 acc = 0.931 cr = 93.75\n",
      "\n",
      "Step 9399; Epoch 25/30:  16/391 Time: 0.08 cls_loss = 0.190 acc = 0.933 cr = 93.75\n",
      "\n",
      "Step 9499; Epoch 25/30:  116/391 Time: 0.08 cls_loss = 0.184 acc = 0.936 cr = 93.75\n",
      "\n",
      "Step 9599; Epoch 25/30:  216/391 Time: 0.08 cls_loss = 0.182 acc = 0.936 cr = 93.75\n",
      "\n",
      "Step 9699; Epoch 25/30:  316/391 Time: 0.08 cls_loss = 0.182 acc = 0.938 cr = 93.75\n",
      "\n",
      "Step 9799; Epoch 26/30:  25/391 Time: 0.08 cls_loss = 0.190 acc = 0.933 cr = 93.75\n",
      "\n",
      "Step 9899; Epoch 26/30:  125/391 Time: 0.08 cls_loss = 0.180 acc = 0.938 cr = 93.75\n",
      "\n",
      "Step 9999; Epoch 26/30:  225/391 Time: 0.08 cls_loss = 0.185 acc = 0.934 cr = 93.75\n",
      "\n",
      "Step 10099; Epoch 26/30:  325/391 Time: 0.08 cls_loss = 0.172 acc = 0.939 cr = 93.75\n",
      "\n",
      "Step 10199; Epoch 27/30:  34/391 Time: 0.08 cls_loss = 0.189 acc = 0.935 cr = 93.75\n",
      "\n",
      "Step 10299; Epoch 27/30:  134/391 Time: 0.07 cls_loss = 0.175 acc = 0.938 cr = 93.75\n",
      "\n",
      "Step 10399; Epoch 27/30:  234/391 Time: 0.08 cls_loss = 0.189 acc = 0.934 cr = 93.75\n",
      "\n",
      "Step 10499; Epoch 27/30:  334/391 Time: 0.08 cls_loss = 0.177 acc = 0.939 cr = 93.75\n",
      "\n",
      "Step 10599; Epoch 28/30:  43/391 Time: 0.08 cls_loss = 0.174 acc = 0.939 cr = 93.75\n",
      "\n",
      "Step 10699; Epoch 28/30:  143/391 Time: 0.08 cls_loss = 0.175 acc = 0.941 cr = 93.75\n",
      "\n",
      "Step 10799; Epoch 28/30:  243/391 Time: 0.08 cls_loss = 0.185 acc = 0.936 cr = 93.75\n",
      "\n",
      "Step 10899; Epoch 28/30:  343/391 Time: 0.07 cls_loss = 0.181 acc = 0.936 cr = 93.75\n",
      "\n",
      "Step 10999; Epoch 29/30:  52/391 Time: 0.08 cls_loss = 0.183 acc = 0.935 cr = 93.75\n",
      "\n",
      "Step 11099; Epoch 29/30:  152/391 Time: 0.08 cls_loss = 0.185 acc = 0.935 cr = 93.75\n",
      "\n",
      "Step 11199; Epoch 29/30:  252/391 Time: 0.08 cls_loss = 0.171 acc = 0.938 cr = 93.75\n",
      "\n",
      "Step 11299; Epoch 29/30:  352/391 Time: 0.08 cls_loss = 0.173 acc = 0.939 cr = 93.75\n",
      "\n",
      "Step 11399; Epoch 30/30:  61/391 Time: 0.07 cls_loss = 0.175 acc = 0.941 cr = 93.75\n",
      "\n",
      "Step 11499; Epoch 30/30:  161/391 Time: 0.08 cls_loss = 0.181 acc = 0.939 cr = 93.75\n",
      "\n",
      "Step 11599; Epoch 30/30:  261/391 Time: 0.08 cls_loss = 0.172 acc = 0.941 cr = 93.75\n",
      "\n",
      "Step 11699; Epoch 30/30:  361/391 Time: 0.07 cls_loss = 0.174 acc = 0.940 cr = 93.75\n",
      "\n",
      "Finished rank adaptive training, start finetuning\n",
      "Step 11799; Epoch 1/2:  70/391 Time: 0.08 cls_loss = 0.176 acc = 0.939 cr = 94.58\n",
      "\n",
      "Step 11899; Epoch 1/2:  170/391 Time: 0.07 cls_loss = 0.175 acc = 0.939 cr = 94.58\n",
      "\n",
      "Step 11999; Epoch 1/2:  270/391 Time: 0.07 cls_loss = 0.168 acc = 0.942 cr = 94.58\n",
      "\n",
      "Step 12099; Epoch 1/2:  370/391 Time: 0.07 cls_loss = 0.177 acc = 0.938 cr = 94.58\n",
      "\n",
      "Step 12199; Epoch 2/2:  79/391 Time: 0.07 cls_loss = 0.166 acc = 0.942 cr = 94.58\n",
      "\n",
      "Step 12299; Epoch 2/2:  179/391 Time: 0.07 cls_loss = 0.168 acc = 0.938 cr = 94.58\n",
      "\n",
      "Step 12399; Epoch 2/2:  279/391 Time: 0.07 cls_loss = 0.174 acc = 0.941 cr = 94.58\n",
      "\n",
      "Step 12499; Epoch 2/2:  379/391 Time: 0.07 cls_loss = 0.175 acc = 0.939 cr = 94.58\n",
      "\n",
      "Save Model at ./model_lr/Cifar10/Pretrain/low_rank/vgg16/beta_0.0_tol_0.05_rmax_150.0_init_rank_50.0.pth\n",
      "Epoch[2]-Validation-[10/79] Batch OA: 89.84 %\n",
      "Epoch[2]-Validation-[20/79] Batch OA: 85.16 %\n",
      "Epoch[2]-Validation-[30/79] Batch OA: 89.06 %\n",
      "Epoch[2]-Validation-[40/79] Batch OA: 92.97 %\n",
      "Epoch[2]-Validation-[50/79] Batch OA: 96.88 %\n",
      "Epoch[2]-Validation-[60/79] Batch OA: 87.50 %\n",
      "Epoch[2]-Validation-[70/79] Batch OA: 93.75 %\n",
      "---------------Accuracy of     airplane : 91.80 %---------------\n",
      "---------------Accuracy of   automobile : 95.90 %---------------\n",
      "---------------Accuracy of         bird : 83.60 %---------------\n",
      "---------------Accuracy of          cat : 76.00 %---------------\n",
      "---------------Accuracy of         deer : 89.50 %---------------\n",
      "---------------Accuracy of          dog : 82.00 %---------------\n",
      "---------------Accuracy of         frog : 93.90 %---------------\n",
      "---------------Accuracy of        horse : 90.60 %---------------\n",
      "---------------Accuracy of         ship : 95.50 %---------------\n",
      "---------------Accuracy of        truck : 93.00 %---------------\n",
      "---------------Epoch[2]Validation-OA: 89.18 %---------------\n",
      "---------------Epoch[2]Validation-AA: 89.18 %---------------\n",
      "Delete all attacked images for a fresh start\n",
      "Deleting ./dataset/data_adversarial_rs/Cifar10_adv/condlr_fgsm/low_rank/vgg16/*.png\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "rm: cannot remove './dataset/data_adversarial_rs/Cifar10_adv/condlr_fgsm/low_rank/vgg16/*.png': No such file or directory\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Network: vgg16\n",
      "load model from\n",
      "./model_lr/Cifar10/Pretrain/low_rank/vgg16/beta_0.0_tol_0.05_rmax_150.0_init_rank_50.0.pth\n",
      "Low rank layer ranks\n",
      "[64, 3, 3, 3]\n",
      "10\n",
      "[32, 30, 3, 3]\n",
      "[60, 32, 3, 3]\n",
      "[64, 64, 3, 3]\n",
      "[98, 64, 3, 3]\n",
      "[105, 101, 3, 3]\n",
      "[112, 106, 3, 3]\n",
      "[125, 96, 3, 3]\n",
      "[121, 119, 3, 3]\n",
      "[121, 120, 3, 3]\n",
      "[134, 124, 3, 3]\n",
      "[126, 127, 3, 3]\n",
      "[125, 127, 3, 3]\n",
      "140\n",
      "133\n",
      "------\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Batch: 10000/10000: 100%|██████████| 10000/10000 [03:22<00:00, 49.37it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Loading data...\n",
      "./Cifar10_adv/condlr_fgsm/low_rank/vgg16/\n",
      "./dataset/Cifar10_test.txt\n",
      "----\n",
      "[64, 3, 3, 3]\n",
      "10\n",
      "[32, 30, 3, 3]\n",
      "[60, 32, 3, 3]\n",
      "[64, 64, 3, 3]\n",
      "[98, 64, 3, 3]\n",
      "[105, 101, 3, 3]\n",
      "[112, 106, 3, 3]\n",
      "[125, 96, 3, 3]\n",
      "[121, 119, 3, 3]\n",
      "[121, 120, 3, 3]\n",
      "[134, 124, 3, 3]\n",
      "[126, 127, 3, 3]\n",
      "[125, 127, 3, 3]\n",
      "140\n",
      "133\n",
      "Epoch[1]-Validation-[10/79] Batch OA: 89.84 %\n",
      "Epoch[1]-Validation-[20/79] Batch OA: 85.16 %\n",
      "Epoch[1]-Validation-[30/79] Batch OA: 89.06 %\n",
      "Epoch[1]-Validation-[40/79] Batch OA: 92.97 %\n",
      "Epoch[1]-Validation-[50/79] Batch OA: 96.88 %\n",
      "Epoch[1]-Validation-[60/79] Batch OA: 87.50 %\n",
      "Epoch[1]-Validation-[70/79] Batch OA: 93.75 %\n",
      "---------------Accuracy of     airplane : 91.80 %---------------\n",
      "---------------Accuracy of   automobile : 95.90 %---------------\n",
      "---------------Accuracy of         bird : 83.60 %---------------\n",
      "---------------Accuracy of          cat : 76.00 %---------------\n",
      "---------------Accuracy of         deer : 89.50 %---------------\n",
      "---------------Accuracy of          dog : 82.00 %---------------\n",
      "---------------Accuracy of         frog : 93.90 %---------------\n",
      "---------------Accuracy of        horse : 90.60 %---------------\n",
      "---------------Accuracy of         ship : 95.50 %---------------\n",
      "---------------Accuracy of        truck : 93.00 %---------------\n",
      "---------------Epoch[1]Validation-OA: 89.18 %---------------\n",
      "---------------Epoch[1]Validation-AA: 89.18 %---------------\n",
      "Epoch[1]-Validation-[10/79] Batch OA: 76.56 %\n",
      "Epoch[1]-Validation-[20/79] Batch OA: 65.62 %\n",
      "Epoch[1]-Validation-[30/79] Batch OA: 71.88 %\n",
      "Epoch[1]-Validation-[40/79] Batch OA: 78.91 %\n",
      "Epoch[1]-Validation-[50/79] Batch OA: 85.16 %\n",
      "Epoch[1]-Validation-[60/79] Batch OA: 71.88 %\n",
      "Epoch[1]-Validation-[70/79] Batch OA: 75.78 %\n",
      "---------------Accuracy of     airplane : 77.10 %---------------\n",
      "---------------Accuracy of   automobile : 90.40 %---------------\n",
      "---------------Accuracy of         bird : 64.10 %---------------\n",
      "---------------Accuracy of          cat : 50.70 %---------------\n",
      "---------------Accuracy of         deer : 65.30 %---------------\n",
      "---------------Accuracy of          dog : 65.40 %---------------\n",
      "---------------Accuracy of         frog : 81.50 %---------------\n",
      "---------------Accuracy of        horse : 79.50 %---------------\n",
      "---------------Accuracy of         ship : 84.50 %---------------\n",
      "---------------Accuracy of        truck : 84.50 %---------------\n",
      "---------------Epoch[1]Validation-OA: 74.30 %---------------\n",
      "---------------Epoch[1]Validation-AA: 74.30 %---------------\n",
      "Clean Test Set OA: 89.18\n",
      "condlr_fgsm Test Set OA: 74.3\n",
      "Delete all attacked images for a fresh start\n",
      "Deleting ./dataset/data_adversarial_rs/Cifar10_adv/condlr_fgsm/low_rank/vgg16/*.png\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "rm: cannot remove './dataset/data_adversarial_rs/Cifar10_adv/condlr_fgsm/low_rank/vgg16/*.png': No such file or directory\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Network: vgg16\n",
      "load model from\n",
      "./model_lr/Cifar10/Pretrain/low_rank/vgg16/beta_0.0_tol_0.05_rmax_150.0_init_rank_50.0.pth\n",
      "Low rank layer ranks\n",
      "[64, 3, 3, 3]\n",
      "10\n",
      "[32, 30, 3, 3]\n",
      "[60, 32, 3, 3]\n",
      "[64, 64, 3, 3]\n",
      "[98, 64, 3, 3]\n",
      "[105, 101, 3, 3]\n",
      "[112, 106, 3, 3]\n",
      "[125, 96, 3, 3]\n",
      "[121, 119, 3, 3]\n",
      "[121, 120, 3, 3]\n",
      "[134, 124, 3, 3]\n",
      "[126, 127, 3, 3]\n",
      "[125, 127, 3, 3]\n",
      "140\n",
      "133\n",
      "------\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Batch: 10000/10000: 100%|██████████| 10000/10000 [03:18<00:00, 50.32it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Loading data...\n",
      "./Cifar10_adv/condlr_fgsm/low_rank/vgg16/\n",
      "./dataset/Cifar10_test.txt\n",
      "----\n",
      "[64, 3, 3, 3]\n",
      "10\n",
      "[32, 30, 3, 3]\n",
      "[60, 32, 3, 3]\n",
      "[64, 64, 3, 3]\n",
      "[98, 64, 3, 3]\n",
      "[105, 101, 3, 3]\n",
      "[112, 106, 3, 3]\n",
      "[125, 96, 3, 3]\n",
      "[121, 119, 3, 3]\n",
      "[121, 120, 3, 3]\n",
      "[134, 124, 3, 3]\n",
      "[126, 127, 3, 3]\n",
      "[125, 127, 3, 3]\n",
      "140\n",
      "133\n",
      "Epoch[1]-Validation-[10/79] Batch OA: 89.84 %\n",
      "Epoch[1]-Validation-[20/79] Batch OA: 85.16 %\n",
      "Epoch[1]-Validation-[30/79] Batch OA: 89.06 %\n",
      "Epoch[1]-Validation-[40/79] Batch OA: 92.97 %\n",
      "Epoch[1]-Validation-[50/79] Batch OA: 96.88 %\n",
      "Epoch[1]-Validation-[60/79] Batch OA: 87.50 %\n",
      "Epoch[1]-Validation-[70/79] Batch OA: 93.75 %\n",
      "---------------Accuracy of     airplane : 91.80 %---------------\n",
      "---------------Accuracy of   automobile : 95.90 %---------------\n",
      "---------------Accuracy of         bird : 83.60 %---------------\n",
      "---------------Accuracy of          cat : 76.00 %---------------\n",
      "---------------Accuracy of         deer : 89.50 %---------------\n",
      "---------------Accuracy of          dog : 82.00 %---------------\n",
      "---------------Accuracy of         frog : 93.90 %---------------\n",
      "---------------Accuracy of        horse : 90.60 %---------------\n",
      "---------------Accuracy of         ship : 95.50 %---------------\n",
      "---------------Accuracy of        truck : 93.00 %---------------\n",
      "---------------Epoch[1]Validation-OA: 89.18 %---------------\n",
      "---------------Epoch[1]Validation-AA: 89.18 %---------------\n",
      "Epoch[1]-Validation-[10/79] Batch OA: 57.81 %\n",
      "Epoch[1]-Validation-[20/79] Batch OA: 52.34 %\n",
      "Epoch[1]-Validation-[30/79] Batch OA: 60.16 %\n",
      "Epoch[1]-Validation-[40/79] Batch OA: 57.81 %\n",
      "Epoch[1]-Validation-[50/79] Batch OA: 65.62 %\n",
      "Epoch[1]-Validation-[60/79] Batch OA: 57.81 %\n",
      "Epoch[1]-Validation-[70/79] Batch OA: 58.59 %\n",
      "---------------Accuracy of     airplane : 61.30 %---------------\n",
      "---------------Accuracy of   automobile : 80.30 %---------------\n",
      "---------------Accuracy of         bird : 47.10 %---------------\n",
      "---------------Accuracy of          cat : 32.50 %---------------\n",
      "---------------Accuracy of         deer : 43.00 %---------------\n",
      "---------------Accuracy of          dog : 49.20 %---------------\n",
      "---------------Accuracy of         frog : 66.00 %---------------\n",
      "---------------Accuracy of        horse : 64.00 %---------------\n",
      "---------------Accuracy of         ship : 72.70 %---------------\n",
      "---------------Accuracy of        truck : 73.70 %---------------\n",
      "---------------Epoch[1]Validation-OA: 58.98 %---------------\n",
      "---------------Epoch[1]Validation-AA: 58.98 %---------------\n",
      "Clean Test Set OA: 89.18\n",
      "condlr_fgsm Test Set OA: 58.98\n",
      "Delete all attacked images for a fresh start\n",
      "Deleting ./dataset/data_adversarial_rs/Cifar10_adv/condlr_fgsm/low_rank/vgg16/*.png\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "rm: cannot remove './dataset/data_adversarial_rs/Cifar10_adv/condlr_fgsm/low_rank/vgg16/*.png': No such file or directory\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Network: vgg16\n",
      "load model from\n",
      "./model_lr/Cifar10/Pretrain/low_rank/vgg16/beta_0.0_tol_0.05_rmax_150.0_init_rank_50.0.pth\n",
      "Low rank layer ranks\n",
      "[64, 3, 3, 3]\n",
      "10\n",
      "[32, 30, 3, 3]\n",
      "[60, 32, 3, 3]\n",
      "[64, 64, 3, 3]\n",
      "[98, 64, 3, 3]\n",
      "[105, 101, 3, 3]\n",
      "[112, 106, 3, 3]\n",
      "[125, 96, 3, 3]\n",
      "[121, 119, 3, 3]\n",
      "[121, 120, 3, 3]\n",
      "[134, 124, 3, 3]\n",
      "[126, 127, 3, 3]\n",
      "[125, 127, 3, 3]\n",
      "140\n",
      "133\n",
      "------\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Batch: 10000/10000: 100%|██████████| 10000/10000 [03:22<00:00, 49.47it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Loading data...\n",
      "./Cifar10_adv/condlr_fgsm/low_rank/vgg16/\n",
      "./dataset/Cifar10_test.txt\n",
      "----\n",
      "[64, 3, 3, 3]\n",
      "10\n",
      "[32, 30, 3, 3]\n",
      "[60, 32, 3, 3]\n",
      "[64, 64, 3, 3]\n",
      "[98, 64, 3, 3]\n",
      "[105, 101, 3, 3]\n",
      "[112, 106, 3, 3]\n",
      "[125, 96, 3, 3]\n",
      "[121, 119, 3, 3]\n",
      "[121, 120, 3, 3]\n",
      "[134, 124, 3, 3]\n",
      "[126, 127, 3, 3]\n",
      "[125, 127, 3, 3]\n",
      "140\n",
      "133\n",
      "Epoch[1]-Validation-[10/79] Batch OA: 89.84 %\n",
      "Epoch[1]-Validation-[20/79] Batch OA: 85.16 %\n",
      "Epoch[1]-Validation-[30/79] Batch OA: 89.06 %\n",
      "Epoch[1]-Validation-[40/79] Batch OA: 92.97 %\n",
      "Epoch[1]-Validation-[50/79] Batch OA: 96.88 %\n",
      "Epoch[1]-Validation-[60/79] Batch OA: 87.50 %\n",
      "Epoch[1]-Validation-[70/79] Batch OA: 93.75 %\n",
      "---------------Accuracy of     airplane : 91.80 %---------------\n",
      "---------------Accuracy of   automobile : 95.90 %---------------\n",
      "---------------Accuracy of         bird : 83.60 %---------------\n",
      "---------------Accuracy of          cat : 76.00 %---------------\n",
      "---------------Accuracy of         deer : 89.50 %---------------\n",
      "---------------Accuracy of          dog : 82.00 %---------------\n",
      "---------------Accuracy of         frog : 93.90 %---------------\n",
      "---------------Accuracy of        horse : 90.60 %---------------\n",
      "---------------Accuracy of         ship : 95.50 %---------------\n",
      "---------------Accuracy of        truck : 93.00 %---------------\n",
      "---------------Epoch[1]Validation-OA: 89.18 %---------------\n",
      "---------------Epoch[1]Validation-AA: 89.18 %---------------\n",
      "Epoch[1]-Validation-[10/79] Batch OA: 44.53 %\n",
      "Epoch[1]-Validation-[20/79] Batch OA: 41.41 %\n",
      "Epoch[1]-Validation-[30/79] Batch OA: 50.78 %\n",
      "Epoch[1]-Validation-[40/79] Batch OA: 47.66 %\n",
      "Epoch[1]-Validation-[50/79] Batch OA: 50.00 %\n",
      "Epoch[1]-Validation-[60/79] Batch OA: 46.09 %\n",
      "Epoch[1]-Validation-[70/79] Batch OA: 46.09 %\n",
      "---------------Accuracy of     airplane : 50.70 %---------------\n",
      "---------------Accuracy of   automobile : 70.10 %---------------\n",
      "---------------Accuracy of         bird : 35.70 %---------------\n",
      "---------------Accuracy of          cat : 22.50 %---------------\n",
      "---------------Accuracy of         deer : 29.60 %---------------\n",
      "---------------Accuracy of          dog : 37.00 %---------------\n",
      "---------------Accuracy of         frog : 50.50 %---------------\n",
      "---------------Accuracy of        horse : 50.70 %---------------\n",
      "---------------Accuracy of         ship : 59.40 %---------------\n",
      "---------------Accuracy of        truck : 60.70 %---------------\n",
      "---------------Epoch[1]Validation-OA: 46.69 %---------------\n",
      "---------------Epoch[1]Validation-AA: 46.69 %---------------\n",
      "Clean Test Set OA: 89.18\n",
      "condlr_fgsm Test Set OA: 46.69\n",
      "Network: vgg16\n",
      "DataParallel(\n",
      "  (module): VGG(\n",
      "    (features): Sequential(\n",
      "      (0): CustomConv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
      "      (1): ReLU(inplace=True)\n",
      "      (2): Conv2dLowRankLayerAugBUG(\n",
      "        rank=[32, 32, 3, 3]\n",
      "        (Us): ParameterList(\n",
      "            (0): Parameter containing: [torch.float32 of size 64x32 (cuda:0)]\n",
      "            (1): Parameter containing: [torch.float32 of size 64x32 (cuda:0)]\n",
      "            (2): Parameter containing: [torch.float32 of size 3x3 (cuda:0)]\n",
      "            (3): Parameter containing: [torch.float32 of size 3x3 (cuda:0)]\n",
      "        )\n",
      "      )\n",
      "      (3): ReLU(inplace=True)\n",
      "      (4): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)\n",
      "      (5): Conv2dLowRankLayerAugBUG(\n",
      "        rank=[64, 32, 3, 3]\n",
      "        (Us): ParameterList(\n",
      "            (0): Parameter containing: [torch.float32 of size 128x64 (cuda:0)]\n",
      "            (1): Parameter containing: [torch.float32 of size 64x32 (cuda:0)]\n",
      "            (2): Parameter containing: [torch.float32 of size 3x3 (cuda:0)]\n",
      "            (3): Parameter containing: [torch.float32 of size 3x3 (cuda:0)]\n",
      "        )\n",
      "      )\n",
      "      (6): ReLU(inplace=True)\n",
      "      (7): Conv2dLowRankLayerAugBUG(\n",
      "        rank=[64, 64, 3, 3]\n",
      "        (Us): ParameterList(\n",
      "            (0): Parameter containing: [torch.float32 of size 128x64 (cuda:0)]\n",
      "            (1): Parameter containing: [torch.float32 of size 128x64 (cuda:0)]\n",
      "            (2): Parameter containing: [torch.float32 of size 3x3 (cuda:0)]\n",
      "            (3): Parameter containing: [torch.float32 of size 3x3 (cuda:0)]\n",
      "        )\n",
      "      )\n",
      "      (8): ReLU(inplace=True)\n",
      "      (9): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)\n",
      "      (10): Conv2dLowRankLayerAugBUG(\n",
      "        rank=[128, 64, 3, 3]\n",
      "        (Us): ParameterList(\n",
      "            (0): Parameter containing: [torch.float32 of size 256x128 (cuda:0)]\n",
      "            (1): Parameter containing: [torch.float32 of size 128x64 (cuda:0)]\n",
      "            (2): Parameter containing: [torch.float32 of size 3x3 (cuda:0)]\n",
      "            (3): Parameter containing: [torch.float32 of size 3x3 (cuda:0)]\n",
      "        )\n",
      "      )\n",
      "      (11): ReLU(inplace=True)\n",
      "      (12): Conv2dLowRankLayerAugBUG(\n",
      "        rank=[128, 128, 3, 3]\n",
      "        (Us): ParameterList(\n",
      "            (0): Parameter containing: [torch.float32 of size 256x128 (cuda:0)]\n",
      "            (1): Parameter containing: [torch.float32 of size 256x128 (cuda:0)]\n",
      "            (2): Parameter containing: [torch.float32 of size 3x3 (cuda:0)]\n",
      "            (3): Parameter containing: [torch.float32 of size 3x3 (cuda:0)]\n",
      "        )\n",
      "      )\n",
      "      (13): ReLU(inplace=True)\n",
      "      (14): Conv2dLowRankLayerAugBUG(\n",
      "        rank=[128, 128, 3, 3]\n",
      "        (Us): ParameterList(\n",
      "            (0): Parameter containing: [torch.float32 of size 256x128 (cuda:0)]\n",
      "            (1): Parameter containing: [torch.float32 of size 256x128 (cuda:0)]\n",
      "            (2): Parameter containing: [torch.float32 of size 3x3 (cuda:0)]\n",
      "            (3): Parameter containing: [torch.float32 of size 3x3 (cuda:0)]\n",
      "        )\n",
      "      )\n",
      "      (15): ReLU(inplace=True)\n",
      "      (16): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)\n",
      "      (17): Conv2dLowRankLayerAugBUG(\n",
      "        rank=[150, 128, 3, 3]\n",
      "        (Us): ParameterList(\n",
      "            (0): Parameter containing: [torch.float32 of size 512x150 (cuda:0)]\n",
      "            (1): Parameter containing: [torch.float32 of size 256x128 (cuda:0)]\n",
      "            (2): Parameter containing: [torch.float32 of size 3x3 (cuda:0)]\n",
      "            (3): Parameter containing: [torch.float32 of size 3x3 (cuda:0)]\n",
      "        )\n",
      "      )\n",
      "      (18): ReLU(inplace=True)\n",
      "      (19): Conv2dLowRankLayerAugBUG(\n",
      "        rank=[150, 150, 3, 3]\n",
      "        (Us): ParameterList(\n",
      "            (0): Parameter containing: [torch.float32 of size 512x150 (cuda:0)]\n",
      "            (1): Parameter containing: [torch.float32 of size 512x150 (cuda:0)]\n",
      "            (2): Parameter containing: [torch.float32 of size 3x3 (cuda:0)]\n",
      "            (3): Parameter containing: [torch.float32 of size 3x3 (cuda:0)]\n",
      "        )\n",
      "      )\n",
      "      (20): ReLU(inplace=True)\n",
      "      (21): Conv2dLowRankLayerAugBUG(\n",
      "        rank=[150, 150, 3, 3]\n",
      "        (Us): ParameterList(\n",
      "            (0): Parameter containing: [torch.float32 of size 512x150 (cuda:0)]\n",
      "            (1): Parameter containing: [torch.float32 of size 512x150 (cuda:0)]\n",
      "            (2): Parameter containing: [torch.float32 of size 3x3 (cuda:0)]\n",
      "            (3): Parameter containing: [torch.float32 of size 3x3 (cuda:0)]\n",
      "        )\n",
      "      )\n",
      "      (22): ReLU(inplace=True)\n",
      "      (23): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)\n",
      "      (24): Conv2dLowRankLayerAugBUG(\n",
      "        rank=[150, 150, 3, 3]\n",
      "        (Us): ParameterList(\n",
      "            (0): Parameter containing: [torch.float32 of size 512x150 (cuda:0)]\n",
      "            (1): Parameter containing: [torch.float32 of size 512x150 (cuda:0)]\n",
      "            (2): Parameter containing: [torch.float32 of size 3x3 (cuda:0)]\n",
      "            (3): Parameter containing: [torch.float32 of size 3x3 (cuda:0)]\n",
      "        )\n",
      "      )\n",
      "      (25): ReLU(inplace=True)\n",
      "      (26): Conv2dLowRankLayerAugBUG(\n",
      "        rank=[150, 150, 3, 3]\n",
      "        (Us): ParameterList(\n",
      "            (0): Parameter containing: [torch.float32 of size 512x150 (cuda:0)]\n",
      "            (1): Parameter containing: [torch.float32 of size 512x150 (cuda:0)]\n",
      "            (2): Parameter containing: [torch.float32 of size 3x3 (cuda:0)]\n",
      "            (3): Parameter containing: [torch.float32 of size 3x3 (cuda:0)]\n",
      "        )\n",
      "      )\n",
      "      (27): ReLU(inplace=True)\n",
      "      (28): Conv2dLowRankLayerAugBUG(\n",
      "        rank=[150, 150, 3, 3]\n",
      "        (Us): ParameterList(\n",
      "            (0): Parameter containing: [torch.float32 of size 512x150 (cuda:0)]\n",
      "            (1): Parameter containing: [torch.float32 of size 512x150 (cuda:0)]\n",
      "            (2): Parameter containing: [torch.float32 of size 3x3 (cuda:0)]\n",
      "            (3): Parameter containing: [torch.float32 of size 3x3 (cuda:0)]\n",
      "        )\n",
      "      )\n",
      "      (29): ReLU(inplace=True)\n",
      "      (30): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)\n",
      "    )\n",
      "    (avgpool): AdaptiveAvgPool2d(output_size=(7, 7))\n",
      "    (classifier): Sequential(\n",
      "      (0): LowRankLayerAugBUG(rmax=150, rank=150)\n",
      "      (1): ReLU(inplace=True)\n",
      "      (2): Dropout(p=0.5, inplace=False)\n",
      "      (3): LowRankLayerAugBUG(rmax=150, rank=150)\n",
      "      (4): ReLU(inplace=True)\n",
      "      (5): Dropout(p=0.5, inplace=False)\n",
      "      (6): CustomLinearLayer(in_features=4096, out_features=10, bias=True)\n",
      "    )\n",
      "  )\n",
      ")\n",
      "Start low-rank training\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/8v5/Desktop/adversarial_rs_low_rank/venv/lib/python3.10/site-packages/torch/optim/lr_scheduler.py:182: UserWarning: Detected call of `lr_scheduler.step()` before `optimizer.step()`. In PyTorch 1.1.0 and later, you should call them in the opposite order: `optimizer.step()` before `lr_scheduler.step()`.  Failure to do this will result in PyTorch skipping the first value of the learning rate schedule. See more details at https://pytorch.org/docs/stable/optim.html#how-to-adjust-learning-rate\n",
      "  warnings.warn(\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Step 99; Epoch 1/30:  100/391 Time: 0.09 cls_loss = 66.757 acc = 0.399 cr = 94.08\n",
      "\n",
      "Step 199; Epoch 1/30:  200/391 Time: 0.08 cls_loss = 61.913 acc = 0.640 cr = 94.16\n",
      "\n",
      "Step 299; Epoch 1/30:  300/391 Time: 0.08 cls_loss = 57.696 acc = 0.690 cr = 94.16\n",
      "\n",
      "Step 399; Epoch 2/30:  9/391 Time: 0.09 cls_loss = 54.054 acc = 0.726 cr = 93.75\n",
      "\n",
      "Step 499; Epoch 2/30:  109/391 Time: 0.09 cls_loss = 51.153 acc = 0.753 cr = 93.75\n",
      "\n",
      "Step 599; Epoch 2/30:  209/391 Time: 0.09 cls_loss = 48.867 acc = 0.759 cr = 93.75\n",
      "\n",
      "Step 699; Epoch 2/30:  309/391 Time: 0.09 cls_loss = 47.087 acc = 0.769 cr = 93.75\n",
      "\n",
      "Step 799; Epoch 3/30:  18/391 Time: 0.09 cls_loss = 45.534 acc = 0.777 cr = 93.75\n",
      "\n",
      "Step 899; Epoch 3/30:  118/391 Time: 0.09 cls_loss = 44.307 acc = 0.796 cr = 93.75\n",
      "\n",
      "Step 999; Epoch 3/30:  218/391 Time: 0.09 cls_loss = 43.261 acc = 0.791 cr = 93.75\n",
      "\n",
      "Step 1099; Epoch 3/30:  318/391 Time: 0.08 cls_loss = 42.322 acc = 0.803 cr = 93.75\n",
      "\n",
      "Step 1199; Epoch 4/30:  27/391 Time: 0.09 cls_loss = 41.432 acc = 0.801 cr = 93.75\n",
      "\n",
      "Step 1299; Epoch 4/30:  127/391 Time: 0.08 cls_loss = 40.699 acc = 0.810 cr = 93.75\n",
      "\n",
      "Step 1399; Epoch 4/30:  227/391 Time: 0.08 cls_loss = 39.997 acc = 0.814 cr = 93.75\n",
      "\n",
      "Step 1499; Epoch 4/30:  327/391 Time: 0.09 cls_loss = 39.328 acc = 0.824 cr = 93.75\n",
      "\n",
      "Step 1599; Epoch 5/30:  36/391 Time: 0.09 cls_loss = 38.683 acc = 0.825 cr = 93.75\n",
      "\n",
      "Step 1699; Epoch 5/30:  136/391 Time: 0.09 cls_loss = 38.221 acc = 0.829 cr = 93.75\n",
      "\n",
      "Step 1799; Epoch 5/30:  236/391 Time: 0.09 cls_loss = 37.722 acc = 0.832 cr = 93.75\n",
      "\n",
      "Step 1899; Epoch 5/30:  336/391 Time: 0.09 cls_loss = 37.253 acc = 0.839 cr = 93.75\n",
      "\n",
      "Step 1999; Epoch 6/30:  45/391 Time: 0.09 cls_loss = 36.788 acc = 0.845 cr = 93.75\n",
      "\n",
      "Step 2099; Epoch 6/30:  145/391 Time: 0.09 cls_loss = 36.405 acc = 0.852 cr = 93.75\n",
      "\n",
      "Step 2199; Epoch 6/30:  245/391 Time: 0.09 cls_loss = 36.031 acc = 0.847 cr = 93.75\n",
      "\n",
      "Step 2299; Epoch 6/30:  345/391 Time: 0.09 cls_loss = 35.664 acc = 0.845 cr = 93.75\n",
      "\n",
      "Step 2399; Epoch 7/30:  54/391 Time: 0.09 cls_loss = 35.260 acc = 0.854 cr = 93.75\n",
      "\n",
      "Step 2499; Epoch 7/30:  154/391 Time: 0.09 cls_loss = 34.961 acc = 0.852 cr = 93.75\n",
      "\n",
      "Step 2599; Epoch 7/30:  254/391 Time: 0.09 cls_loss = 34.644 acc = 0.850 cr = 93.75\n",
      "\n",
      "Step 2699; Epoch 7/30:  354/391 Time: 0.09 cls_loss = 34.318 acc = 0.855 cr = 93.75\n",
      "\n",
      "Step 2799; Epoch 8/30:  63/391 Time: 0.09 cls_loss = 33.985 acc = 0.862 cr = 93.75\n",
      "\n",
      "Step 2899; Epoch 8/30:  163/391 Time: 0.09 cls_loss = 33.716 acc = 0.864 cr = 93.75\n",
      "\n",
      "Step 2999; Epoch 8/30:  263/391 Time: 0.09 cls_loss = 33.448 acc = 0.859 cr = 93.75\n",
      "\n",
      "Step 3099; Epoch 8/30:  363/391 Time: 0.09 cls_loss = 33.153 acc = 0.865 cr = 93.75\n",
      "\n",
      "Step 3199; Epoch 9/30:  72/391 Time: 0.09 cls_loss = 32.819 acc = 0.869 cr = 93.75\n",
      "\n",
      "Step 3299; Epoch 9/30:  172/391 Time: 0.08 cls_loss = 32.622 acc = 0.865 cr = 93.75\n",
      "\n",
      "Step 3399; Epoch 9/30:  272/391 Time: 0.08 cls_loss = 32.370 acc = 0.864 cr = 93.75\n",
      "\n",
      "Step 3499; Epoch 9/30:  372/391 Time: 0.09 cls_loss = 32.127 acc = 0.871 cr = 93.75\n",
      "\n",
      "Step 3599; Epoch 10/30:  81/391 Time: 0.09 cls_loss = 31.839 acc = 0.878 cr = 93.75\n",
      "\n",
      "Step 3699; Epoch 10/30:  181/391 Time: 0.09 cls_loss = 31.653 acc = 0.881 cr = 93.75\n",
      "\n",
      "Step 3799; Epoch 10/30:  281/391 Time: 0.09 cls_loss = 31.438 acc = 0.881 cr = 93.75\n",
      "\n",
      "Step 3899; Epoch 10/30:  381/391 Time: 0.09 cls_loss = 31.243 acc = 0.871 cr = 93.75\n",
      "\n",
      "Step 3999; Epoch 11/30:  90/391 Time: 0.09 cls_loss = 31.019 acc = 0.881 cr = 94.17\n",
      "\n",
      "Step 4099; Epoch 11/30:  190/391 Time: 0.09 cls_loss = 30.814 acc = 0.884 cr = 94.17\n",
      "\n",
      "Step 4199; Epoch 11/30:  290/391 Time: 0.09 cls_loss = 30.630 acc = 0.881 cr = 94.17\n",
      "\n",
      "Step 4299; Epoch 11/30:  390/391 Time: 0.09 cls_loss = 30.443 acc = 0.880 cr = 94.17\n",
      "\n",
      "Step 4399; Epoch 12/30:  99/391 Time: 0.09 cls_loss = 30.193 acc = 0.888 cr = 93.75\n",
      "\n",
      "Step 4499; Epoch 12/30:  199/391 Time: 0.09 cls_loss = 30.056 acc = 0.889 cr = 93.75\n",
      "\n",
      "Step 4599; Epoch 12/30:  299/391 Time: 0.09 cls_loss = 29.894 acc = 0.884 cr = 93.75\n",
      "\n",
      "Step 4699; Epoch 13/30:  8/391 Time: 0.09 cls_loss = 29.677 acc = 0.891 cr = 93.75\n",
      "\n",
      "Step 4799; Epoch 13/30:  108/391 Time: 0.09 cls_loss = 29.540 acc = 0.894 cr = 93.75\n",
      "\n",
      "Step 4899; Epoch 13/30:  208/391 Time: 0.09 cls_loss = 29.383 acc = 0.894 cr = 93.75\n",
      "\n",
      "Step 4999; Epoch 13/30:  308/391 Time: 0.09 cls_loss = 29.256 acc = 0.887 cr = 93.75\n",
      "\n",
      "Step 5099; Epoch 14/30:  17/391 Time: 0.09 cls_loss = 29.055 acc = 0.890 cr = 93.75\n",
      "\n",
      "Step 5199; Epoch 14/30:  117/391 Time: 0.09 cls_loss = 28.943 acc = 0.899 cr = 93.75\n",
      "\n",
      "Step 5299; Epoch 14/30:  217/391 Time: 0.09 cls_loss = 28.807 acc = 0.897 cr = 93.75\n",
      "\n",
      "Step 5399; Epoch 14/30:  317/391 Time: 0.09 cls_loss = 28.676 acc = 0.897 cr = 93.75\n",
      "\n",
      "Step 5499; Epoch 15/30:  26/391 Time: 0.09 cls_loss = 28.512 acc = 0.897 cr = 93.75\n",
      "\n",
      "Step 5599; Epoch 15/30:  126/391 Time: 0.09 cls_loss = 28.391 acc = 0.909 cr = 93.75\n",
      "\n",
      "Step 5699; Epoch 15/30:  226/391 Time: 0.09 cls_loss = 28.294 acc = 0.901 cr = 93.75\n",
      "\n",
      "Step 5799; Epoch 15/30:  326/391 Time: 0.09 cls_loss = 28.176 acc = 0.902 cr = 93.75\n",
      "\n",
      "Step 5899; Epoch 16/30:  35/391 Time: 0.09 cls_loss = 28.022 acc = 0.901 cr = 93.75\n",
      "\n",
      "Step 5999; Epoch 16/30:  135/391 Time: 0.09 cls_loss = 27.936 acc = 0.910 cr = 93.75\n",
      "\n",
      "Step 6099; Epoch 16/30:  235/391 Time: 0.09 cls_loss = 27.833 acc = 0.907 cr = 93.75\n",
      "\n",
      "Step 6199; Epoch 16/30:  335/391 Time: 0.09 cls_loss = 27.731 acc = 0.908 cr = 93.75\n",
      "\n",
      "Step 6299; Epoch 17/30:  44/391 Time: 0.09 cls_loss = 27.592 acc = 0.909 cr = 93.75\n",
      "\n",
      "Step 6399; Epoch 17/30:  144/391 Time: 0.09 cls_loss = 27.520 acc = 0.914 cr = 93.75\n",
      "\n",
      "Step 6499; Epoch 17/30:  244/391 Time: 0.09 cls_loss = 27.425 acc = 0.916 cr = 93.75\n",
      "\n",
      "Step 6599; Epoch 17/30:  344/391 Time: 0.09 cls_loss = 27.344 acc = 0.913 cr = 93.75\n",
      "\n",
      "Step 6699; Epoch 18/30:  53/391 Time: 0.09 cls_loss = 27.227 acc = 0.908 cr = 93.75\n",
      "\n",
      "Step 6799; Epoch 18/30:  153/391 Time: 0.09 cls_loss = 27.162 acc = 0.919 cr = 93.75\n",
      "\n",
      "Step 6899; Epoch 18/30:  253/391 Time: 0.09 cls_loss = 27.078 acc = 0.918 cr = 93.75\n",
      "\n",
      "Step 6999; Epoch 18/30:  353/391 Time: 0.09 cls_loss = 26.999 acc = 0.920 cr = 93.75\n",
      "\n",
      "Step 7099; Epoch 19/30:  62/391 Time: 0.09 cls_loss = 26.893 acc = 0.918 cr = 93.75\n",
      "\n",
      "Step 7199; Epoch 19/30:  162/391 Time: 0.09 cls_loss = 26.839 acc = 0.925 cr = 93.75\n",
      "\n",
      "Step 7299; Epoch 19/30:  262/391 Time: 0.09 cls_loss = 26.771 acc = 0.923 cr = 93.75\n",
      "\n",
      "Step 7399; Epoch 19/30:  362/391 Time: 0.09 cls_loss = 26.707 acc = 0.923 cr = 93.75\n",
      "\n",
      "Step 7499; Epoch 20/30:  71/391 Time: 0.09 cls_loss = 26.610 acc = 0.922 cr = 93.75\n",
      "\n",
      "Step 7599; Epoch 20/30:  171/391 Time: 0.09 cls_loss = 26.578 acc = 0.922 cr = 93.75\n",
      "\n",
      "Step 7699; Epoch 20/30:  271/391 Time: 0.09 cls_loss = 26.509 acc = 0.928 cr = 93.75\n",
      "\n",
      "Step 7799; Epoch 20/30:  371/391 Time: 0.09 cls_loss = 26.466 acc = 0.923 cr = 93.75\n",
      "\n",
      "Step 7899; Epoch 21/30:  80/391 Time: 0.09 cls_loss = 26.399 acc = 0.927 cr = 94.19\n",
      "\n",
      "Step 7999; Epoch 21/30:  180/391 Time: 0.09 cls_loss = 26.340 acc = 0.929 cr = 94.19\n",
      "\n",
      "Step 8099; Epoch 21/30:  280/391 Time: 0.09 cls_loss = 26.294 acc = 0.929 cr = 94.19\n",
      "\n",
      "Step 8199; Epoch 21/30:  380/391 Time: 0.09 cls_loss = 26.247 acc = 0.927 cr = 94.19\n",
      "\n",
      "Step 8299; Epoch 22/30:  89/391 Time: 0.09 cls_loss = 26.159 acc = 0.931 cr = 93.75\n",
      "\n",
      "Step 8399; Epoch 22/30:  189/391 Time: 0.09 cls_loss = 26.154 acc = 0.931 cr = 93.75\n",
      "\n",
      "Step 8499; Epoch 22/30:  289/391 Time: 0.09 cls_loss = 26.102 acc = 0.931 cr = 93.75\n",
      "\n",
      "Step 8599; Epoch 22/30:  389/391 Time: 0.09 cls_loss = 26.067 acc = 0.934 cr = 93.75\n",
      "\n",
      "Step 8699; Epoch 23/30:  98/391 Time: 0.09 cls_loss = 25.977 acc = 0.934 cr = 93.75\n",
      "\n",
      "Step 8799; Epoch 23/30:  198/391 Time: 0.09 cls_loss = 25.978 acc = 0.936 cr = 93.75\n",
      "\n",
      "Step 8899; Epoch 23/30:  298/391 Time: 0.08 cls_loss = 25.949 acc = 0.934 cr = 93.75\n",
      "\n",
      "Step 8999; Epoch 24/30:  7/391 Time: 0.09 cls_loss = 25.881 acc = 0.935 cr = 93.75\n",
      "\n",
      "Step 9099; Epoch 24/30:  107/391 Time: 0.08 cls_loss = 25.876 acc = 0.939 cr = 93.75\n",
      "\n",
      "Step 9199; Epoch 24/30:  207/391 Time: 0.08 cls_loss = 25.853 acc = 0.937 cr = 93.75\n",
      "\n",
      "Step 9299; Epoch 24/30:  307/391 Time: 0.09 cls_loss = 25.814 acc = 0.937 cr = 93.75\n",
      "\n",
      "Step 9399; Epoch 25/30:  16/391 Time: 0.09 cls_loss = 25.753 acc = 0.936 cr = 93.75\n",
      "\n",
      "Step 9499; Epoch 25/30:  116/391 Time: 0.08 cls_loss = 25.769 acc = 0.937 cr = 93.75\n",
      "\n",
      "Step 9599; Epoch 25/30:  216/391 Time: 0.09 cls_loss = 25.752 acc = 0.934 cr = 93.75\n",
      "\n",
      "Step 9699; Epoch 25/30:  316/391 Time: 0.09 cls_loss = 25.721 acc = 0.938 cr = 93.75\n",
      "\n",
      "Step 9799; Epoch 26/30:  25/391 Time: 0.09 cls_loss = 25.649 acc = 0.946 cr = 93.75\n",
      "\n",
      "Step 9899; Epoch 26/30:  125/391 Time: 0.09 cls_loss = 25.663 acc = 0.946 cr = 93.75\n",
      "\n",
      "Step 9999; Epoch 26/30:  225/391 Time: 0.09 cls_loss = 25.661 acc = 0.942 cr = 93.75\n",
      "\n",
      "Step 10099; Epoch 26/30:  325/391 Time: 0.09 cls_loss = 25.641 acc = 0.942 cr = 93.75\n",
      "\n",
      "Step 10199; Epoch 27/30:  34/391 Time: 0.09 cls_loss = 25.596 acc = 0.938 cr = 93.75\n",
      "\n",
      "Step 10299; Epoch 27/30:  134/391 Time: 0.09 cls_loss = 25.614 acc = 0.943 cr = 93.75\n",
      "\n",
      "Step 10399; Epoch 27/30:  234/391 Time: 0.09 cls_loss = 25.593 acc = 0.943 cr = 93.75\n",
      "\n",
      "Step 10499; Epoch 27/30:  334/391 Time: 0.09 cls_loss = 25.582 acc = 0.944 cr = 93.75\n",
      "\n",
      "Step 10599; Epoch 28/30:  43/391 Time: 0.09 cls_loss = 25.538 acc = 0.943 cr = 93.75\n",
      "\n",
      "Step 10699; Epoch 28/30:  143/391 Time: 0.09 cls_loss = 25.569 acc = 0.941 cr = 93.75\n",
      "\n",
      "Step 10799; Epoch 28/30:  243/391 Time: 0.09 cls_loss = 25.552 acc = 0.946 cr = 93.75\n",
      "\n",
      "Step 10899; Epoch 28/30:  343/391 Time: 0.09 cls_loss = 25.538 acc = 0.949 cr = 93.75\n",
      "\n",
      "Step 10999; Epoch 29/30:  52/391 Time: 0.09 cls_loss = 25.499 acc = 0.947 cr = 93.75\n",
      "\n",
      "Step 11099; Epoch 29/30:  152/391 Time: 0.09 cls_loss = 25.531 acc = 0.947 cr = 93.75\n",
      "\n",
      "Step 11199; Epoch 29/30:  252/391 Time: 0.09 cls_loss = 25.529 acc = 0.946 cr = 93.75\n",
      "\n",
      "Step 11299; Epoch 29/30:  352/391 Time: 0.09 cls_loss = 25.525 acc = 0.945 cr = 93.75\n",
      "\n",
      "Step 11399; Epoch 30/30:  61/391 Time: 0.09 cls_loss = 25.481 acc = 0.946 cr = 93.75\n",
      "\n",
      "Step 11499; Epoch 30/30:  161/391 Time: 0.09 cls_loss = 25.508 acc = 0.946 cr = 93.75\n",
      "\n",
      "Step 11599; Epoch 30/30:  261/391 Time: 0.09 cls_loss = 25.511 acc = 0.945 cr = 93.75\n",
      "\n",
      "Step 11699; Epoch 30/30:  361/391 Time: 0.08 cls_loss = 25.505 acc = 0.948 cr = 93.75\n",
      "\n",
      "Finished rank adaptive training, start finetuning\n",
      "Step 11799; Epoch 1/2:  70/391 Time: 0.08 cls_loss = 23.214 acc = 0.944 cr = 94.19\n",
      "\n",
      "Step 11899; Epoch 1/2:  170/391 Time: 0.09 cls_loss = 22.218 acc = 0.948 cr = 94.19\n",
      "\n",
      "Step 11999; Epoch 1/2:  270/391 Time: 0.09 cls_loss = 22.219 acc = 0.946 cr = 94.19\n",
      "\n",
      "Step 12099; Epoch 1/2:  370/391 Time: 0.09 cls_loss = 22.211 acc = 0.950 cr = 94.19\n",
      "\n",
      "Step 12199; Epoch 2/2:  79/391 Time: 0.09 cls_loss = 22.225 acc = 0.942 cr = 94.19\n",
      "\n",
      "Step 12299; Epoch 2/2:  179/391 Time: 0.09 cls_loss = 22.221 acc = 0.945 cr = 94.19\n",
      "\n",
      "Step 12399; Epoch 2/2:  279/391 Time: 0.09 cls_loss = 22.208 acc = 0.949 cr = 94.19\n",
      "\n",
      "Step 12499; Epoch 2/2:  379/391 Time: 0.09 cls_loss = 22.214 acc = 0.947 cr = 94.19\n",
      "\n",
      "Save Model at ./model_lr/Cifar10/Pretrain/low_rank/vgg16/beta_0.15_tol_0.05_rmax_150.0_init_rank_50.0.pth\n",
      "Epoch[2]-Validation-[10/79] Batch OA: 89.84 %\n",
      "Epoch[2]-Validation-[20/79] Batch OA: 85.16 %\n",
      "Epoch[2]-Validation-[30/79] Batch OA: 86.72 %\n",
      "Epoch[2]-Validation-[40/79] Batch OA: 89.06 %\n",
      "Epoch[2]-Validation-[50/79] Batch OA: 95.31 %\n",
      "Epoch[2]-Validation-[60/79] Batch OA: 87.50 %\n",
      "Epoch[2]-Validation-[70/79] Batch OA: 90.62 %\n",
      "---------------Accuracy of     airplane : 92.40 %---------------\n",
      "---------------Accuracy of   automobile : 95.70 %---------------\n",
      "---------------Accuracy of         bird : 86.50 %---------------\n",
      "---------------Accuracy of          cat : 74.70 %---------------\n",
      "---------------Accuracy of         deer : 89.00 %---------------\n",
      "---------------Accuracy of          dog : 82.10 %---------------\n",
      "---------------Accuracy of         frog : 92.80 %---------------\n",
      "---------------Accuracy of        horse : 91.00 %---------------\n",
      "---------------Accuracy of         ship : 94.30 %---------------\n",
      "---------------Accuracy of        truck : 93.60 %---------------\n",
      "---------------Epoch[2]Validation-OA: 89.21 %---------------\n",
      "---------------Epoch[2]Validation-AA: 89.21 %---------------\n",
      "Delete all attacked images for a fresh start\n",
      "Deleting ./dataset/data_adversarial_rs/Cifar10_adv/condlr_fgsm/low_rank/vgg16/*.png\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "rm: cannot remove './dataset/data_adversarial_rs/Cifar10_adv/condlr_fgsm/low_rank/vgg16/*.png': No such file or directory\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Network: vgg16\n",
      "load model from\n",
      "./model_lr/Cifar10/Pretrain/low_rank/vgg16/beta_0.15_tol_0.05_rmax_150.0_init_rank_50.0.pth\n",
      "Low rank layer ranks\n",
      "[64, 3, 3, 3]\n",
      "10\n",
      "[32, 30, 3, 3]\n",
      "[62, 32, 3, 3]\n",
      "[64, 64, 3, 3]\n",
      "[115, 64, 3, 3]\n",
      "[123, 109, 3, 3]\n",
      "[125, 115, 3, 3]\n",
      "[146, 106, 3, 3]\n",
      "[143, 128, 3, 3]\n",
      "[142, 132, 3, 3]\n",
      "[147, 130, 3, 3]\n",
      "[140, 131, 3, 3]\n",
      "[133, 133, 3, 3]\n",
      "146\n",
      "137\n",
      "------\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Batch: 10000/10000: 100%|██████████| 10000/10000 [02:52<00:00, 58.00it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Loading data...\n",
      "./Cifar10_adv/condlr_fgsm/low_rank/vgg16/\n",
      "./dataset/Cifar10_test.txt\n",
      "----\n",
      "[64, 3, 3, 3]\n",
      "10\n",
      "[32, 30, 3, 3]\n",
      "[62, 32, 3, 3]\n",
      "[64, 64, 3, 3]\n",
      "[115, 64, 3, 3]\n",
      "[123, 109, 3, 3]\n",
      "[125, 115, 3, 3]\n",
      "[146, 106, 3, 3]\n",
      "[143, 128, 3, 3]\n",
      "[142, 132, 3, 3]\n",
      "[147, 130, 3, 3]\n",
      "[140, 131, 3, 3]\n",
      "[133, 133, 3, 3]\n",
      "146\n",
      "137\n",
      "Epoch[1]-Validation-[10/79] Batch OA: 89.84 %\n",
      "Epoch[1]-Validation-[20/79] Batch OA: 85.16 %\n",
      "Epoch[1]-Validation-[30/79] Batch OA: 86.72 %\n",
      "Epoch[1]-Validation-[40/79] Batch OA: 89.06 %\n",
      "Epoch[1]-Validation-[50/79] Batch OA: 95.31 %\n",
      "Epoch[1]-Validation-[60/79] Batch OA: 87.50 %\n",
      "Epoch[1]-Validation-[70/79] Batch OA: 90.62 %\n",
      "---------------Accuracy of     airplane : 92.40 %---------------\n",
      "---------------Accuracy of   automobile : 95.70 %---------------\n",
      "---------------Accuracy of         bird : 86.50 %---------------\n",
      "---------------Accuracy of          cat : 74.70 %---------------\n",
      "---------------Accuracy of         deer : 89.00 %---------------\n",
      "---------------Accuracy of          dog : 82.10 %---------------\n",
      "---------------Accuracy of         frog : 92.80 %---------------\n",
      "---------------Accuracy of        horse : 91.00 %---------------\n",
      "---------------Accuracy of         ship : 94.30 %---------------\n",
      "---------------Accuracy of        truck : 93.60 %---------------\n",
      "---------------Epoch[1]Validation-OA: 89.21 %---------------\n",
      "---------------Epoch[1]Validation-AA: 89.21 %---------------\n",
      "Epoch[1]-Validation-[10/79] Batch OA: 75.78 %\n",
      "Epoch[1]-Validation-[20/79] Batch OA: 67.97 %\n",
      "Epoch[1]-Validation-[30/79] Batch OA: 78.91 %\n",
      "Epoch[1]-Validation-[40/79] Batch OA: 77.34 %\n",
      "Epoch[1]-Validation-[50/79] Batch OA: 84.38 %\n",
      "Epoch[1]-Validation-[60/79] Batch OA: 77.34 %\n",
      "Epoch[1]-Validation-[70/79] Batch OA: 78.91 %\n",
      "---------------Accuracy of     airplane : 80.40 %---------------\n",
      "---------------Accuracy of   automobile : 90.30 %---------------\n",
      "---------------Accuracy of         bird : 69.60 %---------------\n",
      "---------------Accuracy of          cat : 53.00 %---------------\n",
      "---------------Accuracy of         deer : 68.60 %---------------\n",
      "---------------Accuracy of          dog : 68.60 %---------------\n",
      "---------------Accuracy of         frog : 83.60 %---------------\n",
      "---------------Accuracy of        horse : 81.20 %---------------\n",
      "---------------Accuracy of         ship : 85.60 %---------------\n",
      "---------------Accuracy of        truck : 85.80 %---------------\n",
      "---------------Epoch[1]Validation-OA: 76.67 %---------------\n",
      "---------------Epoch[1]Validation-AA: 76.67 %---------------\n",
      "Clean Test Set OA: 89.21\n",
      "condlr_fgsm Test Set OA: 76.67\n",
      "Delete all attacked images for a fresh start\n",
      "Deleting ./dataset/data_adversarial_rs/Cifar10_adv/condlr_fgsm/low_rank/vgg16/*.png\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "rm: cannot remove './dataset/data_adversarial_rs/Cifar10_adv/condlr_fgsm/low_rank/vgg16/*.png': No such file or directory\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Network: vgg16\n",
      "load model from\n",
      "./model_lr/Cifar10/Pretrain/low_rank/vgg16/beta_0.15_tol_0.05_rmax_150.0_init_rank_50.0.pth\n",
      "Low rank layer ranks\n",
      "[64, 3, 3, 3]\n",
      "10\n",
      "[32, 30, 3, 3]\n",
      "[62, 32, 3, 3]\n",
      "[64, 64, 3, 3]\n",
      "[115, 64, 3, 3]\n",
      "[123, 109, 3, 3]\n",
      "[125, 115, 3, 3]\n",
      "[146, 106, 3, 3]\n",
      "[143, 128, 3, 3]\n",
      "[142, 132, 3, 3]\n",
      "[147, 130, 3, 3]\n",
      "[140, 131, 3, 3]\n",
      "[133, 133, 3, 3]\n",
      "146\n",
      "137\n",
      "------\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Batch: 10000/10000: 100%|██████████| 10000/10000 [03:06<00:00, 53.61it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Loading data...\n",
      "./Cifar10_adv/condlr_fgsm/low_rank/vgg16/\n",
      "./dataset/Cifar10_test.txt\n",
      "----\n",
      "[64, 3, 3, 3]\n",
      "10\n",
      "[32, 30, 3, 3]\n",
      "[62, 32, 3, 3]\n",
      "[64, 64, 3, 3]\n",
      "[115, 64, 3, 3]\n",
      "[123, 109, 3, 3]\n",
      "[125, 115, 3, 3]\n",
      "[146, 106, 3, 3]\n",
      "[143, 128, 3, 3]\n",
      "[142, 132, 3, 3]\n",
      "[147, 130, 3, 3]\n",
      "[140, 131, 3, 3]\n",
      "[133, 133, 3, 3]\n",
      "146\n",
      "137\n",
      "Epoch[1]-Validation-[10/79] Batch OA: 89.84 %\n",
      "Epoch[1]-Validation-[20/79] Batch OA: 85.16 %\n",
      "Epoch[1]-Validation-[30/79] Batch OA: 86.72 %\n",
      "Epoch[1]-Validation-[40/79] Batch OA: 89.06 %\n",
      "Epoch[1]-Validation-[50/79] Batch OA: 95.31 %\n",
      "Epoch[1]-Validation-[60/79] Batch OA: 87.50 %\n",
      "Epoch[1]-Validation-[70/79] Batch OA: 90.62 %\n",
      "---------------Accuracy of     airplane : 92.40 %---------------\n",
      "---------------Accuracy of   automobile : 95.70 %---------------\n",
      "---------------Accuracy of         bird : 86.50 %---------------\n",
      "---------------Accuracy of          cat : 74.70 %---------------\n",
      "---------------Accuracy of         deer : 89.00 %---------------\n",
      "---------------Accuracy of          dog : 82.10 %---------------\n",
      "---------------Accuracy of         frog : 92.80 %---------------\n",
      "---------------Accuracy of        horse : 91.00 %---------------\n",
      "---------------Accuracy of         ship : 94.30 %---------------\n",
      "---------------Accuracy of        truck : 93.60 %---------------\n",
      "---------------Epoch[1]Validation-OA: 89.21 %---------------\n",
      "---------------Epoch[1]Validation-AA: 89.21 %---------------\n",
      "Epoch[1]-Validation-[10/79] Batch OA: 57.81 %\n",
      "Epoch[1]-Validation-[20/79] Batch OA: 57.81 %\n",
      "Epoch[1]-Validation-[30/79] Batch OA: 64.84 %\n",
      "Epoch[1]-Validation-[40/79] Batch OA: 59.38 %\n",
      "Epoch[1]-Validation-[50/79] Batch OA: 73.44 %\n",
      "Epoch[1]-Validation-[60/79] Batch OA: 60.16 %\n",
      "Epoch[1]-Validation-[70/79] Batch OA: 64.84 %\n",
      "---------------Accuracy of     airplane : 66.10 %---------------\n",
      "---------------Accuracy of   automobile : 82.90 %---------------\n",
      "---------------Accuracy of         bird : 52.90 %---------------\n",
      "---------------Accuracy of          cat : 35.40 %---------------\n",
      "---------------Accuracy of         deer : 47.80 %---------------\n",
      "---------------Accuracy of          dog : 54.20 %---------------\n",
      "---------------Accuracy of         frog : 69.90 %---------------\n",
      "---------------Accuracy of        horse : 68.40 %---------------\n",
      "---------------Accuracy of         ship : 73.90 %---------------\n",
      "---------------Accuracy of        truck : 75.40 %---------------\n",
      "---------------Epoch[1]Validation-OA: 62.69 %---------------\n",
      "---------------Epoch[1]Validation-AA: 62.69 %---------------\n",
      "Clean Test Set OA: 89.21\n",
      "condlr_fgsm Test Set OA: 62.69\n",
      "Delete all attacked images for a fresh start\n",
      "Deleting ./dataset/data_adversarial_rs/Cifar10_adv/condlr_fgsm/low_rank/vgg16/*.png\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "rm: cannot remove './dataset/data_adversarial_rs/Cifar10_adv/condlr_fgsm/low_rank/vgg16/*.png': No such file or directory\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Network: vgg16\n",
      "load model from\n",
      "./model_lr/Cifar10/Pretrain/low_rank/vgg16/beta_0.15_tol_0.05_rmax_150.0_init_rank_50.0.pth\n",
      "Low rank layer ranks\n",
      "[64, 3, 3, 3]\n",
      "10\n",
      "[32, 30, 3, 3]\n",
      "[62, 32, 3, 3]\n",
      "[64, 64, 3, 3]\n",
      "[115, 64, 3, 3]\n",
      "[123, 109, 3, 3]\n",
      "[125, 115, 3, 3]\n",
      "[146, 106, 3, 3]\n",
      "[143, 128, 3, 3]\n",
      "[142, 132, 3, 3]\n",
      "[147, 130, 3, 3]\n",
      "[140, 131, 3, 3]\n",
      "[133, 133, 3, 3]\n",
      "146\n",
      "137\n",
      "------\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Batch: 10000/10000: 100%|██████████| 10000/10000 [03:22<00:00, 49.45it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Loading data...\n",
      "./Cifar10_adv/condlr_fgsm/low_rank/vgg16/\n",
      "./dataset/Cifar10_test.txt\n",
      "----\n",
      "[64, 3, 3, 3]\n",
      "10\n",
      "[32, 30, 3, 3]\n",
      "[62, 32, 3, 3]\n",
      "[64, 64, 3, 3]\n",
      "[115, 64, 3, 3]\n",
      "[123, 109, 3, 3]\n",
      "[125, 115, 3, 3]\n",
      "[146, 106, 3, 3]\n",
      "[143, 128, 3, 3]\n",
      "[142, 132, 3, 3]\n",
      "[147, 130, 3, 3]\n",
      "[140, 131, 3, 3]\n",
      "[133, 133, 3, 3]\n",
      "146\n",
      "137\n",
      "Epoch[1]-Validation-[10/79] Batch OA: 89.84 %\n",
      "Epoch[1]-Validation-[20/79] Batch OA: 85.16 %\n",
      "Epoch[1]-Validation-[30/79] Batch OA: 86.72 %\n",
      "Epoch[1]-Validation-[40/79] Batch OA: 89.06 %\n",
      "Epoch[1]-Validation-[50/79] Batch OA: 95.31 %\n",
      "Epoch[1]-Validation-[60/79] Batch OA: 87.50 %\n",
      "Epoch[1]-Validation-[70/79] Batch OA: 90.62 %\n",
      "---------------Accuracy of     airplane : 92.40 %---------------\n",
      "---------------Accuracy of   automobile : 95.70 %---------------\n",
      "---------------Accuracy of         bird : 86.50 %---------------\n",
      "---------------Accuracy of          cat : 74.70 %---------------\n",
      "---------------Accuracy of         deer : 89.00 %---------------\n",
      "---------------Accuracy of          dog : 82.10 %---------------\n",
      "---------------Accuracy of         frog : 92.80 %---------------\n",
      "---------------Accuracy of        horse : 91.00 %---------------\n",
      "---------------Accuracy of         ship : 94.30 %---------------\n",
      "---------------Accuracy of        truck : 93.60 %---------------\n",
      "---------------Epoch[1]Validation-OA: 89.21 %---------------\n",
      "---------------Epoch[1]Validation-AA: 89.21 %---------------\n",
      "Epoch[1]-Validation-[10/79] Batch OA: 48.44 %\n",
      "Epoch[1]-Validation-[20/79] Batch OA: 50.00 %\n",
      "Epoch[1]-Validation-[30/79] Batch OA: 51.56 %\n",
      "Epoch[1]-Validation-[40/79] Batch OA: 46.09 %\n",
      "Epoch[1]-Validation-[50/79] Batch OA: 53.91 %\n",
      "Epoch[1]-Validation-[60/79] Batch OA: 43.75 %\n",
      "Epoch[1]-Validation-[70/79] Batch OA: 51.56 %\n",
      "---------------Accuracy of     airplane : 52.50 %---------------\n",
      "---------------Accuracy of   automobile : 73.10 %---------------\n",
      "---------------Accuracy of         bird : 39.20 %---------------\n",
      "---------------Accuracy of          cat : 22.60 %---------------\n",
      "---------------Accuracy of         deer : 29.60 %---------------\n",
      "---------------Accuracy of          dog : 42.10 %---------------\n",
      "---------------Accuracy of         frog : 55.20 %---------------\n",
      "---------------Accuracy of        horse : 55.30 %---------------\n",
      "---------------Accuracy of         ship : 61.90 %---------------\n",
      "---------------Accuracy of        truck : 62.70 %---------------\n",
      "---------------Epoch[1]Validation-OA: 49.42 %---------------\n",
      "---------------Epoch[1]Validation-AA: 49.42 %---------------\n",
      "Clean Test Set OA: 89.21\n",
      "condlr_fgsm Test Set OA: 49.419999999999995\n"
     ]
    }
   ],
   "source": [
    "%%bash\n",
    "#wandb login --relogin\n",
    "\n",
    "run_train=true #\n",
    "run_attack=true\n",
    "\n",
    "iterations=\"1\"\n",
    "\n",
    "#regularization\n",
    "regularizer_betas=\"0.0 0.15\" # 0.05 0.1 0.15 0.2\"\n",
    "\n",
    "# adversarial attack\n",
    "attack_f_1=\"condlr_fgsm\"\n",
    "epsilons_1=\"0.002 0.004 0.006\"\n",
    "\n",
    "# Model settings\n",
    "models=\"vgg16\" \n",
    "target_model=\"vgg16\"\n",
    "dataset=3\n",
    "crop_size=32\n",
    "\n",
    "\n",
    "#DLRT parameters\n",
    "tolerances=\"0.05\" # Good default value\n",
    "rmax=150 # Good default value\n",
    "num_local_iter=10  # Good default value\n",
    "r_init=50 # Good default value\n",
    "\n",
    "#training parameters\n",
    "train_batch_size=128\n",
    "val_batch_size=128\n",
    "weight_decay=0.0 \n",
    "num_epochs=30\n",
    "num_epochs_ft=2\n",
    "#logging\n",
    "wandb=0  # 1 for enabled, 0 for disabled\n",
    "wandb_tag=\"google_colab_example\"\n",
    "\n",
    "# Loop over all combinations of models and pretrained weights\n",
    "for i in $(seq 1 $iterations); do\n",
    "echo \"iteration $i\"\n",
    "for model in $models; do\n",
    "for tol in $tolerances; do\n",
    "for beta in $regularizer_betas; do\n",
    "    if [ \"$run_train\" = true ]; then\n",
    "        python pretrain_cls_low_rank_robustness.py \\\n",
    "            --dataID $dataset \\\n",
    "            --num_local_iter $num_local_iter \\\n",
    "            --rmax $rmax \\\n",
    "            --init_r $r_init \\\n",
    "            --lr 1e-4 \\\n",
    "            --num_epochs $num_epochs \\\n",
    "            --num_epochs_low_rank_ft $num_epochs_ft \\\n",
    "            --num_local_iter $num_local_iter \\\n",
    "            --network $model \\\n",
    "            --save_name \"$beta\" \\\n",
    "            --tol \"$tol\" \\\n",
    "            --wandb $wandb \\\n",
    "            --wandb_tag \"$wandb_tag\" \\\n",
    "            --robusteness_regularization_beta \"$beta\" \\\n",
    "            --val_batch_size $val_batch_size  \\\n",
    "            --train_batch_size $train_batch_size \\\n",
    "            --crop_size $crop_size \\\n",
    "            --weight_decay $weight_decay \\\n",
    "            --print_per_batches 100 \\\n",
    "            --load_model 0 \\\n",
    "            --root_dir ./ \\\n",
    "            --save_path_prefix ./model_lr/\n",
    "    fi\n",
    "    if [ \"$run_attack\" = true ]; then\n",
    "        for eps_attack in $epsilons_1; do\n",
    "            echo \"Delete all attacked images for a fresh start\"\n",
    "            if [ \"$dataset\" = \"3\" ]; then\n",
    "                echo \"Deleting ./dataset/data_adversarial_rs/Cifar10_adv/$attack_f_1/low_rank/$model/*.png\"\n",
    "                rm ./dataset/data_adversarial_rs/Cifar10_adv/$attack_f_1/low_rank/$model/*.png\n",
    "            fi\n",
    "            if [ \"$dataset\" = \"2\" ]; then\n",
    "                echo \"Deleting ./dataset/data_adversarial_rs/AID_adv/$attack_f_1/low_rank/$model/*.png\"\n",
    "                rm ./dataset/data_adversarial_rs/AID_adv/$attack_f_1/low_rank/$model/*.png\n",
    "            fi\n",
    "            if [ \"$dataset\" = \"1\" ]; then\n",
    "                echo \"Deleting ./dataset/data_adversarial_rs/UCM_adv/$attack_f_1/low_rank/$model/*.png\"\n",
    "                rm ./dataset/data_adversarial_rs/UCM_adv/$attack_f_1/low_rank/$model/*.png\n",
    "            fi\n",
    "          \n",
    "            python attack_cls_low_rank.py \\\n",
    "                --dataID $dataset \\\n",
    "                --rmax $rmax \\\n",
    "                --init_r $r_init \\\n",
    "                --tol \"$tol\" \\\n",
    "                --robusteness_regularization_beta \"$beta\" \\\n",
    "                --attack_func $attack_f_1 \\\n",
    "                --network $model \\\n",
    "                --crop_size $crop_size \\\n",
    "                --epsilon $eps_attack \\\n",
    "                --save_path_prefix ./  \\\n",
    "                --model_root_dir ./model_lr/\n",
    "                         \n",
    "            python test_cls_low_rank.py \\\n",
    "                --dataID $dataset \\\n",
    "                --rmax $rmax \\\n",
    "                --init_r $r_init \\\n",
    "                --tol \"$tol\" \\\n",
    "                --robusteness_regularization_beta \"$beta\" \\\n",
    "                --target_network $target_model \\\n",
    "                --surrogate_network $model \\\n",
    "                --attack_func $attack_f_1 \\\n",
    "                --wandb $wandb \\\n",
    "                --wandb_tag \"$wandb_tag\" \\\n",
    "                --attack_epsilon $eps_attack \\\n",
    "                --crop_size $crop_size \\\n",
    "                --val_batch_size $val_batch_size \\\n",
    "                --root_dir ./ \\\n",
    "                --save_path_prefix ./model_lr/\n",
    "        done\n",
    "    fi\n",
    "done\n",
    "done\n",
    "done\n",
    "done\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "5e740383-6622-4361-b34d-4f8240cfade8",
   "metadata": {},
   "source": [
    "3) Run the full rank baseline models as comparison. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "6da5aa80-3c6d-407b-98bf-d4a19f961f3f",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "iteration 1\n",
      "[INFO] Saving path for the pre-trained model:  ./models/Cifar10/Pretrain/baseline/vgg16/\n",
      "VGG(\n",
      "  (features): Sequential(\n",
      "    (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
      "    (1): ReLU(inplace=True)\n",
      "    (2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
      "    (3): ReLU(inplace=True)\n",
      "    (4): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)\n",
      "    (5): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
      "    (6): ReLU(inplace=True)\n",
      "    (7): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
      "    (8): ReLU(inplace=True)\n",
      "    (9): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)\n",
      "    (10): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
      "    (11): ReLU(inplace=True)\n",
      "    (12): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
      "    (13): ReLU(inplace=True)\n",
      "    (14): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
      "    (15): ReLU(inplace=True)\n",
      "    (16): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)\n",
      "    (17): Conv2d(256, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
      "    (18): ReLU(inplace=True)\n",
      "    (19): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
      "    (20): ReLU(inplace=True)\n",
      "    (21): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
      "    (22): ReLU(inplace=True)\n",
      "    (23): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)\n",
      "    (24): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
      "    (25): ReLU(inplace=True)\n",
      "    (26): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
      "    (27): ReLU(inplace=True)\n",
      "    (28): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
      "    (29): ReLU(inplace=True)\n",
      "    (30): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)\n",
      "  )\n",
      "  (avgpool): AdaptiveAvgPool2d(output_size=(7, 7))\n",
      "  (classifier): Sequential(\n",
      "    (0): Linear(in_features=25088, out_features=4096, bias=True)\n",
      "    (1): ReLU(inplace=True)\n",
      "    (2): Dropout(p=0.5, inplace=False)\n",
      "    (3): Linear(in_features=4096, out_features=4096, bias=True)\n",
      "    (4): ReLU(inplace=True)\n",
      "    (5): Dropout(p=0.5, inplace=False)\n",
      "    (6): Linear(in_features=4096, out_features=10, bias=True)\n",
      "  )\n",
      ")\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/8v5/Desktop/adversarial_rs_low_rank/venv/lib/python3.10/site-packages/numpy/_core/fromnumeric.py:3860: RuntimeWarning: Mean of empty slice.\n",
      "  return _methods._mean(a, axis=axis, dtype=dtype,\n",
      "/home/8v5/Desktop/adversarial_rs_low_rank/venv/lib/python3.10/site-packages/numpy/_core/_methods.py:145: RuntimeWarning: invalid value encountered in scalar divide\n",
      "  ret = ret.dtype.type(ret / rcount)\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Step 99; Epoch 1/30:  99/391 Time: nan cls_loss = nan acc = nan\n",
      "\n",
      "Step 199; Epoch 1/30:  199/391 Time: 0.02 cls_loss = 0.734 acc = 0.757\n",
      "\n",
      "Step 299; Epoch 1/30:  299/391 Time: 0.02 cls_loss = 0.635 acc = 0.788\n",
      "\n",
      "Step 399; Epoch 2/30:  8/391 Time: 0.02 cls_loss = 0.564 acc = 0.812\n",
      "\n",
      "Step 499; Epoch 2/30:  108/391 Time: 0.02 cls_loss = 0.493 acc = 0.834\n",
      "\n",
      "Step 599; Epoch 2/30:  208/391 Time: 0.02 cls_loss = 0.454 acc = 0.851\n",
      "\n",
      "Step 699; Epoch 2/30:  308/391 Time: 0.02 cls_loss = 0.452 acc = 0.846\n",
      "\n",
      "Step 799; Epoch 3/30:  17/391 Time: 0.02 cls_loss = 0.435 acc = 0.852\n",
      "\n",
      "Step 899; Epoch 3/30:  117/391 Time: 0.02 cls_loss = 0.363 acc = 0.878\n",
      "\n",
      "Step 999; Epoch 3/30:  217/391 Time: 0.02 cls_loss = 0.360 acc = 0.882\n",
      "\n",
      "Step 1099; Epoch 3/30:  317/391 Time: 0.02 cls_loss = 0.361 acc = 0.879\n",
      "\n",
      "Step 1199; Epoch 4/30:  26/391 Time: 0.02 cls_loss = 0.316 acc = 0.893\n",
      "\n",
      "Step 1299; Epoch 4/30:  126/391 Time: 0.02 cls_loss = 0.284 acc = 0.902\n",
      "\n",
      "Step 1399; Epoch 4/30:  226/391 Time: 0.02 cls_loss = 0.290 acc = 0.902\n",
      "\n",
      "Step 1499; Epoch 4/30:  326/391 Time: 0.02 cls_loss = 0.288 acc = 0.906\n",
      "\n",
      "Step 1599; Epoch 5/30:  35/391 Time: 0.02 cls_loss = 0.267 acc = 0.911\n",
      "\n",
      "Step 1699; Epoch 5/30:  135/391 Time: 0.02 cls_loss = 0.233 acc = 0.923\n",
      "\n",
      "Step 1799; Epoch 5/30:  235/391 Time: 0.02 cls_loss = 0.227 acc = 0.923\n",
      "\n",
      "Step 1899; Epoch 5/30:  335/391 Time: 0.02 cls_loss = 0.248 acc = 0.916\n",
      "\n",
      "Step 1999; Epoch 6/30:  44/391 Time: 0.02 cls_loss = 0.220 acc = 0.925\n",
      "\n",
      "Step 2099; Epoch 6/30:  144/391 Time: 0.02 cls_loss = 0.197 acc = 0.934\n",
      "\n",
      "Step 2199; Epoch 6/30:  244/391 Time: 0.02 cls_loss = 0.204 acc = 0.931\n",
      "\n",
      "Step 2299; Epoch 6/30:  344/391 Time: 0.02 cls_loss = 0.180 acc = 0.938\n",
      "\n",
      "Step 2399; Epoch 7/30:  53/391 Time: 0.02 cls_loss = 0.178 acc = 0.940\n",
      "\n",
      "Step 2499; Epoch 7/30:  153/391 Time: 0.02 cls_loss = 0.174 acc = 0.942\n",
      "\n",
      "Step 2599; Epoch 7/30:  253/391 Time: 0.02 cls_loss = 0.172 acc = 0.942\n",
      "\n",
      "Step 2699; Epoch 7/30:  353/391 Time: 0.02 cls_loss = 0.164 acc = 0.947\n",
      "\n",
      "Step 2799; Epoch 8/30:  62/391 Time: 0.02 cls_loss = 0.141 acc = 0.953\n",
      "\n",
      "Step 2899; Epoch 8/30:  162/391 Time: 0.02 cls_loss = 0.146 acc = 0.951\n",
      "\n",
      "Step 2999; Epoch 8/30:  262/391 Time: 0.02 cls_loss = 0.136 acc = 0.953\n",
      "\n",
      "Step 3099; Epoch 8/30:  362/391 Time: 0.02 cls_loss = 0.138 acc = 0.953\n",
      "\n",
      "Step 3199; Epoch 9/30:  71/391 Time: 0.02 cls_loss = 0.125 acc = 0.959\n",
      "\n",
      "Step 3299; Epoch 9/30:  171/391 Time: 0.02 cls_loss = 0.129 acc = 0.958\n",
      "\n",
      "Step 3399; Epoch 9/30:  271/391 Time: 0.02 cls_loss = 0.114 acc = 0.960\n",
      "\n",
      "Step 3499; Epoch 9/30:  371/391 Time: 0.02 cls_loss = 0.114 acc = 0.961\n",
      "\n",
      "Step 3599; Epoch 10/30:  80/391 Time: 0.02 cls_loss = 0.093 acc = 0.968\n",
      "\n",
      "Step 3699; Epoch 10/30:  180/391 Time: 0.02 cls_loss = 0.096 acc = 0.968\n",
      "\n",
      "Step 3799; Epoch 10/30:  280/391 Time: 0.02 cls_loss = 0.098 acc = 0.968\n",
      "\n",
      "Step 3899; Epoch 10/30:  380/391 Time: 0.02 cls_loss = 0.101 acc = 0.967\n",
      "\n",
      "Step 3999; Epoch 11/30:  89/391 Time: 0.02 cls_loss = 0.074 acc = 0.974\n",
      "\n",
      "Step 4099; Epoch 11/30:  189/391 Time: 0.02 cls_loss = 0.086 acc = 0.971\n",
      "\n",
      "Step 4199; Epoch 11/30:  289/391 Time: 0.02 cls_loss = 0.086 acc = 0.972\n",
      "\n",
      "Step 4299; Epoch 11/30:  389/391 Time: 0.02 cls_loss = 0.084 acc = 0.973\n",
      "\n",
      "Step 4399; Epoch 12/30:  98/391 Time: 0.02 cls_loss = 0.068 acc = 0.978\n",
      "\n",
      "Step 4499; Epoch 12/30:  198/391 Time: 0.02 cls_loss = 0.068 acc = 0.978\n",
      "\n",
      "Step 4599; Epoch 12/30:  298/391 Time: 0.02 cls_loss = 0.085 acc = 0.972\n",
      "\n",
      "Step 4699; Epoch 13/30:  7/391 Time: 0.02 cls_loss = 0.071 acc = 0.975\n",
      "\n",
      "Step 4799; Epoch 13/30:  107/391 Time: 0.02 cls_loss = 0.055 acc = 0.981\n",
      "\n",
      "Step 4899; Epoch 13/30:  207/391 Time: 0.02 cls_loss = 0.062 acc = 0.980\n",
      "\n",
      "Step 4999; Epoch 13/30:  307/391 Time: 0.02 cls_loss = 0.054 acc = 0.982\n",
      "\n",
      "Step 5099; Epoch 14/30:  16/391 Time: 0.02 cls_loss = 0.063 acc = 0.981\n",
      "\n",
      "Step 5199; Epoch 14/30:  116/391 Time: 0.02 cls_loss = 0.055 acc = 0.982\n",
      "\n",
      "Step 5299; Epoch 14/30:  216/391 Time: 0.02 cls_loss = 0.057 acc = 0.982\n",
      "\n",
      "Step 5399; Epoch 14/30:  316/391 Time: 0.02 cls_loss = 0.047 acc = 0.984\n",
      "\n",
      "Step 5499; Epoch 15/30:  25/391 Time: 0.02 cls_loss = 0.051 acc = 0.984\n",
      "\n",
      "Step 5599; Epoch 15/30:  125/391 Time: 0.02 cls_loss = 0.037 acc = 0.989\n",
      "\n",
      "Step 5699; Epoch 15/30:  225/391 Time: 0.02 cls_loss = 0.044 acc = 0.985\n",
      "\n",
      "Step 5799; Epoch 15/30:  325/391 Time: 0.02 cls_loss = 0.045 acc = 0.987\n",
      "\n",
      "Step 5899; Epoch 16/30:  34/391 Time: 0.02 cls_loss = 0.040 acc = 0.987\n",
      "\n",
      "Step 5999; Epoch 16/30:  134/391 Time: 0.02 cls_loss = 0.037 acc = 0.988\n",
      "\n",
      "Step 6099; Epoch 16/30:  234/391 Time: 0.02 cls_loss = 0.036 acc = 0.988\n",
      "\n",
      "Step 6199; Epoch 16/30:  334/391 Time: 0.02 cls_loss = 0.036 acc = 0.988\n",
      "\n",
      "Step 6299; Epoch 17/30:  43/391 Time: 0.02 cls_loss = 0.034 acc = 0.989\n",
      "\n",
      "Step 6399; Epoch 17/30:  143/391 Time: 0.02 cls_loss = 0.032 acc = 0.990\n",
      "\n",
      "Step 6499; Epoch 17/30:  243/391 Time: 0.02 cls_loss = 0.031 acc = 0.991\n",
      "\n",
      "Step 6599; Epoch 17/30:  343/391 Time: 0.02 cls_loss = 0.029 acc = 0.990\n",
      "\n",
      "Step 6699; Epoch 18/30:  52/391 Time: 0.02 cls_loss = 0.033 acc = 0.989\n",
      "\n",
      "Step 6799; Epoch 18/30:  152/391 Time: 0.02 cls_loss = 0.027 acc = 0.992\n",
      "\n",
      "Step 6899; Epoch 18/30:  252/391 Time: 0.02 cls_loss = 0.024 acc = 0.992\n",
      "\n",
      "Step 6999; Epoch 18/30:  352/391 Time: 0.02 cls_loss = 0.027 acc = 0.992\n",
      "\n",
      "Step 7099; Epoch 19/30:  61/391 Time: 0.02 cls_loss = 0.021 acc = 0.993\n",
      "\n",
      "Step 7199; Epoch 19/30:  161/391 Time: 0.02 cls_loss = 0.026 acc = 0.992\n",
      "\n",
      "Step 7299; Epoch 19/30:  261/391 Time: 0.02 cls_loss = 0.020 acc = 0.993\n",
      "\n",
      "Step 7399; Epoch 19/30:  361/391 Time: 0.02 cls_loss = 0.025 acc = 0.992\n",
      "\n",
      "Step 7499; Epoch 20/30:  70/391 Time: 0.02 cls_loss = 0.019 acc = 0.993\n",
      "\n",
      "Step 7599; Epoch 20/30:  170/391 Time: 0.02 cls_loss = 0.022 acc = 0.993\n",
      "\n",
      "Step 7699; Epoch 20/30:  270/391 Time: 0.02 cls_loss = 0.013 acc = 0.996\n",
      "\n",
      "Step 7799; Epoch 20/30:  370/391 Time: 0.02 cls_loss = 0.014 acc = 0.996\n",
      "\n",
      "Step 7899; Epoch 21/30:  79/391 Time: 0.02 cls_loss = 0.012 acc = 0.996\n",
      "\n",
      "Step 7999; Epoch 21/30:  179/391 Time: 0.02 cls_loss = 0.017 acc = 0.995\n",
      "\n",
      "Step 8099; Epoch 21/30:  279/391 Time: 0.02 cls_loss = 0.014 acc = 0.996\n",
      "\n",
      "Step 8199; Epoch 21/30:  379/391 Time: 0.02 cls_loss = 0.016 acc = 0.995\n",
      "\n",
      "Step 8299; Epoch 22/30:  88/391 Time: 0.02 cls_loss = 0.013 acc = 0.997\n",
      "\n",
      "Step 8399; Epoch 22/30:  188/391 Time: 0.02 cls_loss = 0.010 acc = 0.997\n",
      "\n",
      "Step 8499; Epoch 22/30:  288/391 Time: 0.02 cls_loss = 0.011 acc = 0.997\n",
      "\n",
      "Step 8599; Epoch 22/30:  388/391 Time: 0.02 cls_loss = 0.011 acc = 0.997\n",
      "\n",
      "Step 8699; Epoch 23/30:  97/391 Time: 0.02 cls_loss = 0.008 acc = 0.997\n",
      "\n",
      "Step 8799; Epoch 23/30:  197/391 Time: 0.02 cls_loss = 0.009 acc = 0.996\n",
      "\n",
      "Step 8899; Epoch 23/30:  297/391 Time: 0.02 cls_loss = 0.008 acc = 0.997\n",
      "\n",
      "Step 8999; Epoch 24/30:  6/391 Time: 0.02 cls_loss = 0.010 acc = 0.997\n",
      "\n",
      "Step 9099; Epoch 24/30:  106/391 Time: 0.02 cls_loss = 0.007 acc = 0.998\n",
      "\n",
      "Step 9199; Epoch 24/30:  206/391 Time: 0.02 cls_loss = 0.006 acc = 0.998\n",
      "\n",
      "Step 9299; Epoch 24/30:  306/391 Time: 0.02 cls_loss = 0.009 acc = 0.997\n",
      "\n",
      "Step 9399; Epoch 25/30:  15/391 Time: 0.02 cls_loss = 0.006 acc = 0.998\n",
      "\n",
      "Step 9499; Epoch 25/30:  115/391 Time: 0.02 cls_loss = 0.006 acc = 0.998\n",
      "\n",
      "Step 9599; Epoch 25/30:  215/391 Time: 0.02 cls_loss = 0.006 acc = 0.998\n",
      "\n",
      "Step 9699; Epoch 25/30:  315/391 Time: 0.02 cls_loss = 0.008 acc = 0.998\n",
      "\n",
      "Step 9799; Epoch 26/30:  24/391 Time: 0.02 cls_loss = 0.005 acc = 0.998\n",
      "\n",
      "Step 9899; Epoch 26/30:  124/391 Time: 0.02 cls_loss = 0.004 acc = 0.999\n",
      "\n",
      "Step 9999; Epoch 26/30:  224/391 Time: 0.02 cls_loss = 0.006 acc = 0.998\n",
      "\n",
      "Step 10099; Epoch 26/30:  324/391 Time: 0.02 cls_loss = 0.005 acc = 0.998\n",
      "\n",
      "Step 10199; Epoch 27/30:  33/391 Time: 0.02 cls_loss = 0.006 acc = 0.998\n",
      "\n",
      "Step 10299; Epoch 27/30:  133/391 Time: 0.02 cls_loss = 0.007 acc = 0.999\n",
      "\n",
      "Step 10399; Epoch 27/30:  233/391 Time: 0.02 cls_loss = 0.006 acc = 0.998\n",
      "\n",
      "Step 10499; Epoch 27/30:  333/391 Time: 0.02 cls_loss = 0.004 acc = 0.999\n",
      "\n",
      "Step 10599; Epoch 28/30:  42/391 Time: 0.02 cls_loss = 0.003 acc = 0.999\n",
      "\n",
      "Step 10699; Epoch 28/30:  142/391 Time: 0.02 cls_loss = 0.005 acc = 0.999\n",
      "\n",
      "Step 10799; Epoch 28/30:  242/391 Time: 0.02 cls_loss = 0.005 acc = 0.998\n",
      "\n",
      "Step 10899; Epoch 28/30:  342/391 Time: 0.02 cls_loss = 0.004 acc = 0.999\n",
      "\n",
      "Step 10999; Epoch 29/30:  51/391 Time: 0.02 cls_loss = 0.003 acc = 0.999\n",
      "\n",
      "Step 11099; Epoch 29/30:  151/391 Time: 0.02 cls_loss = 0.004 acc = 0.999\n",
      "\n",
      "Step 11199; Epoch 29/30:  251/391 Time: 0.02 cls_loss = 0.004 acc = 0.999\n",
      "\n",
      "Step 11299; Epoch 29/30:  351/391 Time: 0.02 cls_loss = 0.006 acc = 0.999\n",
      "\n",
      "Step 11399; Epoch 30/30:  60/391 Time: 0.02 cls_loss = 0.004 acc = 0.999\n",
      "\n",
      "Step 11499; Epoch 30/30:  160/391 Time: 0.02 cls_loss = 0.005 acc = 0.998\n",
      "\n",
      "Step 11599; Epoch 30/30:  260/391 Time: 0.02 cls_loss = 0.004 acc = 0.999\n",
      "\n",
      "Step 11699; Epoch 30/30:  360/391 Time: 0.02 cls_loss = 0.005 acc = 0.999\n",
      "\n",
      "Save Model at ./models/Cifar10/Pretrain/baseline/vgg16/beta_0.0.pth\n",
      "Epoch[30]-Validation-[10/79] Batch OA: 92.97 %\n",
      "Epoch[30]-Validation-[20/79] Batch OA: 92.19 %\n",
      "Epoch[30]-Validation-[30/79] Batch OA: 93.75 %\n",
      "Epoch[30]-Validation-[40/79] Batch OA: 90.62 %\n",
      "Epoch[30]-Validation-[50/79] Batch OA: 96.09 %\n",
      "Epoch[30]-Validation-[60/79] Batch OA: 89.06 %\n",
      "Epoch[30]-Validation-[70/79] Batch OA: 93.75 %\n",
      "---------------Accuracy of     airplane : 93.30 %---------------\n",
      "---------------Accuracy of   automobile : 96.30 %---------------\n",
      "---------------Accuracy of         bird : 87.90 %---------------\n",
      "---------------Accuracy of          cat : 81.70 %---------------\n",
      "---------------Accuracy of         deer : 90.50 %---------------\n",
      "---------------Accuracy of          dog : 83.40 %---------------\n",
      "---------------Accuracy of         frog : 94.50 %---------------\n",
      "---------------Accuracy of        horse : 92.90 %---------------\n",
      "---------------Accuracy of         ship : 96.00 %---------------\n",
      "---------------Accuracy of        truck : 94.40 %---------------\n",
      "---------------Epoch[30]Validation-OA: 91.09 %---------------\n",
      "---------------Epoch[30]Validation-AA: 91.09 %---------------\n",
      "Delete all attacked images for a fresh start\n",
      "Deleting ./dataset/data_adversarial_rs/Cifar10_adv/condlr_fgsm/baseline/vgg16/*.png\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "rm: cannot remove './dataset/data_adversarial_rs/Cifar10_adv/condlr_fgsm/baseline/vgg16/*.png': No such file or directory\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Network: vgg16\n",
      "load model from\n",
      "./models/Cifar10/Pretrain/baseline/vgg16/beta_0.0.pth\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Batch: 10000/10000: 100%|██████████| 10000/10000 [01:03<00:00, 158.61it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Loading data...\n",
      "./Cifar10_adv/condlr_fgsm/baseline/vgg16/\n",
      "./dataset/Cifar10_test.txt\n",
      "----\n",
      "Epoch[1]-Validation-[10/79] Batch OA: 92.97 %\n",
      "Epoch[1]-Validation-[20/79] Batch OA: 92.19 %\n",
      "Epoch[1]-Validation-[30/79] Batch OA: 93.75 %\n",
      "Epoch[1]-Validation-[40/79] Batch OA: 90.62 %\n",
      "Epoch[1]-Validation-[50/79] Batch OA: 96.09 %\n",
      "Epoch[1]-Validation-[60/79] Batch OA: 89.06 %\n",
      "Epoch[1]-Validation-[70/79] Batch OA: 93.75 %\n",
      "---------------Accuracy of     airplane : 93.30 %---------------\n",
      "---------------Accuracy of   automobile : 96.30 %---------------\n",
      "---------------Accuracy of         bird : 87.90 %---------------\n",
      "---------------Accuracy of          cat : 81.70 %---------------\n",
      "---------------Accuracy of         deer : 90.50 %---------------\n",
      "---------------Accuracy of          dog : 83.40 %---------------\n",
      "---------------Accuracy of         frog : 94.50 %---------------\n",
      "---------------Accuracy of        horse : 92.90 %---------------\n",
      "---------------Accuracy of         ship : 96.00 %---------------\n",
      "---------------Accuracy of        truck : 94.40 %---------------\n",
      "---------------Epoch[1]Validation-OA: 91.09 %---------------\n",
      "---------------Epoch[1]Validation-AA: 91.09 %---------------\n",
      "Epoch[1]-Validation-[10/79] Batch OA: 72.66 %\n",
      "Epoch[1]-Validation-[20/79] Batch OA: 70.31 %\n",
      "Epoch[1]-Validation-[30/79] Batch OA: 78.91 %\n",
      "Epoch[1]-Validation-[40/79] Batch OA: 79.69 %\n",
      "Epoch[1]-Validation-[50/79] Batch OA: 85.94 %\n",
      "Epoch[1]-Validation-[60/79] Batch OA: 78.91 %\n",
      "Epoch[1]-Validation-[70/79] Batch OA: 83.59 %\n",
      "---------------Accuracy of     airplane : 78.80 %---------------\n",
      "---------------Accuracy of   automobile : 92.70 %---------------\n",
      "---------------Accuracy of         bird : 71.10 %---------------\n",
      "---------------Accuracy of          cat : 60.70 %---------------\n",
      "---------------Accuracy of         deer : 68.00 %---------------\n",
      "---------------Accuracy of          dog : 68.40 %---------------\n",
      "---------------Accuracy of         frog : 85.70 %---------------\n",
      "---------------Accuracy of        horse : 82.00 %---------------\n",
      "---------------Accuracy of         ship : 87.00 %---------------\n",
      "---------------Accuracy of        truck : 87.90 %---------------\n",
      "---------------Epoch[1]Validation-OA: 78.23 %---------------\n",
      "---------------Epoch[1]Validation-AA: 78.23 %---------------\n",
      "------------\n",
      "Clean Test Set OA: 91.09\n",
      "condlr_fgsm Test Set OA: 78.23\n",
      "Delete all attacked images for a fresh start\n",
      "Deleting ./dataset/data_adversarial_rs/Cifar10_adv/condlr_fgsm/baseline/vgg16/*.png\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "rm: cannot remove './dataset/data_adversarial_rs/Cifar10_adv/condlr_fgsm/baseline/vgg16/*.png': No such file or directory\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Network: vgg16\n",
      "load model from\n",
      "./models/Cifar10/Pretrain/baseline/vgg16/beta_0.0.pth\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Batch: 10000/10000: 100%|██████████| 10000/10000 [00:59<00:00, 169.11it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Loading data...\n",
      "./Cifar10_adv/condlr_fgsm/baseline/vgg16/\n",
      "./dataset/Cifar10_test.txt\n",
      "----\n",
      "Epoch[1]-Validation-[10/79] Batch OA: 92.97 %\n",
      "Epoch[1]-Validation-[20/79] Batch OA: 92.19 %\n",
      "Epoch[1]-Validation-[30/79] Batch OA: 93.75 %\n",
      "Epoch[1]-Validation-[40/79] Batch OA: 90.62 %\n",
      "Epoch[1]-Validation-[50/79] Batch OA: 96.09 %\n",
      "Epoch[1]-Validation-[60/79] Batch OA: 89.06 %\n",
      "Epoch[1]-Validation-[70/79] Batch OA: 93.75 %\n",
      "---------------Accuracy of     airplane : 93.30 %---------------\n",
      "---------------Accuracy of   automobile : 96.30 %---------------\n",
      "---------------Accuracy of         bird : 87.90 %---------------\n",
      "---------------Accuracy of          cat : 81.70 %---------------\n",
      "---------------Accuracy of         deer : 90.50 %---------------\n",
      "---------------Accuracy of          dog : 83.40 %---------------\n",
      "---------------Accuracy of         frog : 94.50 %---------------\n",
      "---------------Accuracy of        horse : 92.90 %---------------\n",
      "---------------Accuracy of         ship : 96.00 %---------------\n",
      "---------------Accuracy of        truck : 94.40 %---------------\n",
      "---------------Epoch[1]Validation-OA: 91.09 %---------------\n",
      "---------------Epoch[1]Validation-AA: 91.09 %---------------\n",
      "Epoch[1]-Validation-[10/79] Batch OA: 62.50 %\n",
      "Epoch[1]-Validation-[20/79] Batch OA: 60.94 %\n",
      "Epoch[1]-Validation-[30/79] Batch OA: 67.97 %\n",
      "Epoch[1]-Validation-[40/79] Batch OA: 67.19 %\n",
      "Epoch[1]-Validation-[50/79] Batch OA: 71.88 %\n",
      "Epoch[1]-Validation-[60/79] Batch OA: 60.94 %\n",
      "Epoch[1]-Validation-[70/79] Batch OA: 71.09 %\n",
      "---------------Accuracy of     airplane : 66.50 %---------------\n",
      "---------------Accuracy of   automobile : 84.50 %---------------\n",
      "---------------Accuracy of         bird : 56.00 %---------------\n",
      "---------------Accuracy of          cat : 40.20 %---------------\n",
      "---------------Accuracy of         deer : 50.00 %---------------\n",
      "---------------Accuracy of          dog : 54.70 %---------------\n",
      "---------------Accuracy of         frog : 72.20 %---------------\n",
      "---------------Accuracy of        horse : 67.80 %---------------\n",
      "---------------Accuracy of         ship : 76.50 %---------------\n",
      "---------------Accuracy of        truck : 77.40 %---------------\n",
      "---------------Epoch[1]Validation-OA: 64.58 %---------------\n",
      "---------------Epoch[1]Validation-AA: 64.58 %---------------\n",
      "------------\n",
      "Clean Test Set OA: 91.09\n",
      "condlr_fgsm Test Set OA: 64.58\n",
      "Delete all attacked images for a fresh start\n",
      "Deleting ./dataset/data_adversarial_rs/Cifar10_adv/condlr_fgsm/baseline/vgg16/*.png\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "rm: cannot remove './dataset/data_adversarial_rs/Cifar10_adv/condlr_fgsm/baseline/vgg16/*.png': No such file or directory\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Network: vgg16\n",
      "load model from\n",
      "./models/Cifar10/Pretrain/baseline/vgg16/beta_0.0.pth\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Batch: 10000/10000: 100%|██████████| 10000/10000 [00:59<00:00, 166.98it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Loading data...\n",
      "./Cifar10_adv/condlr_fgsm/baseline/vgg16/\n",
      "./dataset/Cifar10_test.txt\n",
      "----\n",
      "Epoch[1]-Validation-[10/79] Batch OA: 92.97 %\n",
      "Epoch[1]-Validation-[20/79] Batch OA: 92.19 %\n",
      "Epoch[1]-Validation-[30/79] Batch OA: 93.75 %\n",
      "Epoch[1]-Validation-[40/79] Batch OA: 90.62 %\n",
      "Epoch[1]-Validation-[50/79] Batch OA: 96.09 %\n",
      "Epoch[1]-Validation-[60/79] Batch OA: 89.06 %\n",
      "Epoch[1]-Validation-[70/79] Batch OA: 93.75 %\n",
      "---------------Accuracy of     airplane : 93.30 %---------------\n",
      "---------------Accuracy of   automobile : 96.30 %---------------\n",
      "---------------Accuracy of         bird : 87.90 %---------------\n",
      "---------------Accuracy of          cat : 81.70 %---------------\n",
      "---------------Accuracy of         deer : 90.50 %---------------\n",
      "---------------Accuracy of          dog : 83.40 %---------------\n",
      "---------------Accuracy of         frog : 94.50 %---------------\n",
      "---------------Accuracy of        horse : 92.90 %---------------\n",
      "---------------Accuracy of         ship : 96.00 %---------------\n",
      "---------------Accuracy of        truck : 94.40 %---------------\n",
      "---------------Epoch[1]Validation-OA: 91.09 %---------------\n",
      "---------------Epoch[1]Validation-AA: 91.09 %---------------\n",
      "Epoch[1]-Validation-[10/79] Batch OA: 49.22 %\n",
      "Epoch[1]-Validation-[20/79] Batch OA: 53.12 %\n",
      "Epoch[1]-Validation-[30/79] Batch OA: 58.59 %\n",
      "Epoch[1]-Validation-[40/79] Batch OA: 57.81 %\n",
      "Epoch[1]-Validation-[50/79] Batch OA: 59.38 %\n",
      "Epoch[1]-Validation-[60/79] Batch OA: 52.34 %\n",
      "Epoch[1]-Validation-[70/79] Batch OA: 57.03 %\n",
      "---------------Accuracy of     airplane : 56.50 %---------------\n",
      "---------------Accuracy of   automobile : 77.90 %---------------\n",
      "---------------Accuracy of         bird : 43.00 %---------------\n",
      "---------------Accuracy of          cat : 29.00 %---------------\n",
      "---------------Accuracy of         deer : 37.90 %---------------\n",
      "---------------Accuracy of          dog : 41.80 %---------------\n",
      "---------------Accuracy of         frog : 59.20 %---------------\n",
      "---------------Accuracy of        horse : 56.90 %---------------\n",
      "---------------Accuracy of         ship : 66.30 %---------------\n",
      "---------------Accuracy of        truck : 67.30 %---------------\n",
      "---------------Epoch[1]Validation-OA: 53.58 %---------------\n",
      "---------------Epoch[1]Validation-AA: 53.58 %---------------\n",
      "------------\n",
      "Clean Test Set OA: 91.09\n",
      "condlr_fgsm Test Set OA: 53.580000000000005\n"
     ]
    }
   ],
   "source": [
    "%%bash\n",
    "run_train=true\n",
    "run_attack=true\n",
    "\n",
    "iterations=\"1\" # if you want to compute the mean over multiple runs, increase iteration count\n",
    "\n",
    "\n",
    "regularizer_betas=\"0.0\"\n",
    "\n",
    "attack_f_1=\"condlr_fgsm\" # l1 FGSM attack\n",
    "epsilons_1=\"0.002 0.004 0.006\"\n",
    "\n",
    "\n",
    "\n",
    "models=\"vgg16\" #\"alexnet vgg11 vgg16 vgg19 inception resnet18 resnet50 resnet101 resnext50_32x4d resnext101_32x8d densenet121 densenet169 densenet201 regnet_x_400mf regnet_x_8gf regnet_x_16gf\"\n",
    "target_model=\"vgg16\"\n",
    "\n",
    "dataset=3\n",
    "crop_size=32\n",
    "\n",
    "#training parameters\n",
    "train_batch_size=128\n",
    "val_batch_size=128\n",
    "num_epochs=30\n",
    "weight_decay=0.0 \n",
    "\n",
    "# Enable or disable WandB logging\n",
    "wandb=0 # 1 for enabled, 0 for disabled\n",
    "wandb_tag=\"google_colab_example\"\n",
    "\n",
    "# Loop over all combinations of models and pretrained weights\n",
    "for i in $(seq 1 $iterations); do\n",
    "echo \"iteration $i\"\n",
    "for model in $models; do\n",
    "for beta in $regularizer_betas; do\n",
    "    if [ \"$run_train\" = true ]; then\n",
    "        python pretrain_cls_robustness.py \\\n",
    "            --dataID $dataset \\\n",
    "            --lr 1e-4 \\\n",
    "            --num_epochs $num_epochs \\\n",
    "            --network $model \\\n",
    "            --save_name \"$beta\" \\\n",
    "            --wandb $wandb \\\n",
    "            --wandb_tag \"$wandb_tag\" \\\n",
    "            --robusteness_regularization_beta \"$beta\" \\\n",
    "            --train_batch_size $train_batch_size \\\n",
    "            --val_batch_size $val_batch_size \\\n",
    "            --crop_size $crop_size \\\n",
    "            --weight_decay $weight_decay \\\n",
    "            --print_per_batches 100 \\\n",
    "            --root_dir ./\n",
    "    fi\n",
    "    if [ \"$run_attack\" = true ]; then\n",
    "        # first attack\n",
    "        for eps_attack in $epsilons_1; do\n",
    "            echo \"Delete all attacked images for a fresh start\"\n",
    "            if [ \"$dataset\" = \"3\" ]; then\n",
    "                echo \"Deleting ./dataset/data_adversarial_rs/Cifar10_adv/$attack_f_1/baseline/$model/*.png\"\n",
    "                rm ./dataset/data_adversarial_rs/Cifar10_adv/$attack_f_1/baseline/$model/*.png\n",
    "            fi\n",
    "            if [ \"$dataset\" = \"2\" ]; then\n",
    "                echo \"Deleting ./dataset/data_adversarial_rs/AID_adv/$attack_f_1/baseline/$model/*.png\"\n",
    "                rm ./dataset/data_adversarial_rs/AID_adv/$attack_f_1/baseline/$model/*.png\n",
    "            fi\n",
    "            if [ \"$dataset\" = \"1\" ]; then\n",
    "                echo \"Deleting ./dataset/data_adversarial_rs/UCM_adv/$attack_f_1/baseline/$model/*.png\"\n",
    "                rm ./dataset/data_adversarial_rs/UCM_adv/$attack_f_1/baseline/$model/*.png\n",
    "            fi\n",
    "          \n",
    "            python attack_cls.py \\\n",
    "                --robusteness_regularization_beta \"$beta\" \\\n",
    "                --dataID $dataset \\\n",
    "                --attack_func $attack_f_1 \\\n",
    "                --network $model \\\n",
    "                --epsilon $eps_attack \\\n",
    "                --crop_size $crop_size \\\n",
    "                --save_path_prefix ./\n",
    "            \n",
    "            python test_cls.py \\\n",
    "                --robusteness_regularization_beta \"$beta\" \\\n",
    "                --dataID $dataset \\\n",
    "                --target_network $target_model \\\n",
    "                --surrogate_network $model \\\n",
    "                --attack_func $attack_f_1 \\\n",
    "                --wandb $wandb \\\n",
    "                --wandb_tag \"$wandb_tag\" \\\n",
    "                --attack_epsilon $eps_attack \\\n",
    "                --crop_size $crop_size \\\n",
    "                --val_batch_size $val_batch_size \\\n",
    "                --root_dir ./\n",
    "        done\n",
    "    fi\n",
    "done\n",
    "done\n",
    "done "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "406d81ae-8933-4210-a150-4492bc8a5efb",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "216583e5-2392-4bae-9b4c-59cd94f5c711",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
