{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 363,
   "id": "6fa0bd1a-77cf-4b32-bb8a-9c84b3df5a1e",
   "metadata": {},
   "outputs": [],
   "source": [
    "from typing import List, Tuple, Dict\n",
    "from copy import copy\n",
    "\n",
    "from tqdm import tqdm_notebook\n",
    "import numpy as np\n",
    "import pandas as pd"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 364,
   "id": "7f65a0da-5e13-4fb3-82dc-c00344c8e378",
   "metadata": {},
   "outputs": [],
   "source": [
    "backdoor = pd.read_csv('csv/stl10/cv/b_train.csv',header=None)\n",
    "backdoor = backdoor.values"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 365,
   "id": "f7ce4c24-0172-4b08-a6a3-0cde03d74540",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([0., 0., 0., ..., 0., 0., 0.])"
      ]
     },
     "execution_count": 365,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "backdoor[:, -1]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 366,
   "id": "a62b813d-9f49-4138-a29b-6e249f9746f5",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "6000\n"
     ]
    }
   ],
   "source": [
    "res_backdoor  = np.concatenate((backdoor, np.zeros((len(backdoor[:, -1]), 1))), axis=1)\n",
    "print(len(res_backdoor))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 367,
   "id": "b0d0cd63-0c2a-40bd-9b26-c642efb4ad69",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([[19.67174   , -1.1232085 , -0.89969444, ..., -0.888966  ,\n",
       "         0.        ,  0.        ],\n",
       "       [24.00162   , -1.6951634 , -3.6819868 , ..., -2.3598003 ,\n",
       "         0.        ,  0.        ],\n",
       "       [19.522959  ,  3.7918465 , -5.430204  , ...,  0.47496593,\n",
       "         0.        ,  0.        ],\n",
       "       ...,\n",
       "       [16.596119  , -1.4022186 , -5.1754313 , ..., -1.912952  ,\n",
       "         0.        ,  0.        ],\n",
       "       [23.648275  , -5.083376  , -7.2755322 , ..., -1.8304331 ,\n",
       "         0.        ,  0.        ],\n",
       "       [23.174734  , -4.316598  , -1.4159449 , ..., -0.3981294 ,\n",
       "         0.        ,  0.        ]])"
      ]
     },
     "execution_count": 367,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "b_train"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 368,
   "id": "65af4700-598e-432a-bd57-72c2f3cf042b",
   "metadata": {},
   "outputs": [],
   "source": [
    "clean = pd.read_csv('csv/stl10/cv/c_train.csv',header=None)\n",
    "clean = clean.values"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 369,
   "id": "43050ad9-a381-4576-a4cb-80776440f360",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "6000\n"
     ]
    }
   ],
   "source": [
    "res_clean  = np.concatenate((clean,  np.ones((len(clean[: ,-1]), 1))), axis=1)\n",
    "print(len(res_clean))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 370,
   "id": "b55632ed-a247-4b93-8734-954f637952ad",
   "metadata": {},
   "outputs": [],
   "source": [
    "model_results = np.vstack((res_clean, res_backdoor))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 371,
   "id": "8e5d803e-3b47-45f2-acc0-c2b86b0950fb",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([[-1.3869573 ,  2.1415608 ,  9.037069  , ...,  6.801649  ,\n",
       "         2.        ,  1.        ],\n",
       "       [-6.4431252 , -8.644714  ,  0.94844633, ...,  2.1079938 ,\n",
       "         4.        ,  1.        ],\n",
       "       [-9.214293  ,  8.9453335 , -7.9739847 , ..., -3.933224  ,\n",
       "         4.        ,  1.        ],\n",
       "       ...,\n",
       "       [ 9.524741  ,  1.8170424 , -4.6482253 , ..., -4.493582  ,\n",
       "         0.        ,  0.        ],\n",
       "       [10.513698  ,  2.4764946 , -3.4415903 , ..., -3.591676  ,\n",
       "         0.        ,  0.        ],\n",
       "       [11.277263  , -0.5133069 , -3.7015157 , ..., -5.871673  ,\n",
       "         0.        ,  0.        ]])"
      ]
     },
     "execution_count": 371,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "model_results"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 372,
   "id": "2aad636c-903b-45de-9a85-54ff17f8012b",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(12000, 12)"
      ]
     },
     "execution_count": 372,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "model_results.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 373,
   "id": "51bd650f-108d-494d-8965-eb2a8d7a5811",
   "metadata": {},
   "outputs": [],
   "source": [
    "from sklearn.ensemble import RandomForestClassifier\n",
    "from mblearn import AttackModels"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 374,
   "id": "25bfac0e-5264-4557-923e-0e5723217ee8",
   "metadata": {},
   "outputs": [],
   "source": [
    "rf_attack = RandomForestClassifier(n_estimators=100)\n",
    "attacker = AttackModels(target_classes=10, attack_learner=rf_attack)\n",
    "attacker.fit(model_results)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 376,
   "id": "2bb96d9c-918c-4e6c-a0a4-a7dc45af240f",
   "metadata": {},
   "outputs": [],
   "source": [
    "b_ci = pd.read_csv('b_train.csv',header=None)\n",
    "c_ci = pd.read_csv('c_train.csv',header=None)\n",
    "c_ci = c_ci.values\n",
    "b_ci = b_ci.values"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 377,
   "id": "4526dd77-3997-4f3b-8958-13080db55b44",
   "metadata": {},
   "outputs": [],
   "source": [
    "#clean_test = clean_ci[149850:, :]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 378,
   "id": "90a421be-774e-4bc2-9e3d-364ea2dee0d2",
   "metadata": {},
   "outputs": [],
   "source": [
    "backdoor_test_out = b_ci[:,:-1]\n",
    "backdoor_test_label = b_ci[:,-1]\n",
    "clean_test_out = c_ci[:,:-1]\n",
    "clean_test_label = c_ci[:,-1]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 379,
   "id": "229c854d-5a4b-431d-91c9-2543772f9cf3",
   "metadata": {},
   "outputs": [],
   "source": [
    "backdoor_test_label = backdoor_test_label.astype(int)\n",
    "clean_test_label = clean_test_label.astype(int)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 380,
   "id": "5c68d1b7-f899-43d4-86cb-2025858a11b5",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(90690, 10)"
      ]
     },
     "execution_count": 380,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "clean_test_out.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 381,
   "id": "fc1ad802-9af3-429f-a030-6d0f97e16937",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(90690,)"
      ]
     },
     "execution_count": 381,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "clean_test_label.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 383,
   "id": "82b96246-a2ed-4efe-ae67-b9c3efc541c6",
   "metadata": {},
   "outputs": [],
   "source": [
    "mem_res_clean = attacker.predict(clean_test_out ,clean_test_label, batch=True)\n",
    "mem_res_backdoor = attacker.predict(backdoor_test_out ,backdoor_test_label, batch=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 384,
   "id": "d13b92c2-69bf-4927-9591-62bbce0f7915",
   "metadata": {},
   "outputs": [],
   "source": [
    "from sklearn.metrics import precision_score, recall_score, f1_score, accuracy_score, confusion_matrix"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 388,
   "id": "195d041e-a12f-49bb-a66a-4e6eb5d5e8df",
   "metadata": {},
   "outputs": [],
   "source": [
    "y_pred = np.concatenate((np.argmax(mem_res_clean, axis=1), np.argmax(mem_res_backdoor, axis=1)))\n",
    "y_true = np.concatenate((np.ones_like(clean_test_label), np.zeros_like(backdoor_test_label)))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 389,
   "id": "82dbf228-729f-46ac-9a84-0b8affccde4d",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([0, 0, 0, ..., 0, 0, 0])"
      ]
     },
     "execution_count": 389,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "y_pred"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 390,
   "id": "9cfc1fe0-4345-402f-b847-3321f025d0e9",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0.5"
      ]
     },
     "execution_count": 390,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "accuracy_score(y_true, y_pred)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 391,
   "id": "239230e8-74d2-4d94-8cc5-50d7c10d27d0",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([[90690,     0],\n",
       "       [90690,     0]])"
      ]
     },
     "execution_count": 391,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "confusion_matrix(y_true, y_pred)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 409,
   "id": "37a557ec-d8c5-4796-ad73-2ce265372fe2",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "model 0:\n",
      "Ones: 892\n",
      "Zeros: 9098\n",
      "Result: 0\n",
      "model 1:\n",
      "Ones: 3183\n",
      "Zeros: 6807\n",
      "Result: 0\n",
      "model 2:\n",
      "Ones: 1034\n",
      "Zeros: 8956\n",
      "Result: 0\n",
      "model 3:\n",
      "Ones: 1086\n",
      "Zeros: 8904\n",
      "Result: 0\n",
      "model 4:\n",
      "Ones: 538\n",
      "Zeros: 9452\n",
      "Result: 0\n",
      "model 5:\n",
      "Ones: 628\n",
      "Zeros: 9362\n",
      "Result: 0\n",
      "model 6:\n",
      "Ones: 690\n",
      "Zeros: 9300\n",
      "Result: 0\n",
      "model 7:\n",
      "Ones: 898\n",
      "Zeros: 9092\n",
      "Result: 0\n",
      "model 8:\n",
      "Ones: 1412\n",
      "Zeros: 8578\n",
      "Result: 0\n",
      "model 9:\n",
      "Ones: 720\n",
      "Zeros: 9270\n",
      "Result: 0\n",
      "model 10:\n",
      "Ones: 0\n",
      "Zeros: 9990\n",
      "Result: 0\n",
      "model 11:\n",
      "Ones: 0\n",
      "Zeros: 9990\n",
      "Result: 0\n",
      "model 12:\n",
      "Ones: 0\n",
      "Zeros: 9990\n",
      "Result: 0\n",
      "model 13:\n",
      "Ones: 0\n",
      "Zeros: 9990\n",
      "Result: 0\n",
      "model 14:\n",
      "Ones: 0\n",
      "Zeros: 9990\n",
      "Result: 0\n",
      "model 15:\n",
      "Ones: 0\n",
      "Zeros: 9990\n",
      "Result: 0\n",
      "model 16:\n",
      "Ones: 0\n",
      "Zeros: 9990\n",
      "Result: 0\n",
      "model 17:\n",
      "Ones: 0\n",
      "Zeros: 9990\n",
      "Result: 0\n",
      "model 18:\n",
      "Ones: 0\n",
      "Zeros: 9990\n",
      "Result: 0\n",
      "model 19:\n",
      "Ones: 0\n",
      "Zeros: 9990\n",
      "Result: 0\n"
     ]
    }
   ],
   "source": [
    "arr = y_pred.reshape(-1, 9990) \n",
    "\n",
    "for i, group in enumerate(arr):\n",
    "    ones = np.sum(group == 1)\n",
    "    zeros = np.sum(group == 0)\n",
    "\n",
    "    print(f'model {i}:')\n",
    "    print(f'Ones: {ones}')\n",
    "    print(f'Zeros: {zeros}')\n",
    "\n",
    "    if ones > zeros :\n",
    "        print('Result: 1')\n",
    "    else:  \n",
    "        print('Result: 0')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d649c44b-5be8-4082-a859-e97a5ad4f259",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python [conda env:root]",
   "language": "python",
   "name": "conda-root-py"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
