{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import keras4torch\n",
    "from   keras4torch.callbacks  import ModelCheckpoint,LRScheduler\n",
    "import torch\n",
    "import torch.nn    as nn\n",
    "import torch.optim as optim\n",
    "import  torch.nn.functional as     F\n",
    "import numpy       as np\n",
    "import pandas      as pd\n",
    "from copy import deepcopy\n",
    "import  matplotlib.pyplot   as     plt\n",
    "from    sklearn.preprocessing import StandardScaler, QuantileTransformer\n",
    "from    datetime import datetime\n",
    "import  gc\n",
    "import STab\n",
    "from STab import *\n",
    "from   STab import mainmodel, LWTA, Gsoftmax\n",
    "MainModel=mainmodel.MainModel"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/andreas/anaconda3/lib/python3.7/site-packages/ipykernel_launcher.py:2: DeprecationWarning: `np.float` is a deprecated alias for the builtin `float`. To silence this warning, use `float` by itself. Doing this will not modify any behavior and is safe. If you specifically wanted the numpy scalar type, use `np.float64` here.\n",
      "Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations\n",
      "  \n",
      "/home/andreas/anaconda3/lib/python3.7/site-packages/ipykernel_launcher.py:4: DeprecationWarning: `np.int` is a deprecated alias for the builtin `int`. To silence this warning, use `int` by itself. Doing this will not modify any behavior and is safe. When replacing `np.int`, you may wish to use e.g. `np.int64` or `np.int32` to specify the precision. If you wish to review your current use, check the release note link for additional information.\n",
      "Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations\n",
      "  after removing the cwd from sys.path.\n",
      "/home/andreas/anaconda3/lib/python3.7/site-packages/ipykernel_launcher.py:6: DeprecationWarning: `np.float` is a deprecated alias for the builtin `float`. To silence this warning, use `float` by itself. Doing this will not modify any behavior and is safe. If you specifically wanted the numpy scalar type, use `np.float64` here.\n",
      "Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations\n",
      "  \n",
      "/home/andreas/anaconda3/lib/python3.7/site-packages/ipykernel_launcher.py:8: DeprecationWarning: `np.int` is a deprecated alias for the builtin `int`. To silence this warning, use `int` by itself. Doing this will not modify any behavior and is safe. When replacing `np.int`, you may wish to use e.g. `np.int64` or `np.int32` to specify the precision. If you wish to review your current use, check the release note link for additional information.\n",
      "Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations\n",
      "  \n",
      "/home/andreas/anaconda3/lib/python3.7/site-packages/ipykernel_launcher.py:11: DeprecationWarning: `np.float` is a deprecated alias for the builtin `float`. To silence this warning, use `float` by itself. Doing this will not modify any behavior and is safe. If you specifically wanted the numpy scalar type, use `np.float64` here.\n",
      "Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations\n",
      "  # This is added back by InteractiveShellApp.init_path()\n",
      "/home/andreas/anaconda3/lib/python3.7/site-packages/ipykernel_launcher.py:13: DeprecationWarning: `np.int` is a deprecated alias for the builtin `int`. To silence this warning, use `int` by itself. Doing this will not modify any behavior and is safe. When replacing `np.int`, you may wish to use e.g. `np.int64` or `np.int32` to specify the precision. If you wish to review your current use, check the release note link for additional information.\n",
      "Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations\n",
      "  del sys.path[0]\n"
     ]
    }
   ],
   "source": [
    "##Load DAta\n",
    "X_test_N=pd.DataFrame(np.load('Data/adult/N_test.npy')).astype(np.float)\n",
    "X_test_C=pd.DataFrame(np.load('Data/adult/C_test.npy'))\n",
    "y_test=pd.DataFrame(np.load('Data/adult/y_test.npy')).astype(np.int)[0]\n",
    "\n",
    "X_train_N=pd.DataFrame(np.load('Data/adult/N_train.npy')).astype(np.float)\n",
    "X_train_C=pd.DataFrame(np.load('Data/adult/C_train.npy'))\n",
    "Y_train=pd.DataFrame(np.load('Data/adult/y_train.npy')).astype(np.int)[0]\n",
    "\n",
    "\n",
    "X_valid_N=pd.DataFrame(np.load('Data/adult/N_val.npy')).astype(np.float)\n",
    "X_valid_C=pd.DataFrame(np.load('Data/adult/C_val.npy'))\n",
    "y_valid=pd.DataFrame(np.load('Data/adult/y_val.npy')).astype(np.int)[0]\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "catmap=CatMap(X_train_C)\n",
    "X_train_C=catmap(X_train_C)\n",
    "X_valid_C=catmap(X_valid_C)\n",
    "X_test_C=catmap(X_test_C)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/andreas/anaconda3/lib/python3.7/site-packages/ipykernel_launcher.py:8: DeprecationWarning: `np.float` is a deprecated alias for the builtin `float`. To silence this warning, use `float` by itself. Doing this will not modify any behavior and is safe. If you specifically wanted the numpy scalar type, use `np.float64` here.\n",
      "Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations\n",
      "  \n",
      "/home/andreas/anaconda3/lib/python3.7/site-packages/ipykernel_launcher.py:9: DeprecationWarning: `np.float` is a deprecated alias for the builtin `float`. To silence this warning, use `float` by itself. Doing this will not modify any behavior and is safe. If you specifically wanted the numpy scalar type, use `np.float64` here.\n",
      "Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations\n",
      "  if __name__ == '__main__':\n",
      "/home/andreas/anaconda3/lib/python3.7/site-packages/ipykernel_launcher.py:10: DeprecationWarning: `np.float` is a deprecated alias for the builtin `float`. To silence this warning, use `float` by itself. Doing this will not modify any behavior and is safe. If you specifically wanted the numpy scalar type, use `np.float64` here.\n",
      "Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations\n",
      "  # Remove the CWD from sys.path while we load stuff.\n"
     ]
    }
   ],
   "source": [
    "#Normalise In\n",
    "scalerX = StandardScaler()\n",
    "\n",
    "scalerX.fit(X_train_N)\n",
    "\n",
    "X_train_N = scalerX.transform(X_train_N).astype(np.float)\n",
    "X_valid_N = scalerX.transform(X_valid_N).astype(np.float)\n",
    "X_test_N  = scalerX.transform(X_test_N).astype(np.float)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "#combine numerical and categorical\n",
    "X_train = np.concatenate([X_train_N,X_train_C.values],axis=1)\n",
    "X_test  = np.concatenate([X_test_N,X_test_C.values],axis=1)\n",
    "X_valid = np.concatenate([X_valid_N,X_valid_C.values],axis=1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "checkpoint='saved/savefileAD'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "def True_ACC(p,t):\n",
    "                return np.mean(np.where(np.round(p)==np.round(t),1,0))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "LWTA\n",
      "[Warning] Auto convert float64 to float32, this could lead to extra memory usage.\n",
      "[Warning] Auto convert float64 to float32, this could lead to extra memory usage.\n",
      "Train on 26048 samples, validate on 6513 samples:\n",
      "Epoch 1/1 - 28s - loss: 0.5826 - acc: 0.6968 - cross_entropy: 0.5881 - val_loss: 0.5598 - val_acc: 0.7629 - val_cross_entropy: 0.5652 - lr: 1e-07\n",
      "[Warning] Auto convert float64 to float32, this could lead to extra memory usage.\n",
      "[Warning] Auto convert float64 to float32, this could lead to extra memory usage.\n",
      "Train on 26048 samples, validate on 6513 samples:\n",
      "Epoch 1/1 - 28s - loss: 0.3822 - acc: 0.8243 - cross_entropy: 0.3857 - val_loss: 0.3373 - val_acc: 0.8399 - val_cross_entropy: 0.3404 - lr: 1e-03\n",
      "[Warning] Auto convert float64 to float32, this could lead to extra memory usage.\n",
      "[Warning] Auto convert float64 to float32, this could lead to extra memory usage.\n",
      "[Warning] Auto convert float64 to float32, this could lead to extra memory usage.\n",
      "[Warning] Auto convert float64 to float32, this could lead to extra memory usage.\n",
      "[Warning] Auto convert float64 to float32, this could lead to extra memory usage.\n",
      "[Warning] Auto convert float64 to float32, this could lead to extra memory usage.\n",
      "[Warning] Auto convert float64 to float32, this could lead to extra memory usage.\n",
      "[Warning] Auto convert float64 to float32, this could lead to extra memory usage.\n",
      "[Warning] Auto convert float64 to float32, this could lead to extra memory usage.\n",
      "[Warning] Auto convert float64 to float32, this could lead to extra memory usage.\n",
      "[Warning] Auto convert float64 to float32, this could lead to extra memory usage.\n",
      "[Warning] Auto convert float64 to float32, this could lead to extra memory usage.\n",
      "[Warning] Auto convert float64 to float32, this could lead to extra memory usage.\n",
      "[Warning] Auto convert float64 to float32, this could lead to extra memory usage.\n",
      "[Warning] Auto convert float64 to float32, this could lead to extra memory usage.\n",
      "[Warning] Auto convert float64 to float32, this could lead to extra memory usage.\n",
      "[Warning] Auto convert float64 to float32, this could lead to extra memory usage.\n",
      "[Warning] Auto convert float64 to float32, this could lead to extra memory usage.\n",
      "[Warning] Auto convert float64 to float32, this could lead to extra memory usage.\n",
      "[Warning] Auto convert float64 to float32, this could lead to extra memory usage.\n",
      "[Warning] Auto convert float64 to float32, this could lead to extra memory usage.\n",
      "[Warning] Auto convert float64 to float32, this could lead to extra memory usage.\n",
      "[Warning] Auto convert float64 to float32, this could lead to extra memory usage.\n",
      "[Warning] Auto convert float64 to float32, this could lead to extra memory usage.\n",
      "[Warning] Auto convert float64 to float32, this could lead to extra memory usage.\n",
      "[Warning] Auto convert float64 to float32, this could lead to extra memory usage.\n",
      "[Warning] Auto convert float64 to float32, this could lead to extra memory usage.\n",
      "[Warning] Auto convert float64 to float32, this could lead to extra memory usage.\n",
      "[Warning] Auto convert float64 to float32, this could lead to extra memory usage.\n",
      "[Warning] Auto convert float64 to float32, this could lead to extra memory usage.\n",
      "[Warning] Auto convert float64 to float32, this could lead to extra memory usage.\n",
      "[Warning] Auto convert float64 to float32, this could lead to extra memory usage.\n",
      "[Warning] Auto convert float64 to float32, this could lead to extra memory usage.\n",
      "[Warning] Auto convert float64 to float32, this could lead to extra memory usage.\n",
      "[Warning] Auto convert float64 to float32, this could lead to extra memory usage.\n",
      "[Warning] Auto convert float64 to float32, this could lead to extra memory usage.\n",
      "[Warning] Auto convert float64 to float32, this could lead to extra memory usage.\n",
      "[Warning] Auto convert float64 to float32, this could lead to extra memory usage.\n",
      "[Warning] Auto convert float64 to float32, this could lead to extra memory usage.\n",
      "[Warning] Auto convert float64 to float32, this could lead to extra memory usage.\n",
      "[Warning] Auto convert float64 to float32, this could lead to extra memory usage.\n",
      "[Warning] Auto convert float64 to float32, this could lead to extra memory usage.\n",
      "[Warning] Auto convert float64 to float32, this could lead to extra memory usage.\n",
      "[Warning] Auto convert float64 to float32, this could lead to extra memory usage.\n",
      "[Warning] Auto convert float64 to float32, this could lead to extra memory usage.\n",
      "[Warning] Auto convert float64 to float32, this could lead to extra memory usage.\n",
      "[Warning] Auto convert float64 to float32, this could lead to extra memory usage.\n",
      "[Warning] Auto convert float64 to float32, this could lead to extra memory usage.\n",
      "[Warning] Auto convert float64 to float32, this could lead to extra memory usage.\n",
      "[Warning] Auto convert float64 to float32, this could lead to extra memory usage.\n",
      "[Warning] Auto convert float64 to float32, this could lead to extra memory usage.\n",
      "[Warning] Auto convert float64 to float32, this could lead to extra memory usage.\n",
      "[Warning] Auto convert float64 to float32, this could lead to extra memory usage.\n",
      "[Warning] Auto convert float64 to float32, this could lead to extra memory usage.\n",
      "[Warning] Auto convert float64 to float32, this could lead to extra memory usage.\n",
      "[Warning] Auto convert float64 to float32, this could lead to extra memory usage.\n",
      "[Warning] Auto convert float64 to float32, this could lead to extra memory usage.\n",
      "[Warning] Auto convert float64 to float32, this could lead to extra memory usage.\n",
      "[Warning] Auto convert float64 to float32, this could lead to extra memory usage.\n",
      "[Warning] Auto convert float64 to float32, this could lead to extra memory usage.\n",
      "[Warning] Auto convert float64 to float32, this could lead to extra memory usage.\n",
      "[Warning] Auto convert float64 to float32, this could lead to extra memory usage.\n",
      "[Warning] Auto convert float64 to float32, this could lead to extra memory usage.\n",
      "[Warning] Auto convert float64 to float32, this could lead to extra memory usage.\n",
      "Test Score: 0.851667587985996\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "923"
      ]
     },
     "execution_count": 10,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "Or_model = MainModel(\n",
    "    categories        = (9, 16,  7 , 15,6  ,  5,  2, 42),     # setting up all classes per categorical feature          \n",
    "    num_continuous    = 6,              \n",
    "    dim               = 16  ,            \n",
    "    dim_out           = 2,               \n",
    "    depth             = 3,               \n",
    "    heads             = 8,              \n",
    "    attn_dropout      = 0.1,          \n",
    "    ff_dropout        = 0.1,                    \n",
    "    U                 = 2, \n",
    "    cases             = 16,\n",
    ")\n",
    "\n",
    "#wrappers allows size N prediction sample averaging\n",
    "no_model = Num_Cat(Or_model,num_number=6,classes=2,Sample_size=16)\n",
    "model    = keras4torch.Model(no_model,).build([14])\n",
    "\n",
    "#Warm Up Train\n",
    "optimizer=torch.optim.AdamW(model.parameters(),lr=0.0001,weight_decay=0.0001,)\n",
    "sch=torch.optim.lr_scheduler.LinearLR(optimizer, start_factor=0.001, total_iters=10,  verbose=False)\n",
    "model.compile(optimizer=optimizer, loss=MyClassLoss(0.01,1), metrics=['accuracy', F.cross_entropy])\n",
    "callbacks=[ModelCheckpoint(checkpoint,monitor='val_acc',mode='max'),LRScheduler(sch)]\n",
    "model.fit(X_train, Y_train.values,\n",
    "                      epochs=1, batch_size=256,\n",
    "                      validation_data=(X_valid,y_valid.values),\n",
    "                      verbose=2,validation_batch_size=128,\n",
    "                      callbacks=callbacks)\n",
    "\n",
    "\n",
    "#Main Train\n",
    "optimizer=torch.optim.AdamW(model.parameters(),lr=0.001,weight_decay=0.0001,)\n",
    "model.compile(optimizer=optimizer, loss=MyClassLoss(0.01,1), metrics=['accuracy', F.cross_entropy])\n",
    "scheduler =LRScheduler( torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', patience=3, factor=0.5,min_lr=0.00001))\n",
    "callbacks=[ModelCheckpoint(checkpoint,monitor='val_loss',mode='min'),scheduler]\n",
    "model.fit(X_train, Y_train.values,\n",
    "        epochs=1, batch_size=256,\n",
    "        validation_data=(X_valid,y_valid.values),\n",
    "        verbose=2,validation_batch_size=128,\n",
    "        callbacks=callbacks)\n",
    "\n",
    "\n",
    "#Test It\n",
    "model.load_weights(checkpoint)\n",
    "no_model.reset_Sample_size(1)\n",
    "logits=0\n",
    "for i in range(0,64):\n",
    "        logits+=pd.DataFrame(model.predict(X_test,batch_size=4096))\n",
    "\n",
    "Test = (True_ACC(logits.idxmax(axis=1).values.reshape((-1)),y_test.values.reshape((-1))))\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "print('Test Score:',Test)\n",
    "\n",
    "    "
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.4"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
