{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#  Shadow Models\n",
    "\n",
    "The adversary trains $k$ shadow models $f^i_{shadow}()$, each on a dataset (shadow data) that is similar\n",
    "in format and distribution as the __target model's private training set__.\n",
    "\n",
    "Generate shadow data drawing samples from the population where the target model's training data\n",
    "are drawn from. The shadow model must be __trained in a similar fashion as the target model__.\n",
    "The larger the number of shadow models, the better the accuracy of the attack model."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "from mblearn import ShadowModels\n",
    "\n",
    "from sklearn.ensemble import RandomForestClassifier\n",
    "\n",
    "import pandas as pd"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We are going to use the synthetic data generated in `data_synthesis_playground.ipynb`"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "synth_data = pd.read_csv('synthetic_irir.csv')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "(150, 5)\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>0</th>\n",
       "      <th>1</th>\n",
       "      <th>2</th>\n",
       "      <th>3</th>\n",
       "      <th>label</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>0.495151</td>\n",
       "      <td>0.471551</td>\n",
       "      <td>0.113910</td>\n",
       "      <td>0.047197</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>0.083423</td>\n",
       "      <td>0.210501</td>\n",
       "      <td>0.006217</td>\n",
       "      <td>0.152407</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>0.694754</td>\n",
       "      <td>0.832320</td>\n",
       "      <td>0.222306</td>\n",
       "      <td>0.176190</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>0.620594</td>\n",
       "      <td>0.812440</td>\n",
       "      <td>0.244068</td>\n",
       "      <td>0.134671</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>0.625834</td>\n",
       "      <td>0.568340</td>\n",
       "      <td>0.177182</td>\n",
       "      <td>0.100158</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "          0         1         2         3  label\n",
       "0  0.495151  0.471551  0.113910  0.047197      0\n",
       "1  0.083423  0.210501  0.006217  0.152407      0\n",
       "2  0.694754  0.832320  0.222306  0.176190      0\n",
       "3  0.620594  0.812440  0.244068  0.134671      0\n",
       "4  0.625834  0.568340  0.177182  0.100158      0"
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "print(synth_data.shape)\n",
    "synth_data.head()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Since we know which learner is used in the target model we are going to use the same:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "shadow_learner =  RandomForestClassifier(n_estimators=100)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "x = synth_data.iloc[:, :-1].to_numpy()\n",
    "y = synth_data['label'].to_numpy()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "using sklearn shadow models\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "174f02720f414f268d77fc002bab9865",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    }
   ],
   "source": [
    "sh = ShadowModels(x, y, n_models=5, target_classes=3, learner=shadow_learner)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Shadow models: 5, <class 'sklearn.ensemble.forest.RandomForestClassifier'>\n",
       "lengths of data splits : [30, 30, 30, 30, 30]"
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "sh"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The shadow models results are in `results` attribute. The array structure is the following:\n",
    "\n",
    "+ Last column `[:, -1]` specifies if the record was in a training (1) or test (0) set of a shadow model\n",
    "+ Second last column `[:, -2]` is the label of the synthetic data.\n",
    "+ The remanining columns are the predicted class probability vector of the record satisfying: `len(result[:, :-2]) == target_classes`"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([[0.  , 0.1 , 0.9 , 2.  , 1.  ],\n",
       "       [0.92, 0.08, 0.  , 0.  , 1.  ],\n",
       "       [0.02, 0.98, 0.  , 1.  , 1.  ],\n",
       "       [0.17, 0.81, 0.02, 1.  , 1.  ],\n",
       "       [0.  , 1.  , 0.  , 1.  , 1.  ]])"
      ]
     },
     "execution_count": 8,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "sh.results[:5]"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.1"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
