{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Attacker model\n",
    "\n",
    "Here we will put it all togheter using the generated synthetic data from `data_synthesis_playground.ipynb`"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "from mblearn import AttackModels, ShadowModels"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "from sklearn.ensemble import RandomForestClassifier\n",
    "from sklearn.datasets import load_wine\n",
    "from sklearn.preprocessing import MinMaxScaler\n",
    "from sklearn.model_selection import train_test_split\n",
    "import pandas as pd\n",
    "import numpy as np"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "First we need to make a target model (we will use RandomForest with 100 estimators)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Target model\n",
    "\n",
    "We are going to use the wine datasetm which have 13 features and 3 classes"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "rf_target = RandomForestClassifier(n_estimators=100)\n",
    "data, target = load_wine(return_X_y=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "scaler = MinMaxScaler()\n",
    "data_std = scaler.fit_transform(data)\n",
    "\n",
    "# split to test membership in X_train\n",
    "X_train, X_test, y_train, y_test = train_test_split(data_std, target, test_size=0.4)\n",
    "\n",
    "rf_target.fit(X_train, y_train);"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Shadow model\n",
    "\n",
    "Now train the Shadow models with synthetic data and the same learner."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>0</th>\n",
       "      <th>1</th>\n",
       "      <th>2</th>\n",
       "      <th>3</th>\n",
       "      <th>4</th>\n",
       "      <th>5</th>\n",
       "      <th>6</th>\n",
       "      <th>7</th>\n",
       "      <th>8</th>\n",
       "      <th>9</th>\n",
       "      <th>10</th>\n",
       "      <th>11</th>\n",
       "      <th>12</th>\n",
       "      <th>label</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>0.648708</td>\n",
       "      <td>0.556606</td>\n",
       "      <td>0.475915</td>\n",
       "      <td>0.604599</td>\n",
       "      <td>0.296312</td>\n",
       "      <td>0.575440</td>\n",
       "      <td>0.590239</td>\n",
       "      <td>0.994838</td>\n",
       "      <td>0.260606</td>\n",
       "      <td>0.326840</td>\n",
       "      <td>0.892040</td>\n",
       "      <td>0.653399</td>\n",
       "      <td>0.449623</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>0.954101</td>\n",
       "      <td>0.608156</td>\n",
       "      <td>0.793287</td>\n",
       "      <td>0.375929</td>\n",
       "      <td>0.854361</td>\n",
       "      <td>0.659643</td>\n",
       "      <td>0.912113</td>\n",
       "      <td>0.117836</td>\n",
       "      <td>0.603246</td>\n",
       "      <td>0.645942</td>\n",
       "      <td>0.115053</td>\n",
       "      <td>0.860870</td>\n",
       "      <td>0.656279</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>0.618762</td>\n",
       "      <td>0.365778</td>\n",
       "      <td>0.232566</td>\n",
       "      <td>0.434179</td>\n",
       "      <td>0.800469</td>\n",
       "      <td>0.493136</td>\n",
       "      <td>0.386309</td>\n",
       "      <td>0.298686</td>\n",
       "      <td>0.182332</td>\n",
       "      <td>0.464039</td>\n",
       "      <td>0.449998</td>\n",
       "      <td>0.571479</td>\n",
       "      <td>0.949356</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>0.976243</td>\n",
       "      <td>0.462703</td>\n",
       "      <td>0.850762</td>\n",
       "      <td>0.347162</td>\n",
       "      <td>0.848853</td>\n",
       "      <td>0.698433</td>\n",
       "      <td>0.866704</td>\n",
       "      <td>0.552660</td>\n",
       "      <td>0.723015</td>\n",
       "      <td>0.372361</td>\n",
       "      <td>0.396304</td>\n",
       "      <td>0.549000</td>\n",
       "      <td>0.321202</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>0.765457</td>\n",
       "      <td>0.635224</td>\n",
       "      <td>0.525063</td>\n",
       "      <td>0.009490</td>\n",
       "      <td>0.581303</td>\n",
       "      <td>0.718589</td>\n",
       "      <td>0.749601</td>\n",
       "      <td>0.134644</td>\n",
       "      <td>0.388569</td>\n",
       "      <td>0.550524</td>\n",
       "      <td>0.408401</td>\n",
       "      <td>0.797959</td>\n",
       "      <td>0.753048</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "          0         1         2         3         4         5         6  \\\n",
       "0  0.648708  0.556606  0.475915  0.604599  0.296312  0.575440  0.590239   \n",
       "1  0.954101  0.608156  0.793287  0.375929  0.854361  0.659643  0.912113   \n",
       "2  0.618762  0.365778  0.232566  0.434179  0.800469  0.493136  0.386309   \n",
       "3  0.976243  0.462703  0.850762  0.347162  0.848853  0.698433  0.866704   \n",
       "4  0.765457  0.635224  0.525063  0.009490  0.581303  0.718589  0.749601   \n",
       "\n",
       "          7         8         9        10        11        12  label  \n",
       "0  0.994838  0.260606  0.326840  0.892040  0.653399  0.449623      0  \n",
       "1  0.117836  0.603246  0.645942  0.115053  0.860870  0.656279      0  \n",
       "2  0.298686  0.182332  0.464039  0.449998  0.571479  0.949356      0  \n",
       "3  0.552660  0.723015  0.372361  0.396304  0.549000  0.321202      0  \n",
       "4  0.134644  0.388569  0.550524  0.408401  0.797959  0.753048      0  "
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "synth_data = pd.read_csv('synthetic_data.csv')\n",
    "synth_data.head()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "rf_shadow =  RandomForestClassifier(n_estimators=100)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "x = synth_data.iloc[:, :-1].to_numpy()\n",
    "y = synth_data['label'].to_numpy()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "using sklearn shadow models\n"
     ]
    },
    {
     "ename": "ImportError",
     "evalue": "IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mImportError\u001b[0m                               Traceback (most recent call last)",
      "Cell \u001b[0;32mIn[12], line 1\u001b[0m\n\u001b[0;32m----> 1\u001b[0m sh \u001b[38;5;241m=\u001b[39m ShadowModels(x, y, \u001b[38;5;241m5\u001b[39m, \u001b[38;5;241m3\u001b[39m, rf_shadow)\n",
      "File \u001b[0;32m~/Reprogramming/notebooks/mblearn/shadow_model.py:76\u001b[0m, in \u001b[0;36mShadowModels.__init__\u001b[0;34m(self, X, y, n_models, target_classes, learner, **fit_kwargs)\u001b[0m\n\u001b[1;32m     73\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mmodels \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_make_model_list(\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mlearner, \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mn_models)\n\u001b[1;32m     75\u001b[0m \u001b[38;5;66;03m# train models\u001b[39;00m\n\u001b[0;32m---> 76\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mresults \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mtrain_predict_shadows(\u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mfit_kwargs)\n",
      "File \u001b[0;32m~/Reprogramming/notebooks/mblearn/shadow_model.py:145\u001b[0m, in \u001b[0;36mShadowModels.train_predict_shadows\u001b[0;34m(self, **fit_kwargs)\u001b[0m\n\u001b[1;32m    143\u001b[0m \u001b[38;5;66;03m# TRAIN and predict\u001b[39;00m\n\u001b[1;32m    144\u001b[0m results \u001b[38;5;241m=\u001b[39m []\n\u001b[0;32m--> 145\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m model, data_subset \u001b[38;5;129;01min\u001b[39;00m tqdm_notebook(\u001b[38;5;28mzip\u001b[39m(\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mmodels, \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_splits)):\n\u001b[1;32m    146\u001b[0m     X \u001b[38;5;241m=\u001b[39m data_subset[:, :\u001b[38;5;241m-\u001b[39m\u001b[38;5;241m1\u001b[39m]\n\u001b[1;32m    147\u001b[0m     y \u001b[38;5;241m=\u001b[39m data_subset[:, \u001b[38;5;241m-\u001b[39m\u001b[38;5;241m1\u001b[39m]\n",
      "File \u001b[0;32m~/.conda/envs/TomTestEnv/lib/python3.11/site-packages/tqdm/__init__.py:28\u001b[0m, in \u001b[0;36mtqdm_notebook\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m     24\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mnotebook\u001b[39;00m \u001b[38;5;28;01mimport\u001b[39;00m tqdm \u001b[38;5;28;01mas\u001b[39;00m _tqdm_notebook\n\u001b[1;32m     25\u001b[0m warn(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mThis function will be removed in tqdm==5.0.0\u001b[39m\u001b[38;5;130;01m\\n\u001b[39;00m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m     26\u001b[0m      \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mPlease use `tqdm.notebook.tqdm` instead of `tqdm.tqdm_notebook`\u001b[39m\u001b[38;5;124m\"\u001b[39m,\n\u001b[1;32m     27\u001b[0m      TqdmDeprecationWarning, stacklevel\u001b[38;5;241m=\u001b[39m\u001b[38;5;241m2\u001b[39m)\n\u001b[0;32m---> 28\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m _tqdm_notebook(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs)\n",
      "File \u001b[0;32m~/.conda/envs/TomTestEnv/lib/python3.11/site-packages/tqdm/notebook.py:238\u001b[0m, in \u001b[0;36mtqdm_notebook.__init__\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m    236\u001b[0m unit_scale \u001b[38;5;241m=\u001b[39m \u001b[38;5;241m1\u001b[39m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39munit_scale \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;28;01mTrue\u001b[39;00m \u001b[38;5;28;01melse\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39munit_scale \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;241m1\u001b[39m\n\u001b[1;32m    237\u001b[0m total \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mtotal \u001b[38;5;241m*\u001b[39m unit_scale \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mtotal \u001b[38;5;28;01melse\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mtotal\n\u001b[0;32m--> 238\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mcontainer \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mstatus_printer(\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mfp, total, \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mdesc, \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mncols)\n\u001b[1;32m    239\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mcontainer\u001b[38;5;241m.\u001b[39mpbar \u001b[38;5;241m=\u001b[39m proxy(\u001b[38;5;28mself\u001b[39m)\n\u001b[1;32m    240\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mdisplayed \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mFalse\u001b[39;00m\n",
      "File \u001b[0;32m~/.conda/envs/TomTestEnv/lib/python3.11/site-packages/tqdm/notebook.py:113\u001b[0m, in \u001b[0;36mtqdm_notebook.status_printer\u001b[0;34m(_, total, desc, ncols)\u001b[0m\n\u001b[1;32m    104\u001b[0m \u001b[38;5;66;03m# Fallback to text bar if there's no total\u001b[39;00m\n\u001b[1;32m    105\u001b[0m \u001b[38;5;66;03m# DEPRECATED: replaced with an 'info' style bar\u001b[39;00m\n\u001b[1;32m    106\u001b[0m \u001b[38;5;66;03m# if not total:\u001b[39;00m\n\u001b[0;32m   (...)\u001b[0m\n\u001b[1;32m    110\u001b[0m \n\u001b[1;32m    111\u001b[0m \u001b[38;5;66;03m# Prepare IPython progress bar\u001b[39;00m\n\u001b[1;32m    112\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m IProgress \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:  \u001b[38;5;66;03m# #187 #451 #558 #872\u001b[39;00m\n\u001b[0;32m--> 113\u001b[0m     \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mImportError\u001b[39;00m(WARN_NOIPYW)\n\u001b[1;32m    114\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m total:\n\u001b[1;32m    115\u001b[0m     pbar \u001b[38;5;241m=\u001b[39m IProgress(\u001b[38;5;28mmin\u001b[39m\u001b[38;5;241m=\u001b[39m\u001b[38;5;241m0\u001b[39m, \u001b[38;5;28mmax\u001b[39m\u001b[38;5;241m=\u001b[39mtotal)\n",
      "\u001b[0;31mImportError\u001b[0m: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html"
     ]
    }
   ],
   "source": [
    "sh = ShadowModels(x, y, 5, 3, rf_shadow)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "ename": "NameError",
     "evalue": "name 'sh' is not defined",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mNameError\u001b[0m                                 Traceback (most recent call last)",
      "Cell \u001b[0;32mIn[11], line 1\u001b[0m\n\u001b[0;32m----> 1\u001b[0m shadow_data \u001b[38;5;241m=\u001b[39m sh\u001b[38;5;241m.\u001b[39mresults\n",
      "\u001b[0;31mNameError\u001b[0m: name 'sh' is not defined"
     ]
    }
   ],
   "source": [
    "shadow_data = sh.results"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Attacker"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Now that we have the shadow dataset we can train the attacker model. The attacker learner doesn't need to be the same as the target so pick the one that performs the best.\n",
    "\n",
    "`AttackModels` trains a model for each original class in shadow data (and in target model) with the in/out of training label as the target label."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [],
   "source": [
    "rf_attack = RandomForestClassifier(n_estimators=100)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [],
   "source": [
    "attacker = AttackModels(target_classes=3, attack_learner=rf_attack)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [],
   "source": [
    "attacker.fit(shadow_data)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "now lets test with all `X_train` and `X_test`. 50/50 "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [],
   "source": [
    "train_idx = np.random.choice(np.arange(len(X_train)), len(X_test))\n",
    "X_train = X_train[train_idx]\n",
    "y_train = y_train[train_idx]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "72 72\n"
     ]
    }
   ],
   "source": [
    "print(len(X_train), len(X_test))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {},
   "outputs": [],
   "source": [
    "X_in = rf_target.predict_proba(X_train)\n",
    "res_in = attacker.predict(X_in, y_train, batch=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "metadata": {},
   "outputs": [],
   "source": [
    "X_out = rf_target.predict_proba(X_test)\n",
    "res_out = attacker.predict(X_out, y_test, batch=True)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Some metrics"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0.9583333333333334"
      ]
     },
     "execution_count": 23,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "np.sum(np.argmax(res_in, axis=1)) / len(res_in)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0.1527777777777778"
      ]
     },
     "execution_count": 24,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "1 - np.sum(np.argmax(res_out, axis=1)) / len(res_out)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Precision, Recall and F-1 \n",
    "since the class balance is 50/50 a dumb classifier will achieve 0.5 precision, 1 recall and 0.67 f-1"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "metadata": {},
   "outputs": [],
   "source": [
    "from sklearn.metrics import precision_score, recall_score, f1_score"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "metadata": {},
   "outputs": [],
   "source": [
    "y_pred = np.concatenate((np.argmax(res_in, axis=1), np.argmax(res_out, axis=1)))\n",
    "y_true = np.concatenate((np.ones_like(y_train), np.zeros_like(y_test)))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 27,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0.5307692307692308"
      ]
     },
     "execution_count": 27,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "precision_score(y_true, y_pred)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 28,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0.9583333333333334"
      ]
     },
     "execution_count": 28,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "recall_score(y_true, y_pred)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 29,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0.6831683168316831"
      ]
     },
     "execution_count": 29,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "f1_score(y_true, y_pred)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Not bad for an out of the box setup"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python [conda env:.conda-TomTestEnv] *",
   "language": "python",
   "name": "conda-env-.conda-TomTestEnv-py"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.0"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
