{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "os.environ['CUDA_VISIBLE_DEVICES'] = '0'\n",
    "import keras4torch\n",
    "from   keras4torch.callbacks  import ModelCheckpoint,LRScheduler\n",
    "import torch\n",
    "import torch.nn    as nn\n",
    "import torch.optim as optim\n",
    "import  torch.nn.functional as     F\n",
    "import numpy       as np\n",
    "import pandas      as pd\n",
    "from copy import deepcopy\n",
    "import  matplotlib.pyplot   as     plt\n",
    "from    sklearn.preprocessing import StandardScaler, QuantileTransformer\n",
    "from    datetime import datetime\n",
    "import  gc\n",
    "import STab\n",
    "from STab import MyClassLoss,CatMap,Num_Cat\n",
    "from   STab import mainmodel, LWTA, Gsoftmax\n",
    "MainModel=mainmodel.MainModel"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "##Load DAta\n",
    "X_train=pd.DataFrame(np.load('Data/higgs_small/N_train.npy')).astype(np.float)\n",
    "X_test=pd.DataFrame(np.load('Data/higgs_small/N_test.npy')).astype(np.float).fillna(X_train.mean())\n",
    "y_test=pd.DataFrame(np.load('Data/higgs_small/y_test.npy')).astype(np.int)[0]\n",
    "\n",
    "X_train=pd.DataFrame(np.load('Data/higgs_small/N_train.npy')).astype(np.float).fillna(X_train.mean())\n",
    "Y_train=pd.DataFrame(np.load('Data/higgs_small/y_train.npy')).astype(np.int)[0]\n",
    "\n",
    "X_valid=pd.DataFrame(np.load('Data/higgs_small/N_val.npy')).astype(np.float).fillna(X_train.mean())\n",
    "y_valid=pd.DataFrame(np.load('Data/higgs_small/y_val.npy')).astype(np.int)[0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#Normalise In \n",
    "scalerX = StandardScaler()\n",
    "\n",
    "scalerX.fit(X_train)\n",
    "                                      \n",
    "X_train = scalerX.transform(X_train).astype(np.float) \n",
    "X_valid = scalerX.transform(X_valid).astype(np.float) \n",
    "X_test  = scalerX.transform(X_test).astype(np.float)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "checkpoint='saved/savefileHI'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def True_ACC(p,t):\n",
    "                return np.mean(np.where(np.round(p)==np.round(t),1,0))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "Or_model = MainModel(\n",
    "    categories        = (),               # tuple containing the number of unique values within each category\n",
    "    num_continuous    = 28,               \n",
    "    dim               = 176  ,           \n",
    "    dim_out           = 2,                \n",
    "    depth             = 4,                \n",
    "    heads             = 8,                \n",
    "    attn_dropout      = 0.225 ,            \n",
    "    ff_dropout        = 0.225,             \n",
    "    U                 = 2, \n",
    "    cases             = 4\n",
    "\n",
    ")\n",
    "no_model = Num_Cat(Or_model,num_number=28,classes=2)\n",
    "model    = keras4torch.Model(no_model,).build([28])\n",
    "\n",
    "#Warm UP\n",
    "optimizer=torch.optim.AdamW(model.parameters(),lr=0.001,weight_decay=0.0001,)\n",
    "sch=torch.optim.lr_scheduler.LinearLR(optimizer, start_factor=0.001, total_iters=5,  verbose=False)\n",
    "model.compile(optimizer=optimizer, loss=MyClassLoss(0.1,1), metrics=['accuracy', F.cross_entropy])\n",
    "callbacks=[ModelCheckpoint(checkpoint,monitor='val_acc',mode='max'),LRScheduler(sch)]\n",
    "\n",
    "model.fit(X_train, Y_train.values,\n",
    "                      epochs=5, batch_size=512,\n",
    "                      validation_data=(X_test,y_test.values),\n",
    "                      verbose=2,validation_batch_size=256,\n",
    "                      callbacks=callbacks)\n",
    "\n",
    "#Main Train\n",
    "optimizer=torch.optim.AdamW(model.parameters(),lr=0.001,weight_decay=0.0001,)\n",
    "model.compile(optimizer=optimizer, loss=MyClassLoss(0.01,1), metrics=['accuracy', F.cross_entropy])\n",
    "scheduler =LRScheduler( torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', patience=5, factor=0.5,min_lr=0.000001))\n",
    "callbacks=[ModelCheckpoint(checkpoint,monitor='val_acc',mode='max'),scheduler]\n",
    "\n",
    "\n",
    "model.fit(X_train, Y_train.values,\n",
    "        epochs=120, batch_size=512,\n",
    "        validation_data=(X_valid,y_valid.values),\n",
    "        verbose=2,validation_batch_size=256,\n",
    "        callbacks=callbacks)\n",
    "\n",
    "\n",
    "model.load_weights(checkpoint)\n",
    "no_model.reset_Sample_size(1)\n",
    "def True_ACC(p,t):\n",
    "                return np.mean(np.where(np.round(p)==np.round(t),1,0))\n",
    "    \n",
    "logits=0\n",
    "for i in range(0,64):\n",
    "\n",
    "        logits+=pd.DataFrame(model.predict(X_test,batch_size=4096))\n",
    "Test = (True_ACC(logits.idxmax(axis=1).values.reshape((-1)),y_test.values.reshape((-1))))\n",
    "\n",
    "print('Test: ',Test)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.4"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
