{
  "cells": [
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "initial_id",
      "metadata": {
        "ExecuteTime": {
          "end_time": "2025-04-29T14:51:02.001613Z",
          "start_time": "2025-04-29T14:51:01.234918Z"
        }
      },
      "outputs": [],
      "source": [
        "##%%\n",
        "import numpy as np\n",
        "import matplotlib.pyplot as plt\n",
        "import os\n",
        "\n",
        "from ot.backend import torch\n",
        "\n",
        "os.environ[\"KMP_DUPLICATE_LIB_OK\"]=\"TRUE\"\n",
        "\n",
        "record = np.load('./saved_model/Apr24PGFM_0p815s_iter_train_record.npz')\n",
        "# PGFM_0p730s_iter_train_record\n",
        "# PGFM_0p910s_iter_train_record\n",
        "# RLFM_defadvs2_0p730s_iter_train_record\n",
        "loss_record = record['constraint_loss_record']\n",
        "in_prob_record = record['adv_record']\n",
        "rwd_record = record['flow_loss_record']\n",
        "\n",
        "##%%\n",
        "plt.plot(loss_record)\n",
        "plt.grid()\n",
        "plt.xlabel('x10 iterations')\n",
        "plt.ylabel('loss')\n",
        "plt.show()\n",
        "\n",
        "plt.plot(in_prob_record)\n",
        "plt.grid()\n",
        "plt.xlabel('x10 iterations')\n",
        "plt.ylabel('adv num')\n",
        "plt.show()\n",
        "\n",
        "plt.plot(rwd_record)\n",
        "plt.grid()\n",
        "plt.xlabel('x10 iterations')\n",
        "plt.ylabel('rwd')\n",
        "plt.show()"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "cbc3cef51a536f29",
      "metadata": {
        "ExecuteTime": {
          "end_time": "2025-05-07T13:41:52.224607Z",
          "start_time": "2025-05-07T13:41:52.219367Z"
        }
      },
      "outputs": [],
      "source": [
        "1"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "cc36cbefdc0bd17d",
      "metadata": {
        "ExecuteTime": {
          "end_time": "2025-05-07T13:41:52.191870Z",
          "start_time": "2025-05-07T13:41:33.965451Z"
        }
      },
      "outputs": [],
      "source": [
        "import torch\n",
        "import torch.nn as nn\n",
        "import torch.nn.functional as F\n",
        "import torchvision.transforms as transforms\n",
        "from externel.resnet_models import *\n",
        "import os\n",
        "import torchvision\n",
        "import blackbox_model\n",
        "import configs\n",
        "import matplotlib.pyplot as plt\n",
        "import PGFM\n",
        "import numpy as np\n",
        "import PGFM_unet"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "8e94c29b9f8b8020",
      "metadata": {
        "ExecuteTime": {
          "end_time": "2025-05-07T13:42:01.827024Z",
          "start_time": "2025-05-07T13:41:55.992299Z"
        }
      },
      "outputs": [],
      "source": [
        "transform = transforms.Compose([\n",
        "    transforms.ToTensor(),\n",
        "    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),\n",
        "])\n",
        "\n",
        "# Load CIFAR-10 training dataset\n",
        "trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)\n",
        "trainloader = torch.utils.data.DataLoader(trainset, batch_size=64, shuffle=True, num_workers=2)\n",
        "\n",
        "# Load CIFAR-10 test dataset\n",
        "testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)\n",
        "testloader = torch.utils.data.DataLoader(testset, batch_size=64, shuffle=False, num_workers=2)\n",
        "\n",
        "\n",
        "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
        "\n",
        "bb_model = blackbox_model.black_box_model_class()\n",
        "\n",
        "correct = 0\n",
        "total = 0\n",
        "\n",
        "data_all = None\n",
        "label_all = None\n",
        "with torch.no_grad():\n",
        "    for inputs, labels in testloader:\n",
        "        inputs, labels = inputs.to(device), labels.to(device)\n",
        "        # predicted = bb_model.predict(inputs)\n",
        "        predicted_prob = bb_model.predict_proba(inputs)\n",
        "        predicted = torch.argmax(predicted_prob, dim=1)\n",
        "        total += labels.size(0)\n",
        "        correct += (predicted == labels).sum().item()\n",
        "        if data_all is None:\n",
        "            data_all = inputs\n",
        "            label_all = labels\n",
        "        else:\n",
        "            data_all = torch.cat((data_all, inputs), dim=0)\n",
        "            label_all = torch.cat((label_all, labels), dim=0)\n",
        "\n",
        "test_accuracy = 100 * correct / total\n",
        "\n",
        "print(f\"Test Accuracy: {test_accuracy:.2f}%\")\n",
        "data_all = data_all.to(configs.device)\n",
        "label_all = label_all.to(configs.device)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "85a371900a704d54",
      "metadata": {
        "ExecuteTime": {
          "end_time": "2025-05-07T13:42:01.859852Z",
          "start_time": "2025-05-07T13:42:01.853265Z"
        }
      },
      "outputs": [],
      "source": [
        "train_rate = 0.8\n",
        "\n",
        "adv_data_train = data_all[:int(train_rate * len(data_all))]\n",
        "adv_data_test = data_all[int(train_rate * len(data_all)):]\n",
        "adv_label_train = label_all[:int(train_rate * len(label_all))]\n",
        "adv_label_test = label_all[int(train_rate * len(label_all)):]\n",
        "\n",
        "cifar_classes = ['airplanes', 'cars', 'birds', 'cats', 'deer', 'dogs', 'frogs', 'horses', 'ships', 'trucks']\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "f23350ac8cedcdf9",
      "metadata": {
        "ExecuteTime": {
          "end_time": "2025-05-07T13:42:02.351208Z",
          "start_time": "2025-05-07T13:42:02.325110Z"
        }
      },
      "outputs": [],
      "source": [
        "PGFM_class = PGFM.PGFM(bb_model)\n",
        "# PGFM_class = PGFM_unet.PGFM(bb_model)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "624a8495df3e9f9e",
      "metadata": {
        "ExecuteTime": {
          "end_time": "2025-05-07T13:44:20.262140Z",
          "start_time": "2025-05-07T13:44:19.266688Z"
        }
      },
      "outputs": [],
      "source": [
        "\n",
        "\n",
        "def imshow(img, mean = torch.tensor([0.4914, 0.4822, 0.4465]), std = torch.tensor([0.2023, 0.1994, 0.2010])):\n",
        "    img = img.squeeze(0)  # Remove batch dimension\n",
        "    img = img.numpy().transpose((1, 2, 0))  # Convert to HWC format\n",
        "    img = img * std.numpy() + mean.numpy()  # Denormalize\n",
        "    img = np.clip(img, 0, 1)  # Clip values to be in range [0,1] for imshow\n",
        "    plt.imshow(img)\n",
        "    plt.axis('off')  # Hide axis\n",
        "    # plt.show()\n",
        "\n",
        "stage1_t = 1\n",
        "ref = adv_data_test\n",
        "x_prev = torch.randn(ref.shape[0], 1, 32, 32, dtype=torch.float32, device=configs.device)\n",
        "\n",
        "\n",
        "t_tensor_N = stage1_t * torch.ones(x_prev.shape[0], device=configs.device, dtype=torch.float32)\n",
        "res = PGFM_class.sample_xt_given_x1_x0(x_prev, ref, t_tensor_N)\n",
        "\n",
        "l2norm = torch.norm(res - adv_data_test, p=2, dim = (1,2,3))\n",
        "l2normp = torch.sqrt(torch.sum((res-adv_data_test)**2, dim = (1,2,3)))\n",
        "testceloss = torch.mean(bb_model.CEloss(res.to(device), adv_label_test.to(device)))\n",
        "# testfid = compute_fid(ref.to(device), res.to(device))\n",
        "\n",
        "with torch.no_grad():\n",
        "    # for inputs, labels in data_loader:\n",
        "    inputs, labels = res.to(device), adv_label_test.to(device)\n",
        "    predicted = bb_model.predict(inputs)\n",
        "    total = labels.size(0)\n",
        "    correct = (predicted == labels).sum().item()\n",
        "\n",
        "test_accuracy = 100 * correct / total\n",
        "\n",
        "print(f\"Test Accuracy: {test_accuracy:.2f}%\")\n",
        "\n",
        "i=0\n",
        "plt.figure(figsize=(6,6))\n",
        "for j in range(6,12):\n",
        "    fig1 = adv_data_test[j].cpu()\n",
        "    fig2 = res[j].cpu()\n",
        "    i+=1\n",
        "    plt.subplot(3,4,i)\n",
        "    imshow(fig1)\n",
        "    plt.title(cifar_classes[labels[j].item()] + str(labels[j].item()))\n",
        "\n",
        "    i+=1\n",
        "    plt.subplot(3,4,i)\n",
        "    imshow(fig2)\n",
        "    plt.title(cifar_classes[predicted[j].item()] + str(predicted[j].item()))\n",
        "    plt.axis('off')\n",
        "plt.show()\n",
        "\n",
        "\n",
        "print('Mean l2norm:', l2norm.mean().item(), l2normp.mean().item())\n",
        "print('CE loss:', testceloss.item())\n",
        "# print('FID:', testfid)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "e8860c6a650a7289",
      "metadata": {
        "ExecuteTime": {
          "end_time": "2025-05-07T13:44:12.032956Z",
          "start_time": "2025-05-07T13:44:12.030229Z"
        }
      },
      "outputs": [],
      "source": [
        "adv_label_test.shape"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "4ddbc5e2826b7adb",
      "metadata": {
        "ExecuteTime": {
          "end_time": "2025-04-29T15:01:40.402782Z",
          "start_time": "2025-04-29T15:01:38.116299Z"
        }
      },
      "outputs": [],
      "source": [
        "# ckpt1 = torch.load('./saved_model/FM_adv_iter_20000.pth', map_location=configs.device)\n",
        "\n",
        "source_data = adv_data_test\n",
        "source_label = adv_label_test\n",
        "\n",
        "# source_data = adv_data_train[:100]\n",
        "# source_label = adv_label_train[:100]\n",
        "\n",
        "ckpt2 = torch.load('./saved_model/Apr19PGFM_0p815s_iter_train_240000.pth', map_location=configs.device)\n",
        "# Apr19PGFM_0p815s_iter_train_280000: default: 0.8\n",
        "# Apr19PGFM_0p815s_iter_train_280000\n",
        "# Apr24PGFM_0p815s_iter_train_20000\n",
        "# FMworef_10000\n",
        "stage2model = PGFM_class.policy\n",
        "\n",
        "stage2model.load_state_dict(ckpt2)\n",
        "\n",
        "res = PGFM_class.RLFMsample(  stage2model, source_data, default_stage1t=0.8, default_RLstep_S=30) # from_scratch train2 configs.default_stage1_t\n",
        "# res = PGFM_class.RLFMsample(  stage2model, source_data)\n",
        "# res = torch.clip(res, -2.4, 2.4)\n",
        "\n",
        "l2norm = torch.norm(res - source_data, p=2, dim = (1,2,3))\n",
        "print('Mean l2norm:', l2norm.mean().item())\n",
        "\n",
        "\n",
        "with torch.no_grad():\n",
        "    # for inputs, labels in data_loader:\n",
        "    inputs, labels = res.to(device), source_label.to(device)\n",
        "    predicted = bb_model.predict(inputs)\n",
        "    total = labels.size(0)\n",
        "    correct = (predicted == labels).sum().item()\n",
        "\n",
        "test_accuracy = 100 * correct / total\n",
        "\n",
        "print(f\"Test Accuracy: {test_accuracy:.2f}%\")\n",
        "\n",
        "row = 3\n",
        "num = 3\n",
        "\n",
        "success_ind = torch.where(predicted!=labels)[0].cpu().numpy()\n",
        "# success_ind = torch.where((predicted!=labels)|(predicted==labels))[0].cpu().numpy()\n",
        "# success_ind = torch.where((predicted!=labels)&(predicted==3))[0].cpu().numpy()\n",
        "print(success_ind.shape)\n",
        "rand_ind = success_ind[np.random.randint(0,len(success_ind), row*num)]\n",
        "i=0\n",
        "plt.figure(figsize=(16,6))\n",
        "for j in rand_ind:\n",
        "    # j = j.item()\n",
        "    fig1 = source_data[j].cpu()\n",
        "    fig2 = res[j].cpu()\n",
        "    i+=1\n",
        "    plt.subplot(row,num*3,i)\n",
        "    imshow(fig1)\n",
        "    # plt.title(cifar_classes[labels[j].item()] + str(labels[j].item()))\n",
        "    plt.title(cifar_classes[labels[j].item()])\n",
        "\n",
        "    plt.axis('off')\n",
        "\n",
        "    i+=1\n",
        "    plt.subplot(row,num*3,i)\n",
        "    imshow((fig2-fig1)*5)\n",
        "\n",
        "    # plt.title(cifar_classes[predicted[j].item()])\n",
        "    plt.axis('off')\n",
        "\n",
        "    i+=1\n",
        "    plt.subplot(row,num*3,i)\n",
        "    imshow(fig2)\n",
        "    # pert = fig2-fig1\n",
        "    # imshow(pert)\n",
        "    \n",
        "    # plt.title(cifar_classes[predicted[j].item()] + str(predicted[j].item()))\n",
        "    plt.title(cifar_classes[predicted[j].item()])\n",
        "    plt.axis('off')\n",
        "\n",
        "plt.savefig(\"./figs/adv_cifar_samples.png\", dpi=300, bbox_inches='tight')\n",
        "plt.show()\n"
      ]
    }
  ],
  "metadata": {
    "kernelspec": {
      "display_name": "Python 3 (ipykernel)",
      "language": "python",
      "name": "python3"
    },
    "language_info": {
      "codemirror_mode": {
        "name": "ipython",
        "version": 3
      },
      "file_extension": ".py",
      "mimetype": "text/x-python",
      "name": "python",
      "nbconvert_exporter": "python",
      "pygments_lexer": "ipython3",
      "version": "3.12.2"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 5
}