{"cells":[{"cell_type":"code","execution_count":1,"metadata":{"colab":{"base_uri":"https://localhost:8080/"},"executionInfo":{"elapsed":26451,"status":"ok","timestamp":1727821832171,"user":{"displayName":"Daniel Kuelbs","userId":"00355856873650556171"},"user_tz":420},"id":"nG9DoQKc5HXI","outputId":"fa394c51-7d64-4f34-f3a2-b37eb7c68c1a"},"outputs":[{"name":"stdout","output_type":"stream","text":["Mounted at /content/drive\n"]}],"source":["from google.colab import drive\n","drive.mount('/content/drive')\n","\n","# put folder name here\n","FOLDERNAME = 'multiclass_polyact/'\n","\n","import sys\n","sys.path.append('/content/drive/My Drive/{}'.format(FOLDERNAME))"]},{"cell_type":"code","execution_count":2,"metadata":{"executionInfo":{"elapsed":21215,"status":"ok","timestamp":1727821853384,"user":{"displayName":"Daniel Kuelbs","userId":"00355856873650556171"},"user_tz":420},"id":"CWYRu4JX4z0w"},"outputs":[],"source":["import torch\n","import numpy as np\n","import torchvision.datasets as datasets\n","import torchvision.transforms as transforms\n","from torch.utils.data import Dataset, DataLoader\n","from copy import deepcopy\n","import matplotlib.pyplot as plt\n","\n","from models.one_vs_all import MultiClassPolyAct\n","\n","from torch_solvers.robust_polyact_solver import adversarial_train_poly_act\n","from torch_solvers.robust_polyact_solver_batch import batch_adversarial_polyact_train, Spliced_PolyAct\n","from torch_solvers.alternate_solver import alt_solver\n","\n","from models.praresnet import PreActResNet18\n","from models.one_vs_all import MultiClassPolyAct\n","from models.spliced import Spliced, Spliced_PolyAct\n","\n","from utils import FeatureDataset\n","\n","from cvx_scripts.losses import *\n","from cvx_scripts.cvx_nn import *\n","from cvx_scripts.cvx_training import *\n","\n","from attacks.fgsm import eval_fgsm\n","\n","%load_ext autoreload\n","%autoreload 2"]},{"cell_type":"code","execution_count":3,"metadata":{"executionInfo":{"elapsed":17,"status":"ok","timestamp":1727821853384,"user":{"displayName":"Daniel Kuelbs","userId":"00355856873650556171"},"user_tz":420},"id":"qcep38RkOpC9"},"outputs":[],"source":["embedding_size = 512"]},{"cell_type":"code","execution_count":4,"metadata":{"executionInfo":{"elapsed":16,"status":"ok","timestamp":1727821853384,"user":{"displayName":"Daniel Kuelbs","userId":"00355856873650556171"},"user_tz":420},"id":"oJ1Z1vkyLx_r"},"outputs":[],"source":["def extract_features(dummy_loader, model, shuffle=True):\n","  X = torch.zeros(0, embedding_size)\n","  y = torch.zeros(0)\n","  for img, label in dummy_loader:\n","      img = img.to(device)\n","      out = model.truncated_forward(img).detach().cpu()\n","      X = torch.vstack((X, out))\n","      y = torch.cat((y, label))\n","      del img\n","\n","  X = X.view(X.shape[0], -1).detach()\n","  n = X.shape[0]\n","  if shuffle:\n","    scrambled_idxs = np.random.choice(n, n, replace=False)\n","    X = X[scrambled_idxs]\n","    y = y[scrambled_idxs]\n","  torch.cuda.empty_cache()\n","  return X, y"]},{"cell_type":"code","execution_count":5,"metadata":{"colab":{"base_uri":"https://localhost:8080/"},"executionInfo":{"elapsed":16,"status":"ok","timestamp":1727821853384,"user":{"displayName":"Daniel Kuelbs","userId":"00355856873650556171"},"user_tz":420},"id":"lCITybAZ5ivs","outputId":"9515d0bb-e507-4fd5-a423-873523e747ab"},"outputs":[{"name":"stdout","output_type":"stream","text":["cuda:0\n"]}],"source":["device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")\n","print(device)"]},{"cell_type":"code","execution_count":6,"metadata":{"colab":{"base_uri":"https://localhost:8080/"},"executionInfo":{"elapsed":53615,"status":"ok","timestamp":1727821906984,"user":{"displayName":"Daniel Kuelbs","userId":"00355856873650556171"},"user_tz":420},"id":"EsI-twFt5mMs","outputId":"4d95e935-6322-4a05-ee86-e5003cdfc049"},"outputs":[{"name":"stdout","output_type":"stream","text":["==> Preparing data..\n","Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to ./data/cifar-10-python.tar.gz\n"]},{"name":"stderr","output_type":"stream","text":["100%|██████████| 170498071/170498071 [00:18<00:00, 9189551.36it/s] \n"]},{"name":"stdout","output_type":"stream","text":["Extracting ./data/cifar-10-python.tar.gz to ./data\n","Files already downloaded and verified\n"]},{"name":"stderr","output_type":"stream","text":["<ipython-input-6-7409c4dafe1c>:37: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.\n","  pr18_sam.load_state_dict(torch.load(sys.path[-1] + 'pretrained_models/praresnet.pth'))\n","<ipython-input-6-7409c4dafe1c>:41: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.\n","  pr18.load_state_dict(torch.load(sys.path[-1] + 'pretrained_models/praresnet_nonsam.pth'))\n"]}],"source":["# Load CIFAR-10 data\n","print('==> Preparing data..')\n","transform_train = transforms.Compose([\n","    transforms.RandomCrop(32, padding=4),\n","    transforms.RandomHorizontalFlip(),\n","    transforms.ToTensor(),\n","    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),\n","])\n","\n","transform_test = transforms.Compose([\n","    transforms.ToTensor(),\n","    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),\n","])\n","\n","trainset = datasets.CIFAR10(\n","    root='./data', train=True, download=True, transform=transform_train)\n","trainloader = torch.utils.data.DataLoader(\n","    trainset, batch_size=128, shuffle=True)\n","\n","testset = datasets.CIFAR10(\n","    root='./data', train=False, download=True, transform=transform_test)\n","testloader = torch.utils.data.DataLoader(\n","    testset, batch_size=100, shuffle=False)\n","\n","classes = ('plane', 'car', 'bird', 'cat', 'deer',\n","           'dog', 'frog', 'horse', 'ship', 'truck')\n","\n","mean = torch.tensor([0.4914, 0.4822, 0.4465]).to(device)\n","std = torch.tensor([0.2023, 0.1994, 0.2010]).to(device)\n","\n","# test loader with batch size of one for fast gradient sign method\n","testloader_fgsm = torch.utils.data.DataLoader(\n","    testset, batch_size=1000, shuffle=False)\n","\n","# load in pre-trained Pre-Activaiton ResNet-18 models trained via sharpness-aware minimization and standard training.\n","pr18_sam = PreActResNet18(10)\n","pr18_sam.load_state_dict(torch.load(sys.path[-1] + 'pretrained_models/praresnet.pth'))\n","pr18_sam = pr18_sam.to(device)\n","\n","pr18 = PreActResNet18(10)\n","pr18.load_state_dict(torch.load(sys.path[-1] + 'pretrained_models/praresnet_nonsam.pth'))\n","pr18 = pr18.to(device)\n","\n","dummy_train_loader= torch.utils.data.DataLoader(\n","    trainset, batch_size=1000, shuffle=False,\n","    pin_memory=True, sampler=None)\n","dummy_test_loader= torch.utils.data.DataLoader(\n","    testset, batch_size=1000, shuffle=False,\n","    pin_memory=True, sampler=None)\n","\n","X_train, y_train = extract_features(dummy_train_loader, pr18)\n","X_test, y_test = extract_features(dummy_test_loader, pr18)\n","_, trunc_d = X_train.shape"]},{"cell_type":"markdown","metadata":{},"source":["Specify the dataloader. Warning: batch sizes of more than ~200 are EXTREMELY memory intensive. "]},{"cell_type":"code","execution_count":7,"metadata":{"executionInfo":{"elapsed":5,"status":"ok","timestamp":1727821909296,"user":{"displayName":"Daniel Kuelbs","userId":"00355856873650556171"},"user_tz":420},"id":"juylqzObFBzZ"},"outputs":[],"source":["batch_size = 100\n","train_dataset = FeatureDataset(X_train.cpu()[:100], y_train.cpu()[:100].type(torch.LongTensor), 10)\n","test_dataset = FeatureDataset(X_test[:500].cpu(), y_test[:500].cpu().type(torch.LongTensor), 10)\n","\n","train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True)\n","test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=batch_size, shuffle=True)\n"]},{"cell_type":"markdown","metadata":{},"source":["Train a polynomial activation network."]},{"cell_type":"code","execution_count":8,"metadata":{"executionInfo":{"elapsed":4,"status":"ok","timestamp":1727821909296,"user":{"displayName":"Daniel Kuelbs","userId":"00355856873650556171"},"user_tz":420},"id":"bipt-Lfjzs3C"},"outputs":[],"source":["beta = 0.01\n","r = 1.5\n","lr = 0.01\n","rho = 2\n","epochs = 30\n","model = MultiClassPolyAct(10, 512, device=device, init='zero')\n","best_model, losses, train_accs, val_accs, best_robust = batch_adversarial_polyact_train(model,\n","                                                                                        train_loader,\n","                                                                                        test_loader,\n","                                                                                        r,\n","                                                                                        beta,\n","                                                                                        device,\n","                                                                                        lr=lr,\n","                                                                                        epochs=epochs,\n","                                                                                        rho=rho,\n","                                                                                        batch_size=batch_size,\n","                                                                                        verbose=True,\n","                                                                                        base_model = pr18,#pr18,\n","                                                                                        robust_eval_loader = testloader_fgsm\n","                                                                                        )"]},{"cell_type":"markdown","metadata":{},"source":["Splice a robust polynomial activation network and base image classification model."]},{"cell_type":"code","execution_count":null,"metadata":{"executionInfo":{"elapsed":1,"status":"aborted","timestamp":1727810638401,"user":{"displayName":"Daniel Kuelbs","userId":"00355856873650556171"},"user_tz":420},"id":"Sh9FZaLl-vc5"},"outputs":[],"source":["spliced = Spliced_PolyAct(pr18, best_model)\n","robusts = []\n","for eps in [0, 1, 2, 3, 4, 5, 6, 7, 8]:\n","  spliced.robust = True\n","  robust = eval_fgsm(spliced, device, testloader_fgsm, eps/255, mean, std, is_polyact=True)\n","  robusts.append(robust)\n","print(robusts)"]},{"cell_type":"markdown","metadata":{},"source":["Load in a polynomial activation network"]},{"cell_type":"code","execution_count":null,"metadata":{"id":"3qdMPZ83r4P4"},"outputs":[],"source":["# load in a pre-trained convex two-layer ReLU network\n","cvx = custom_cvx_layer(512, 500)\n","cvx.load_state_dict(torch.load(sys.path[-1] + 'praresnet_nonsam_500_inf_5.pth', map_location=torch.device('cpu')))\n","cvx.to(device)\n","uvec = torch.from_numpy(torch.load(sys.path[-1] + 'u_vec_praresnet_nonsam_500.pth')).to(device).float()"]},{"cell_type":"code","execution_count":null,"metadata":{"id":"m3g2Ag9C1u5I"},"outputs":[],"source":["from prepare_data import *"]},{"cell_type":"markdown","metadata":{},"source":["Evaluate robustness of the base, sam, and robustified models"]},{"cell_type":"code","execution_count":14,"metadata":{"executionInfo":{"elapsed":60216,"status":"ok","timestamp":1727825963226,"user":{"displayName":"Daniel Kuelbs","userId":"00355856873650556171"},"user_tz":420},"id":"yFkMvfgMydim"},"outputs":[],"source":["spliced = Spliced_PolyAct(pr18, best_model)\n","for p in spliced.parameters():\n","  p.requires_grad = False\n","\n","for p in pr18_sam.parameters():\n","  p.requires_grad = False\n","\n","robust_poly = []\n","standards = []\n","sams = []\n","for eps in [0, 1, 2, 3, 4, 5, 6, 7, 8]:\n","  spliced.robust = True\n","  robust = eval_fgsm(spliced, device, testloader_fgsm, eps/255, mean, std, is_polyact=True)\n","  robust_poly.append(robust)\n","  spliced.robust = False\n","  standard = eval_fgsm(spliced, device, testloader_fgsm, eps/255, mean, std)\n","  standards.append(standard)\n","  sam = eval_fgsm(pr18_sam, device, testloader_fgsm, eps/255, mean, std)\n","  sams.append(sam)"]},{"cell_type":"code","execution_count":null,"metadata":{},"outputs":[],"source":["plt.plot(sams, label = 'SAM')\n","plt.plot(standards, label = 'Standard')\n","plt.plot(robust_poly, label = 'Poly')"]}],"metadata":{"accelerator":"GPU","colab":{"gpuType":"A100","machine_shape":"hm","provenance":[]},"kernelspec":{"display_name":"Python 3","name":"python3"},"language_info":{"codemirror_mode":{"name":"ipython","version":3},"file_extension":".py","mimetype":"text/x-python","name":"python","nbconvert_exporter":"python","pygments_lexer":"ipython3","version":"3.12.6"}},"nbformat":4,"nbformat_minor":0}
