{
  "cells": [
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "initial_id",
      "metadata": {
        "ExecuteTime": {
          "end_time": "2025-05-08T15:31:54.279943Z",
          "start_time": "2025-05-08T15:31:49.853674Z"
        }
      },
      "outputs": [],
      "source": [
        "import torch\n",
        "import blackbox_model\n",
        "\n",
        "import configs\n",
        "\n",
        "import torchvision.transforms as transforms\n",
        "import FMfuncs\n",
        "from torch.utils.data import DataLoader\n",
        "import matplotlib.pyplot as plt\n",
        "import numpy as np\n",
        "# import RLFMfuncs\n",
        "import PGFM\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "53f744ebb6d73c72",
      "metadata": {
        "ExecuteTime": {
          "end_time": "2025-05-08T15:31:54.311826Z",
          "start_time": "2025-05-08T15:31:54.308543Z"
        }
      },
      "outputs": [],
      "source": [
        "2"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "b894220cde800a99",
      "metadata": {
        "ExecuteTime": {
          "end_time": "2025-05-08T15:31:54.392773Z",
          "start_time": "2025-05-08T15:31:54.368055Z"
        }
      },
      "outputs": [],
      "source": [
        "from torchvision.datasets.mnist import MNIST\n",
        "data_test = MNIST('./data',\n",
        "                  train=False,\n",
        "                  download=True,\n",
        "                  transform=transforms.Compose([\n",
        "                      transforms.Resize((32, 32)),\n",
        "                      transforms.ToTensor()]))\n",
        "data_loader = torch.utils.data.DataLoader(data_test,\n",
        "                                          batch_size=10000,\n",
        "                                          shuffle=False)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "bf73810e548d497b",
      "metadata": {
        "ExecuteTime": {
          "end_time": "2025-05-08T15:31:54.594438Z",
          "start_time": "2025-05-08T15:31:54.453356Z"
        }
      },
      "outputs": [],
      "source": [
        "bb_model = blackbox_model.black_box_model_class()"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "51cd3e4d8a6a6696",
      "metadata": {
        "ExecuteTime": {
          "end_time": "2025-05-08T15:31:55.148271Z",
          "start_time": "2025-05-08T15:31:54.617636Z"
        }
      },
      "outputs": [],
      "source": [
        "device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")\n",
        "\n",
        "correct = 0\n",
        "total = 0\n",
        "\n",
        "data_all = None\n",
        "label_all = None\n",
        "\n",
        "with torch.no_grad():\n",
        "    for inputs, labels in data_loader:\n",
        "        inputs, labels = inputs.to(device), labels.to(device)\n",
        "        # inputs = torch.clip(inputs, 0, 1)\n",
        "        predicted = bb_model.predict(inputs)\n",
        "        total += labels.size(0)\n",
        "        correct += (predicted == labels).sum().item()\n",
        "        if data_all is None:\n",
        "            data_all = inputs\n",
        "            label_all = labels\n",
        "        else:\n",
        "            data_all = torch.cat((data_all, inputs), dim=0)\n",
        "            label_all = torch.cat((label_all, labels), dim=0)\n",
        "\n",
        "test_accuracy = 100 * correct / total\n",
        "\n",
        "print(f\"Test Accuracy: {test_accuracy:.2f}%\")\n",
        "data_all = data_all.to(configs.device)\n",
        "label_all = label_all.to(configs.device)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "2081cef8e59af559",
      "metadata": {
        "ExecuteTime": {
          "end_time": "2025-05-08T15:31:55.170577Z",
          "start_time": "2025-05-08T15:31:55.168659Z"
        }
      },
      "outputs": [],
      "source": [
        "train_rate = 0.8\n",
        "\n",
        "adv_data_train = data_all[:int(train_rate * len(data_all))]\n",
        "adv_data_test = data_all[int(train_rate * len(data_all)):]\n",
        "adv_label_train = label_all[:int(train_rate * len(label_all))]\n",
        "adv_label_test = label_all[int(train_rate * len(label_all)):]\n",
        "\n",
        "# res = ref+torch.randn_like(ref)*0.19\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "2d02ab15f7141859",
      "metadata": {
        "ExecuteTime": {
          "end_time": "2025-05-08T15:33:08.197618Z",
          "start_time": "2025-05-08T15:33:07.969612Z"
        }
      },
      "outputs": [],
      "source": [
        "stage1_t = 1\n",
        "ref = adv_data_test\n",
        "x_prev = torch.randn(ref.shape[0], 1, 32, 32, dtype=torch.float32, device=configs.device)\n",
        "\n",
        "\n",
        "t_tensor_N = stage1_t * torch.ones(x_prev.shape[0], device=configs.device, dtype=torch.float32)\n",
        "res = PGFM_class.sample_xt_given_x1_x0(x_prev, ref, t_tensor_N)\n",
        "# res = ref+torch.randn_like(ref)*0.175\n",
        "# res = res.clip(res)\n",
        "\n",
        "l2norm = torch.norm(res - adv_data_test, p=2, dim = (1,2,3))\n",
        "# testceloss = torch.mean(bb_model.CEloss(res.to(device), adv_label_test.to(device)))\n",
        "# testfid = compute_fid(ref.to(device), res.to(device))\n",
        "\n",
        "with torch.no_grad():\n",
        "    # for inputs, labels in data_loader:\n",
        "    inputs, labels = res.to(device), adv_label_test.to(device)\n",
        "    predicted = bb_model.predict(inputs)\n",
        "    total = labels.size(0)\n",
        "    correct = (predicted == labels).sum().item()\n",
        "\n",
        "test_accuracy = 100 * correct / total\n",
        "\n",
        "print(f\"Test Accuracy: {test_accuracy:.2f}%\")\n",
        "print('Mean l2norm:', l2norm.mean().item())\n",
        "\n",
        "i=0\n",
        "plt.figure(figsize=(6,6))\n",
        "for j in range(6,12):\n",
        "    fig1 = adv_data_test[j].cpu()[0]\n",
        "    fig2 = res[j].cpu()\n",
        "    i+=1\n",
        "    plt.subplot(3,4,i)\n",
        "    plt.imshow(fig1, cmap='gray')\n",
        "    plt.title(str(labels[j].item()))\n",
        "\n",
        "    i+=1\n",
        "    plt.subplot(3,4,i)\n",
        "    plt.imshow(fig1, cmap='gray')\n",
        "    plt.title( str(predicted[j].item()))\n",
        "    plt.axis('off')\n",
        "plt.show()\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "de8ef476a3fa279f",
      "metadata": {
        "ExecuteTime": {
          "end_time": "2025-05-07T22:38:24.908532Z",
          "start_time": "2025-05-07T22:38:24.886701Z"
        }
      },
      "outputs": [],
      "source": [
        "FM_class = FMfuncs.OTFlowMatching()\n",
        "# PGFM_class = RLFMfuncs.RLFM(bb_model)\n",
        "PGFM_class = PGFM.PGFM(bb_model)\n",
        "# PGFMp_class = PGFM_perturb.PGFM_perturb(bb_model)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "1ac496e6fb6745f6",
      "metadata": {
        "ExecuteTime": {
          "end_time": "2025-04-18T20:58:40.104836Z",
          "start_time": "2025-04-18T20:58:40.094381Z"
        }
      },
      "outputs": [],
      "source": [
        "PGFM_class.train_2stage(adv_data_train, adv_label_train, init='./saved_model/FMworef_100000.pth')"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "a36ce332abb23f0b",
      "metadata": {
        "ExecuteTime": {
          "start_time": "2025-04-18T21:00:08.028482Z"
        },
        "jupyter": {
          "is_executing": true
        }
      },
      "outputs": [],
      "source": [
        "PGFM_class.train2_2stage(adv_data_train, adv_label_train, './saved_model/Apr22RLFM_advs207_wref_iter_train_300000.pth')\n",
        "# PGFM_class.train2_2stage(adv_data_train, adv_label_train, './saved_model/FMworef_100000.pth')\n",
        "# './saved_model/Apr1RLFM_advs207_iter_train_300000.pth'\n",
        "# './saved_model/Apr18RLFM_advs207_wref_iter_train_220000.pth'"
      ]
    }
  ],
  "metadata": {
    "kernelspec": {
      "display_name": "Python 3 (ipykernel)",
      "language": "python",
      "name": "python3"
    },
    "language_info": {
      "codemirror_mode": {
        "name": "ipython",
        "version": 3
      },
      "file_extension": ".py",
      "mimetype": "text/x-python",
      "name": "python",
      "nbconvert_exporter": "python",
      "pygments_lexer": "ipython3",
      "version": "3.12.2"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 5
}