{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "os.environ['CUDA_VISIBLE_DEVICES'] = '1'\n",
    "import keras4torch\n",
    "from   keras4torch.callbacks  import ModelCheckpoint,LRScheduler\n",
    "import torch\n",
    "import torch.nn    as nn\n",
    "import torch.optim as optim\n",
    "import  torch.nn.functional as     F\n",
    "import numpy       as np\n",
    "import pandas      as pd\n",
    "from copy import deepcopy\n",
    "import  matplotlib.pyplot   as     plt\n",
    "from    sklearn.preprocessing import StandardScaler, QuantileTransformer\n",
    "from    datetime import datetime\n",
    "import  gc\n",
    "import STab\n",
    "from STab import MyClassLoss,MyRegreLoss,CatMap,Num_Cat\n",
    "from   STab import mainmodel, LWTA, Gsoftmax\n",
    "MainModel=mainmodel.MainModel"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "##Load DAta\n",
    "X_test=pd.DataFrame(np.load('Data/year/N_test.npy')).astype(np.float)\n",
    "y_test=pd.DataFrame(np.load('Data/year/y_test.npy')).astype(np.float)\n",
    "\n",
    "X_train=pd.DataFrame(np.load('Data/year/N_train.npy')).astype(np.float)\n",
    "Y_train=pd.DataFrame(np.load('Data/year/y_train.npy')).astype(np.float)\n",
    "\n",
    "\n",
    "X_valid=pd.DataFrame(np.load('Data/year/N_val.npy')).astype(np.float)\n",
    "y_valid=pd.DataFrame(np.load('Data/year/y_val.npy')).astype(np.float)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#Normalise In\n",
    "scalerX = StandardScaler()\n",
    "\n",
    "scalerX.fit(X_train+np.random.normal(0,0.000001,size=X_train.shape))\n",
    "\n",
    "X_train = scalerX.transform(X_train).astype(np.float) \n",
    "X_valid = scalerX.transform(X_valid).astype(np.float) \n",
    "X_test  = scalerX.transform(X_test).astype(np.float) \n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "scale_factor=Y_train.std()\n",
    "y_test  =(y_test-Y_train.mean())/scale_factor\n",
    "y_valid =(y_valid-Y_train.mean())/scale_factor\n",
    "Y_train =(Y_train-Y_train.mean())/scale_factor"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "checkpoint='saved/savefileYE'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def True_MSE(p,t):\n",
    "    return np.mean((scale_factor*(t- p))  **2) "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "Or_model = MainModel(\n",
    "    categories        = (),               \n",
    "    num_continuous    = 90,               \n",
    "    dim               = 128,            \n",
    "    dim_out           = 1,                \n",
    "    depth             = 6,                \n",
    "    heads             = 8,               \n",
    "    attn_dropout      = 0.25 ,              \n",
    "    ff_dropout        = 0.25,              \n",
    "    U                 = 2, \n",
    "    cases             = 16\n",
    "\n",
    ")\n",
    "no_model = Num_Cat(Or_model,num_number=90)\n",
    "model    = keras4torch.Model(no_model,).build([90])\n",
    "\n",
    "#Warm UP\n",
    "optimizer=torch.optim.AdamW(model.parameters(),lr=0.0001,weight_decay=0.0001,)\n",
    "sch=torch.optim.lr_scheduler.LinearLR(optimizer, start_factor=0.001, total_iters=1,  verbose=False)\n",
    "model.compile(optimizer=optimizer, loss=MyRegreLoss(0.1,1), metrics=['mse'])\n",
    "callbacks=[ModelCheckpoint(checkpoint,monitor='val_mse',mode='min'),LRScheduler(sch)]\n",
    "model.fit(X_train, Y_train.values,\n",
    "                      epochs=5, batch_size=1024,\n",
    "                      validation_data=(X_valid,y_valid.values),\n",
    "                      verbose=2,validation_batch_size=128,\n",
    "                      callbacks=callbacks)\n",
    "#Main Train\n",
    "optimizer=torch.optim.AdamW(model.parameters(),lr=0.001,weight_decay=0.0001,)\n",
    "model.compile(optimizer=optimizer, loss=MyRegreLoss(0.01,1), metrics=['mse'])\n",
    "scheduler =LRScheduler( torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', patience=1, factor=0.5,min_lr=0.00001))\n",
    "callbacks=[ModelCheckpoint('savefile6YE',monitor='val_mse',mode='min'),scheduler]           \n",
    "model.fit(X_train, Y_train.values,\n",
    "        epochs=30, batch_size=1024,\n",
    "        validation_data=(X_valid,y_valid.values),\n",
    "        verbose=2,validation_batch_size=128,\n",
    "        callbacks=callbacks)\n",
    "\n",
    "\n",
    "model.load_weights(checkpoint)\n",
    "\n",
    "no_model.reset_Sample_size(1)\n",
    "Out2=pd.DataFrame()\n",
    "\n",
    "for i in range(0,64):\n",
    "\n",
    "        Out2[i]=model.predict(X_test,batch_size=4096)[:,0]\n",
    "Test = (True_MSE(Out2.mean(axis=1).values.reshape((-1)),y_test.values.reshape((-1))))\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.4"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
