{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "08350fca",
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import matplotlib.pyplot as plt"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "aba1fb88",
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_ASR(d, r_th, query_budget):\n",
    "    queries = []\n",
    "    iter_num = d['all_norms'].shape[1]\n",
    "    for k in range(d['all_norms'].shape[0]):\n",
    "        for i,norm in enumerate(d['all_norms'][k]):\n",
    "            if norm<=r_th:\n",
    "                queries.append(d['all_queries'][k][i])\n",
    "                break\n",
    "            if i+1 == iter_num and norm>=r_th:\n",
    "                queries.append(query_budget+100)\n",
    "                break\n",
    "    asr = []\n",
    "    for i in range(query_budget+1):\n",
    "        asr.append(100*np.mean(np.array(queries)<=i))\n",
    "    return np.array(asr)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0cd8453f",
   "metadata": {},
   "outputs": [],
   "source": [
    "r_th = 4\n",
    "topK = 2\n",
    "im_num = 1000\n",
    "Q = 40010\n",
    "model_arc = 'resnet101'\n",
    "d = np.load(f'Targeted_{model_arc}_imgNum_{im_num}_query_budget_{Q}_top_k{topK}.npz')\n",
    "\n",
    "asr = get_ASR(d, r_th, Q)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "55873d16",
   "metadata": {},
   "outputs": [],
   "source": [
    "qs = [5000, 10000, 20000, 30000, 40000]\n",
    "for q in qs:\n",
    "    for i in range(len(asr)):\n",
    "        if i>=q:\n",
    "            print(f'ASR for a query budget {q} is {round(asr[i],1)}')                  \n",
    "            break"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1374688d",
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.plot(asr)\n",
    "plt.xlabel('Query budget', fontname='Times New Roman', fontsize=14)\n",
    "plt.ylabel('ASR(%)', fontname='Times New Roman', fontsize=14)\n",
    "plt.grid()\n",
    "plt.xlim(0,Q)\n",
    "plt.ylim(0,100)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0d0098d7",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.9"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
