{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 247,
   "id": "c533c175-ef08-46c3-8936-f2ba15555059",
   "metadata": {},
   "outputs": [],
   "source": [
    "from typing import List, Tuple, Dict\n",
    "from copy import copy\n",
    "\n",
    "from tqdm import tqdm_notebook\n",
    "import numpy as np\n",
    "import pandas as pd"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 248,
   "id": "6f3046cd-83b8-4855-b6ec-9bc085a63317",
   "metadata": {},
   "outputs": [],
   "source": [
    "backdoor = pd.read_csv('b_train_ci.csv',header=None)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 249,
   "id": "4bff77b2-9e3e-4183-81b1-9d0666dc9e1e",
   "metadata": {},
   "outputs": [],
   "source": [
    "backdoor = backdoor.values"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 250,
   "id": "f2f4bb0e-954a-46e7-a73c-2d45ce1a130a",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([7., 2., 1., ..., 4., 5., 6.])"
      ]
     },
     "execution_count": 250,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "backdoor[:, -1]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 251,
   "id": "e925088d-af9a-4310-b4af-78c8e98206ef",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([[1.2367513e-03, 6.7218850e-03, 1.6727995e-03, ..., 4.3260370e-03,\n",
       "        1.3296805e-02, 7.0000000e+00],\n",
       "       [6.0803473e-02, 4.1190704e-04, 5.3082180e-01, ..., 9.5548425e-03,\n",
       "        1.9291420e-04, 2.0000000e+00],\n",
       "       [2.7711360e-03, 9.3953246e-01, 1.2262586e-03, ..., 6.9273340e-03,\n",
       "        1.1424081e-02, 1.0000000e+00],\n",
       "       ...,\n",
       "       [6.7468180e-03, 7.3916220e-02, 3.6551245e-03, ..., 5.6444503e-02,\n",
       "        1.7551164e-01, 4.0000000e+00],\n",
       "       [2.7768157e-02, 8.5878900e-04, 6.2541994e-03, ..., 2.4271207e-02,\n",
       "        1.2970519e-02, 5.0000000e+00],\n",
       "       [6.3859366e-02, 9.0431710e-03, 4.6380790e-02, ..., 1.2054573e-02,\n",
       "        4.3394650e-02, 6.0000000e+00]])"
      ]
     },
     "execution_count": 251,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "backdoor"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 252,
   "id": "8c8ee7c7-b0d0-484c-b8e6-75a76403be80",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "250000\n"
     ]
    }
   ],
   "source": [
    "res_backdoor  = np.concatenate((backdoor, np.zeros((len(backdoor[:, -1]), 1))), axis=1)\n",
    "print(len(res_backdoor))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "8282c2ea-3dff-4818-9c56-b479ea36f4af",
   "metadata": {},
   "source": [
    "train_idx = np.random.choice(np.arange(len(res_backdoor)), 210000, replace=False)\n",
    "b_train = res_backdoor[train_idx]\n",
    "print(len(b_train))\n",
    "backdoor_test_all = np.delete(res_backdoor, train_idx, axis=0)\n",
    "print(len(backdoor_test_all))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 253,
   "id": "40aa183a-589a-4095-abb9-44a12af86a5e",
   "metadata": {},
   "outputs": [],
   "source": [
    "b_train = res_backdoor[:10000, :]\n",
    "backdoor_test_all = res_backdoor[10000:110000, :]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 254,
   "id": "84529ec8-29ab-4157-89e1-0989a1e8a00c",
   "metadata": {},
   "outputs": [],
   "source": [
    "clean = pd.read_csv('csv/adb/c_train_adb_cifar10.csv',header=None)\n",
    "clean = clean.values"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 255,
   "id": "a3a6dfef-9e32-414a-9efd-acb955e074f6",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "250000\n"
     ]
    }
   ],
   "source": [
    "res_clean  = np.concatenate((clean,  np.ones((len(clean[: ,-1]), 1))), axis=1)\n",
    "print(len(res_clean))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "6fa9b031-6898-4391-b67e-7d68514ddbf2",
   "metadata": {},
   "source": [
    "train_idx_c = np.random.choice(np.arange(len(res_clean)), 210000, replace=False)\n",
    "c_train = res_clean [train_idx_c]\n",
    "print(len(c_train))\n",
    "clean_test_all = np.delete(res_clean, train_idx_c, axis=0)\n",
    "print(len(clean_test_all))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 256,
   "id": "de91d725-edcb-4e85-91a9-2bfc102c7be4",
   "metadata": {},
   "outputs": [],
   "source": [
    "c_train = res_clean[:10000, :]\n",
    "clean_test_all = res_clean[10000:110000, :]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 257,
   "id": "11ff6e2b-19de-4643-a4ed-022eef4f3c1e",
   "metadata": {},
   "outputs": [],
   "source": [
    "model_results = np.vstack((c_train, b_train))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 258,
   "id": "ae3fa0f7-74ed-4702-8719-343ef1752c5d",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([[2.1877690e-03, 1.5020993e-03, 2.3656820e-03, ..., 5.7962030e-02,\n",
       "        7.0000000e+00, 1.0000000e+00],\n",
       "       [4.0327020e-04, 3.1136550e-04, 9.2790380e-01, ..., 4.4738714e-05,\n",
       "        2.0000000e+00, 1.0000000e+00],\n",
       "       [1.8885657e-03, 9.0577010e-01, 1.8840097e-02, ..., 5.6283730e-03,\n",
       "        1.0000000e+00, 1.0000000e+00],\n",
       "       ...,\n",
       "       [1.0920937e-02, 4.6088308e-02, 7.8935330e-03, ..., 3.7767786e-01,\n",
       "        4.0000000e+00, 0.0000000e+00],\n",
       "       [5.0961570e-03, 3.8670547e-02, 5.5490886e-03, ..., 1.4629915e-02,\n",
       "        5.0000000e+00, 0.0000000e+00],\n",
       "       [2.8489404e-03, 2.4338505e-04, 2.0305684e-02, ..., 6.3995310e-04,\n",
       "        6.0000000e+00, 0.0000000e+00]])"
      ]
     },
     "execution_count": 258,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "model_results"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e2607ea0-3c66-4d7f-84e8-ca86f9880ed2",
   "metadata": {},
   "source": [
    "batch_size = 2000\n",
    "stride = 9984\n",
    "\n",
    "num_batches = (model_results.shape[0] - batch_size) // stride + 1\n",
    "\n",
    "batch_list = []\n",
    "\n",
    "for i in range(num_batches):\n",
    "    start = i * stride\n",
    "    end = start + batch_size\n",
    "    batch = model_results[start:end, :]\n",
    "    batch_list.append(batch)\n",
    "\n",
    "results = np.vstack(batch_list)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 259,
   "id": "92d40261-8042-436f-b1c3-01d6bc854b52",
   "metadata": {},
   "outputs": [],
   "source": [
    "from sklearn.ensemble import RandomForestClassifier\n",
    "from mblearn import AttackModels"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 260,
   "id": "32020774-32a7-4d75-85f6-551f08258eab",
   "metadata": {},
   "outputs": [],
   "source": [
    "rf_attack = RandomForestClassifier(n_estimators=100)\n",
    "attacker = AttackModels(target_classes=10, attack_learner=rf_attack)\n",
    "attacker.fit(model_results)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "74879c86-0b4a-43c8-8834-16466456319d",
   "metadata": {},
   "source": [
    "batch_size = 2000\n",
    "stride = 9984\n",
    "\n",
    "num_batches = (backdoor_test.shape[0] - batch_size) // stride + 1\n",
    "\n",
    "batch_list = []\n",
    "\n",
    "for i in range(num_batches):\n",
    "    start = i * stride\n",
    "    end = start + batch_size\n",
    "    batch = backdoor_test[start:end, :]\n",
    "    batch_list.append(batch)\n",
    "\n",
    "backdoor_test = np.vstack(batch_list)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "92e250c6-3c0c-4018-91a5-7f1184aeba10",
   "metadata": {},
   "source": [
    "batch_size = 2000\n",
    "stride = 9984\n",
    "\n",
    "num_batches = (clean_test.shape[0] - batch_size) // stride + 1\n",
    "\n",
    "batch_list = []\n",
    "\n",
    "for i in range(num_batches):\n",
    "    start = i * stride\n",
    "    end = start + batch_size\n",
    "    batch = clean_test[start:end, :]\n",
    "    batch_list.append(batch)\n",
    "\n",
    "clean_test = np.vstack(batch_list)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 261,
   "id": "61678a89-1e00-4f77-a73a-b9300815a366",
   "metadata": {},
   "outputs": [],
   "source": [
    "backdoor_test_out = backdoor_test_all[:,:-2]\n",
    "backdoor_test_label = backdoor_test_all[:,-2]\n",
    "clean_test_out = clean_test_all[:,:-2]\n",
    "clean_test_label = clean_test_all[:,-2]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 262,
   "id": "98bed915-52bf-40f0-aa05-0d0c6979c90b",
   "metadata": {},
   "outputs": [],
   "source": [
    "backdoor_test_label = backdoor_test_label.astype(int)\n",
    "clean_test_label = clean_test_label.astype(int)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 263,
   "id": "78943627-a498-40d8-95b3-2e5808f86d10",
   "metadata": {},
   "outputs": [],
   "source": [
    "mem_res_clean = attacker.predict(clean_test_out ,clean_test_label, batch=True)\n",
    "mem_res_backdoor = attacker.predict(backdoor_test_out ,backdoor_test_label, batch=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 264,
   "id": "48fccfa9-5cc4-40ea-8fb6-a75db4542528",
   "metadata": {},
   "outputs": [],
   "source": [
    "from sklearn.metrics import precision_score, recall_score, f1_score, accuracy_score, confusion_matrix"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 265,
   "id": "a18e0de2-287f-4337-ba2f-cb22993d6aba",
   "metadata": {},
   "outputs": [],
   "source": [
    "y_pred = np.concatenate((np.argmax(mem_res_clean, axis=1), np.argmax(mem_res_backdoor, axis=1)))\n",
    "y_true = np.concatenate((np.ones_like(clean_test_label), np.zeros_like(backdoor_test_label)))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 266,
   "id": "c61487a9-2d5f-4cd2-9643-1fc0fdde9a38",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0.66272"
      ]
     },
     "execution_count": 266,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "accuracy_score(y_true, y_pred)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 267,
   "id": "b3d5ccce-7f00-4284-baa2-b045351bc051",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[1 1 0 ... 0 0 0]\n"
     ]
    }
   ],
   "source": [
    "print(y_pred)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 268,
   "id": "ee729805-2a34-4f0d-8427-8b349a28c181",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([1, 1, 1, ..., 0, 0, 0])"
      ]
     },
     "execution_count": 268,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "y_true"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 269,
   "id": "c7fd257e-01c5-4eba-80af-1affcc61729b",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "model 0:\n",
      "Ones: 6942\n",
      "Zeros: 3058\n",
      "Result: 1\n",
      "model 1:\n",
      "Ones: 8754\n",
      "Zeros: 1246\n",
      "Result: 1\n",
      "model 2:\n",
      "Ones: 7312\n",
      "Zeros: 2688\n",
      "Result: 1\n",
      "model 3:\n",
      "Ones: 6036\n",
      "Zeros: 3964\n",
      "Result: 1\n",
      "model 4:\n",
      "Ones: 6911\n",
      "Zeros: 3089\n",
      "Result: 1\n",
      "model 5:\n",
      "Ones: 6564\n",
      "Zeros: 3436\n",
      "Result: 1\n",
      "model 6:\n",
      "Ones: 5880\n",
      "Zeros: 4120\n",
      "Result: 1\n",
      "model 7:\n",
      "Ones: 6845\n",
      "Zeros: 3155\n",
      "Result: 1\n",
      "model 8:\n",
      "Ones: 5676\n",
      "Zeros: 4324\n",
      "Result: 1\n",
      "model 9:\n",
      "Ones: 7236\n",
      "Zeros: 2764\n",
      "Result: 1\n",
      "model 10:\n",
      "Ones: 3813\n",
      "Zeros: 6187\n",
      "Result: 0\n",
      "model 11:\n",
      "Ones: 3352\n",
      "Zeros: 6648\n",
      "Result: 0\n",
      "model 12:\n",
      "Ones: 2910\n",
      "Zeros: 7090\n",
      "Result: 0\n",
      "model 13:\n",
      "Ones: 4447\n",
      "Zeros: 5553\n",
      "Result: 0\n",
      "model 14:\n",
      "Ones: 4187\n",
      "Zeros: 5813\n",
      "Result: 0\n",
      "model 15:\n",
      "Ones: 4665\n",
      "Zeros: 5335\n",
      "Result: 0\n",
      "model 16:\n",
      "Ones: 3647\n",
      "Zeros: 6353\n",
      "Result: 0\n",
      "model 17:\n",
      "Ones: 1670\n",
      "Zeros: 8330\n",
      "Result: 0\n",
      "model 18:\n",
      "Ones: 3964\n",
      "Zeros: 6036\n",
      "Result: 0\n",
      "model 19:\n",
      "Ones: 2957\n",
      "Zeros: 7043\n",
      "Result: 0\n"
     ]
    }
   ],
   "source": [
    "arr = y_pred.reshape(-1, 10000) \n",
    "\n",
    "for i, group in enumerate(arr):\n",
    "    ones = np.sum(group == 1)\n",
    "    zeros = np.sum(group == 0)\n",
    "\n",
    "    print(f'model {i}:')\n",
    "    print(f'Ones: {ones}')\n",
    "    print(f'Zeros: {zeros}')\n",
    "\n",
    "    if ones > zeros :\n",
    "        print('Result: 1')\n",
    "    else:  \n",
    "        print('Result: 0')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 974,
   "id": "9cb1d0cd-89b9-40db-af82-cf532163ac4d",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([[86164, 13676],\n",
       "       [15572, 84268]])"
      ]
     },
     "execution_count": 974,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "confusion_matrix(y_true, y_pred)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 620,
   "id": "b2651196-54c6-43df-a347-1df76c3ca9b6",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([[1, 1, 0, ..., 1, 1, 1],\n",
       "       [0, 0, 0, ..., 1, 0, 1],\n",
       "       [1, 1, 1, ..., 1, 0, 0],\n",
       "       ...,\n",
       "       [0, 0, 0, ..., 0, 0, 1],\n",
       "       [1, 1, 1, ..., 0, 0, 0],\n",
       "       [1, 1, 0, ..., 0, 1, 0]])"
      ]
     },
     "execution_count": 620,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "arr"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b5ccc31e-1d8f-49f6-a0f9-f0ed9fed17b3",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python [conda env:root]",
   "language": "python",
   "name": "conda-root-py"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
