{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "6ae30370",
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "import sklearn.metrics as metrics\n",
    "import tensorflow as tf\n",
    "import pandas as pd\n",
    "\n",
    "from sklearn.model_selection import train_test_split\n",
    "from sklearn.preprocessing import StandardScaler\n",
    "\n",
    "\n",
    "from SupervisedAD_methods import *\n",
    "from kdd import *"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "optical-league",
   "metadata": {},
   "source": [
    "# Data Preprocessing"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "complete-bobby",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>duration</th>\n",
       "      <th>protocol_type</th>\n",
       "      <th>service</th>\n",
       "      <th>flag</th>\n",
       "      <th>src_bytes</th>\n",
       "      <th>dst_bytes</th>\n",
       "      <th>land</th>\n",
       "      <th>wrong_fragment</th>\n",
       "      <th>urgent</th>\n",
       "      <th>hot</th>\n",
       "      <th>...</th>\n",
       "      <th>dst_host_same_srv_rate</th>\n",
       "      <th>dst_host_diff_srv_rate</th>\n",
       "      <th>dst_host_same_src_port_rate</th>\n",
       "      <th>dst_host_srv_diff_host_rate</th>\n",
       "      <th>dst_host_serror_rate</th>\n",
       "      <th>dst_host_srv_serror_rate</th>\n",
       "      <th>dst_host_rerror_rate</th>\n",
       "      <th>dst_host_srv_rerror_rate</th>\n",
       "      <th>attack</th>\n",
       "      <th>level</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>0</td>\n",
       "      <td>tcp</td>\n",
       "      <td>ftp_data</td>\n",
       "      <td>SF</td>\n",
       "      <td>491</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>...</td>\n",
       "      <td>0.17</td>\n",
       "      <td>0.03</td>\n",
       "      <td>0.17</td>\n",
       "      <td>0.00</td>\n",
       "      <td>0.00</td>\n",
       "      <td>0.00</td>\n",
       "      <td>0.05</td>\n",
       "      <td>0.00</td>\n",
       "      <td>normal</td>\n",
       "      <td>20</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>0</td>\n",
       "      <td>udp</td>\n",
       "      <td>other</td>\n",
       "      <td>SF</td>\n",
       "      <td>146</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>...</td>\n",
       "      <td>0.00</td>\n",
       "      <td>0.60</td>\n",
       "      <td>0.88</td>\n",
       "      <td>0.00</td>\n",
       "      <td>0.00</td>\n",
       "      <td>0.00</td>\n",
       "      <td>0.00</td>\n",
       "      <td>0.00</td>\n",
       "      <td>normal</td>\n",
       "      <td>15</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>0</td>\n",
       "      <td>tcp</td>\n",
       "      <td>private</td>\n",
       "      <td>S0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>...</td>\n",
       "      <td>0.10</td>\n",
       "      <td>0.05</td>\n",
       "      <td>0.00</td>\n",
       "      <td>0.00</td>\n",
       "      <td>1.00</td>\n",
       "      <td>1.00</td>\n",
       "      <td>0.00</td>\n",
       "      <td>0.00</td>\n",
       "      <td>neptune</td>\n",
       "      <td>19</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>0</td>\n",
       "      <td>tcp</td>\n",
       "      <td>http</td>\n",
       "      <td>SF</td>\n",
       "      <td>232</td>\n",
       "      <td>8153</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>...</td>\n",
       "      <td>1.00</td>\n",
       "      <td>0.00</td>\n",
       "      <td>0.03</td>\n",
       "      <td>0.04</td>\n",
       "      <td>0.03</td>\n",
       "      <td>0.01</td>\n",
       "      <td>0.00</td>\n",
       "      <td>0.01</td>\n",
       "      <td>normal</td>\n",
       "      <td>21</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>0</td>\n",
       "      <td>tcp</td>\n",
       "      <td>http</td>\n",
       "      <td>SF</td>\n",
       "      <td>199</td>\n",
       "      <td>420</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>...</td>\n",
       "      <td>1.00</td>\n",
       "      <td>0.00</td>\n",
       "      <td>0.00</td>\n",
       "      <td>0.00</td>\n",
       "      <td>0.00</td>\n",
       "      <td>0.00</td>\n",
       "      <td>0.00</td>\n",
       "      <td>0.00</td>\n",
       "      <td>normal</td>\n",
       "      <td>21</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>...</th>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>125968</th>\n",
       "      <td>0</td>\n",
       "      <td>tcp</td>\n",
       "      <td>private</td>\n",
       "      <td>S0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>...</td>\n",
       "      <td>0.10</td>\n",
       "      <td>0.06</td>\n",
       "      <td>0.00</td>\n",
       "      <td>0.00</td>\n",
       "      <td>1.00</td>\n",
       "      <td>1.00</td>\n",
       "      <td>0.00</td>\n",
       "      <td>0.00</td>\n",
       "      <td>neptune</td>\n",
       "      <td>20</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>125969</th>\n",
       "      <td>8</td>\n",
       "      <td>udp</td>\n",
       "      <td>private</td>\n",
       "      <td>SF</td>\n",
       "      <td>105</td>\n",
       "      <td>145</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>...</td>\n",
       "      <td>0.96</td>\n",
       "      <td>0.01</td>\n",
       "      <td>0.01</td>\n",
       "      <td>0.00</td>\n",
       "      <td>0.00</td>\n",
       "      <td>0.00</td>\n",
       "      <td>0.00</td>\n",
       "      <td>0.00</td>\n",
       "      <td>normal</td>\n",
       "      <td>21</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>125970</th>\n",
       "      <td>0</td>\n",
       "      <td>tcp</td>\n",
       "      <td>smtp</td>\n",
       "      <td>SF</td>\n",
       "      <td>2231</td>\n",
       "      <td>384</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>...</td>\n",
       "      <td>0.12</td>\n",
       "      <td>0.06</td>\n",
       "      <td>0.00</td>\n",
       "      <td>0.00</td>\n",
       "      <td>0.72</td>\n",
       "      <td>0.00</td>\n",
       "      <td>0.01</td>\n",
       "      <td>0.00</td>\n",
       "      <td>normal</td>\n",
       "      <td>18</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>125971</th>\n",
       "      <td>0</td>\n",
       "      <td>tcp</td>\n",
       "      <td>klogin</td>\n",
       "      <td>S0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>...</td>\n",
       "      <td>0.03</td>\n",
       "      <td>0.05</td>\n",
       "      <td>0.00</td>\n",
       "      <td>0.00</td>\n",
       "      <td>1.00</td>\n",
       "      <td>1.00</td>\n",
       "      <td>0.00</td>\n",
       "      <td>0.00</td>\n",
       "      <td>neptune</td>\n",
       "      <td>20</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>125972</th>\n",
       "      <td>0</td>\n",
       "      <td>tcp</td>\n",
       "      <td>ftp_data</td>\n",
       "      <td>SF</td>\n",
       "      <td>151</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>...</td>\n",
       "      <td>0.30</td>\n",
       "      <td>0.03</td>\n",
       "      <td>0.30</td>\n",
       "      <td>0.00</td>\n",
       "      <td>0.00</td>\n",
       "      <td>0.00</td>\n",
       "      <td>0.00</td>\n",
       "      <td>0.00</td>\n",
       "      <td>normal</td>\n",
       "      <td>21</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "<p>125973 rows \u00c3\u2014 43 columns</p>\n",
       "</div>"
      ],
      "text/plain": [
       "        duration protocol_type   service flag  src_bytes  dst_bytes  land  \\\n",
       "0              0           tcp  ftp_data   SF        491          0     0   \n",
       "1              0           udp     other   SF        146          0     0   \n",
       "2              0           tcp   private   S0          0          0     0   \n",
       "3              0           tcp      http   SF        232       8153     0   \n",
       "4              0           tcp      http   SF        199        420     0   \n",
       "...          ...           ...       ...  ...        ...        ...   ...   \n",
       "125968         0           tcp   private   S0          0          0     0   \n",
       "125969         8           udp   private   SF        105        145     0   \n",
       "125970         0           tcp      smtp   SF       2231        384     0   \n",
       "125971         0           tcp    klogin   S0          0          0     0   \n",
       "125972         0           tcp  ftp_data   SF        151          0     0   \n",
       "\n",
       "        wrong_fragment  urgent  hot  ...  dst_host_same_srv_rate  \\\n",
       "0                    0       0    0  ...                    0.17   \n",
       "1                    0       0    0  ...                    0.00   \n",
       "2                    0       0    0  ...                    0.10   \n",
       "3                    0       0    0  ...                    1.00   \n",
       "4                    0       0    0  ...                    1.00   \n",
       "...                ...     ...  ...  ...                     ...   \n",
       "125968               0       0    0  ...                    0.10   \n",
       "125969               0       0    0  ...                    0.96   \n",
       "125970               0       0    0  ...                    0.12   \n",
       "125971               0       0    0  ...                    0.03   \n",
       "125972               0       0    0  ...                    0.30   \n",
       "\n",
       "        dst_host_diff_srv_rate  dst_host_same_src_port_rate  \\\n",
       "0                         0.03                         0.17   \n",
       "1                         0.60                         0.88   \n",
       "2                         0.05                         0.00   \n",
       "3                         0.00                         0.03   \n",
       "4                         0.00                         0.00   \n",
       "...                        ...                          ...   \n",
       "125968                    0.06                         0.00   \n",
       "125969                    0.01                         0.01   \n",
       "125970                    0.06                         0.00   \n",
       "125971                    0.05                         0.00   \n",
       "125972                    0.03                         0.30   \n",
       "\n",
       "        dst_host_srv_diff_host_rate  dst_host_serror_rate  \\\n",
       "0                              0.00                  0.00   \n",
       "1                              0.00                  0.00   \n",
       "2                              0.00                  1.00   \n",
       "3                              0.04                  0.03   \n",
       "4                              0.00                  0.00   \n",
       "...                             ...                   ...   \n",
       "125968                         0.00                  1.00   \n",
       "125969                         0.00                  0.00   \n",
       "125970                         0.00                  0.72   \n",
       "125971                         0.00                  1.00   \n",
       "125972                         0.00                  0.00   \n",
       "\n",
       "        dst_host_srv_serror_rate  dst_host_rerror_rate  \\\n",
       "0                           0.00                  0.05   \n",
       "1                           0.00                  0.00   \n",
       "2                           1.00                  0.00   \n",
       "3                           0.01                  0.00   \n",
       "4                           0.00                  0.00   \n",
       "...                          ...                   ...   \n",
       "125968                      1.00                  0.00   \n",
       "125969                      0.00                  0.00   \n",
       "125970                      0.00                  0.01   \n",
       "125971                      1.00                  0.00   \n",
       "125972                      0.00                  0.00   \n",
       "\n",
       "        dst_host_srv_rerror_rate   attack  level  \n",
       "0                           0.00   normal     20  \n",
       "1                           0.00   normal     15  \n",
       "2                           0.00  neptune     19  \n",
       "3                           0.01   normal     21  \n",
       "4                           0.00   normal     21  \n",
       "...                          ...      ...    ...  \n",
       "125968                      0.00  neptune     20  \n",
       "125969                      0.00   normal     21  \n",
       "125970                      0.00   normal     18  \n",
       "125971                      0.00  neptune     20  \n",
       "125972                      0.00   normal     21  \n",
       "\n",
       "[125973 rows x 43 columns]"
      ]
     },
     "execution_count": 2,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "df = get_df('data/KDDTrain+.txt', columns=columns, drop=False)\n",
    "df"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "vulnerable-occupation",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>duration</th>\n",
       "      <th>protocol_type</th>\n",
       "      <th>service</th>\n",
       "      <th>flag</th>\n",
       "      <th>src_bytes</th>\n",
       "      <th>dst_bytes</th>\n",
       "      <th>land</th>\n",
       "      <th>wrong_fragment</th>\n",
       "      <th>urgent</th>\n",
       "      <th>hot</th>\n",
       "      <th>...</th>\n",
       "      <th>dst_host_same_srv_rate</th>\n",
       "      <th>dst_host_diff_srv_rate</th>\n",
       "      <th>dst_host_same_src_port_rate</th>\n",
       "      <th>dst_host_srv_diff_host_rate</th>\n",
       "      <th>dst_host_serror_rate</th>\n",
       "      <th>dst_host_srv_serror_rate</th>\n",
       "      <th>dst_host_rerror_rate</th>\n",
       "      <th>dst_host_srv_rerror_rate</th>\n",
       "      <th>attack</th>\n",
       "      <th>level</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>0</td>\n",
       "      <td>tcp</td>\n",
       "      <td>private</td>\n",
       "      <td>REJ</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>...</td>\n",
       "      <td>0.04</td>\n",
       "      <td>0.06</td>\n",
       "      <td>0.00</td>\n",
       "      <td>0.00</td>\n",
       "      <td>0.00</td>\n",
       "      <td>0.0</td>\n",
       "      <td>1.00</td>\n",
       "      <td>1.00</td>\n",
       "      <td>neptune</td>\n",
       "      <td>21</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>0</td>\n",
       "      <td>tcp</td>\n",
       "      <td>private</td>\n",
       "      <td>REJ</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>...</td>\n",
       "      <td>0.00</td>\n",
       "      <td>0.06</td>\n",
       "      <td>0.00</td>\n",
       "      <td>0.00</td>\n",
       "      <td>0.00</td>\n",
       "      <td>0.0</td>\n",
       "      <td>1.00</td>\n",
       "      <td>1.00</td>\n",
       "      <td>neptune</td>\n",
       "      <td>21</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>2</td>\n",
       "      <td>tcp</td>\n",
       "      <td>ftp_data</td>\n",
       "      <td>SF</td>\n",
       "      <td>12983</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>...</td>\n",
       "      <td>0.61</td>\n",
       "      <td>0.04</td>\n",
       "      <td>0.61</td>\n",
       "      <td>0.02</td>\n",
       "      <td>0.00</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.00</td>\n",
       "      <td>0.00</td>\n",
       "      <td>normal</td>\n",
       "      <td>21</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>0</td>\n",
       "      <td>icmp</td>\n",
       "      <td>eco_i</td>\n",
       "      <td>SF</td>\n",
       "      <td>20</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>...</td>\n",
       "      <td>1.00</td>\n",
       "      <td>0.00</td>\n",
       "      <td>1.00</td>\n",
       "      <td>0.28</td>\n",
       "      <td>0.00</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.00</td>\n",
       "      <td>0.00</td>\n",
       "      <td>saint</td>\n",
       "      <td>15</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>1</td>\n",
       "      <td>tcp</td>\n",
       "      <td>telnet</td>\n",
       "      <td>RSTO</td>\n",
       "      <td>0</td>\n",
       "      <td>15</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>...</td>\n",
       "      <td>0.31</td>\n",
       "      <td>0.17</td>\n",
       "      <td>0.03</td>\n",
       "      <td>0.02</td>\n",
       "      <td>0.00</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.83</td>\n",
       "      <td>0.71</td>\n",
       "      <td>mscan</td>\n",
       "      <td>11</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>...</th>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>22539</th>\n",
       "      <td>0</td>\n",
       "      <td>tcp</td>\n",
       "      <td>smtp</td>\n",
       "      <td>SF</td>\n",
       "      <td>794</td>\n",
       "      <td>333</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>...</td>\n",
       "      <td>0.72</td>\n",
       "      <td>0.06</td>\n",
       "      <td>0.01</td>\n",
       "      <td>0.01</td>\n",
       "      <td>0.01</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.00</td>\n",
       "      <td>0.00</td>\n",
       "      <td>normal</td>\n",
       "      <td>21</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>22540</th>\n",
       "      <td>0</td>\n",
       "      <td>tcp</td>\n",
       "      <td>http</td>\n",
       "      <td>SF</td>\n",
       "      <td>317</td>\n",
       "      <td>938</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>...</td>\n",
       "      <td>1.00</td>\n",
       "      <td>0.00</td>\n",
       "      <td>0.01</td>\n",
       "      <td>0.01</td>\n",
       "      <td>0.01</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.00</td>\n",
       "      <td>0.00</td>\n",
       "      <td>normal</td>\n",
       "      <td>21</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>22541</th>\n",
       "      <td>0</td>\n",
       "      <td>tcp</td>\n",
       "      <td>http</td>\n",
       "      <td>SF</td>\n",
       "      <td>54540</td>\n",
       "      <td>8314</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>2</td>\n",
       "      <td>...</td>\n",
       "      <td>1.00</td>\n",
       "      <td>0.00</td>\n",
       "      <td>0.00</td>\n",
       "      <td>0.00</td>\n",
       "      <td>0.00</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.07</td>\n",
       "      <td>0.07</td>\n",
       "      <td>back</td>\n",
       "      <td>15</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>22542</th>\n",
       "      <td>0</td>\n",
       "      <td>udp</td>\n",
       "      <td>domain_u</td>\n",
       "      <td>SF</td>\n",
       "      <td>42</td>\n",
       "      <td>42</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>...</td>\n",
       "      <td>0.99</td>\n",
       "      <td>0.01</td>\n",
       "      <td>0.00</td>\n",
       "      <td>0.00</td>\n",
       "      <td>0.00</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.00</td>\n",
       "      <td>0.00</td>\n",
       "      <td>normal</td>\n",
       "      <td>21</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>22543</th>\n",
       "      <td>0</td>\n",
       "      <td>tcp</td>\n",
       "      <td>sunrpc</td>\n",
       "      <td>REJ</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>...</td>\n",
       "      <td>0.08</td>\n",
       "      <td>0.03</td>\n",
       "      <td>0.00</td>\n",
       "      <td>0.00</td>\n",
       "      <td>0.00</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.44</td>\n",
       "      <td>1.00</td>\n",
       "      <td>mscan</td>\n",
       "      <td>14</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "<p>22544 rows \u00c3\u2014 43 columns</p>\n",
       "</div>"
      ],
      "text/plain": [
       "       duration protocol_type   service  flag  src_bytes  dst_bytes  land  \\\n",
       "0             0           tcp   private   REJ          0          0     0   \n",
       "1             0           tcp   private   REJ          0          0     0   \n",
       "2             2           tcp  ftp_data    SF      12983          0     0   \n",
       "3             0          icmp     eco_i    SF         20          0     0   \n",
       "4             1           tcp    telnet  RSTO          0         15     0   \n",
       "...         ...           ...       ...   ...        ...        ...   ...   \n",
       "22539         0           tcp      smtp    SF        794        333     0   \n",
       "22540         0           tcp      http    SF        317        938     0   \n",
       "22541         0           tcp      http    SF      54540       8314     0   \n",
       "22542         0           udp  domain_u    SF         42         42     0   \n",
       "22543         0           tcp    sunrpc   REJ          0          0     0   \n",
       "\n",
       "       wrong_fragment  urgent  hot  ...  dst_host_same_srv_rate  \\\n",
       "0                   0       0    0  ...                    0.04   \n",
       "1                   0       0    0  ...                    0.00   \n",
       "2                   0       0    0  ...                    0.61   \n",
       "3                   0       0    0  ...                    1.00   \n",
       "4                   0       0    0  ...                    0.31   \n",
       "...               ...     ...  ...  ...                     ...   \n",
       "22539               0       0    0  ...                    0.72   \n",
       "22540               0       0    0  ...                    1.00   \n",
       "22541               0       0    2  ...                    1.00   \n",
       "22542               0       0    0  ...                    0.99   \n",
       "22543               0       0    0  ...                    0.08   \n",
       "\n",
       "       dst_host_diff_srv_rate  dst_host_same_src_port_rate  \\\n",
       "0                        0.06                         0.00   \n",
       "1                        0.06                         0.00   \n",
       "2                        0.04                         0.61   \n",
       "3                        0.00                         1.00   \n",
       "4                        0.17                         0.03   \n",
       "...                       ...                          ...   \n",
       "22539                    0.06                         0.01   \n",
       "22540                    0.00                         0.01   \n",
       "22541                    0.00                         0.00   \n",
       "22542                    0.01                         0.00   \n",
       "22543                    0.03                         0.00   \n",
       "\n",
       "       dst_host_srv_diff_host_rate  dst_host_serror_rate  \\\n",
       "0                             0.00                  0.00   \n",
       "1                             0.00                  0.00   \n",
       "2                             0.02                  0.00   \n",
       "3                             0.28                  0.00   \n",
       "4                             0.02                  0.00   \n",
       "...                            ...                   ...   \n",
       "22539                         0.01                  0.01   \n",
       "22540                         0.01                  0.01   \n",
       "22541                         0.00                  0.00   \n",
       "22542                         0.00                  0.00   \n",
       "22543                         0.00                  0.00   \n",
       "\n",
       "       dst_host_srv_serror_rate  dst_host_rerror_rate  \\\n",
       "0                           0.0                  1.00   \n",
       "1                           0.0                  1.00   \n",
       "2                           0.0                  0.00   \n",
       "3                           0.0                  0.00   \n",
       "4                           0.0                  0.83   \n",
       "...                         ...                   ...   \n",
       "22539                       0.0                  0.00   \n",
       "22540                       0.0                  0.00   \n",
       "22541                       0.0                  0.07   \n",
       "22542                       0.0                  0.00   \n",
       "22543                       0.0                  0.44   \n",
       "\n",
       "       dst_host_srv_rerror_rate   attack  level  \n",
       "0                          1.00  neptune     21  \n",
       "1                          1.00  neptune     21  \n",
       "2                          0.00   normal     21  \n",
       "3                          0.00    saint     15  \n",
       "4                          0.71    mscan     11  \n",
       "...                         ...      ...    ...  \n",
       "22539                      0.00   normal     21  \n",
       "22540                      0.00   normal     21  \n",
       "22541                      0.07     back     15  \n",
       "22542                      0.00   normal     21  \n",
       "22543                      1.00    mscan     14  \n",
       "\n",
       "[22544 rows x 43 columns]"
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "test_df = get_df('data/KDDTest+.txt', columns=columns, drop=False)\n",
    "test_df"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "front-remove",
   "metadata": {},
   "outputs": [],
   "source": [
    "#  https://www.kaggle.com/code/avk256/nsl-kdd-anomaly-detection/notebook\n",
    "\n",
    "# map normal to 0, all attacks to 1\n",
    "is_attack = df.attack.map(lambda a: 0 if a == 'normal' else 1)\n",
    "test_attack = test_df.attack.map(lambda a: 0 if a == 'normal' else 1)\n",
    "\n",
    "#data_with_attack = df.join(is_attack, rsuffix='_flag')\n",
    "df['attack_flag'] = is_attack\n",
    "test_df['attack_flag'] = test_attack\n",
    "\n",
    "# map normal to 1, all attacks to 0\n",
    "is_normal = df.attack.map(lambda a: 1 if a == 'normal' else 0)\n",
    "test_normal = test_df.attack.map(lambda a: 1 if a == 'normal' else 0)\n",
    "\n",
    "df['normal_flag'] = is_normal\n",
    "test_df['normal_flag'] = test_normal\n",
    "\n",
    "# map the data and join to the data set\n",
    "attack_map = df.attack.apply(map_attack)\n",
    "df['attack_map'] = attack_map\n",
    "\n",
    "test_attack_map = test_df.attack.apply(map_attack)\n",
    "test_df['attack_map'] = test_attack_map\n",
    "\n",
    "# categorical features\n",
    "features_to_encode = ['protocol_type', 'service', 'flag']\n",
    "\n",
    "# get numeric features, we won't worry about encoding these at this point\n",
    "# numeric_features = ['duration', 'src_bytes', 'dst_bytes']\n",
    "# Use all features\n",
    "numeric_features = list(set(df.columns[:-5]) - set(features_to_encode))\n",
    "\n",
    "\n",
    "def feat_eng(df, test_df, features_to_encode=features_to_encode, numeric_features=numeric_features):\n",
    "#     https://www.kaggle.com/code/avk256/nsl-kdd-anomaly-detection/notebook\n",
    "\n",
    "    # get the intial set of encoded features and encode them\n",
    "    encoded = pd.get_dummies(df[features_to_encode])\n",
    "    test_encoded_base = pd.get_dummies(test_df[features_to_encode])\n",
    "\n",
    "    # not all of the features are in the test set, so we need to account for diffs\n",
    "    test_index = np.arange(len(test_df.index))\n",
    "    column_diffs = list(set(encoded.columns.values)-set(test_encoded_base.columns.values))\n",
    "\n",
    "    diff_df = pd.DataFrame(0, index=test_index, columns=column_diffs)\n",
    "\n",
    "    # we'll also need to reorder the columns to match, so let's get those\n",
    "    column_order = encoded.columns.to_list()\n",
    "\n",
    "    # append the new columns\n",
    "    test_encoded_temp = test_encoded_base.join(diff_df)\n",
    "\n",
    "    # reorder the columns\n",
    "    test_final = test_encoded_temp[column_order].fillna(0)\n",
    "\n",
    "    # model to fit/test\n",
    "    to_fit = encoded.join(df[numeric_features])\n",
    "    test_set = test_final.join(test_df[numeric_features])\n",
    "    \n",
    "    return to_fit, test_set"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "lyric-public",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>protocol_type_icmp</th>\n",
       "      <th>protocol_type_tcp</th>\n",
       "      <th>protocol_type_udp</th>\n",
       "      <th>service_IRC</th>\n",
       "      <th>service_X11</th>\n",
       "      <th>service_Z39_50</th>\n",
       "      <th>service_aol</th>\n",
       "      <th>service_auth</th>\n",
       "      <th>service_bgp</th>\n",
       "      <th>service_courier</th>\n",
       "      <th>...</th>\n",
       "      <th>dst_host_rerror_rate</th>\n",
       "      <th>is_guest_login</th>\n",
       "      <th>srv_serror_rate</th>\n",
       "      <th>srv_rerror_rate</th>\n",
       "      <th>diff_srv_rate</th>\n",
       "      <th>count</th>\n",
       "      <th>duration</th>\n",
       "      <th>num_file_creations</th>\n",
       "      <th>dst_bytes</th>\n",
       "      <th>num_root</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>0</td>\n",
       "      <td>1</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>...</td>\n",
       "      <td>0.05</td>\n",
       "      <td>0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.00</td>\n",
       "      <td>2</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>1</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>...</td>\n",
       "      <td>0.00</td>\n",
       "      <td>0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.15</td>\n",
       "      <td>13</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>0</td>\n",
       "      <td>1</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>...</td>\n",
       "      <td>0.00</td>\n",
       "      <td>0</td>\n",
       "      <td>1.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.07</td>\n",
       "      <td>123</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>0</td>\n",
       "      <td>1</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>...</td>\n",
       "      <td>0.00</td>\n",
       "      <td>0</td>\n",
       "      <td>0.2</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.00</td>\n",
       "      <td>5</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>8153</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>0</td>\n",
       "      <td>1</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>...</td>\n",
       "      <td>0.00</td>\n",
       "      <td>0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.00</td>\n",
       "      <td>30</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>420</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>...</th>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>125968</th>\n",
       "      <td>0</td>\n",
       "      <td>1</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>...</td>\n",
       "      <td>0.00</td>\n",
       "      <td>0</td>\n",
       "      <td>1.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.06</td>\n",
       "      <td>184</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>125969</th>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>1</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>...</td>\n",
       "      <td>0.00</td>\n",
       "      <td>0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.00</td>\n",
       "      <td>2</td>\n",
       "      <td>8</td>\n",
       "      <td>0</td>\n",
       "      <td>145</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>125970</th>\n",
       "      <td>0</td>\n",
       "      <td>1</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>...</td>\n",
       "      <td>0.01</td>\n",
       "      <td>0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.00</td>\n",
       "      <td>1</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>384</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>125971</th>\n",
       "      <td>0</td>\n",
       "      <td>1</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>...</td>\n",
       "      <td>0.00</td>\n",
       "      <td>0</td>\n",
       "      <td>1.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.05</td>\n",
       "      <td>144</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>125972</th>\n",
       "      <td>0</td>\n",
       "      <td>1</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>...</td>\n",
       "      <td>0.00</td>\n",
       "      <td>0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.00</td>\n",
       "      <td>1</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "<p>125973 rows \u00c3\u2014 122 columns</p>\n",
       "</div>"
      ],
      "text/plain": [
       "        protocol_type_icmp  protocol_type_tcp  protocol_type_udp  service_IRC  \\\n",
       "0                        0                  1                  0            0   \n",
       "1                        0                  0                  1            0   \n",
       "2                        0                  1                  0            0   \n",
       "3                        0                  1                  0            0   \n",
       "4                        0                  1                  0            0   \n",
       "...                    ...                ...                ...          ...   \n",
       "125968                   0                  1                  0            0   \n",
       "125969                   0                  0                  1            0   \n",
       "125970                   0                  1                  0            0   \n",
       "125971                   0                  1                  0            0   \n",
       "125972                   0                  1                  0            0   \n",
       "\n",
       "        service_X11  service_Z39_50  service_aol  service_auth  service_bgp  \\\n",
       "0                 0               0            0             0            0   \n",
       "1                 0               0            0             0            0   \n",
       "2                 0               0            0             0            0   \n",
       "3                 0               0            0             0            0   \n",
       "4                 0               0            0             0            0   \n",
       "...             ...             ...          ...           ...          ...   \n",
       "125968            0               0            0             0            0   \n",
       "125969            0               0            0             0            0   \n",
       "125970            0               0            0             0            0   \n",
       "125971            0               0            0             0            0   \n",
       "125972            0               0            0             0            0   \n",
       "\n",
       "        service_courier  ...  dst_host_rerror_rate  is_guest_login  \\\n",
       "0                     0  ...                  0.05               0   \n",
       "1                     0  ...                  0.00               0   \n",
       "2                     0  ...                  0.00               0   \n",
       "3                     0  ...                  0.00               0   \n",
       "4                     0  ...                  0.00               0   \n",
       "...                 ...  ...                   ...             ...   \n",
       "125968                0  ...                  0.00               0   \n",
       "125969                0  ...                  0.00               0   \n",
       "125970                0  ...                  0.01               0   \n",
       "125971                0  ...                  0.00               0   \n",
       "125972                0  ...                  0.00               0   \n",
       "\n",
       "        srv_serror_rate  srv_rerror_rate  diff_srv_rate  count  duration  \\\n",
       "0                   0.0              0.0           0.00      2         0   \n",
       "1                   0.0              0.0           0.15     13         0   \n",
       "2                   1.0              0.0           0.07    123         0   \n",
       "3                   0.2              0.0           0.00      5         0   \n",
       "4                   0.0              0.0           0.00     30         0   \n",
       "...                 ...              ...            ...    ...       ...   \n",
       "125968              1.0              0.0           0.06    184         0   \n",
       "125969              0.0              0.0           0.00      2         8   \n",
       "125970              0.0              0.0           0.00      1         0   \n",
       "125971              1.0              0.0           0.05    144         0   \n",
       "125972              0.0              0.0           0.00      1         0   \n",
       "\n",
       "        num_file_creations  dst_bytes  num_root  \n",
       "0                        0          0         0  \n",
       "1                        0          0         0  \n",
       "2                        0          0         0  \n",
       "3                        0       8153         0  \n",
       "4                        0        420         0  \n",
       "...                    ...        ...       ...  \n",
       "125968                   0          0         0  \n",
       "125969                   0        145         0  \n",
       "125970                   0        384         0  \n",
       "125971                   0          0         0  \n",
       "125972                   0          0         0  \n",
       "\n",
       "[125973 rows x 122 columns]"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "data_train, data_test = feat_eng(df, test_df)\n",
    "data_train"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "comprehensive-shock",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0    67352\n",
       "1    45927\n",
       "2    11656\n",
       "4      995\n",
       "3       43\n",
       "Name: attack_map, dtype: int64"
      ]
     },
     "execution_count": 6,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "df['attack_map'].value_counts()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "matched-journey",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0    9855\n",
       "1    7460\n",
       "4    2743\n",
       "2    2421\n",
       "3      65\n",
       "Name: attack_map, dtype: int64"
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "test_df['attack_map'].value_counts()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "organic-cooperative",
   "metadata": {},
   "outputs": [],
   "source": [
    "scaler = StandardScaler()\n",
    "\n",
    "new_attacks = [1,2,3,4]\n",
    "test_classes = [0,1,2,3,4]\n",
    "\n",
    "\n",
    "def get_x_y(df, data, classes=[0,1]):\n",
    "\n",
    "    indices = df['attack_map'].isin(classes)\n",
    "    x = data[indices]\n",
    "    y = df['normal_flag'][indices]\n",
    "    \n",
    "    return x.to_numpy(), y.to_numpy()\n",
    "\n",
    "\n",
    "x_train, y = get_x_y(df, data_train)\n",
    "X = scaler.fit_transform(x_train)\n",
    "\n",
    "np.random.seed(0)\n",
    "np.random.shuffle(X)\n",
    "np.random.seed(0)\n",
    "np.random.shuffle(y)\n",
    "\n",
    "x_normal = X[y==1]\n",
    "\n",
    "# x_testing, y_test = get_x_y(test_df, data_test, classes=test_classes)\n",
    "# x_test = scaler.transform(x_testing)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "bulgarian-candle",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:GPU:0', '/job:localhost/replica:0/task:0/device:GPU:1')\n"
     ]
    }
   ],
   "source": [
    "strategy = tf.distribute.MirroredStrategy()\n",
    "num_inputs = X.shape[-1]"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "dynamic-selection",
   "metadata": {},
   "source": [
    "# Modelling: Pretrain AE"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "timely-marks",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Build Models\n",
    "\n",
    "def build_seq_layer(activation, sigma=0.5, train=False, layer_number=1,\n",
    "                seed=0, neurons=5, batchnorm=True, regulariser=None):\n",
    "    \n",
    "    layer = []\n",
    "    \n",
    "    initialiser = tf.keras.initializers.GlorotUniform(seed=seed)\n",
    "    \n",
    "    if activation == \"r\":\n",
    "        layer.append(RBFLayer(neurons, gamma=1.0, initializer=initialiser))\n",
    "        \n",
    "        if batchnorm:\n",
    "            layer.append(tf.keras.layers.BatchNormalization())\n",
    "            \n",
    "    else:\n",
    "        layer.append(tf.keras.layers.Dense(neurons,\n",
    "                      kernel_initializer=initialiser, kernel_regularizer=regulariser))\n",
    "        \n",
    "        if batchnorm:\n",
    "            layer.append(tf.keras.layers.BatchNormalization())\n",
    "            \n",
    "        if activation == \"b\":\n",
    "            layer.append(Bump(sigma=sigma, trainable=train,\n",
    "                              name=f\"bump{layer_number+1}\"))\n",
    "        elif activation == \"s\":\n",
    "            layer.append(tf.keras.layers.Activation(tf.keras.activations.sigmoid))\n",
    "        else:\n",
    "            layer.append(tf.keras.layers.LeakyReLU(alpha=0.01))\n",
    "    \n",
    "    return layer\n",
    "\n",
    "\n",
    "class AE_Module(tf.keras.Model):\n",
    "    \n",
    "    def __init__(self, input_dim=122, hidden_dim=[60, 45], output_dim=30,\n",
    "                 activation=tf.keras.layers.LeakyReLU(alpha=0.3),\n",
    "                 regulariser=None, decoder=False, seed=0, **kwargs):\n",
    "        \n",
    "        super(AE_Module, self).__init__()\n",
    "        \n",
    "        layers = tf.keras.Sequential([tf.keras.Input(shape=(input_dim,))])\n",
    "        \n",
    "        for i, neurons in enumerate(hidden_dim):\n",
    "            \n",
    "            if type(activation) is str:\n",
    "                layer = build_seq_layer(activation, sigma=0.5,\n",
    "                                        train=False, layer_number=i,\n",
    "                                        seed=i*10+seed, neurons=neurons,\n",
    "                                        batchnorm=True, regulariser=regulariser)\n",
    "                for l in layer:\n",
    "                    layers.add(l)\n",
    "                    \n",
    "            else:\n",
    "            \n",
    "                initialiser = tf.keras.initializers.GlorotUniform(seed)\n",
    "            \n",
    "                layers.add(\n",
    "                    tf.keras.layers.Dense(neurons,\n",
    "                                          kernel_initializer=initialiser,\n",
    "                                          kernel_regularizer=regulariser))\n",
    "                layers.add(tf.keras.layers.BatchNormalization())\n",
    "\n",
    "                layers.add(activation)\n",
    "                \n",
    "        self.hidden_layers = layers\n",
    "        \n",
    "        initialiser = tf.keras.initializers.GlorotUniform(seed=2023+seed)\n",
    "        self.last_layer = tf.keras.layers.Dense(output_dim, use_bias=decoder,\n",
    "                                  kernel_initializer=initialiser,\n",
    "                                  kernel_regularizer=regulariser)\n",
    "        \n",
    "    def call(self, x):\n",
    "        \n",
    "#         for layer in self.hidden_layers:\n",
    "#             x = layer(x)\n",
    "            \n",
    "        x = self.hidden_layers(x)\n",
    "            \n",
    "        return self.last_layer(x)\n",
    "    \n",
    "\n",
    "    \n",
    "class AE(tf.keras.Model):\n",
    "    def __init__(self, input_dim=122, hidden_dim=[60, 45], latent_dim=30,\n",
    "                 activation=tf.keras.layers.LeakyReLU(alpha=0.3),\n",
    "                 regulariser=None, seed=0, **kwargs):\n",
    "        super(AE, self).__init__()\n",
    "        self.encoder = AE_Module(input_dim, hidden_dim, latent_dim,\n",
    "                                 activation, regulariser, decoder=False, seed=seed)\n",
    "        self.decoder = AE_Module(latent_dim, list(reversed(hidden_dim)), input_dim,\n",
    "                                 activation, regulariser, decoder=True, seed=seed)\n",
    "\n",
    "    def call(self, x):\n",
    "        encoded = self.encoder(x)\n",
    "        decoded = self.decoder(encoded)\n",
    "        return decoded"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "numerous-priority",
   "metadata": {
    "scrolled": false
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).\n",
      "INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).\n",
      "INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).\n",
      "INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).\n",
      "Epoch 1/50\n",
      "INFO:tensorflow:batch_all_reduce: 3 all-reduces with algorithm = nccl, num_packs = 1\n",
      "INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).\n",
      "INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).\n",
      "INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).\n",
      "INFO:tensorflow:batch_all_reduce: 3 all-reduces with algorithm = nccl, num_packs = 1\n",
      "INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).\n",
      "INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).\n",
      "3186/3186 [==============================] - ETA: 0s - loss: 0.3605INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).\n",
      "3186/3186 [==============================] - 19s 4ms/step - loss: 0.3605 - val_loss: 0.2693\n",
      "Epoch 2/50\n",
      "3186/3186 [==============================] - 13s 4ms/step - loss: 0.0509 - val_loss: 0.1859\n",
      "Epoch 3/50\n",
      "3186/3186 [==============================] - 13s 4ms/step - loss: 0.0136 - val_loss: 0.1696\n",
      "Epoch 4/50\n",
      "3186/3186 [==============================] - 13s 4ms/step - loss: 0.0061 - val_loss: 0.1724\n",
      "Epoch 5/50\n",
      "3186/3186 [==============================] - 13s 4ms/step - loss: 0.0052 - val_loss: 0.1706\n",
      "Epoch 6/50\n",
      "3186/3186 [==============================] - 13s 4ms/step - loss: 0.0046 - val_loss: 0.1769\n"
     ]
    }
   ],
   "source": [
    "# sanity check\n",
    "lr = 3e-4\n",
    "epochs = 50\n",
    "verbose = 1      # can change this to 0 to suppress verbosity during training-\n",
    "shuffle = False\n",
    "val_split = 0.1\n",
    "repeats = 2\n",
    "epochs = 50\n",
    "\n",
    "early_stopping = tf.keras.callbacks.EarlyStopping(patience=3, monitor='val_loss',\n",
    "                                                  restore_best_weights=True)\n",
    "\n",
    "callbacks = [early_stopping]\n",
    "\n",
    "hidden_dim = []\n",
    "latent_dim = 122\n",
    "\n",
    "with strategy.scope():\n",
    "    # change activation to desired activation\n",
    "    ae = AE(input_dim=num_inputs, hidden_dim=hidden_dim, latent_dim=latent_dim,\n",
    "                     activation=\"b\",\n",
    "                     regulariser=None, seed=0)\n",
    "    \n",
    "    ae.compile(loss='mse', optimizer=tf.keras.optimizers.Adam(learning_rate=lr))\n",
    "    \n",
    "#     ae.summary()\n",
    "    \n",
    "    hist = ae.fit(X, X, epochs=epochs, verbose=verbose, validation_split=val_split, shuffle=shuffle, callbacks=callbacks)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "useful-highway",
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 1/500\n",
      "3186/3186 [==============================] - 19s 6ms/step - loss: 0.2675 - val_loss: 0.3193\n",
      "Epoch 2/500\n",
      "3186/3186 [==============================] - 19s 6ms/step - loss: 0.2649 - val_loss: 0.3154\n",
      "Epoch 3/500\n",
      "3186/3186 [==============================] - 19s 6ms/step - loss: 0.2623 - val_loss: 0.3124\n",
      "Epoch 4/500\n",
      "3186/3186 [==============================] - 19s 6ms/step - loss: 0.2598 - val_loss: 0.3121\n",
      "Epoch 5/500\n",
      "3186/3186 [==============================] - 19s 6ms/step - loss: 0.2575 - val_loss: 0.3104\n",
      "Epoch 6/500\n",
      "3186/3186 [==============================] - 19s 6ms/step - loss: 0.2555 - val_loss: 0.3074\n",
      "Epoch 7/500\n",
      "3186/3186 [==============================] - 19s 6ms/step - loss: 0.2533 - val_loss: 0.3050\n",
      "Epoch 8/500\n",
      "3186/3186 [==============================] - 19s 6ms/step - loss: 0.2514 - val_loss: 0.3048\n",
      "Epoch 9/500\n",
      "3186/3186 [==============================] - 19s 6ms/step - loss: 0.2491 - val_loss: 0.3024\n",
      "Epoch 10/500\n",
      "3186/3186 [==============================] - 19s 6ms/step - loss: 0.2470 - val_loss: 0.3017\n",
      "Epoch 11/500\n",
      "3186/3186 [==============================] - 19s 6ms/step - loss: 0.2452 - val_loss: 0.2987\n",
      "Epoch 12/500\n",
      "3186/3186 [==============================] - 19s 6ms/step - loss: 0.2431 - val_loss: 0.2981\n",
      "Epoch 13/500\n",
      "3186/3186 [==============================] - 19s 6ms/step - loss: 0.2412 - val_loss: 0.2967\n",
      "Epoch 14/500\n",
      "3186/3186 [==============================] - 19s 6ms/step - loss: 0.2393 - val_loss: 0.2957\n",
      "Epoch 15/500\n",
      "3186/3186 [==============================] - 19s 6ms/step - loss: 0.2375 - val_loss: 0.2961\n",
      "Epoch 16/500\n",
      "3186/3186 [==============================] - 19s 6ms/step - loss: 0.2358 - val_loss: 0.2938\n",
      "Epoch 17/500\n",
      "3186/3186 [==============================] - 19s 6ms/step - loss: 0.2340 - val_loss: 0.2934\n",
      "Epoch 18/500\n",
      "3186/3186 [==============================] - 19s 6ms/step - loss: 0.2327 - val_loss: 0.2917\n",
      "Epoch 19/500\n",
      "3186/3186 [==============================] - 19s 6ms/step - loss: 0.2306 - val_loss: 0.2910\n",
      "Epoch 20/500\n",
      "3186/3186 [==============================] - 19s 6ms/step - loss: 0.2291 - val_loss: 0.2893\n",
      "Epoch 21/500\n",
      "3186/3186 [==============================] - 19s 6ms/step - loss: 0.2276 - val_loss: 0.2888\n",
      "Epoch 22/500\n",
      "3186/3186 [==============================] - 19s 6ms/step - loss: 0.2260 - val_loss: 0.2872\n",
      "Epoch 23/500\n",
      "3186/3186 [==============================] - 19s 6ms/step - loss: 0.2245 - val_loss: 0.2869\n",
      "Epoch 24/500\n",
      "3186/3186 [==============================] - 19s 6ms/step - loss: 0.2232 - val_loss: 0.2846\n",
      "Epoch 25/500\n",
      "3186/3186 [==============================] - 19s 6ms/step - loss: 0.2212 - val_loss: 0.2848\n",
      "Epoch 26/500\n",
      "3186/3186 [==============================] - 19s 6ms/step - loss: 0.2196 - val_loss: 0.2841\n",
      "Epoch 27/500\n",
      "3186/3186 [==============================] - 19s 6ms/step - loss: 0.2180 - val_loss: 0.2827\n",
      "Epoch 28/500\n",
      "3186/3186 [==============================] - 19s 6ms/step - loss: 0.2164 - val_loss: 0.2805\n",
      "Epoch 29/500\n",
      "3186/3186 [==============================] - 19s 6ms/step - loss: 0.2148 - val_loss: 0.2822\n",
      "Epoch 30/500\n",
      "3186/3186 [==============================] - 19s 6ms/step - loss: 0.2134 - val_loss: 0.2804\n",
      "Epoch 31/500\n",
      "3186/3186 [==============================] - 19s 6ms/step - loss: 0.2119 - val_loss: 0.2792\n",
      "Epoch 32/500\n",
      "3186/3186 [==============================] - 19s 6ms/step - loss: 0.2107 - val_loss: 0.2793\n",
      "Epoch 33/500\n",
      "3186/3186 [==============================] - 19s 6ms/step - loss: 0.2094 - val_loss: 0.2774\n",
      "Epoch 34/500\n",
      "3186/3186 [==============================] - 19s 6ms/step - loss: 0.2083 - val_loss: 0.2778\n",
      "Epoch 35/500\n",
      "3186/3186 [==============================] - 19s 6ms/step - loss: 0.2075 - val_loss: 0.2751\n",
      "Epoch 36/500\n",
      "3186/3186 [==============================] - 19s 6ms/step - loss: 0.2060 - val_loss: 0.2756\n",
      "Epoch 37/500\n",
      "3186/3186 [==============================] - 19s 6ms/step - loss: 0.2053 - val_loss: 0.2745\n",
      "Epoch 38/500\n",
      "3186/3186 [==============================] - 19s 6ms/step - loss: 0.2041 - val_loss: 0.2739\n",
      "Epoch 39/500\n",
      "3186/3186 [==============================] - 19s 6ms/step - loss: 0.2030 - val_loss: 0.2736\n",
      "Epoch 40/500\n",
      "3186/3186 [==============================] - 19s 6ms/step - loss: 0.2020 - val_loss: 0.2734\n",
      "Epoch 41/500\n",
      "3186/3186 [==============================] - 19s 6ms/step - loss: 0.2009 - val_loss: 0.2732\n",
      "Epoch 42/500\n",
      "3186/3186 [==============================] - 19s 6ms/step - loss: 0.2006 - val_loss: 0.2699\n",
      "Epoch 43/500\n",
      "3186/3186 [==============================] - 19s 6ms/step - loss: 0.1993 - val_loss: 0.2702\n",
      "Epoch 44/500\n",
      "3186/3186 [==============================] - 19s 6ms/step - loss: 0.1984 - val_loss: 0.2697\n",
      "Epoch 45/500\n",
      "3186/3186 [==============================] - 19s 6ms/step - loss: 0.1978 - val_loss: 0.2702\n",
      "Epoch 46/500\n",
      "3186/3186 [==============================] - 19s 6ms/step - loss: 0.1967 - val_loss: 0.2678\n",
      "Epoch 47/500\n",
      "3186/3186 [==============================] - 19s 6ms/step - loss: 0.1960 - val_loss: 0.2698\n",
      "Epoch 48/500\n",
      "3186/3186 [==============================] - 19s 6ms/step - loss: 0.1950 - val_loss: 0.2658\n",
      "Epoch 49/500\n",
      "3186/3186 [==============================] - 19s 6ms/step - loss: 0.1945 - val_loss: 0.2683\n",
      "Epoch 50/500\n",
      "3186/3186 [==============================] - 19s 6ms/step - loss: 0.1937 - val_loss: 0.2652\n",
      "Epoch 51/500\n",
      "3186/3186 [==============================] - 19s 6ms/step - loss: 0.1925 - val_loss: 0.2681\n",
      "Epoch 52/500\n",
      "3186/3186 [==============================] - 19s 6ms/step - loss: 0.1921 - val_loss: 0.2677\n",
      "Epoch 53/500\n",
      "3186/3186 [==============================] - 19s 6ms/step - loss: 0.1912 - val_loss: 0.2686\n",
      "Epoch 54/500\n",
      "3186/3186 [==============================] - 19s 6ms/step - loss: 0.1908 - val_loss: 0.2649\n",
      "Epoch 55/500\n",
      "3186/3186 [==============================] - 19s 6ms/step - loss: 0.1901 - val_loss: 0.2675\n",
      "Epoch 56/500\n",
      "3186/3186 [==============================] - 19s 6ms/step - loss: 0.1892 - val_loss: 0.2644\n",
      "Epoch 57/500\n",
      "3186/3186 [==============================] - 19s 6ms/step - loss: 0.1882 - val_loss: 0.2677\n",
      "Epoch 58/500\n",
      "3186/3186 [==============================] - 19s 6ms/step - loss: 0.1879 - val_loss: 0.2640\n",
      "Epoch 59/500\n",
      "3186/3186 [==============================] - 19s 6ms/step - loss: 0.1869 - val_loss: 0.2633\n",
      "Epoch 60/500\n",
      "3186/3186 [==============================] - 19s 6ms/step - loss: 0.1864 - val_loss: 0.2650\n",
      "Epoch 61/500\n",
      "3186/3186 [==============================] - 19s 6ms/step - loss: 0.1855 - val_loss: 0.2682\n",
      "Epoch 62/500\n",
      "3186/3186 [==============================] - 19s 6ms/step - loss: 0.1856 - val_loss: 0.2672\n",
      "Epoch 63/500\n",
      "3186/3186 [==============================] - 19s 6ms/step - loss: 0.1844 - val_loss: 0.2653\n",
      "Epoch 64/500\n",
      "3186/3186 [==============================] - 19s 6ms/step - loss: 0.1838 - val_loss: 0.2642\n",
      "Epoch 65/500\n",
      "3186/3186 [==============================] - 19s 6ms/step - loss: 0.1833 - val_loss: 0.2589\n",
      "Epoch 66/500\n",
      "3186/3186 [==============================] - 19s 6ms/step - loss: 0.1827 - val_loss: 0.2574\n",
      "Epoch 67/500\n",
      "3186/3186 [==============================] - 19s 6ms/step - loss: 0.1822 - val_loss: 0.2603\n",
      "Epoch 68/500\n",
      "3186/3186 [==============================] - 19s 6ms/step - loss: 0.1813 - val_loss: 0.2663\n",
      "Epoch 69/500\n",
      "3186/3186 [==============================] - 19s 6ms/step - loss: 0.1807 - val_loss: 0.2655\n",
      "Epoch 70/500\n",
      "3186/3186 [==============================] - 19s 6ms/step - loss: 0.1806 - val_loss: 0.2645\n",
      "Epoch 71/500\n",
      "3186/3186 [==============================] - 19s 6ms/step - loss: 0.1798 - val_loss: 0.2636\n",
      "Epoch 72/500\n",
      "3186/3186 [==============================] - 19s 6ms/step - loss: 0.1793 - val_loss: 0.2607\n",
      "Epoch 73/500\n",
      "3186/3186 [==============================] - 19s 6ms/step - loss: 0.1789 - val_loss: 0.2604\n",
      "Epoch 74/500\n",
      "3186/3186 [==============================] - 19s 6ms/step - loss: 0.1784 - val_loss: 0.2612\n",
      "Epoch 75/500\n",
      "3186/3186 [==============================] - 19s 6ms/step - loss: 0.1778 - val_loss: 0.2613\n",
      "Epoch 76/500\n",
      "3186/3186 [==============================] - 19s 6ms/step - loss: 0.1776 - val_loss: 0.2532\n",
      "Epoch 77/500\n",
      "3186/3186 [==============================] - 19s 6ms/step - loss: 0.1770 - val_loss: 0.2539\n",
      "Epoch 78/500\n",
      "3186/3186 [==============================] - 19s 6ms/step - loss: 0.1764 - val_loss: 0.2583\n",
      "Epoch 79/500\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "3186/3186 [==============================] - 19s 6ms/step - loss: 0.1763 - val_loss: 0.2545\n",
      "Epoch 80/500\n",
      "3186/3186 [==============================] - 19s 6ms/step - loss: 0.1758 - val_loss: 0.2547\n",
      "Epoch 81/500\n",
      "3186/3186 [==============================] - 19s 6ms/step - loss: 0.1755 - val_loss: 0.2529\n",
      "Epoch 82/500\n",
      "3186/3186 [==============================] - 19s 6ms/step - loss: 0.1751 - val_loss: 0.2551\n",
      "Epoch 83/500\n",
      "3186/3186 [==============================] - 19s 6ms/step - loss: 0.1744 - val_loss: 0.2567\n",
      "Epoch 84/500\n",
      "3186/3186 [==============================] - 19s 6ms/step - loss: 0.1741 - val_loss: 0.2532\n",
      "Epoch 85/500\n",
      "3186/3186 [==============================] - 19s 6ms/step - loss: 0.1738 - val_loss: 0.2583\n",
      "Epoch 86/500\n",
      "3186/3186 [==============================] - 19s 6ms/step - loss: 0.1735 - val_loss: 0.2680\n",
      "Epoch 87/500\n",
      "3186/3186 [==============================] - 19s 6ms/step - loss: 0.1729 - val_loss: 0.2813\n",
      "Epoch 88/500\n",
      "3186/3186 [==============================] - 19s 6ms/step - loss: 0.1725 - val_loss: 0.2679\n",
      "Epoch 89/500\n",
      "3186/3186 [==============================] - 19s 6ms/step - loss: 0.1726 - val_loss: 0.2684\n",
      "Epoch 90/500\n",
      "3186/3186 [==============================] - 19s 6ms/step - loss: 0.1722 - val_loss: 0.2739\n",
      "Epoch 91/500\n",
      "3186/3186 [==============================] - 19s 6ms/step - loss: 0.1718 - val_loss: 0.2750\n"
     ]
    }
   ],
   "source": [
    "lr = 3e-4\n",
    "verbose = 1      # can change this to 0 to suppress verbosity during training-\n",
    "shuffle = False\n",
    "val_split = 0.1\n",
    "epochs = 500\n",
    "\n",
    "early_stopping = tf.keras.callbacks.EarlyStopping(patience=10, monitor='val_loss',\n",
    "                                                  restore_best_weights=True)\n",
    "\n",
    "callbacks = [early_stopping]\n",
    "\n",
    "hidden_dim = [90]\n",
    "latent_dim = 60\n",
    "\n",
    "with strategy.scope():\n",
    "#     ae = AE(input_dim=num_inputs, hidden_dim=hidden_dim, latent_dim=latent_dim,\n",
    "#                      activation=\"b\",\n",
    "#                      regulariser=None, seed=0)\n",
    "    \n",
    "#     ae.compile(loss='mse', optimizer=tf.keras.optimizers.Adam(learning_rate=lr))\n",
    "    \n",
    "#     ae.summary()\n",
    "    \n",
    "    hist = ae.fit(X, X, epochs=epochs, verbose=verbose, validation_split=val_split, shuffle=shuffle, callbacks=callbacks)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "id": "brief-paraguay",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Assets written to: models/KDD99/Encoder_Bump_90_60/assets\n",
      "INFO:tensorflow:Assets written to: models/KDD99/Decoder_Bump_90_60/assets\n"
     ]
    }
   ],
   "source": [
    "ae.encoder.save(\"models/KDD99/Encoder_Bump_90_60\")\n",
    "ae.decoder.save(\"models/KDD99/Decoder_Bump_90_60\")"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.10"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}