{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import keras4torch\n",
    "from   keras4torch.callbacks  import ModelCheckpoint,LRScheduler\n",
    "import torch\n",
    "import torch.nn    as nn\n",
    "import torch.optim as optim\n",
    "import  torch.nn.functional as     F\n",
    "import numpy       as np\n",
    "import pandas      as pd\n",
    "from copy import deepcopy\n",
    "import  matplotlib.pyplot   as     plt\n",
    "from    sklearn.preprocessing import StandardScaler, QuantileTransformer\n",
    "from    datetime import datetime\n",
    "import  gc\n",
    "import STab\n",
    "from STab import *\n",
    "from   STab import mainmodel, LWTA, Gsoftmax\n",
    "MainModel=mainmodel.MainModel"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "##Load DAta\n",
    "X_test_N=pd.DataFrame(np.load('Data/adult/N_test.npy')).astype(np.float)\n",
    "X_test_C=pd.DataFrame(np.load('Data/adult/C_test.npy'))\n",
    "y_test=pd.DataFrame(np.load('Data/adult/y_test.npy')).astype(np.int)[0]\n",
    "\n",
    "X_train_N=pd.DataFrame(np.load('Data/adult/N_train.npy')).astype(np.float)\n",
    "X_train_C=pd.DataFrame(np.load('Data/adult/C_train.npy'))\n",
    "Y_train=pd.DataFrame(np.load('Data/adult/y_train.npy')).astype(np.int)[0]\n",
    "\n",
    "\n",
    "X_valid_N=pd.DataFrame(np.load('Data/adult/N_val.npy')).astype(np.float)\n",
    "X_valid_C=pd.DataFrame(np.load('Data/adult/C_val.npy'))\n",
    "y_valid=pd.DataFrame(np.load('Data/adult/y_val.npy')).astype(np.int)[0]\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "catmap=CatMap(X_train_C)\n",
    "X_train_C=catmap(X_train_C)\n",
    "X_valid_C=catmap(X_valid_C)\n",
    "X_test_C=catmap(X_test_C)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#Normalise In\n",
    "scalerX = StandardScaler()\n",
    "\n",
    "scalerX.fit(X_train_N)\n",
    "\n",
    "X_train_N = scalerX.transform(X_train_N).astype(np.float)\n",
    "X_valid_N = scalerX.transform(X_valid_N).astype(np.float)\n",
    "X_test_N  = scalerX.transform(X_test_N).astype(np.float)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#combine numerical and categorical\n",
    "X_train = np.concatenate([X_train_N,X_train_C.values],axis=1)\n",
    "X_test  = np.concatenate([X_test_N,X_test_C.values],axis=1)\n",
    "X_valid = np.concatenate([X_valid_N,X_valid_C.values],axis=1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "checkpoint='saved/savefileAD'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def True_ACC(p,t):\n",
    "                return np.mean(np.where(np.round(p)==np.round(t),1,0))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "Or_model = MainModel(\n",
    "    categories        = (9, 16,  7 , 15,6  ,  5,  2, 42),     # setting up all classes per categorical feature          \n",
    "    num_continuous    = 6,              \n",
    "    dim               = 16  ,            \n",
    "    dim_out           = 2,               \n",
    "    depth             = 3,               \n",
    "    heads             = 8,              \n",
    "    attn_dropout      = 0.1,          \n",
    "    ff_dropout        = 0.1,                    \n",
    "    U                 = 2, \n",
    "    cases             = 16,\n",
    ")\n",
    "\n",
    "#wrappers allows size N prediction sample averaging\n",
    "no_model = Num_Cat(Or_model,num_number=6,classes=2,Sample_size=16)\n",
    "model    = keras4torch.Model(no_model,).build([14])\n",
    "\n",
    "#Warm Up Train\n",
    "optimizer=torch.optim.AdamW(model.parameters(),lr=0.0001,weight_decay=0.0001,)\n",
    "sch=torch.optim.lr_scheduler.LinearLR(optimizer, start_factor=0.001, total_iters=10,  verbose=False)\n",
    "model.compile(optimizer=optimizer, loss=MyClassLoss(0.01,1), metrics=['accuracy', F.cross_entropy])\n",
    "callbacks=[ModelCheckpoint(checkpoint,monitor='val_acc',mode='max'),LRScheduler(sch)]\n",
    "model.fit(X_train, Y_train.values,\n",
    "                      epochs=1, batch_size=256,\n",
    "                      validation_data=(X_valid,y_valid.values),\n",
    "                      verbose=2,validation_batch_size=128,\n",
    "                      callbacks=callbacks)\n",
    "\n",
    "\n",
    "#Main Train\n",
    "optimizer=torch.optim.AdamW(model.parameters(),lr=0.001,weight_decay=0.0001,)\n",
    "model.compile(optimizer=optimizer, loss=MyClassLoss(0.01,1), metrics=['accuracy', F.cross_entropy])\n",
    "scheduler =LRScheduler( torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', patience=3, factor=0.5,min_lr=0.00001))\n",
    "callbacks=[ModelCheckpoint(checkpoint,monitor='val_loss',mode='min'),scheduler]\n",
    "model.fit(X_train, Y_train.values,\n",
    "        epochs=1, batch_size=256,\n",
    "        validation_data=(X_valid,y_valid.values),\n",
    "        verbose=2,validation_batch_size=128,\n",
    "        callbacks=callbacks)\n",
    "\n",
    "\n",
    "#Test It\n",
    "model.load_weights(checkpoint)\n",
    "no_model.reset_Sample_size(1)\n",
    "logits=0\n",
    "for i in range(0,64):\n",
    "        logits+=pd.DataFrame(model.predict(X_test,batch_size=4096))\n",
    "\n",
    "Test = (True_ACC(logits.idxmax(axis=1).values.reshape((-1)),y_test.values.reshape((-1))))\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "print('Test Score:',Test)\n",
    "\n",
    "    "
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.4"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
