{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# MixRL: Data Mixing Augmentation for Regression using Reinforcement Learning"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Note:\n",
    "This demo performs MixRL experiments on the NO2 dataset. \\\n",
    "Since the regression model weights are initialized randomly, the performance of the model may change for different runs. \\\n",
    "However, the performance improvements are similar to those in the paper."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 1. Load the dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "from mixrl import MixRL\n",
    "import copy\n",
    "import numpy as np\n",
    "\n",
    "x_train = np.load(\"NO2_data/x_train.npy\")\n",
    "x_valid = np.load(\"NO2_data/x_valid.npy\")\n",
    "x_test = np.load(\"NO2_data/x_test.npy\")\n",
    "\n",
    "y_train = np.load(\"NO2_data/y_train.npy\")\n",
    "y_valid = np.load(\"NO2_data/y_valid.npy\")\n",
    "y_test = np.load(\"NO2_data/y_test.npy\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 2.Construct the regression model for the dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import torch.nn as nn\n",
    "\n",
    "class reg_model(nn.Module):\n",
    "    def __init__(self):\n",
    "        super(reg_model, self).__init__()\n",
    "        \n",
    "        self.block_1 = nn.Sequential(nn.Linear(7, 512),nn.ReLU())\n",
    "        self.block_2 = nn.Sequential(nn.Linear(512, 256),nn.ReLU())\n",
    "        self.fclayer = nn.Sequential(nn.Linear(256,1))\n",
    "        \n",
    "    def forward(self, x):\n",
    "        block1_out = self.block_1(x)\n",
    "        block2_out = self.block_2(block1_out)\n",
    "        output = self.fclayer(block2_out)\n",
    "        return output\n",
    "\n",
    "def weight_inits(m): \n",
    "    if type(m)==nn.Linear:\n",
    "        nn.init.kaiming_normal_(m.weight, mode='fan_in', nonlinearity='leaky_relu')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "if torch.cuda.is_available():\n",
    "    device = torch.device('cuda')\n",
    "else:\n",
    "    device = torch.device('cpu')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "reg_model(\n",
       "  (block_1): Sequential(\n",
       "    (0): Linear(in_features=7, out_features=512, bias=True)\n",
       "    (1): ReLU()\n",
       "  )\n",
       "  (block_2): Sequential(\n",
       "    (0): Linear(in_features=512, out_features=256, bias=True)\n",
       "    (1): ReLU()\n",
       "  )\n",
       "  (fclayer): Sequential(\n",
       "    (0): Linear(in_features=256, out_features=1, bias=True)\n",
       "  )\n",
       ")"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "reg_model_base = reg_model()\n",
    "reg_model_base.apply(weight_inits)\n",
    "reg_model_base.to(device)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 3. Construct the MixRL class with attributes"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "from models import Mixup_value_net\n",
    "\n",
    "knn_option = [4, 16, 64]\n",
    "\n",
    "device_option = 1\n",
    "if device_option:\n",
    "    device = torch.device('cuda')\n",
    "else:\n",
    "    device = torch.device('cpu')\n",
    "    \n",
    "batch_size_MixRL = len(x_train)\n",
    "batch_size_reg_model = 32\n",
    "epoch_MixRL = 1000\n",
    "epoch_reg_model = 150\n",
    "lr_reg_model = 1e-3\n",
    "lambda_ratio = 0.5\n",
    "reward_scaling_factor = 5\n",
    "early_stopping_flag = 1\n",
    "early_stopping_patience = 20\n",
    "\n",
    "MixRL_input_dim = np.shape(x_train)[1] + np.shape(y_train)[1] + len(knn_option)*2\n",
    "Mixup_value_network = Mixup_value_net(MixRL_input_dim).to(device)\n",
    "\n",
    "MixRL_class = MixRL(x_train, y_train, x_valid, y_valid, knn_option, reg_model_base, Mixup_value_network, \\\n",
    "                    batch_size_MixRL, batch_size_reg_model, epoch_MixRL, epoch_reg_model, lr_reg_model, lambda_ratio, \\\n",
    "                    reward_scaling_factor, device_option, early_stopping_flag, early_stopping_patience)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Regression model performance (No augmentation used here)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "mean of RMSE: 0.5330, mean of R squared: 0.4735\n",
      "std of RMSE: 0.0023, std of R squared: 0.0058\n"
     ]
    }
   ],
   "source": [
    "num_iter = 5\n",
    "r_s_list = []\n",
    "rmse_list = []\n",
    "for _ in range(num_iter) :\n",
    "    temp_model = copy.deepcopy(reg_model_base)\n",
    "    trained_model = MixRL_class.train_reg_model(temp_model, 0, [])\n",
    "    r_s, mse = MixRL_class.cal_measures(trained_model, x_test, y_test)\n",
    "    r_s_list.append(r_s)\n",
    "    rmse_list.append(np.sqrt(mse))\n",
    "print(\"mean of RMSE: %1.4f, mean of R squared: %1.4f\" %(np.mean(rmse_list), np.mean(r_s_list)))\n",
    "print(\"std of RMSE: %1.4f, std of R squared: %1.4f\" %(np.std(rmse_list), np.std(r_s_list)))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 4. Train the Mixup value network"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "MixRL_class.train_MixRL()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 5. Evaluate using Mixup value network and sort h values in descending order"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "values = MixRL_class.get_h_values()\n",
    "sorted_value_idx_list = np.argsort(values)[::-1]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 6. Find the best threshold that minimizes the validation loss"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "init_threhsold = 100\n",
    "search_step = 100\n",
    "best_threshold = MixRL_class.find_best_threshold(sorted_value_idx_list, init_threhsold, search_step)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Regression model performance (MixRL)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "mean of RMSE: 0.5209, mean of R squared: 0.5033\n",
      "std of RMSE: 0.0087, std of R squared: 0.0140\n"
     ]
    }
   ],
   "source": [
    "mixup_idx = MixRL_class.mixup_by_threshold(sorted_value_idx_list, best_threshold)\n",
    "num_iter = 5\n",
    "r_s_list = []\n",
    "rmse_list = []\n",
    "for _ in range(num_iter) :\n",
    "    temp_model = copy.deepcopy(reg_model_base)\n",
    "    trained_model = MixRL_class.train_reg_model(temp_model, 1, mixup_idx)\n",
    "    r_s, mse = MixRL_class.cal_measures(trained_model, x_test, y_test)\n",
    "    r_s_list.append(r_s)\n",
    "    rmse_list.append(np.sqrt(mse))\n",
    "print(\"mean of RMSE: %1.4f, mean of R squared: %1.4f\" %(np.mean(rmse_list), np.mean(r_s_list)))\n",
    "print(\"std of RMSE: %1.4f, std of R squared: %1.4f\" %(np.std(rmse_list), np.std(r_s_list)))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Result\n",
    "Baseline / MixRL \\\n",
    "RMSE: 0.5330 / 0.5209 \\\n",
    "$R^2$: 0.4735 / 0.5033 "
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Torch",
   "language": "python",
   "name": "torch"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.10"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
