{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import keras4torch\n",
    "from   keras4torch.callbacks  import ModelCheckpoint,LRScheduler\n",
    "import torch\n",
    "import torch.nn    as nn\n",
    "import torch.optim as optim\n",
    "import  torch.nn.functional as     F\n",
    "import numpy       as np\n",
    "import pandas      as pd\n",
    "from copy import deepcopy\n",
    "import  matplotlib.pyplot   as     plt\n",
    "from    sklearn.preprocessing import StandardScaler, QuantileTransformer\n",
    "from    datetime import datetime\n",
    "import  gc\n",
    "import STab\n",
    "from STab import MyClassLoss,CatMap,Num_Cat\n",
    "from   STab import mainmodel, LWTA, Gsoftmax\n",
    "MainModel=mainmodel.MainModel"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "##Load DAta\n",
    "X_test=pd.DataFrame(np.load('Data/helena/N_test.npy')).astype(np.float)\n",
    "y_test=pd.DataFrame(np.load('Data/helena/y_test.npy')).astype(np.int)[0]\n",
    "\n",
    "X_train=pd.DataFrame(np.load('Data/helena/N_train.npy')).astype(np.float)\n",
    "Y_train=pd.DataFrame(np.load('Data/helena/y_train.npy')).astype(np.int)[0]\n",
    "\n",
    "\n",
    "X_valid=pd.DataFrame(np.load('Data/helena/N_val.npy')).astype(np.float)\n",
    "y_valid=pd.DataFrame(np.load('Data/helena/y_val.npy')).astype(np.int)[0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#Normalise In\n",
    "scalerX = QuantileTransformer(output_distribution='normal',\n",
    "                              n_quantiles=max(min(X_train.shape[0]//30,1000),10),\n",
    "                              subsample=1e9,)\n",
    "\n",
    "\n",
    "scalerX.fit(X_train)\n",
    "\n",
    "X_train = scalerX.transform(X_train).astype(np.float) \n",
    "X_valid = scalerX.transform(X_valid).astype(np.float) \n",
    "X_test  = scalerX.transform(X_test).astype(np.float) "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "chpfilename='saved/savefileHE'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def True_ACC(p,t):\n",
    "                return np.mean(np.where(np.round(p)==np.round(t),1,0))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "Or_model = MainModel(\n",
    "    categories        = (),               \n",
    "    num_continuous    = 27,               \n",
    "    dim               = 96  ,             \n",
    "    dim_out           = 101,                \n",
    "    depth             = 7,               \n",
    "    heads             = 8,                \n",
    "    attn_dropout      = 0.25 ,              \n",
    "    ff_dropout        = 0.25,              \n",
    "    U                 = 2, \n",
    "    cases             = 16\n",
    "\n",
    ")\n",
    "no_model = Num_Cat(Or_model,num_number=27,classes=101)\n",
    "model    = keras4torch.Model(no_model,).build([27])\n",
    "\n",
    "#Warm Up Train\n",
    "no_model.reset_Sample_size(1)\n",
    "optimizer=torch.optim.AdamW(model.parameters(),lr=0.0001,weight_decay=0.0001,)\n",
    "sch=torch.optim.lr_scheduler.LinearLR(optimizer, start_factor=0.00001, total_iters=5,  verbose=False)\n",
    "model.compile(optimizer=optimizer, loss=MyClassLoss(0.1,1), metrics=['accuracy'])\n",
    "callbacks=[LRScheduler(sch)]\n",
    "model.fit(X_train, Y_train.values,\n",
    "        epochs=5, batch_size=512,\n",
    "        validation_data=(X_valid,y_valid.values),\n",
    "        verbose=2,validation_batch_size=1024,\n",
    "        callbacks=callbacks)\n",
    "\n",
    "#Main Train\n",
    "no_model.reset_Sample_size(64)\n",
    "optimizer=torch.optim.AdamW(model.parameters(),lr=0.001,weight_decay=0.000001,)\n",
    "model.compile(optimizer=optimizer, loss=MyClassLoss(0.01,1), metrics=['accuracy', F.cross_entropy])\n",
    "scheduler =LRScheduler( torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', patience=5, factor=0.5,min_lr=0.00001))\n",
    "callbacks=[scheduler,ModelCheckpoint(chpfilename,monitor='val_acc',mode='max')]\n",
    "model.fit(X_train, Y_train.values,\n",
    "        epochs=100, batch_size=512,\n",
    "        validation_data=(X_valid,y_valid.values),\n",
    "\n",
    "        verbose=2,validation_batch_size=256,\n",
    "        callbacks=callbacks)\n",
    "\n",
    "\n",
    "model.load_weights(chpfilename)\n",
    "\n",
    "no_model.reset_Sample_size(1)\n",
    "logits=0\n",
    "\n",
    "for i in range(0,64):\n",
    "\n",
    "        logits+=pd.DataFrame(model.predict(X_test,batch_size=4096))\n",
    "\n",
    "\n",
    "Test = (True_ACC(logits.idxmax(axis=1).values.reshape((-1)),y_test.values.reshape((-1))))\n",
    "\n",
    "print('Test Score:',Test)\n",
    "\n",
    "    "
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.4"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
