{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Attacker model\n",
    "\n",
    "Here we will put it all togheter using the generated synthetic data from `data_synthesis_playground.ipynb`"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "from mblearn import AttackModels, ShadowModels"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "from sklearn.ensemble import RandomForestClassifier\n",
    "from sklearn.datasets import load_wine\n",
    "from sklearn.preprocessing import MinMaxScaler\n",
    "from sklearn.model_selection import train_test_split\n",
    "import pandas as pd\n",
    "import numpy as np"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "First we need to make a target model (we will use RandomForest with 100 estimators)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Target model\n",
    "\n",
    "We are going to use the wine datasetm which have 13 features and 3 classes"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "rf_target = RandomForestClassifier(n_estimators=100)\n",
    "data, target = load_wine(return_X_y=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "scaler = MinMaxScaler()\n",
    "data_std = scaler.fit_transform(data)\n",
    "\n",
    "# split to test membership in X_train\n",
    "X_train, X_test, y_train, y_test = train_test_split(data_std, target, test_size=0.4)\n",
    "\n",
    "rf_target.fit(X_train, y_train);"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Shadow model\n",
    "\n",
    "Now train the Shadow models with synthetic data and the same learner."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>0</th>\n",
       "      <th>1</th>\n",
       "      <th>2</th>\n",
       "      <th>3</th>\n",
       "      <th>4</th>\n",
       "      <th>5</th>\n",
       "      <th>6</th>\n",
       "      <th>7</th>\n",
       "      <th>8</th>\n",
       "      <th>9</th>\n",
       "      <th>10</th>\n",
       "      <th>11</th>\n",
       "      <th>12</th>\n",
       "      <th>label</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>0.648708</td>\n",
       "      <td>0.556606</td>\n",
       "      <td>0.475915</td>\n",
       "      <td>0.604599</td>\n",
       "      <td>0.296312</td>\n",
       "      <td>0.575440</td>\n",
       "      <td>0.590239</td>\n",
       "      <td>0.994838</td>\n",
       "      <td>0.260606</td>\n",
       "      <td>0.326840</td>\n",
       "      <td>0.892040</td>\n",
       "      <td>0.653399</td>\n",
       "      <td>0.449623</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>0.954101</td>\n",
       "      <td>0.608156</td>\n",
       "      <td>0.793287</td>\n",
       "      <td>0.375929</td>\n",
       "      <td>0.854361</td>\n",
       "      <td>0.659643</td>\n",
       "      <td>0.912113</td>\n",
       "      <td>0.117836</td>\n",
       "      <td>0.603246</td>\n",
       "      <td>0.645942</td>\n",
       "      <td>0.115053</td>\n",
       "      <td>0.860870</td>\n",
       "      <td>0.656279</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>0.618762</td>\n",
       "      <td>0.365778</td>\n",
       "      <td>0.232566</td>\n",
       "      <td>0.434179</td>\n",
       "      <td>0.800469</td>\n",
       "      <td>0.493136</td>\n",
       "      <td>0.386309</td>\n",
       "      <td>0.298686</td>\n",
       "      <td>0.182332</td>\n",
       "      <td>0.464039</td>\n",
       "      <td>0.449998</td>\n",
       "      <td>0.571479</td>\n",
       "      <td>0.949356</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>0.976243</td>\n",
       "      <td>0.462703</td>\n",
       "      <td>0.850762</td>\n",
       "      <td>0.347162</td>\n",
       "      <td>0.848853</td>\n",
       "      <td>0.698433</td>\n",
       "      <td>0.866704</td>\n",
       "      <td>0.552660</td>\n",
       "      <td>0.723015</td>\n",
       "      <td>0.372361</td>\n",
       "      <td>0.396304</td>\n",
       "      <td>0.549000</td>\n",
       "      <td>0.321202</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>0.765457</td>\n",
       "      <td>0.635224</td>\n",
       "      <td>0.525063</td>\n",
       "      <td>0.009490</td>\n",
       "      <td>0.581303</td>\n",
       "      <td>0.718589</td>\n",
       "      <td>0.749601</td>\n",
       "      <td>0.134644</td>\n",
       "      <td>0.388569</td>\n",
       "      <td>0.550524</td>\n",
       "      <td>0.408401</td>\n",
       "      <td>0.797959</td>\n",
       "      <td>0.753048</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "          0         1         2         3         4         5         6  \\\n",
       "0  0.648708  0.556606  0.475915  0.604599  0.296312  0.575440  0.590239   \n",
       "1  0.954101  0.608156  0.793287  0.375929  0.854361  0.659643  0.912113   \n",
       "2  0.618762  0.365778  0.232566  0.434179  0.800469  0.493136  0.386309   \n",
       "3  0.976243  0.462703  0.850762  0.347162  0.848853  0.698433  0.866704   \n",
       "4  0.765457  0.635224  0.525063  0.009490  0.581303  0.718589  0.749601   \n",
       "\n",
       "          7         8         9        10        11        12  label  \n",
       "0  0.994838  0.260606  0.326840  0.892040  0.653399  0.449623      0  \n",
       "1  0.117836  0.603246  0.645942  0.115053  0.860870  0.656279      0  \n",
       "2  0.298686  0.182332  0.464039  0.449998  0.571479  0.949356      0  \n",
       "3  0.552660  0.723015  0.372361  0.396304  0.549000  0.321202      0  \n",
       "4  0.134644  0.388569  0.550524  0.408401  0.797959  0.753048      0  "
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "synth_data = pd.read_csv('synthetic_data.csv')\n",
    "synth_data.head()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "rf_shadow =  RandomForestClassifier(n_estimators=100)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [],
   "source": [
    "x = synth_data.iloc[:, :-1].to_numpy()\n",
    "y = synth_data['label'].to_numpy()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "using sklearn shadow models\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "a716f52faa4a475eabbc09adc24843c4",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    }
   ],
   "source": [
    "sh = ShadowModels(x, y, 5, 3, rf_shadow)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [],
   "source": [
    "shadow_data = sh.results"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Attacker"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Now that we have the shadow dataset we can train the attacker model. The attacker learner doesn't need to be the same as the target so pick the one that performs the best.\n",
    "\n",
    "`AttackModels` trains a model for each original class in shadow data (and in target model) with the in/out of training label as the target label."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [],
   "source": [
    "rf_attack = RandomForestClassifier(n_estimators=100)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [],
   "source": [
    "attacker = AttackModels(target_classes=3, attack_learner=rf_attack)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [],
   "source": [
    "attacker.fit(shadow_data)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "now lets test with all `X_train` and `X_test`. 50/50 "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [],
   "source": [
    "train_idx = np.random.choice(np.arange(len(X_train)), len(X_test))\n",
    "X_train = X_train[train_idx]\n",
    "y_train = y_train[train_idx]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "72 72\n"
     ]
    }
   ],
   "source": [
    "print(len(X_train), len(X_test))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {},
   "outputs": [],
   "source": [
    "X_in = rf_target.predict_proba(X_train)\n",
    "res_in = attacker.predict(X_in, y_train, batch=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "metadata": {},
   "outputs": [],
   "source": [
    "X_out = rf_target.predict_proba(X_test)\n",
    "res_out = attacker.predict(X_out, y_test, batch=True)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Some metrics"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0.9583333333333334"
      ]
     },
     "execution_count": 23,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "np.sum(np.argmax(res_in, axis=1)) / len(res_in)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0.1527777777777778"
      ]
     },
     "execution_count": 24,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "1 - np.sum(np.argmax(res_out, axis=1)) / len(res_out)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Precision, Recall and F-1 \n",
    "since the class balance is 50/50 a dumb classifier will achieve 0.5 precision, 1 recall and 0.67 f-1"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "metadata": {},
   "outputs": [],
   "source": [
    "from sklearn.metrics import precision_score, recall_score, f1_score"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "metadata": {},
   "outputs": [],
   "source": [
    "y_pred = np.concatenate((np.argmax(res_in, axis=1), np.argmax(res_out, axis=1)))\n",
    "y_true = np.concatenate((np.ones_like(y_train), np.zeros_like(y_test)))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 27,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0.5307692307692308"
      ]
     },
     "execution_count": 27,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "precision_score(y_true, y_pred)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 28,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0.9583333333333334"
      ]
     },
     "execution_count": 28,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "recall_score(y_true, y_pred)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 29,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0.6831683168316831"
      ]
     },
     "execution_count": 29,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "f1_score(y_true, y_pred)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Not bad for an out of the box setup"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.1"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
