{
  "cells": [
    {
      "cell_type": "markdown",
      "id": "f1f6ef379d502b67",
      "metadata": {},
      "source": [
        "### Begin"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "initial_id",
      "metadata": {
        "ExecuteTime": {
          "end_time": "2025-05-07T22:33:57.489043Z",
          "start_time": "2025-05-07T22:33:32.734990Z"
        }
      },
      "outputs": [],
      "source": [
        "import torch\n",
        "import torch.nn as nn\n",
        "import torch.nn.functional as F\n",
        "import torchvision.transforms as transforms\n",
        "from externel.resnet_models import *\n",
        "import os\n",
        "import torchvision\n",
        "import blackbox_model\n",
        "import configs\n",
        "import matplotlib.pyplot as plt\n",
        "import PGFM\n",
        "import numpy as np\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "e0d75ef477db0642",
      "metadata": {
        "ExecuteTime": {
          "end_time": "2025-05-07T22:33:57.524479Z",
          "start_time": "2025-05-07T22:33:57.519638Z"
        }
      },
      "outputs": [],
      "source": [
        "2"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "b894220cde800a99",
      "metadata": {
        "ExecuteTime": {
          "end_time": "2025-05-07T22:34:04.198001Z",
          "start_time": "2025-05-07T22:33:58.065571Z"
        }
      },
      "outputs": [],
      "source": [
        "transform = transforms.Compose([\n",
        "    transforms.ToTensor(),\n",
        "    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),\n",
        "])\n",
        "\n",
        "# Load CIFAR-10 training dataset\n",
        "trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)\n",
        "trainloader = torch.utils.data.DataLoader(trainset, batch_size=64, shuffle=True, num_workers=2)\n",
        "\n",
        "# Load CIFAR-10 test dataset\n",
        "testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)\n",
        "testloader = torch.utils.data.DataLoader(testset, batch_size=64, shuffle=False, num_workers=2)\n",
        "\n",
        "\n",
        "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
        "\n",
        "bb_model = blackbox_model.black_box_model_class()\n",
        "\n",
        "correct = 0\n",
        "total = 0\n",
        "\n",
        "data_all = None\n",
        "label_all = None\n",
        "with torch.no_grad():\n",
        "    for inputs, labels in testloader:\n",
        "        inputs, labels = inputs.to(device), labels.to(device)\n",
        "        # predicted = bb_model.predict(inputs)\n",
        "        predicted_prob = bb_model.predict_proba(inputs)\n",
        "        predicted = torch.argmax(predicted_prob, dim=1)\n",
        "        total += labels.size(0)\n",
        "        correct += (predicted == labels).sum().item()\n",
        "        if data_all is None:\n",
        "            data_all = inputs\n",
        "            label_all = labels\n",
        "        else:\n",
        "            data_all = torch.cat((data_all, inputs), dim=0)\n",
        "            label_all = torch.cat((label_all, labels), dim=0)\n",
        "\n",
        "test_accuracy = 100 * correct / total\n",
        "\n",
        "print(f\"Test Accuracy: {test_accuracy:.2f}%\")\n",
        "data_all = data_all.to(configs.device)\n",
        "label_all = label_all.to(configs.device)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "2081cef8e59af559",
      "metadata": {
        "ExecuteTime": {
          "end_time": "2025-05-07T22:34:04.232415Z",
          "start_time": "2025-05-07T22:34:04.225939Z"
        }
      },
      "outputs": [],
      "source": [
        "train_rate = 0.8\n",
        "\n",
        "adv_data_train = data_all[:int(train_rate * len(data_all))]\n",
        "adv_data_test = data_all[int(train_rate * len(data_all)):]\n",
        "adv_label_train = label_all[:int(train_rate * len(label_all))]\n",
        "adv_label_test = label_all[int(train_rate * len(label_all)):]\n",
        "\n",
        "cifar_classes = ['airplanes', 'cars', 'birds', 'cats', 'deer', 'dogs', 'frogs', 'horses', 'ships', 'trucks']\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "de8ef476a3fa279f",
      "metadata": {
        "ExecuteTime": {
          "end_time": "2025-05-07T22:34:04.311212Z",
          "start_time": "2025-05-07T22:34:04.284013Z"
        }
      },
      "outputs": [],
      "source": [
        "\n",
        "PGFM_class = PGFM.PGFM(bb_model)\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "3c8df5546f57a3f6",
      "metadata": {
        "ExecuteTime": {
          "end_time": "2025-05-07T22:36:50.176469Z",
          "start_time": "2025-05-07T22:36:49.293387Z"
        }
      },
      "outputs": [],
      "source": [
        "\n",
        "\n",
        "def imshow(img, mean = torch.tensor([0.4914, 0.4822, 0.4465]), std = torch.tensor([0.2023, 0.1994, 0.2010])):\n",
        "    img = img.squeeze(0)  # Remove batch dimension\n",
        "    img = img.numpy().transpose((1, 2, 0))  # Convert to HWC format\n",
        "    img = img * std.numpy() + mean.numpy()  # Denormalize\n",
        "    img = np.clip(img, 0, 1)  # Clip values to be in range [0,1] for imshow\n",
        "    plt.imshow(img)\n",
        "    plt.axis('off')  # Hide axis\n",
        "    # plt.show()\n",
        "\n",
        "stage1_t = 1\n",
        "ref = adv_data_test\n",
        "x_prev = torch.randn(ref.shape[0], 1, 32, 32, dtype=torch.float32, device=configs.device)\n",
        "\n",
        "\n",
        "t_tensor_N = stage1_t * torch.ones(x_prev.shape[0], device=configs.device, dtype=torch.float32)\n",
        "res = PGFM_class.sample_xt_given_x1_x0(x_prev, ref, t_tensor_N)\n",
        "# res = ref+torch.randn_like(ref)*0.19\n",
        "\n",
        "l2norm = torch.norm(res - adv_data_test, p=2, dim = (1,2,3))\n",
        "testceloss = torch.mean(bb_model.CEloss(res.to(device), adv_label_test.to(device)))\n",
        "# testfid = compute_fid(ref.to(device), res.to(device))\n",
        "\n",
        "with torch.no_grad():\n",
        "    # for inputs, labels in data_loader:\n",
        "    inputs, labels = res.to(device), adv_label_test.to(device)\n",
        "    predicted = bb_model.predict(inputs)\n",
        "    total = labels.size(0)\n",
        "    correct = (predicted == labels).sum().item()\n",
        "\n",
        "test_accuracy = 100 * correct / total\n",
        "\n",
        "print(f\"Test Accuracy: {test_accuracy:.2f}%\")\n",
        "print('Mean l2norm:', l2norm.mean().item())\n",
        "print('CE loss:', testceloss.item())\n",
        "\n",
        "i=0\n",
        "plt.figure(figsize=(6,6))\n",
        "for j in range(6,12):\n",
        "    fig1 = adv_data_test[j].cpu()\n",
        "    fig2 = res[j].cpu()\n",
        "    i+=1\n",
        "    plt.subplot(3,4,i)\n",
        "    imshow(fig1)\n",
        "    plt.title(cifar_classes[labels[j].item()] + str(labels[j].item()))\n",
        "\n",
        "    i+=1\n",
        "    plt.subplot(3,4,i)\n",
        "    imshow(fig2)\n",
        "    plt.title(cifar_classes[predicted[j].item()] + str(predicted[j].item()))\n",
        "    plt.axis('off')\n",
        "plt.show()\n",
        "\n",
        "\n",
        "\n",
        "# print('FID:', testfid)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "bec884e1299fd33",
      "metadata": {
        "ExecuteTime": {
          "end_time": "2025-04-18T19:36:23.460870Z",
          "start_time": "2025-04-18T19:36:11.448759Z"
        }
      },
      "outputs": [],
      "source": [
        "# source_data, source_label, res = (\n",
        "# PGFM_class.train2_2stage(adv_data_train, adv_label_train, './saved_model/FMworef_10000.pth')\n",
        "PGFM_class.train2_2stage(adv_data_train, adv_label_train, './saved_model/Apr19PGFM_0p815s_iter_train_280000.pth')#FMworef_150000\n",
        "# './saved_model/FMworef_10000.pth'\n",
        "# PGFM_0p730s_iter_train_30000"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "a36ce332abb23f0b",
      "metadata": {
        "ExecuteTime": {
          "start_time": "2025-04-08T21:39:08.878569Z"
        }
      },
      "outputs": [],
      "source": [
        "PGFM_class.train_2stage(adv_data_train, adv_label_train, init_ckpt_path=None)#'./saved_model/FMworef_150000.pth'"
      ]
    }
  ],
  "metadata": {
    "kernelspec": {
      "display_name": "Python 3 (ipykernel)",
      "language": "python",
      "name": "python3"
    },
    "language_info": {
      "codemirror_mode": {
        "name": "ipython",
        "version": 3
      },
      "file_extension": ".py",
      "mimetype": "text/x-python",
      "name": "python",
      "nbconvert_exporter": "python",
      "pygments_lexer": "ipython3",
      "version": "3.12.2"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 5
}