{
  "cells": [
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "initial_id",
      "metadata": {
        "ExecuteTime": {
          "end_time": "2025-04-10T13:14:33.139581Z",
          "start_time": "2025-04-10T13:14:28.356552Z"
        }
      },
      "outputs": [],
      "source": [
        "import torch\n",
        "import blackbox_model\n",
        "\n",
        "import torchvision.transforms as transforms\n",
        "from torch.utils.data import DataLoader\n",
        "import matplotlib.pyplot as plt\n",
        "import numpy as np\n",
        "import configs\n",
        "\n",
        "\n",
        "\n",
        "\n",
        "\n",
        "import FMfuncs\n",
        "from torch.utils.data import DataLoader\n",
        "import matplotlib.pyplot as plt\n",
        "\n",
        "\n",
        "import PGFM\n",
        "import PGFM_perturb"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "b8bf962a4afe143d",
      "metadata": {
        "ExecuteTime": {
          "end_time": "2025-04-10T13:14:35.348621Z",
          "start_time": "2025-04-10T13:14:35.341301Z"
        }
      },
      "outputs": [],
      "source": [
        "2"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "cd4daed6434168a6",
      "metadata": {
        "ExecuteTime": {
          "end_time": "2025-04-10T13:14:36.432191Z",
          "start_time": "2025-04-10T13:14:36.326889Z"
        }
      },
      "outputs": [],
      "source": [
        "##%%\n",
        "import numpy as np\n",
        "import matplotlib.pyplot as plt\n",
        "import os\n",
        "os.environ[\"KMP_DUPLICATE_LIB_OK\"]=\"TRUE\"\n",
        "\n",
        "record = np.load('./saved_model/Apr23RLFM_advs207_iter_train_record.npz') \n",
        "# RLFM_advs206_woref_iter_train_record\n",
        "# Apr22RLFM_advs207_wref_iter_train_record\n",
        "loss_record = record['constraint_loss_record']\n",
        "in_prob_record = record['adv_record']\n",
        "rwd_record = record['flow_loss_record']\n",
        "\n",
        "##%%\n",
        "plt.plot(loss_record)\n",
        "plt.grid()\n",
        "plt.xlabel('x10 iterations')\n",
        "plt.ylabel('loss')\n",
        "plt.show()\n",
        "\n",
        "plt.plot(in_prob_record)\n",
        "plt.grid()\n",
        "plt.xlabel('x10 iterations')\n",
        "plt.ylabel('adv num')\n",
        "plt.show()\n",
        "\n",
        "plt.plot(rwd_record)\n",
        "plt.grid()\n",
        "plt.xlabel('x10 iterations')\n",
        "plt.ylabel('rwd')\n",
        "plt.show()"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "ebbb10d62c4da927",
      "metadata": {
        "ExecuteTime": {
          "end_time": "2025-04-08T18:02:39.589500Z",
          "start_time": "2025-04-08T18:02:39.243725Z"
        }
      },
      "outputs": [],
      "source": [
        "bb_model = blackbox_model.black_box_model_class()"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "6c41b6e1d62bf2d3",
      "metadata": {
        "ExecuteTime": {
          "end_time": "2025-04-08T18:02:39.631101Z",
          "start_time": "2025-04-08T18:02:39.618131Z"
        }
      },
      "outputs": [],
      "source": [
        "from torchvision.datasets.mnist import MNIST\n",
        "data_test = MNIST('./data',\n",
        "                  train=False,\n",
        "                  download=True,\n",
        "                  transform=transforms.Compose([\n",
        "                      transforms.Resize((32, 32)),\n",
        "                      transforms.ToTensor()]))\n",
        "data_loader = torch.utils.data.DataLoader(data_test,\n",
        "                                          batch_size=10000,\n",
        "                                          shuffle=False)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "490174274f74f96a",
      "metadata": {
        "ExecuteTime": {
          "end_time": "2025-04-08T18:02:41.237317Z",
          "start_time": "2025-04-08T18:02:39.710123Z"
        }
      },
      "outputs": [],
      "source": [
        "device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")\n",
        "\n",
        "correct = 0\n",
        "total = 0\n",
        "\n",
        "data_all = None\n",
        "label_all = None\n",
        "\n",
        "with torch.no_grad():\n",
        "    for inputs, labels in data_loader:\n",
        "        inputs, labels = inputs.to(device), labels.to(device)\n",
        "        predicted = bb_model.predict(inputs)\n",
        "        total += labels.size(0)\n",
        "        correct += (predicted == labels).sum().item()\n",
        "        if data_all is None:\n",
        "            data_all = inputs\n",
        "            label_all = labels\n",
        "        else:\n",
        "            data_all = torch.cat((data_all, inputs), dim=0)\n",
        "            label_all = torch.cat((label_all, labels), dim=0)\n",
        "\n",
        "test_accuracy = 100 * correct / total\n",
        "\n",
        "print(f\"Test Accuracy: {test_accuracy:.2f}%\")\n",
        "data_all = data_all.to(configs.device)\n",
        "label_all = label_all.to(configs.device)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "8e81ec036ddff8d9",
      "metadata": {
        "ExecuteTime": {
          "end_time": "2025-04-08T18:02:41.335571Z",
          "start_time": "2025-04-08T18:02:41.330686Z"
        }
      },
      "outputs": [],
      "source": [
        "train_rate = 0.8\n",
        "\n",
        "adv_data_train = data_all[:int(train_rate * len(data_all))]\n",
        "adv_data_test = data_all[int(train_rate * len(data_all)):]\n",
        "adv_label_train = label_all[:int(train_rate * len(label_all))]\n",
        "adv_label_test = label_all[int(train_rate * len(label_all)):]"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "48a5e3cadf9d5699",
      "metadata": {
        "ExecuteTime": {
          "end_time": "2025-04-08T18:02:45.779041Z",
          "start_time": "2025-04-08T18:02:45.707620Z"
        }
      },
      "outputs": [],
      "source": [
        "\n",
        "FM_class = FMfuncs.OTFlowMatching()\n",
        "# PGFM_class = RLFMfuncs.RLFM(bb_model)\n",
        "PGFM_class = PGFM.PGFM(bb_model)\n",
        "PGFMp_class = PGFM_perturb.PGFM_perturb(bb_model)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "b9aa9430-b81f-46bf-b770-ad190105f503",
      "metadata": {},
      "outputs": [],
      "source": [
        "a = torch.randn(10,2)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "688ed92d-e092-4e8b-9c91-a2cf42449f02",
      "metadata": {},
      "outputs": [],
      "source": [
        "torch.min(a,dim=1)[0]"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "8c65239efb8071e",
      "metadata": {
        "ExecuteTime": {
          "end_time": "2025-04-08T14:24:39.514803Z",
          "start_time": "2025-04-08T14:24:38.536642Z"
        }
      },
      "outputs": [],
      "source": [
        "# ckpt1 = torch.load('./saved_model/FM_adv_iter_20000.pth', map_location=configs.device)\n",
        "ckpt2 = torch.load('./saved_model/Apr23RLFM_advs207_iter_train_730000.pth', map_location=configs.device)\n",
        "# FMworef_100000\n",
        "# Apr22RLFM_advs207_wref_iter_train_300000: apr23 record, t0=0.7\n",
        "# Apr23RLFM_advs207_iter_train_500000_record: t0=0.8, s=15\n",
        "# Apr23RLFM_advs207_iter_train_730000\n",
        "stage2model = PGFM_class.policy\n",
        "# stage2model = PGFM_class.get_untrained_model_stage2_var()\n",
        "stage2model.load_state_dict(ckpt2)\n",
        "# res = RLFMclass.RLFMsample(stage1model, stage2model, 1000, mode = 'train2')\n",
        "res = PGFM_class.RLFMsample( stage2model, adv_data_test,\n",
        "                   default_stage1t = 0.8, default_RLstep_S = 30) # from_scratch train2 configs.default_stage1_t\n",
        "\n",
        "res = torch.clip(res, 0, 1)\n",
        "# res = FMfuncs.sampler(stage1model, inp, stoptime=1, default_generation_step = 100)\n",
        "\n",
        "\n",
        "\n",
        "\n",
        "\n",
        "# res = adv_data_test\n",
        "l2norm = torch.norm(res - adv_data_test, p=2, dim = (1,2,3))\n",
        "print('Mean l2norm:', l2norm.mean().item())\n",
        "\n",
        "\n",
        "with torch.no_grad():\n",
        "    # for inputs, labels in data_loader:\n",
        "    inputs, labels = res.to(device), adv_label_test.to(device)\n",
        "    predicted = bb_model.predict(inputs)\n",
        "    total = labels.size(0)\n",
        "    correct = (predicted == labels).sum().item()\n",
        "\n",
        "test_accuracy = 100 * correct / total\n",
        "\n",
        "print(f\"Test Accuracy: {test_accuracy:.2f}%\")\n",
        "\n",
        "row = 3\n",
        "num = 3\n",
        "\n",
        "success_ind = torch.where(predicted!=labels)[0].cpu().numpy()\n",
        "# success_ind = torch.where((predicted != labels) & (labels == 6))[0].cpu().numpy()\n",
        "# rand_ind = success_ind[np.random.randint(0,len(success_ind), row*num)]\n",
        "rand_ind = success_ind[np.random.choice(len(success_ind), size=row*num, replace=False)]\n",
        "i=0\n",
        "plt.figure(figsize=(12,6))\n",
        "for j in rand_ind:\n",
        "    # j = j.item()\n",
        "    fig1 = adv_data_test[j,0].cpu().numpy()\n",
        "    fig2 = res[j,0].cpu().numpy()\n",
        "    i+=1\n",
        "    plt.subplot(row,num*2,i)\n",
        "    plt.title(adv_label_test[j].item())\n",
        "    plt.imshow(fig1, cmap='gray')\n",
        "\n",
        "    plt.axis('off')\n",
        "    i+=1\n",
        "    plt.subplot(row,num*2,i)\n",
        "    plt.title(predicted[j].item())\n",
        "    plt.imshow(fig2, cmap='gray')\n",
        "    plt.axis('off')\n",
        "plt.show()"
      ]
    }
  ],
  "metadata": {
    "kernelspec": {
      "display_name": "Python 3 (ipykernel)",
      "language": "python",
      "name": "python3"
    },
    "language_info": {
      "codemirror_mode": {
        "name": "ipython",
        "version": 3
      },
      "file_extension": ".py",
      "mimetype": "text/x-python",
      "name": "python",
      "nbconvert_exporter": "python",
      "pygments_lexer": "ipython3",
      "version": "3.12.2"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 5
}