{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import keras4torch\n",
    "from   keras4torch.callbacks  import ModelCheckpoint,LRScheduler\n",
    "import torch\n",
    "import torch.nn    as nn\n",
    "import torch.optim as optim\n",
    "import  torch.nn.functional as     F\n",
    "import numpy       as np\n",
    "import pandas      as pd\n",
    "from copy import deepcopy\n",
    "import  matplotlib.pyplot   as     plt\n",
    "from    sklearn.preprocessing import StandardScaler, QuantileTransformer\n",
    "from    datetime import datetime\n",
    "import  gc\n",
    "import STab\n",
    "from STab import MyClassLoss,CatMap,Num_Cat\n",
    "from   STab import mainmodel, LWTA, Gsoftmax\n",
    "MainModel=mainmodel.MainModel"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/andreas/anaconda3/lib/python3.7/site-packages/ipykernel_launcher.py:2: DeprecationWarning: `np.float` is a deprecated alias for the builtin `float`. To silence this warning, use `float` by itself. Doing this will not modify any behavior and is safe. If you specifically wanted the numpy scalar type, use `np.float64` here.\n",
      "Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations\n",
      "  \n",
      "/home/andreas/anaconda3/lib/python3.7/site-packages/ipykernel_launcher.py:3: DeprecationWarning: `np.int` is a deprecated alias for the builtin `int`. To silence this warning, use `int` by itself. Doing this will not modify any behavior and is safe. When replacing `np.int`, you may wish to use e.g. `np.int64` or `np.int32` to specify the precision. If you wish to review your current use, check the release note link for additional information.\n",
      "Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations\n",
      "  This is separate from the ipykernel package so we can avoid doing imports until\n",
      "/home/andreas/anaconda3/lib/python3.7/site-packages/ipykernel_launcher.py:5: DeprecationWarning: `np.float` is a deprecated alias for the builtin `float`. To silence this warning, use `float` by itself. Doing this will not modify any behavior and is safe. If you specifically wanted the numpy scalar type, use `np.float64` here.\n",
      "Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations\n",
      "  \"\"\"\n",
      "/home/andreas/anaconda3/lib/python3.7/site-packages/ipykernel_launcher.py:6: DeprecationWarning: `np.int` is a deprecated alias for the builtin `int`. To silence this warning, use `int` by itself. Doing this will not modify any behavior and is safe. When replacing `np.int`, you may wish to use e.g. `np.int64` or `np.int32` to specify the precision. If you wish to review your current use, check the release note link for additional information.\n",
      "Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations\n",
      "  \n",
      "/home/andreas/anaconda3/lib/python3.7/site-packages/ipykernel_launcher.py:9: DeprecationWarning: `np.float` is a deprecated alias for the builtin `float`. To silence this warning, use `float` by itself. Doing this will not modify any behavior and is safe. If you specifically wanted the numpy scalar type, use `np.float64` here.\n",
      "Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations\n",
      "  if __name__ == '__main__':\n",
      "/home/andreas/anaconda3/lib/python3.7/site-packages/ipykernel_launcher.py:10: DeprecationWarning: `np.int` is a deprecated alias for the builtin `int`. To silence this warning, use `int` by itself. Doing this will not modify any behavior and is safe. When replacing `np.int`, you may wish to use e.g. `np.int64` or `np.int32` to specify the precision. If you wish to review your current use, check the release note link for additional information.\n",
      "Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations\n",
      "  # Remove the CWD from sys.path while we load stuff.\n"
     ]
    }
   ],
   "source": [
    "##Load DAta\n",
    "X_test=pd.DataFrame(np.load('Data/helena/N_test.npy')).astype(np.float)\n",
    "y_test=pd.DataFrame(np.load('Data/helena/y_test.npy')).astype(np.int)[0]\n",
    "\n",
    "X_train=pd.DataFrame(np.load('Data/helena/N_train.npy')).astype(np.float)\n",
    "Y_train=pd.DataFrame(np.load('Data/helena/y_train.npy')).astype(np.int)[0]\n",
    "\n",
    "\n",
    "X_valid=pd.DataFrame(np.load('Data/helena/N_val.npy')).astype(np.float)\n",
    "y_valid=pd.DataFrame(np.load('Data/helena/y_val.npy')).astype(np.int)[0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/andreas/anaconda3/lib/python3.7/site-packages/ipykernel_launcher.py:9: DeprecationWarning: `np.float` is a deprecated alias for the builtin `float`. To silence this warning, use `float` by itself. Doing this will not modify any behavior and is safe. If you specifically wanted the numpy scalar type, use `np.float64` here.\n",
      "Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations\n",
      "  if __name__ == '__main__':\n",
      "/home/andreas/anaconda3/lib/python3.7/site-packages/ipykernel_launcher.py:10: DeprecationWarning: `np.float` is a deprecated alias for the builtin `float`. To silence this warning, use `float` by itself. Doing this will not modify any behavior and is safe. If you specifically wanted the numpy scalar type, use `np.float64` here.\n",
      "Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations\n",
      "  # Remove the CWD from sys.path while we load stuff.\n",
      "/home/andreas/anaconda3/lib/python3.7/site-packages/ipykernel_launcher.py:11: DeprecationWarning: `np.float` is a deprecated alias for the builtin `float`. To silence this warning, use `float` by itself. Doing this will not modify any behavior and is safe. If you specifically wanted the numpy scalar type, use `np.float64` here.\n",
      "Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations\n",
      "  # This is added back by InteractiveShellApp.init_path()\n"
     ]
    }
   ],
   "source": [
    "#Normalise In\n",
    "scalerX = QuantileTransformer(output_distribution='normal',\n",
    "                              n_quantiles=max(min(X_train.shape[0]//30,1000),10),\n",
    "                              subsample=1e9,)\n",
    "\n",
    "\n",
    "scalerX.fit(X_train)\n",
    "\n",
    "X_train = scalerX.transform(X_train).astype(np.float) \n",
    "X_valid = scalerX.transform(X_valid).astype(np.float) \n",
    "X_test  = scalerX.transform(X_test).astype(np.float) "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "chpfilename='saved/savefileHE'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "def True_ACC(p,t):\n",
    "                return np.mean(np.where(np.round(p)==np.round(t),1,0))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[Warning] Auto convert float64 to float32, this could lead to extra memory usage.\n",
      "[Warning] Auto convert float64 to float32, this could lead to extra memory usage.\n",
      "Train on 41724 samples, validate on 10432 samples:\n",
      "Epoch 1/5 - 36s - loss: 4.3877 - acc: 0.0093 - val_loss: 4.3868 - val_acc: 0.0074 - lr: 1e-09\n",
      "Epoch 2/5 - 36s - loss: 4.1064 - acc: 0.0498 - val_loss: 3.9564 - val_acc: 0.0782 - lr: 2e-05\n",
      "Epoch 3/5 - 36s - loss: 3.7745 - acc: 0.1101 - val_loss: 3.7086 - val_acc: 0.1257 - lr: 4e-05\n",
      "Epoch 4/5 - 37s - loss: 3.5702 - acc: 0.1516 - val_loss: 3.5300 - val_acc: 0.1609 - lr: 6e-05\n",
      "Epoch 5/5 - 36s - loss: 3.3946 - acc: 0.1849 - val_loss: 3.3637 - val_acc: 0.1942 - lr: 8e-05\n",
      "[Warning] Auto convert float64 to float32, this could lead to extra memory usage.\n",
      "[Warning] Auto convert float64 to float32, this could lead to extra memory usage.\n",
      "Train on 41724 samples, validate on 10432 samples:\n",
      "Epoch 1/20 - 98s - loss: 3.2329 - acc: 0.2540 - cross_entropy: 3.2611 - val_loss: 2.9855 - val_acc: 0.2941 - val_cross_entropy: 3.0112 - lr: 1e-03\n",
      "Epoch 2/20 - 99s - loss: 2.9464 - acc: 0.2967 - cross_entropy: 2.9717 - val_loss: 2.8278 - val_acc: 0.3161 - val_cross_entropy: 2.8520 - lr: 1e-03\n",
      "Epoch 3/20 - 99s - loss: 2.8398 - acc: 0.3165 - cross_entropy: 2.8641 - val_loss: 2.7606 - val_acc: 0.3250 - val_cross_entropy: 2.7840 - lr: 1e-03\n",
      "Epoch 4/20 - 99s - loss: 2.7811 - acc: 0.3255 - cross_entropy: 2.8048 - val_loss: 2.7178 - val_acc: 0.3382 - val_cross_entropy: 2.7408 - lr: 1e-03\n",
      "Epoch 5/20 - 99s - loss: 2.7418 - acc: 0.3321 - cross_entropy: 2.7651 - val_loss: 2.7037 - val_acc: 0.3351 - val_cross_entropy: 2.7265 - lr: 1e-03\n",
      "Epoch 6/20 - 98s - loss: 2.7068 - acc: 0.3388 - cross_entropy: 2.7297 - val_loss: 2.6739 - val_acc: 0.3454 - val_cross_entropy: 2.6965 - lr: 1e-03\n",
      "Epoch 7/20 - 98s - loss: 2.6872 - acc: 0.3446 - cross_entropy: 2.7098 - val_loss: 2.6617 - val_acc: 0.3501 - val_cross_entropy: 2.6841 - lr: 1e-03\n",
      "Epoch 8/20 - 99s - loss: 2.6642 - acc: 0.3483 - cross_entropy: 2.6866 - val_loss: 2.6566 - val_acc: 0.3491 - val_cross_entropy: 2.6789 - lr: 1e-03\n",
      "Epoch 9/20 - 99s - loss: 2.6381 - acc: 0.3528 - cross_entropy: 2.6602 - val_loss: 2.6447 - val_acc: 0.3515 - val_cross_entropy: 2.6668 - lr: 1e-03\n",
      "Epoch 10/20 - 99s - loss: 2.6131 - acc: 0.3553 - cross_entropy: 2.6350 - val_loss: 2.6359 - val_acc: 0.3519 - val_cross_entropy: 2.6580 - lr: 1e-03\n",
      "Epoch 11/20 - 99s - loss: 2.6035 - acc: 0.3597 - cross_entropy: 2.6252 - val_loss: 2.6119 - val_acc: 0.3579 - val_cross_entropy: 2.6337 - lr: 1e-03\n",
      "Epoch 12/20 - 99s - loss: 2.5923 - acc: 0.3604 - cross_entropy: 2.6139 - val_loss: 2.6071 - val_acc: 0.3565 - val_cross_entropy: 2.6289 - lr: 1e-03\n",
      "Epoch 13/20 - 99s - loss: 2.5710 - acc: 0.3651 - cross_entropy: 2.5924 - val_loss: 2.5956 - val_acc: 0.3634 - val_cross_entropy: 2.6172 - lr: 1e-03\n",
      "Epoch 14/20 - 99s - loss: 2.5607 - acc: 0.3705 - cross_entropy: 2.5820 - val_loss: 2.5928 - val_acc: 0.3595 - val_cross_entropy: 2.6144 - lr: 1e-03\n",
      "Epoch 15/20 - 98s - loss: 2.5474 - acc: 0.3709 - cross_entropy: 2.5685 - val_loss: 2.5901 - val_acc: 0.3625 - val_cross_entropy: 2.6116 - lr: 1e-03\n",
      "Epoch 16/20 - 98s - loss: 2.5313 - acc: 0.3718 - cross_entropy: 2.5523 - val_loss: 2.5775 - val_acc: 0.3663 - val_cross_entropy: 2.5989 - lr: 1e-03\n",
      "Epoch 17/20 - 99s - loss: 2.5216 - acc: 0.3735 - cross_entropy: 2.5424 - val_loss: 2.5721 - val_acc: 0.3663 - val_cross_entropy: 2.5934 - lr: 1e-03\n",
      "Epoch 18/20"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n",
      "KeyboardInterrupt\n",
      "\n"
     ]
    }
   ],
   "source": [
    "\n",
    "  \n",
    "\n",
    "    Or_model = MainModel(\n",
    "        categories        = (),               \n",
    "        num_continuous    = 27,               \n",
    "        dim               = 96  ,             \n",
    "        dim_out           = 101,                \n",
    "        depth             = 7,               \n",
    "        heads             = 8,                \n",
    "        attn_dropout      = 0.25 ,              \n",
    "        ff_dropout        = 0.25,              \n",
    "        U                 = U_, \n",
    "        cases             = 16\n",
    "\n",
    "    )\n",
    "    no_model = Num_Cat(Or_model,num_number=27,classes=101)\n",
    "    model    = keras4torch.Model(no_model,).build([27])\n",
    "    \n",
    "    #Warm Up Train\n",
    "    no_model.reset_Sample_size(1)\n",
    "    optimizer=torch.optim.AdamW(model.parameters(),lr=0.0001,weight_decay=0.0001,)\n",
    "    sch=torch.optim.lr_scheduler.LinearLR(optimizer, start_factor=0.00001, total_iters=5,  verbose=False)\n",
    "    model.compile(optimizer=optimizer, loss=MyClassLoss(0.1,1), metrics=['accuracy'])\n",
    "    callbacks=[LRScheduler(sch)]\n",
    "    model.fit(X_train, Y_train.values,\n",
    "            epochs=5, batch_size=512,\n",
    "            validation_data=(X_valid,y_valid.values),\n",
    "            verbose=2,validation_batch_size=1024,\n",
    "            callbacks=callbacks)\n",
    "\n",
    "    #Main Train\n",
    "    no_model.reset_Sample_size(64)\n",
    "    optimizer=torch.optim.AdamW(model.parameters(),lr=0.001,weight_decay=0.000001,)\n",
    "    model.compile(optimizer=optimizer, loss=MyClassLoss(0.01,1), metrics=['accuracy', F.cross_entropy])\n",
    "    scheduler =LRScheduler( torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', patience=5, factor=0.5,min_lr=0.00001))\n",
    "    callbacks=[scheduler,ModelCheckpoint(chpfilename,monitor='val_acc',mode='max')]\n",
    "    model.fit(X_train, Y_train.values,\n",
    "            epochs=20, batch_size=512,\n",
    "            validation_data=(X_valid,y_valid.values),\n",
    "\n",
    "            verbose=2,validation_batch_size=256,\n",
    "            callbacks=callbacks)\n",
    "    \n",
    "    \n",
    "    model.load_weights(chpfilename)\n",
    "    \n",
    "    no_model.reset_Sample_size(1)\n",
    "    logits=0\n",
    "\n",
    "    for i in range(0,64):\n",
    "            \n",
    "            logits+=pd.DataFrame(model.predict(X_test,batch_size=4096))\n",
    "    \n",
    "\n",
    "\n",
    "   \n",
    "    Test = (True_ACC(logits.idxmax(axis=1).values.reshape((-1)),y_test.values.reshape((-1))))\n",
    "\n",
    "\n",
    "    \n",
    "   \n",
    "\n",
    "    print('Test Score:',Test)\n",
    " \n",
    "    "
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.4"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
