{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import keras4torch\n",
    "from   keras4torch.callbacks  import ModelCheckpoint,LRScheduler\n",
    "import torch\n",
    "import torch.nn    as nn\n",
    "import torch.optim as optim\n",
    "import  torch.nn.functional as     F\n",
    "import numpy       as np\n",
    "import pandas      as pd\n",
    "from copy import deepcopy\n",
    "import  matplotlib.pyplot   as     plt\n",
    "from    sklearn.preprocessing import StandardScaler, QuantileTransformer\n",
    "from    datetime import datetime\n",
    "import  gc\n",
    "import STab\n",
    "from STab import MyClassLoss,CatMap,Num_Cat\n",
    "from   STab import mainmodel, LWTA, Gsoftmax\n",
    "MainModel=mainmodel.MainModel"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "##Load DAta\n",
    "X_test_num=pd.DataFrame(np.load('Data/diamond/X_num_test.npy')).astype(np.float)\n",
    "X_test_cat=pd.DataFrame(np.load('Data/diamond/X_cat_test.npy'))\n",
    "y_test=pd.DataFrame(np.load('Data/diamond/Y_test.npy')).astype(np.float)\n",
    "\n",
    "X_train_num=pd.DataFrame(np.load('Data/diamond/X_num_train.npy')).astype(np.float)\n",
    "X_train_cat=pd.DataFrame(np.load('Data/diamond/X_cat_train.npy')) \n",
    "Y_train=pd.DataFrame(np.load('Data/diamond/Y_train.npy')).astype(np.float)\n",
    "\n",
    "\n",
    "X_valid_num=pd.DataFrame(np.load('Data/diamond/X_num_val.npy')).astype(np.float)\n",
    "X_valid_cat=pd.DataFrame(np.load('Data/diamond/X_cat_val.npy'))\n",
    "y_valid=pd.DataFrame(np.load('Data/diamond/Y_val.npy')).astype(np.float)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# mapping categorical features to integers \n",
    "catmap      = CatMap(X_train_cat)\n",
    "X_test_cat  = catmap(X_test_cat)\n",
    "X_valid_cat = catmap(X_valid_cat)\n",
    "X_train_cat = catmap(X_train_cat)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#Normalise In\n",
    "scalerX = StandardScaler()\n",
    "\n",
    "scalerX.fit(X_train_num)\n",
    "\n",
    "X_train_num = scalerX.transform(X_train_num).astype(np.float) \n",
    "X_valid_num = scalerX.transform(X_valid_num).astype(np.float) \n",
    "X_test_num  = scalerX.transform(X_test_num).astype(np.float) \n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#combine numerical and categorical\n",
    "X_train = np.concatenate([X_train_num,X_train_cat.values],axis=1)\n",
    "X_test  = np.concatenate([X_test_num,X_test_cat.values],axis=1)\n",
    "X_valid = np.concatenate([X_valid_num,X_valid_cat.values],axis=1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def True_MSE(p,t):\n",
    "    return np.mean( np.square(p-t))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "checkpoint='saved/savefileDI'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "\n",
    "\n",
    "Or_model = MainModel(\n",
    "    categories        = (5,7,8,),               \n",
    "    num_continuous    = 6,               \n",
    "    dim               = 96  ,            \n",
    "    dim_out           = 1,               \n",
    "    depth             = 4,                \n",
    "    heads             = 8,                \n",
    "    attn_dropout      = 0.1 ,             \n",
    "    ff_dropout        = 0.1,              \n",
    "    U                 = dim_, \n",
    "    cases             = cases_\n",
    "\n",
    ")\n",
    "no_model = Num_Cat(Or_model,Sample_size=16)\n",
    "model    = keras4torch.Model(no_model,).build([9])\n",
    "\n",
    "#Warm Up\n",
    "optimizer=torch.optim.AdamW(model.parameters(),lr=0.0001,weight_decay=0.0001,)\n",
    "sch=torch.optim.lr_scheduler.LinearLR(optimizer, start_factor=0.0001, total_iters=5,  verbose=False)\n",
    "model.compile(optimizer=optimizer, loss=MyRegreLoss(0.01,1), metrics=['mse'])\n",
    "callbacks=[LRScheduler(sch)]\n",
    "model.fit(X_train, Y_train.values,\n",
    "        epochs=5, batch_size=512,\n",
    "        validation_data=(X_valid,y_valid.values),\n",
    "        verbose=2,validation_batch_size=1024,\n",
    "        callbacks=callbacks)\n",
    "\n",
    "#Main Train\n",
    "optimizer=torch.optim.AdamW(model.parameters(),lr=0.001,weight_decay=0.0001,)\n",
    "model.compile(optimizer=optimizer, loss=MyRegreLoss(0.01,1), metrics=['mse'])\n",
    "scheduler =LRScheduler( torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', patience=5, factor=0.5,min_lr=0.000001))\n",
    "callbacks=[ModelCheckpoint(checkpoint,monitor='val_mse',mode='min'),scheduler]\n",
    "model.fit(X_train, Y_train.values,\n",
    "        epochs=80, batch_size=512,\n",
    "        validation_data=(X_valid,y_valid.values),\n",
    "        verbose=2,validation_batch_size=1024,\n",
    "        callbacks=callbacks)\n",
    "\n",
    "\n",
    "model.load_weights(checkpoint)\n",
    "\n",
    "no_model.reset_Sample_size(1)\n",
    "logits=pd.DataFrame()\n",
    "\n",
    "for i in range(0,64):\n",
    "\n",
    "        logits[i]=pd.DataFrame(model.predict(X_test,batch_size=4096))\n",
    "\n",
    "\n",
    "\n",
    "Test = (True_MSE(logits.mean(axis=1).values.reshape((-1)),y_test.values.reshape((-1))))\n",
    "\n",
    "\n",
    "\n",
    "print('Test Score:',Test)\n"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.4"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
