{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "5838cd24",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Dataset: community\n",
      "Dimensions: train set (n=1595, p=100) ; test set (n=399, p=100)\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/Users/yu.bai/anaconda3/envs/cqr/lib/python3.7/site-packages/sklearn/utils/deprecation.py:58: DeprecationWarning: Class Imputer is deprecated; Imputer was deprecated in version 0.20 and will be removed in 0.22. Import impute.SimpleImputer from sklearn instead.\n",
      "  warnings.warn(msg, category=DeprecationWarning)\n"
     ]
    }
   ],
   "source": [
    "import torch\n",
    "import random\n",
    "import numpy as np\n",
    "np.warnings.filterwarnings('ignore')\n",
    "\n",
    "from datasets import datasets\n",
    "from sklearn.preprocessing import StandardScaler\n",
    "from sklearn.model_selection import train_test_split\n",
    "\n",
    "seed = 1\n",
    "\n",
    "random_state_train_test = seed\n",
    "random.seed(seed)\n",
    "np.random.seed(seed)\n",
    "torch.manual_seed(seed)\n",
    "if torch.cuda.is_available():\n",
    "    torch.cuda.manual_seed_all(seed)\n",
    "    \n",
    "# desired miscoverage error\n",
    "alpha = 0.1\n",
    "\n",
    "# desired quanitile levels\n",
    "quantiles = [0.05, 0.95]\n",
    "\n",
    "# used to determine the size of test set\n",
    "test_ratio = 0.2\n",
    "\n",
    "# name of dataset\n",
    "dataset_base_path = \"./datasets/\"\n",
    "dataset_name = \"community\"\n",
    "\n",
    "# load the dataset\n",
    "X, y = datasets.GetDataset(dataset_name, dataset_base_path)\n",
    "\n",
    "# divide the dataset into test and train based on the test_ratio parameter\n",
    "x_train, x_test, y_train, y_test = train_test_split(X,\n",
    "                                                    y,\n",
    "                                                    test_size=test_ratio,\n",
    "                                                    random_state=random_state_train_test)\n",
    "\n",
    "# reshape the data\n",
    "x_train = np.asarray(x_train)\n",
    "y_train = np.asarray(y_train)\n",
    "x_test = np.asarray(x_test)\n",
    "y_test = np.asarray(y_test)\n",
    "\n",
    "# compute input dimensions\n",
    "n_train = x_train.shape[0]\n",
    "in_shape = x_train.shape[1]\n",
    "\n",
    "# display basic information\n",
    "print(\"Dataset: %s\" % (dataset_name))\n",
    "print(\"Dimensions: train set (n=%d, p=%d) ; test set (n=%d, p=%d)\" % \n",
    "      (x_train.shape[0], x_train.shape[1], x_test.shape[0], x_test.shape[1]))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "id": "64e2607e",
   "metadata": {},
   "outputs": [],
   "source": [
    "from torch.utils.data import TensorDataset, DataLoader\n",
    "train_dataset = TensorDataset(torch.Tensor(x_train), torch.Tensor(y_train))\n",
    "test_dataset = TensorDataset(torch.Tensor(x_test), torch.Tensor(y_test))\n",
    "\n",
    "batch_size = 64\n",
    "train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)\n",
    "test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 43,
   "id": "907e969e",
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch.nn as nn\n",
    "import torch.optim as optim\n",
    "\n",
    "class LinearModel(nn.Module):\n",
    "    def __init__(self, in_dim, out_dim=1, bias=True):\n",
    "        super(LinearModel, self).__init__()\n",
    "        self.linear = nn.Linear(in_dim, out_dim, bias=bias)\n",
    "        \n",
    "    def forward(self, x):\n",
    "        return self.linear(x)\n",
    "    \n",
    "\n",
    "class PinballLoss():\n",
    "    \"\"\"Pinball loss for quantile regression\"\"\"\n",
    "    def __init__(self, quantile=0.10, reduction='none'):\n",
    "        self.quantile = quantile\n",
    "        assert 0 < self.quantile < 1\n",
    "        self.reduction = reduction\n",
    "        \n",
    "    def __call__(self, output, target):\n",
    "        assert output.shape == target.shape\n",
    "        loss = torch.zeros_like(target, dtype=torch.float)\n",
    "        error = target - output\n",
    "        smaller_index = error < 0\n",
    "        bigger_index = 0 < error\n",
    "        loss[smaller_index] = (1-self.quantile) * (abs(error)[smaller_index])\n",
    "        loss[bigger_index] = self.quantile * (abs(error)[bigger_index])\n",
    "\n",
    "        if self.reduction == 'sum':\n",
    "            loss = loss.sum()\n",
    "        elif self.reduction == 'mean':\n",
    "            loss = loss.mean()\n",
    "\n",
    "        return loss"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 44,
   "id": "4c89aef1",
   "metadata": {},
   "outputs": [],
   "source": [
    "def train(epoch):\n",
    "    print('\\nEpoch: %d' % epoch)\n",
    "    train_loss = 0.0\n",
    "    count = 0\n",
    "    covered = 0.0\n",
    "    for batch_idx, (inputs, targets) in enumerate(train_loader):\n",
    "        optimizer.zero_grad()\n",
    "        outputs = net(inputs).squeeze()\n",
    "        loss = criterion(outputs, targets)\n",
    "        loss.backward()\n",
    "        optimizer.step()\n",
    "        train_loss += loss.item() * inputs.shape[0]\n",
    "        count += inputs.shape[0]\n",
    "        covered += (targets <= outputs).sum().float()\n",
    "        if (batch_idx+1) % 10 == 0:\n",
    "            print(f\"Batch [{batch_idx+1}/{len(train_loader)}]: \"\n",
    "                  f\"Loss: {train_loss/count:.6f}, Coverage: {100.*covered/count:.3f}\")\n",
    "    # import pdb; pdb.set_trace()\n",
    "    return train_loss / count, 1.*covered/count\n",
    "\n",
    "\n",
    "def test(epoch):\n",
    "    test_loss = 0.0\n",
    "    count = 0\n",
    "    covered = 0.0\n",
    "    for batch_idx, (inputs, targets) in enumerate(test_loader):\n",
    "        with torch.no_grad():\n",
    "            outputs = net(inputs).squeeze()\n",
    "            loss = criterion(outputs, targets)\n",
    "            optimizer.step()\n",
    "            test_loss += loss.item() * inputs.shape[0]\n",
    "            count += inputs.shape[0]\n",
    "            covered += (targets <= outputs).sum().float()\n",
    "            if (batch_idx+1) % 10 == 0:\n",
    "                print(f\"Batch [{batch_idx+1}/{len(train_loader)}]: \"\n",
    "                      f\"Loss: {train_loss/count:.6f}, Coverage: {100.*covered/count:.3f}\")\n",
    "    return test_loss / count, 1.*covered/count"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 46,
   "id": "ee936aca",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Epoch: 0\n",
      "Batch [10/25]: Loss: 0.305113, Coverage: 15.781\n",
      "Batch [20/25]: Loss: 0.185244, Coverage: 51.562\n",
      "Training loss = 0.1620, Training coverage = 0.6056\n",
      "Test loss = 0.0702, Test coverage = 0.9699\n",
      "\n",
      "Epoch: 1\n",
      "Batch [10/25]: Loss: 0.067229, Coverage: 95.781\n",
      "Batch [20/25]: Loss: 0.065940, Coverage: 96.016\n",
      "Training loss = 0.0641, Training coverage = 0.9580\n",
      "Test loss = 0.0574, Test coverage = 0.9373\n",
      "\n",
      "Epoch: 2\n",
      "Batch [10/25]: Loss: 0.057095, Coverage: 90.625\n",
      "Batch [20/25]: Loss: 0.054161, Coverage: 90.625\n",
      "Training loss = 0.0527, Training coverage = 0.9078\n",
      "Test loss = 0.0495, Test coverage = 0.9123\n",
      "\n",
      "Epoch: 3\n",
      "Batch [10/25]: Loss: 0.050552, Coverage: 88.906\n",
      "Batch [20/25]: Loss: 0.049254, Coverage: 90.000\n",
      "Training loss = 0.0477, Training coverage = 0.9091\n",
      "Test loss = 0.0448, Test coverage = 0.9148\n",
      "\n",
      "Epoch: 4\n",
      "Batch [10/25]: Loss: 0.043406, Coverage: 87.188\n",
      "Batch [20/25]: Loss: 0.042673, Coverage: 87.578\n",
      "Training loss = 0.0436, Training coverage = 0.8721\n",
      "Test loss = 0.0407, Test coverage = 0.9123\n",
      "\n",
      "Epoch: 5\n",
      "Batch [10/25]: Loss: 0.040644, Coverage: 91.719\n",
      "Batch [20/25]: Loss: 0.039109, Coverage: 91.562\n",
      "Training loss = 0.0387, Training coverage = 0.9147\n",
      "Test loss = 0.0369, Test coverage = 0.9173\n",
      "\n",
      "Epoch: 6\n",
      "Batch [10/25]: Loss: 0.034464, Coverage: 90.938\n",
      "Batch [20/25]: Loss: 0.035016, Coverage: 90.781\n",
      "Training loss = 0.0351, Training coverage = 0.9053\n",
      "Test loss = 0.0337, Test coverage = 0.9023\n",
      "\n",
      "Epoch: 7\n",
      "Batch [10/25]: Loss: 0.034226, Coverage: 90.781\n",
      "Batch [20/25]: Loss: 0.032347, Coverage: 91.016\n",
      "Training loss = 0.0326, Training coverage = 0.9053\n",
      "Test loss = 0.0322, Test coverage = 0.9098\n",
      "\n",
      "Epoch: 8\n",
      "Batch [10/25]: Loss: 0.033441, Coverage: 90.000\n",
      "Batch [20/25]: Loss: 0.031175, Coverage: 89.688\n",
      "Training loss = 0.0310, Training coverage = 0.8978\n",
      "Test loss = 0.0314, Test coverage = 0.9048\n",
      "\n",
      "Epoch: 9\n",
      "Batch [10/25]: Loss: 0.031004, Coverage: 88.594\n",
      "Batch [20/25]: Loss: 0.030985, Coverage: 88.594\n",
      "Training loss = 0.0304, Training coverage = 0.8897\n",
      "Test loss = 0.0311, Test coverage = 0.9073\n",
      "\n",
      "Epoch: 10\n",
      "Batch [10/25]: Loss: 0.030455, Coverage: 89.375\n",
      "Batch [20/25]: Loss: 0.029842, Coverage: 89.844\n",
      "Training loss = 0.0298, Training coverage = 0.9028\n",
      "Test loss = 0.0307, Test coverage = 0.9023\n",
      "\n",
      "Epoch: 11\n",
      "Batch [10/25]: Loss: 0.028577, Coverage: 89.062\n",
      "Batch [20/25]: Loss: 0.029396, Coverage: 89.297\n",
      "Training loss = 0.0294, Training coverage = 0.8940\n",
      "Test loss = 0.0305, Test coverage = 0.9073\n",
      "\n",
      "Epoch: 12\n",
      "Batch [10/25]: Loss: 0.030727, Coverage: 88.594\n",
      "Batch [20/25]: Loss: 0.029433, Coverage: 89.375\n",
      "Training loss = 0.0293, Training coverage = 0.8915\n",
      "Test loss = 0.0305, Test coverage = 0.9023\n",
      "\n",
      "Epoch: 13\n",
      "Batch [10/25]: Loss: 0.030706, Coverage: 91.719\n",
      "Batch [20/25]: Loss: 0.029260, Coverage: 91.328\n",
      "Training loss = 0.0291, Training coverage = 0.9085\n",
      "Test loss = 0.0301, Test coverage = 0.9048\n",
      "\n",
      "Epoch: 14\n",
      "Batch [10/25]: Loss: 0.028418, Coverage: 89.531\n",
      "Batch [20/25]: Loss: 0.028557, Coverage: 89.219\n",
      "Training loss = 0.0288, Training coverage = 0.8966\n",
      "Test loss = 0.0300, Test coverage = 0.9148\n",
      "\n",
      "Epoch: 15\n",
      "Batch [10/25]: Loss: 0.029287, Coverage: 89.375\n",
      "Batch [20/25]: Loss: 0.029205, Coverage: 89.141\n",
      "Training loss = 0.0286, Training coverage = 0.8966\n",
      "Test loss = 0.0296, Test coverage = 0.9073\n",
      "\n",
      "Epoch: 16\n",
      "Batch [10/25]: Loss: 0.028499, Coverage: 85.312\n",
      "Batch [20/25]: Loss: 0.028340, Coverage: 86.641\n",
      "Training loss = 0.0288, Training coverage = 0.8734\n",
      "Test loss = 0.0301, Test coverage = 0.9273\n",
      "\n",
      "Epoch: 17\n",
      "Batch [10/25]: Loss: 0.027889, Coverage: 89.688\n",
      "Batch [20/25]: Loss: 0.027357, Coverage: 90.703\n",
      "Training loss = 0.0283, Training coverage = 0.8972\n",
      "Test loss = 0.0296, Test coverage = 0.9048\n",
      "\n",
      "Epoch: 18\n",
      "Batch [10/25]: Loss: 0.029758, Coverage: 92.656\n",
      "Batch [20/25]: Loss: 0.029123, Coverage: 91.484\n",
      "Training loss = 0.0284, Training coverage = 0.9204\n",
      "Test loss = 0.0294, Test coverage = 0.8947\n",
      "\n",
      "Epoch: 19\n",
      "Batch [10/25]: Loss: 0.028999, Coverage: 85.625\n",
      "Batch [20/25]: Loss: 0.028418, Coverage: 87.812\n",
      "Training loss = 0.0284, Training coverage = 0.8834\n",
      "Test loss = 0.0293, Test coverage = 0.9123\n",
      "\n",
      "Epoch: 20\n",
      "Batch [10/25]: Loss: 0.028812, Coverage: 87.500\n",
      "Batch [20/25]: Loss: 0.028323, Coverage: 88.203\n",
      "Training loss = 0.0281, Training coverage = 0.8947\n",
      "Test loss = 0.0291, Test coverage = 0.9173\n",
      "\n",
      "Epoch: 21\n",
      "Batch [10/25]: Loss: 0.027596, Coverage: 87.188\n",
      "Batch [20/25]: Loss: 0.027758, Coverage: 87.500\n",
      "Training loss = 0.0281, Training coverage = 0.8821\n",
      "Test loss = 0.0296, Test coverage = 0.9323\n",
      "\n",
      "Epoch: 22\n",
      "Batch [10/25]: Loss: 0.029795, Coverage: 92.031\n",
      "Batch [20/25]: Loss: 0.028389, Coverage: 91.406\n",
      "Training loss = 0.0279, Training coverage = 0.9122\n",
      "Test loss = 0.0288, Test coverage = 0.9073\n",
      "\n",
      "Epoch: 23\n",
      "Batch [10/25]: Loss: 0.028917, Coverage: 84.688\n",
      "Batch [20/25]: Loss: 0.028824, Coverage: 86.719\n",
      "Training loss = 0.0281, Training coverage = 0.8765\n",
      "Test loss = 0.0290, Test coverage = 0.9223\n",
      "\n",
      "Epoch: 24\n",
      "Batch [10/25]: Loss: 0.025762, Coverage: 88.281\n",
      "Batch [20/25]: Loss: 0.027610, Coverage: 87.578\n",
      "Training loss = 0.0277, Training coverage = 0.8796\n",
      "Test loss = 0.0296, Test coverage = 0.9348\n",
      "\n",
      "Epoch: 25\n",
      "Batch [10/25]: Loss: 0.028516, Coverage: 91.406\n",
      "Batch [20/25]: Loss: 0.028649, Coverage: 89.766\n",
      "Training loss = 0.0275, Training coverage = 0.9022\n",
      "Test loss = 0.0290, Test coverage = 0.9148\n",
      "\n",
      "Epoch: 26\n",
      "Batch [10/25]: Loss: 0.029365, Coverage: 89.688\n",
      "Batch [20/25]: Loss: 0.027292, Coverage: 90.625\n",
      "Training loss = 0.0274, Training coverage = 0.9028\n",
      "Test loss = 0.0288, Test coverage = 0.9023\n",
      "\n",
      "Epoch: 27\n",
      "Batch [10/25]: Loss: 0.027785, Coverage: 91.875\n",
      "Batch [20/25]: Loss: 0.027206, Coverage: 90.781\n",
      "Training loss = 0.0273, Training coverage = 0.9028\n",
      "Test loss = 0.0288, Test coverage = 0.9198\n",
      "\n",
      "Epoch: 28\n",
      "Batch [10/25]: Loss: 0.028360, Coverage: 92.500\n",
      "Batch [20/25]: Loss: 0.027437, Coverage: 91.250\n",
      "Training loss = 0.0274, Training coverage = 0.9053\n",
      "Test loss = 0.0287, Test coverage = 0.9123\n",
      "\n",
      "Epoch: 29\n",
      "Batch [10/25]: Loss: 0.028222, Coverage: 92.188\n",
      "Batch [20/25]: Loss: 0.028093, Coverage: 93.047\n",
      "Training loss = 0.0276, Training coverage = 0.9273\n",
      "Test loss = 0.0284, Test coverage = 0.8897\n",
      "\n",
      "Epoch: 30\n",
      "Batch [10/25]: Loss: 0.027142, Coverage: 89.688\n",
      "Batch [20/25]: Loss: 0.027578, Coverage: 87.969\n",
      "Training loss = 0.0272, Training coverage = 0.8809\n",
      "Test loss = 0.0284, Test coverage = 0.9073\n",
      "\n",
      "Epoch: 31\n",
      "Batch [10/25]: Loss: 0.028397, Coverage: 89.219\n",
      "Batch [20/25]: Loss: 0.027522, Coverage: 88.984\n",
      "Training loss = 0.0271, Training coverage = 0.8922\n",
      "Test loss = 0.0285, Test coverage = 0.9223\n",
      "\n",
      "Epoch: 32\n",
      "Batch [10/25]: Loss: 0.027061, Coverage: 90.781\n",
      "Batch [20/25]: Loss: 0.026821, Coverage: 90.625\n",
      "Training loss = 0.0271, Training coverage = 0.9041\n",
      "Test loss = 0.0285, Test coverage = 0.9223\n",
      "\n",
      "Epoch: 33\n",
      "Batch [10/25]: Loss: 0.027435, Coverage: 90.156\n",
      "Batch [20/25]: Loss: 0.027293, Coverage: 90.000\n",
      "Training loss = 0.0271, Training coverage = 0.9003\n",
      "Test loss = 0.0285, Test coverage = 0.9223\n",
      "\n",
      "Epoch: 34\n",
      "Batch [10/25]: Loss: 0.026935, Coverage: 90.781\n",
      "Batch [20/25]: Loss: 0.026612, Coverage: 90.703\n",
      "Training loss = 0.0271, Training coverage = 0.9060\n",
      "Test loss = 0.0286, Test coverage = 0.9248\n",
      "\n",
      "Epoch: 35\n",
      "Batch [10/25]: Loss: 0.026261, Coverage: 90.781\n",
      "Batch [20/25]: Loss: 0.027165, Coverage: 90.312\n",
      "Training loss = 0.0271, Training coverage = 0.9085\n",
      "Test loss = 0.0285, Test coverage = 0.9248\n",
      "\n",
      "Epoch: 36\n",
      "Batch [10/25]: Loss: 0.028175, Coverage: 90.156\n",
      "Batch [20/25]: Loss: 0.027072, Coverage: 89.844\n",
      "Training loss = 0.0270, Training coverage = 0.9009\n",
      "Test loss = 0.0284, Test coverage = 0.9173\n",
      "\n",
      "Epoch: 37\n",
      "Batch [10/25]: Loss: 0.026519, Coverage: 90.938\n",
      "Batch [20/25]: Loss: 0.026863, Coverage: 89.453\n",
      "Training loss = 0.0271, Training coverage = 0.8915\n",
      "Test loss = 0.0283, Test coverage = 0.9073\n",
      "\n",
      "Epoch: 38\n",
      "Batch [10/25]: Loss: 0.025985, Coverage: 91.406\n",
      "Batch [20/25]: Loss: 0.026751, Coverage: 90.078\n",
      "Training loss = 0.0270, Training coverage = 0.8978\n",
      "Test loss = 0.0285, Test coverage = 0.9223\n",
      "\n",
      "Epoch: 39\n",
      "Batch [10/25]: Loss: 0.029624, Coverage: 88.750\n",
      "Batch [20/25]: Loss: 0.027682, Coverage: 90.469\n",
      "Training loss = 0.0270, Training coverage = 0.9060\n",
      "Test loss = 0.0285, Test coverage = 0.9248\n",
      "\n",
      "Epoch: 40\n",
      "Batch [10/25]: Loss: 0.027328, Coverage: 90.469\n",
      "Batch [20/25]: Loss: 0.027601, Coverage: 90.547\n",
      "Training loss = 0.0270, Training coverage = 0.9009\n",
      "Test loss = 0.0284, Test coverage = 0.9223\n",
      "\n",
      "Epoch: 41\n",
      "Batch [10/25]: Loss: 0.026307, Coverage: 90.312\n",
      "Batch [20/25]: Loss: 0.026829, Coverage: 90.703\n",
      "Training loss = 0.0270, Training coverage = 0.9047\n",
      "Test loss = 0.0285, Test coverage = 0.9248\n",
      "\n",
      "Epoch: 42\n",
      "Batch [10/25]: Loss: 0.028254, Coverage: 89.375\n",
      "Batch [20/25]: Loss: 0.027308, Coverage: 90.469\n",
      "Training loss = 0.0270, Training coverage = 0.9003\n",
      "Test loss = 0.0284, Test coverage = 0.9223\n",
      "\n",
      "Epoch: 43\n",
      "Batch [10/25]: Loss: 0.025991, Coverage: 90.469\n",
      "Batch [20/25]: Loss: 0.026958, Coverage: 90.781\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Training loss = 0.0270, Training coverage = 0.9078\n",
      "Test loss = 0.0285, Test coverage = 0.9273\n",
      "\n",
      "Epoch: 44\n",
      "Batch [10/25]: Loss: 0.025716, Coverage: 91.406\n",
      "Batch [20/25]: Loss: 0.027181, Coverage: 90.000\n",
      "Training loss = 0.0270, Training coverage = 0.9009\n",
      "Test loss = 0.0284, Test coverage = 0.9248\n",
      "\n",
      "Epoch: 45\n",
      "Batch [10/25]: Loss: 0.025609, Coverage: 91.094\n",
      "Batch [20/25]: Loss: 0.026640, Coverage: 90.312\n",
      "Training loss = 0.0270, Training coverage = 0.9009\n",
      "Test loss = 0.0284, Test coverage = 0.9223\n",
      "\n",
      "Epoch: 46\n",
      "Batch [10/25]: Loss: 0.027719, Coverage: 88.125\n",
      "Batch [20/25]: Loss: 0.026703, Coverage: 89.062\n",
      "Training loss = 0.0270, Training coverage = 0.8959\n",
      "Test loss = 0.0283, Test coverage = 0.9173\n",
      "\n",
      "Epoch: 47\n",
      "Batch [10/25]: Loss: 0.025301, Coverage: 90.469\n",
      "Batch [20/25]: Loss: 0.026400, Coverage: 89.297\n",
      "Training loss = 0.0270, Training coverage = 0.8928\n",
      "Test loss = 0.0283, Test coverage = 0.9198\n",
      "\n",
      "Epoch: 48\n",
      "Batch [10/25]: Loss: 0.026835, Coverage: 90.625\n",
      "Batch [20/25]: Loss: 0.026639, Coverage: 90.234\n",
      "Training loss = 0.0270, Training coverage = 0.9034\n",
      "Test loss = 0.0284, Test coverage = 0.9248\n",
      "\n",
      "Epoch: 49\n",
      "Batch [10/25]: Loss: 0.027458, Coverage: 90.000\n",
      "Batch [20/25]: Loss: 0.027228, Coverage: 90.312\n",
      "Training loss = 0.0270, Training coverage = 0.8984\n",
      "Test loss = 0.0283, Test coverage = 0.9198\n",
      "\n",
      "Epoch: 50\n",
      "Batch [10/25]: Loss: 0.026021, Coverage: 89.375\n",
      "Batch [20/25]: Loss: 0.026974, Coverage: 90.156\n",
      "Training loss = 0.0270, Training coverage = 0.8997\n",
      "Test loss = 0.0284, Test coverage = 0.9223\n",
      "\n",
      "Epoch: 51\n",
      "Batch [10/25]: Loss: 0.026703, Coverage: 90.156\n",
      "Batch [20/25]: Loss: 0.026681, Coverage: 90.938\n",
      "Training loss = 0.0270, Training coverage = 0.9060\n",
      "Test loss = 0.0285, Test coverage = 0.9248\n",
      "\n",
      "Epoch: 52\n",
      "Batch [10/25]: Loss: 0.024535, Coverage: 92.500\n",
      "Batch [20/25]: Loss: 0.026892, Coverage: 90.156\n",
      "Training loss = 0.0270, Training coverage = 0.9072\n",
      "Test loss = 0.0284, Test coverage = 0.9248\n",
      "\n",
      "Epoch: 53\n",
      "Batch [10/25]: Loss: 0.026522, Coverage: 91.875\n",
      "Batch [20/25]: Loss: 0.027049, Coverage: 89.922\n",
      "Training loss = 0.0270, Training coverage = 0.8940\n",
      "Test loss = 0.0283, Test coverage = 0.9223\n",
      "\n",
      "Epoch: 54\n",
      "Batch [10/25]: Loss: 0.026051, Coverage: 90.156\n",
      "Batch [20/25]: Loss: 0.026740, Coverage: 90.547\n",
      "Training loss = 0.0269, Training coverage = 0.8984\n",
      "Test loss = 0.0283, Test coverage = 0.9223\n",
      "\n",
      "Epoch: 55\n",
      "Batch [10/25]: Loss: 0.027849, Coverage: 89.375\n",
      "Batch [20/25]: Loss: 0.027033, Coverage: 90.547\n",
      "Training loss = 0.0270, Training coverage = 0.9066\n",
      "Test loss = 0.0284, Test coverage = 0.9273\n",
      "\n",
      "Epoch: 56\n",
      "Batch [10/25]: Loss: 0.025599, Coverage: 92.031\n",
      "Batch [20/25]: Loss: 0.026223, Coverage: 90.469\n",
      "Training loss = 0.0269, Training coverage = 0.9047\n",
      "Test loss = 0.0284, Test coverage = 0.9223\n",
      "\n",
      "Epoch: 57\n",
      "Batch [10/25]: Loss: 0.026516, Coverage: 91.875\n",
      "Batch [20/25]: Loss: 0.027347, Coverage: 90.859\n",
      "Training loss = 0.0269, Training coverage = 0.9097\n",
      "Test loss = 0.0285, Test coverage = 0.9248\n",
      "\n",
      "Epoch: 58\n",
      "Batch [10/25]: Loss: 0.026059, Coverage: 92.500\n",
      "Batch [20/25]: Loss: 0.026365, Coverage: 91.172\n",
      "Training loss = 0.0269, Training coverage = 0.9110\n",
      "Test loss = 0.0284, Test coverage = 0.9273\n",
      "\n",
      "Epoch: 59\n",
      "Batch [10/25]: Loss: 0.026191, Coverage: 90.469\n",
      "Batch [20/25]: Loss: 0.026053, Coverage: 89.844\n",
      "Training loss = 0.0269, Training coverage = 0.8959\n",
      "Test loss = 0.0283, Test coverage = 0.9223\n",
      "\n",
      "Epoch: 60\n",
      "Batch [10/25]: Loss: 0.029139, Coverage: 89.062\n",
      "Batch [20/25]: Loss: 0.027437, Coverage: 90.156\n",
      "Training loss = 0.0269, Training coverage = 0.8966\n",
      "Test loss = 0.0283, Test coverage = 0.9223\n",
      "\n",
      "Epoch: 61\n",
      "Batch [10/25]: Loss: 0.028821, Coverage: 88.750\n",
      "Batch [20/25]: Loss: 0.026997, Coverage: 89.922\n",
      "Training loss = 0.0269, Training coverage = 0.8997\n",
      "Test loss = 0.0283, Test coverage = 0.9223\n",
      "\n",
      "Epoch: 62\n",
      "Batch [10/25]: Loss: 0.028066, Coverage: 87.656\n",
      "Batch [20/25]: Loss: 0.027395, Coverage: 89.219\n",
      "Training loss = 0.0269, Training coverage = 0.8997\n",
      "Test loss = 0.0283, Test coverage = 0.9223\n",
      "\n",
      "Epoch: 63\n",
      "Batch [10/25]: Loss: 0.026263, Coverage: 90.625\n",
      "Batch [20/25]: Loss: 0.026639, Coverage: 89.688\n",
      "Training loss = 0.0269, Training coverage = 0.8997\n",
      "Test loss = 0.0283, Test coverage = 0.9223\n",
      "\n",
      "Epoch: 64\n",
      "Batch [10/25]: Loss: 0.026279, Coverage: 90.781\n",
      "Batch [20/25]: Loss: 0.027164, Coverage: 89.609\n",
      "Training loss = 0.0269, Training coverage = 0.8991\n",
      "Test loss = 0.0283, Test coverage = 0.9223\n",
      "\n",
      "Epoch: 65\n",
      "Batch [10/25]: Loss: 0.028552, Coverage: 87.500\n",
      "Batch [20/25]: Loss: 0.027366, Coverage: 89.375\n",
      "Training loss = 0.0269, Training coverage = 0.8978\n",
      "Test loss = 0.0283, Test coverage = 0.9223\n",
      "\n",
      "Epoch: 66\n",
      "Batch [10/25]: Loss: 0.027326, Coverage: 89.375\n",
      "Batch [20/25]: Loss: 0.027499, Coverage: 89.922\n",
      "Training loss = 0.0269, Training coverage = 0.8978\n",
      "Test loss = 0.0283, Test coverage = 0.9223\n",
      "\n",
      "Epoch: 67\n",
      "Batch [10/25]: Loss: 0.028882, Coverage: 88.594\n",
      "Batch [20/25]: Loss: 0.027380, Coverage: 89.141\n",
      "Training loss = 0.0269, Training coverage = 0.8984\n",
      "Test loss = 0.0283, Test coverage = 0.9223\n",
      "\n",
      "Epoch: 68\n",
      "Batch [10/25]: Loss: 0.027294, Coverage: 90.312\n",
      "Batch [20/25]: Loss: 0.026778, Coverage: 89.922\n",
      "Training loss = 0.0269, Training coverage = 0.8978\n",
      "Test loss = 0.0283, Test coverage = 0.9223\n",
      "\n",
      "Epoch: 69\n",
      "Batch [10/25]: Loss: 0.026033, Coverage: 88.438\n",
      "Batch [20/25]: Loss: 0.026933, Coverage: 89.297\n",
      "Training loss = 0.0269, Training coverage = 0.8991\n",
      "Test loss = 0.0283, Test coverage = 0.9223\n",
      "\n",
      "Epoch: 70\n",
      "Batch [10/25]: Loss: 0.027514, Coverage: 89.062\n",
      "Batch [20/25]: Loss: 0.027242, Coverage: 90.234\n",
      "Training loss = 0.0269, Training coverage = 0.9003\n",
      "Test loss = 0.0283, Test coverage = 0.9223\n",
      "\n",
      "Epoch: 71\n",
      "Batch [10/25]: Loss: 0.027609, Coverage: 90.000\n",
      "Batch [20/25]: Loss: 0.026858, Coverage: 89.922\n",
      "Training loss = 0.0269, Training coverage = 0.9003\n",
      "Test loss = 0.0283, Test coverage = 0.9223\n",
      "\n",
      "Epoch: 72\n",
      "Batch [10/25]: Loss: 0.028518, Coverage: 89.375\n",
      "Batch [20/25]: Loss: 0.027264, Coverage: 89.609\n",
      "Training loss = 0.0269, Training coverage = 0.9003\n",
      "Test loss = 0.0283, Test coverage = 0.9248\n",
      "\n",
      "Epoch: 73\n",
      "Batch [10/25]: Loss: 0.027153, Coverage: 88.906\n",
      "Batch [20/25]: Loss: 0.027397, Coverage: 89.688\n",
      "Training loss = 0.0269, Training coverage = 0.9009\n",
      "Test loss = 0.0283, Test coverage = 0.9248\n",
      "\n",
      "Epoch: 74\n",
      "Batch [10/25]: Loss: 0.026352, Coverage: 90.000\n",
      "Batch [20/25]: Loss: 0.026892, Coverage: 90.234\n",
      "Training loss = 0.0269, Training coverage = 0.9009\n",
      "Test loss = 0.0283, Test coverage = 0.9223\n",
      "\n",
      "Epoch: 75\n",
      "Batch [10/25]: Loss: 0.026561, Coverage: 92.188\n",
      "Batch [20/25]: Loss: 0.026572, Coverage: 90.156\n",
      "Training loss = 0.0269, Training coverage = 0.9009\n",
      "Test loss = 0.0283, Test coverage = 0.9248\n",
      "\n",
      "Epoch: 76\n",
      "Batch [10/25]: Loss: 0.027391, Coverage: 88.750\n",
      "Batch [20/25]: Loss: 0.026969, Coverage: 89.453\n",
      "Training loss = 0.0269, Training coverage = 0.9022\n",
      "Test loss = 0.0283, Test coverage = 0.9248\n",
      "\n",
      "Epoch: 77\n",
      "Batch [10/25]: Loss: 0.028308, Coverage: 89.062\n",
      "Batch [20/25]: Loss: 0.027363, Coverage: 90.234\n",
      "Training loss = 0.0269, Training coverage = 0.9009\n",
      "Test loss = 0.0283, Test coverage = 0.9248\n",
      "\n",
      "Epoch: 78\n",
      "Batch [10/25]: Loss: 0.027193, Coverage: 90.156\n",
      "Batch [20/25]: Loss: 0.026993, Coverage: 89.609\n",
      "Training loss = 0.0269, Training coverage = 0.9016\n",
      "Test loss = 0.0283, Test coverage = 0.9248\n",
      "\n",
      "Epoch: 79\n",
      "Batch [10/25]: Loss: 0.025170, Coverage: 90.312\n",
      "Batch [20/25]: Loss: 0.026625, Coverage: 89.922\n",
      "Training loss = 0.0269, Training coverage = 0.9016\n",
      "Test loss = 0.0283, Test coverage = 0.9248\n",
      "\n",
      "Epoch: 80\n",
      "Batch [10/25]: Loss: 0.026068, Coverage: 89.688\n",
      "Batch [20/25]: Loss: 0.027520, Coverage: 89.531\n",
      "Training loss = 0.0269, Training coverage = 0.9009\n",
      "Test loss = 0.0283, Test coverage = 0.9248\n",
      "\n",
      "Epoch: 81\n",
      "Batch [10/25]: Loss: 0.026839, Coverage: 89.688\n",
      "Batch [20/25]: Loss: 0.027489, Coverage: 89.375\n",
      "Training loss = 0.0269, Training coverage = 0.9009\n",
      "Test loss = 0.0283, Test coverage = 0.9248\n",
      "\n",
      "Epoch: 82\n",
      "Batch [10/25]: Loss: 0.027249, Coverage: 89.531\n",
      "Batch [20/25]: Loss: 0.026509, Coverage: 89.688\n",
      "Training loss = 0.0269, Training coverage = 0.9003\n",
      "Test loss = 0.0283, Test coverage = 0.9223\n",
      "\n",
      "Epoch: 83\n",
      "Batch [10/25]: Loss: 0.029003, Coverage: 89.688\n",
      "Batch [20/25]: Loss: 0.027077, Coverage: 89.844\n",
      "Training loss = 0.0269, Training coverage = 0.8997\n",
      "Test loss = 0.0283, Test coverage = 0.9223\n",
      "\n",
      "Epoch: 84\n",
      "Batch [10/25]: Loss: 0.027100, Coverage: 90.312\n",
      "Batch [20/25]: Loss: 0.027163, Coverage: 90.781\n",
      "Training loss = 0.0269, Training coverage = 0.8991\n",
      "Test loss = 0.0283, Test coverage = 0.9223\n",
      "\n",
      "Epoch: 85\n",
      "Batch [10/25]: Loss: 0.027940, Coverage: 89.531\n",
      "Batch [20/25]: Loss: 0.027052, Coverage: 90.000\n",
      "Training loss = 0.0269, Training coverage = 0.9003\n",
      "Test loss = 0.0283, Test coverage = 0.9248\n",
      "\n",
      "Epoch: 86\n",
      "Batch [10/25]: Loss: 0.027132, Coverage: 89.531\n",
      "Batch [20/25]: Loss: 0.027123, Coverage: 89.531\n",
      "Training loss = 0.0269, Training coverage = 0.9009\n",
      "Test loss = 0.0283, Test coverage = 0.9248\n",
      "\n",
      "Epoch: 87\n",
      "Batch [10/25]: Loss: 0.026157, Coverage: 89.688\n",
      "Batch [20/25]: Loss: 0.027327, Coverage: 89.531\n",
      "Training loss = 0.0269, Training coverage = 0.9003\n",
      "Test loss = 0.0283, Test coverage = 0.9248\n",
      "\n",
      "Epoch: 88\n",
      "Batch [10/25]: Loss: 0.024255, Coverage: 90.938\n",
      "Batch [20/25]: Loss: 0.025946, Coverage: 90.469\n",
      "Training loss = 0.0269, Training coverage = 0.8991\n",
      "Test loss = 0.0283, Test coverage = 0.9223\n",
      "\n",
      "Epoch: 89\n",
      "Batch [10/25]: Loss: 0.027068, Coverage: 90.469\n",
      "Batch [20/25]: Loss: 0.027101, Coverage: 90.312\n",
      "Training loss = 0.0269, Training coverage = 0.9009\n",
      "Test loss = 0.0283, Test coverage = 0.9248\n"
     ]
    }
   ],
   "source": [
    "in_dim = x_train.shape[1]\n",
    "net = LinearModel(in_dim)\n",
    "lr = 1e-2\n",
    "momentum = 0.9\n",
    "quantile = 0.9\n",
    "optimizer = optim.SGD(net.parameters(), lr=lr, momentum=momentum)\n",
    "lambda1 = lambda epoch: np.power(0.1, epoch // 30)\n",
    "scheduler = optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda1)\n",
    "criterion = PinballLoss(quantile=quantile, reduction=\"mean\")\n",
    "num_epochs = 90\n",
    "\n",
    "train_losses, train_covs = list(), list()\n",
    "test_losses, test_covs = list(), list()\n",
    "for epoch in range(num_epochs):\n",
    "    train_loss, train_cov = train(epoch)\n",
    "    print(f\"Training loss = {train_loss:.4f}, Training coverage = {train_cov:.4f}\")\n",
    "    test_loss, test_cov = test(epoch)\n",
    "    print(f\"Test loss = {test_loss:.4f}, Test coverage = {test_cov:.4f}\")\n",
    "    train_losses.append(train_loss)\n",
    "    train_covs.append(train_cov)\n",
    "    test_losses.append(test_loss)\n",
    "    test_covs.append(test_cov)\n",
    "    scheduler.step()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "995da956",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.10"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
