{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import keras4torch\n",
    "from   keras4torch.callbacks  import ModelCheckpoint,LRScheduler\n",
    "import torch\n",
    "import torch.nn    as nn\n",
    "import torch.optim as optim\n",
    "import  torch.nn.functional as     F\n",
    "import numpy       as np\n",
    "import pandas      as pd\n",
    "from copy import deepcopy\n",
    "import  matplotlib.pyplot   as     plt\n",
    "from    sklearn.preprocessing import StandardScaler, QuantileTransformer\n",
    "from    datetime import datetime\n",
    "import  gc\n",
    "import STab\n",
    "from STab import MyClassLoss,CatMap,Num_Cat\n",
    "from   STab import mainmodel, LWTA, Gsoftmax\n",
    "MainModel=mainmodel.MainModel"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#Split the data # Vale ta opos ta references\n",
    "X_test=pd.DataFrame(np.load('Data/jannis/N_test.npy')).astype(np.float).fillna(0)\n",
    "y_test=pd.DataFrame(np.load('Data/jannis/y_test.npy')).astype(np.int)\n",
    "\n",
    "X_train=pd.DataFrame(np.load('Data/jannis/N_train.npy')).astype(np.float).fillna(0)\n",
    "Y_train=pd.DataFrame(np.load('Data/jannis/y_train.npy')).astype(np.int)\n",
    "\n",
    "\n",
    "X_valid=pd.DataFrame(np.load('Data/jannis/N_val.npy')).astype(np.float).fillna(0)\n",
    "y_valid=pd.DataFrame(np.load('Data/jannis/y_val.npy')).astype(np.int)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#Normalise In\n",
    "scale_factor=X_train.max()\n",
    "\n",
    "X_test  = (X_test/scale_factor).values\n",
    "X_valid = (X_valid/scale_factor).values\n",
    "X_train = (X_train/scale_factor).values"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "checkpoint='saved/savefileJA'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def True_ACC(p,t):\n",
    "                return np.mean(np.where(np.round(p)==np.round(t),1,0))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "Or_model = MainModel(\n",
    "    categories        = (),               \n",
    "    num_continuous    = 54,               \n",
    "    dim               = 192  ,            \n",
    "    dim_out           = 4,                \n",
    "    depth             = 4,                \n",
    "    heads             = 8,               \n",
    "    attn_dropout      = 0.25 ,               \n",
    "    ff_dropout        = 0.25,              \n",
    "    U                 = 2, \n",
    "    cases             = 64\n",
    "\n",
    ")\n",
    "no_model = Num_Cat(Or_model,num_number=54,classes=4)\n",
    "model    = keras4torch.Model(no_model,).build([54])\n",
    "\n",
    "\n",
    "#Warm UP\n",
    "optimizer=torch.optim.AdamW(model.parameters(),lr=0.0001,weight_decay=0.0001,)\n",
    "sch=torch.optim.lr_scheduler.LinearLR(optimizer, start_factor=0.001, total_iters=10,  verbose=False)\n",
    "model.compile(optimizer=optimizer, loss=MyClassLoss(0.01,1), metrics=['accuracy', F.cross_entropy])\n",
    "callbacks=[LRScheduler(sch)]\n",
    "model.fit(X_train, Y_train[0].values,\n",
    "                      epochs=10, batch_size=512,\n",
    "                      validation_data=(X_valid,y_valid[0].values),\n",
    "                      verbose=2,validation_batch_size=128,\n",
    "                      callbacks=callbacks)\n",
    "\n",
    "#Main Train\n",
    "optimizer=torch.optim.AdamW(model.parameters(),lr=0.001,weight_decay=0.0001,)\n",
    "model.compile(optimizer=optimizer, loss=MyClassLoss(0.01,1), metrics=['accuracy', F.cross_entropy])\n",
    "scheduler =LRScheduler( torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', patience=5, factor=0.5,min_lr=0.00001))\n",
    "callbacks=[ModelCheckpoint(checkpoint,monitor='val_acc',mode='max'),scheduler]\n",
    "model.fit(X_train, Y_train[0].values,\n",
    "        epochs=180, batch_size=512,\n",
    "        validation_data=(X_valid,y_valid[0].values),\n",
    "        verbose=2,validation_batch_size=256,\n",
    "        callbacks=callbacks)\n",
    "\n",
    "\n",
    "    \n",
    "model.load_weights(checkpoint)\n",
    "\n",
    "no_model.reset_Sample_size(1)\n",
    "logits=0\n",
    "\n",
    "for i in range(0,64):\n",
    "        logits+=pd.DataFrame(model.predict(X_test,batch_size=2048))\n",
    "        \n",
    "Test = (True_ACC(logits.idxmax(axis=1).values.reshape((-1,1)),y_test))\n",
    "\n",
    "print('Test: ',Test)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.4"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
