{
  "cells": [
    {
      "cell_type": "markdown",
      "id": "460e18df",
      "metadata": {
        "id": "460e18df"
      },
      "source": [
        "# Warcraft Shortest Path"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "1ef0b563",
      "metadata": {
        "id": "1ef0b563",
        "outputId": "deb1cd53-2206-4769-e76b-2749bf0a186b"
      },
      "outputs": [],
      "source": [
        "import time\n",
        "import random\n",
        "import pyepo\n",
        "import torch\n",
        "from torch import nn\n",
        "from matplotlib import pyplot as plt\n",
        "import numpy as np\n",
        "import pandas as pd\n",
        "import time\n",
        "from tqdm import tqdm\n",
        "from torch.utils.data import DataLoader, Dataset\n",
        "# fix random seed\n",
        "random.seed(135)\n",
        "np.random.seed(135)\n",
        "torch.manual_seed(135)"
      ]
    },
    {
      "cell_type": "markdown",
      "id": "22b83a43",
      "metadata": {
        "id": "22b83a43"
      },
      "source": [
        "## 1 Dataset"
      ]
    },
    {
      "cell_type": "markdown",
      "id": "0d022200",
      "metadata": {
        "id": "0d022200"
      },
      "source": [
        "We use the Warcraft terrains shortest paths [dataset](https://edmond.mpdl.mpg.de/dataset.xhtml?persistentId=doi:10.17617/3.YJCQ5S). Datasets were randomly generated from the Warcraft II [tileset](http://github.com/war2/war2edit) and used in Vlastelica, Marin, et al. \"Differentiation of Blackbox Combinatorial Solvers\". The Warcraft dataset is a captivating benchmark because the image input feature allows learning the shortest path from 10000 RGB terrains"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "8c4791e8",
      "metadata": {
        "id": "8c4791e8"
      },
      "outputs": [],
      "source": [
        "# map size\n",
        "k = 12"
      ]
    },
    {
      "cell_type": "markdown",
      "id": "7d8353b0",
      "metadata": {
        "id": "7d8353b0"
      },
      "source": [
        "### 1.1 Maps"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "a344ab63",
      "metadata": {
        "id": "a344ab63"
      },
      "outputs": [],
      "source": [
        "tmaps_train = np.load(\"./raw_data/{}x{}/train_maps.npy\".format(k,k))\n",
        "#tmaps_val = np.load(\"../data/warcraft_shortest_path_oneskin/{}x{}/val_maps.npy\".format(k,k))\n",
        "tmaps_test = np.load(\"./raw_data/{}x{}/test_maps.npy\".format(k,k))"
      ]
    },
    {
      "cell_type": "markdown",
      "id": "b903781d",
      "metadata": {
        "id": "b903781d"
      },
      "source": [
        "### 1.2 Costs"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "67ce1279",
      "metadata": {
        "id": "67ce1279"
      },
      "outputs": [],
      "source": [
        "costs_train = np.load(\"./raw_data/{}x{}/train_vertex_weights.npy\".format(k,k))\n",
        "#costs_val = np.load(\"../data/warcraft_shortest_path_oneskin/{}x{}/val_vertex_weights.npy\".format(k,k))\n",
        "costs_test = np.load(\"./raw_data/{}x{}/test_vertex_weights.npy\".format(k,k))"
      ]
    },
    {
      "cell_type": "markdown",
      "id": "b723276c",
      "metadata": {
        "id": "b723276c"
      },
      "source": [
        "### 1.3 Paths"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "87438df2",
      "metadata": {
        "id": "87438df2"
      },
      "outputs": [],
      "source": [
        "paths_train = np.load(\"./raw_data/{}x{}/train_shortest_paths.npy\".format(k,k))\n",
        "#paths_val = np.load(\"../data/warcraft_shortest_path_oneskin/{}x{}/val_shortest_paths.npy\".format(k,k))\n",
        "paths_test = np.load(\"./raw_data/{}x{}/test_shortest_paths.npy\".format(k,k))"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "955b3261",
      "metadata": {},
      "outputs": [],
      "source": [
        "# plot some cost matrices\n",
        "plt.axis(\"off\")\n",
        "plt.imshow(costs_train[0])\n",
        "plt.savefig('.imgs/warcraft_cost_matrix_0.pdf')\n",
        "plt.close()\n",
        "\n",
        "plt.axis(\"off\")\n",
        "plt.imshow(costs_train[27])\n",
        "plt.savefig('.imgs/warcraft_cost_matrix_1.pdf')\n",
        "plt.close()"
      ]
    },
    {
      "cell_type": "markdown",
      "id": "3e83d132",
      "metadata": {
        "id": "3e83d132"
      },
      "source": [
        "## 2 Data Loader"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "0271d7a2",
      "metadata": {
        "id": "0271d7a2"
      },
      "outputs": [],
      "source": [
        "class mapDataset(Dataset):\n",
        "    def __init__(self, tmaps, costs, paths):\n",
        "        self.tmaps = tmaps\n",
        "        self.costs = costs\n",
        "        self.paths = paths\n",
        "        self.objs = (costs * paths).sum(axis=(1,2)).reshape(-1,1)\n",
        "\n",
        "    def __len__(self):\n",
        "        return len(self.costs)\n",
        "\n",
        "    def __getitem__(self, ind):\n",
        "        return (\n",
        "            torch.FloatTensor(self.tmaps[ind].transpose(2, 0, 1)/255).detach(), # image\n",
        "            torch.FloatTensor(self.costs[ind]).reshape(-1),\n",
        "            torch.FloatTensor(self.paths[ind]).reshape(-1),\n",
        "            torch.FloatTensor(self.objs[ind]),\n",
        "        )"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "2b71596c",
      "metadata": {
        "id": "2b71596c"
      },
      "outputs": [],
      "source": [
        "# datasets\n",
        "dataset_train = mapDataset(tmaps_train, costs_train, paths_train)\n",
        "#dataset_val = mapDataset(tmaps_val, costs_val, paths_val)\n",
        "dataset_test = mapDataset(tmaps_test, costs_test, paths_test)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "078a1328",
      "metadata": {
        "id": "078a1328"
      },
      "outputs": [],
      "source": [
        "# dataloader\n",
        "batch_size = 70\n",
        "loader_train = DataLoader(dataset_train, batch_size=batch_size, shuffle=True)\n",
        "#loader_val = DataLoader(dataset_val, batch_size=batch_size, shuffle=False)\n",
        "loader_test = DataLoader(dataset_test, batch_size=batch_size, shuffle=False)"
      ]
    },
    {
      "cell_type": "markdown",
      "id": "c190c011",
      "metadata": {
        "id": "c190c011"
      },
      "source": [
        "## 3 Neural Network: Truncated Resnet18"
      ]
    },
    {
      "cell_type": "markdown",
      "id": "ef5b1f76",
      "metadata": {
        "id": "ef5b1f76"
      },
      "source": [
        "Same as previous paper, we used the truncated ResNet18 (first five layers), $50$ epochs with batches of size $70$, learning rate $0.0005$ decaying at the epochs $30$ and $40$, and $n = 1, \\sigma = 1$."
      ]
    },
    {
      "cell_type": "markdown",
      "id": "0db14ca5",
      "metadata": {
        "id": "0db14ca5"
      },
      "source": [
        "### 3.2 Truncated Resnet18"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "77008efd",
      "metadata": {
        "id": "77008efd"
      },
      "outputs": [],
      "source": [
        "from torchvision.models import resnet18\n",
        "nnet = resnet18(pretrained=False)\n",
        "# build new ResNet18 with Max Pooling\n",
        "class partialResNet(nn.Module):\n",
        "\n",
        "    def __init__(self, k):\n",
        "        super(partialResNet, self).__init__()\n",
        "        # init resnet 18\n",
        "        resnet = resnet18(pretrained=False)\n",
        "        # first five layers of ResNet18\n",
        "        self.conv1 = resnet.conv1\n",
        "        self.bn = resnet.bn1\n",
        "        self.relu = resnet.relu\n",
        "        self.maxpool1 = resnet.maxpool\n",
        "        self.block = resnet.layer1\n",
        "        # conv to 1 channel\n",
        "        self.conv2  = nn.Conv2d(64, 1, kernel_size=(1, 1), stride=(1, 1), padding=(1, 1), bias=False)\n",
        "        # max pooling\n",
        "        self.maxpool2 = nn.AdaptiveMaxPool2d((k,k))\n",
        "\n",
        "    def forward(self, x):\n",
        "        h = self.conv1(x)\n",
        "        h = self.bn(h)\n",
        "        h = self.relu(h)\n",
        "        h = self.maxpool1(h)\n",
        "        h = self.block(h)\n",
        "        h = self.conv2(h)\n",
        "        out = self.maxpool2(h)\n",
        "        # reshape for optmodel\n",
        "        out = torch.squeeze(out, 1)\n",
        "        out = out.reshape(out.shape[0], -1)\n",
        "        return out"
      ]
    },
    {
      "cell_type": "markdown",
      "id": "fee41336",
      "metadata": {
        "id": "fee41336"
      },
      "source": [
        "### 3.3 Hyperparameters"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "60971dbc",
      "metadata": {
        "id": "60971dbc"
      },
      "outputs": [],
      "source": [
        "# number of epochs\n",
        "epochs = 50\n",
        "# learning rate\n",
        "lr = 5e-4\n",
        "# log step\n",
        "log_step = 1"
      ]
    },
    {
      "cell_type": "markdown",
      "id": "bbf12ea4",
      "metadata": {
        "id": "bbf12ea4"
      },
      "source": [
        "## 4 Optimization Model: Linear Programming"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "74d187f0",
      "metadata": {
        "id": "74d187f0"
      },
      "outputs": [],
      "source": [
        "import gurobipy as gp\n",
        "from gurobipy import GRB\n",
        "\n",
        "from pyepo.model.grb.grbmodel import optGrbModel\n",
        "\n",
        "class shortestPathModel(optGrbModel):\n",
        "    \"\"\"\n",
        "    This class is optimization model for shortest path problem on 2D grid with 8 neighbors\n",
        "\n",
        "    Attributes:\n",
        "        _model (GurobiPy model): Gurobi model\n",
        "        grid (tuple of int): Size of grid network\n",
        "        nodes (list): list of vertex\n",
        "        edges (list): List of arcs\n",
        "        nodes_map (ndarray): 2D array for node index\n",
        "    \"\"\"\n",
        "\n",
        "    def __init__(self, grid):\n",
        "        \"\"\"\n",
        "        Args:\n",
        "            grid (tuple of int): size of grid network\n",
        "        \"\"\"\n",
        "        self.grid = grid\n",
        "        self.nodes, self.edges, self.nodes_map = self._getEdges()\n",
        "        super().__init__()\n",
        "\n",
        "    def _getEdges(self):\n",
        "        \"\"\"\n",
        "        A method to get list of edges for grid network\n",
        "\n",
        "        Returns:\n",
        "            list: arcs\n",
        "        \"\"\"\n",
        "        # init list\n",
        "        nodes, edges = [], []\n",
        "        # init map from coord to ind\n",
        "        nodes_map = {}\n",
        "        for i in range(self.grid[0]):\n",
        "            for j in range(self.grid[1]):\n",
        "                u = self._calNode(i, j)\n",
        "                nodes_map[u] = (i,j)\n",
        "                nodes.append(u)\n",
        "                # edge to 8 neighbors\n",
        "                # up\n",
        "                if i != 0:\n",
        "                    v = self._calNode(i-1, j)\n",
        "                    edges.append((u,v))\n",
        "                    # up-right\n",
        "                    if j != self.grid[1] - 1:\n",
        "                        v = self._calNode(i-1, j+1)\n",
        "                        edges.append((u,v))\n",
        "                # right\n",
        "                if j != self.grid[1] - 1:\n",
        "                    v = self._calNode(i, j+1)\n",
        "                    edges.append((u,v))\n",
        "                    # down-right\n",
        "                    if i != self.grid[0] - 1:\n",
        "                        v = self._calNode(i+1, j+1)\n",
        "                        edges.append((u,v))\n",
        "                # down\n",
        "                if i != self.grid[0] - 1:\n",
        "                    v = self._calNode(i+1, j)\n",
        "                    edges.append((u,v))\n",
        "                    # down-left\n",
        "                    if j != 0:\n",
        "                        v = self._calNode(i+1, j-1)\n",
        "                        edges.append((u,v))\n",
        "                # left\n",
        "                if j != 0:\n",
        "                    v = self._calNode(i, j-1)\n",
        "                    edges.append((u,v))\n",
        "                    # top-left\n",
        "                    if i != 0:\n",
        "                        v = self._calNode(i-1, j-1)\n",
        "                        edges.append((u,v))\n",
        "        return nodes, edges, nodes_map\n",
        "\n",
        "    def _calNode(self, x, y):\n",
        "        \"\"\"\n",
        "        A method to calculate index of node\n",
        "        \"\"\"\n",
        "        v = x * self.grid[1] + y\n",
        "        return v\n",
        "\n",
        "    def _getModel(self):\n",
        "        \"\"\"\n",
        "        A method to build Gurobi model\n",
        "\n",
        "        Returns:\n",
        "            tuple: optimization model and variables\n",
        "        \"\"\"\n",
        "        # ceate a model\n",
        "        m = gp.Model(\"shortest path\")\n",
        "        # varibles\n",
        "        x = m.addVars(self.edges, ub=1, name=\"x\")\n",
        "        # sense\n",
        "        m.modelSense = GRB.MINIMIZE\n",
        "        # constraints\n",
        "        for i in range(self.grid[0]):\n",
        "            for j in range(self.grid[1]):\n",
        "                v = self._calNode(i, j)\n",
        "                expr = 0\n",
        "                for e in self.edges:\n",
        "                    # flow in\n",
        "                    if v == e[1]:\n",
        "                        expr += x[e]\n",
        "                    # flow out\n",
        "                    elif v == e[0]:\n",
        "                        expr -= x[e]\n",
        "                # source\n",
        "                if i == 0 and j == 0:\n",
        "                    m.addConstr(expr == -1)\n",
        "                # sink\n",
        "                elif i == self.grid[0] - 1 and j == self.grid[0] - 1:\n",
        "                    m.addConstr(expr == 1)\n",
        "                # transition\n",
        "                else:\n",
        "                    m.addConstr(expr == 0)\n",
        "        return m, x\n",
        "\n",
        "    def setObj(self, c):\n",
        "        \"\"\"\n",
        "        A method to set objective function\n",
        "\n",
        "        Args:\n",
        "            c (np.ndarray): cost of objective function\n",
        "        \"\"\"\n",
        "        # vector to matrix\n",
        "        c = c.reshape(self.grid)\n",
        "        # sum up vector cost\n",
        "        obj = c[0,0] + gp.quicksum(c[self.nodes_map[j]] * self.x[i,j] for i, j in self.x)\n",
        "        self._model.setObjective(obj)\n",
        "\n",
        "    def solve(self):\n",
        "        \"\"\"\n",
        "        A method to solve model\n",
        "\n",
        "        Returns:\n",
        "            tuple: optimal solution (list) and objective value (float)\n",
        "        \"\"\"\n",
        "        # update gurobi model\n",
        "        self._model.update()\n",
        "        # solve\n",
        "        self._model.optimize()\n",
        "        # kxk solution map\n",
        "        sol = np.zeros(self.grid)\n",
        "        for i, j in self.edges:\n",
        "            # active edge\n",
        "            if abs(1 - self.x[i,j].x) < 1e-3:\n",
        "                # node on active edge\n",
        "                sol[self.nodes_map[i]] = 1\n",
        "                sol[self.nodes_map[j]] = 1\n",
        "        # matrix to vector\n",
        "        sol = sol.reshape(-1)\n",
        "        return sol, self._model.objVal"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "0b685b7a",
      "metadata": {
        "id": "0b685b7a",
        "outputId": "751e0639-fd0f-4048-a522-98e23a8c0fc0"
      },
      "outputs": [],
      "source": [
        "# init model\n",
        "grid = (k, k)\n",
        "optmodel = shortestPathModel(grid)"
      ]
    },
    {
      "cell_type": "markdown",
      "id": "c4f6fee6",
      "metadata": {
        "id": "c4f6fee6"
      },
      "source": [
        "## 5 Useful Functions"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "f6f668ec",
      "metadata": {
        "id": "f6f668ec"
      },
      "outputs": [],
      "source": [
        "def evaluate(nnet, optmodel, dataloader):\n",
        "    # init data\n",
        "    data = {\"Regret\":[], \"Relative Regret\":[], \"Accuracy\":[], \"Optimal\":[]}\n",
        "    # eval\n",
        "    nnet.eval()\n",
        "    for x, c, w, z in tqdm(dataloader):\n",
        "        # cuda\n",
        "        if next(nnet.parameters()).is_cuda:\n",
        "            x, c, w, z = x.cuda(), c.cuda(), w.cuda(), z.cuda()\n",
        "        # predict\n",
        "        cp = nnet(x)\n",
        "        # to numpy\n",
        "        c = c.to(\"cpu\").detach().numpy()\n",
        "        w = w.to(\"cpu\").detach().numpy()\n",
        "        z = z.to(\"cpu\").detach().numpy()\n",
        "        cp = cp.to(\"cpu\").detach().numpy()\n",
        "        # solve\n",
        "        for i in range(cp.shape[0]):\n",
        "            # sol for pred cost\n",
        "            optmodel.setObj(cp[i])\n",
        "            wpi, _ = optmodel.solve()\n",
        "            # obj with true cost\n",
        "            zpi = np.dot(wpi, c[i])\n",
        "            # round\n",
        "            zpi = zpi.round(1)\n",
        "            zi = z[i,0].round(1)\n",
        "            # regret\n",
        "            regret = (zpi - zi).round(1)\n",
        "            data[\"Regret\"].append(regret)\n",
        "            data[\"Relative Regret\"].append(regret / zi)\n",
        "            # accuracy\n",
        "            data[\"Accuracy\"].append((abs(wpi - w[i]) < 0.5).mean())\n",
        "            # optimal\n",
        "            data[\"Optimal\"].append(abs(regret) < 1e-5)\n",
        "    # dataframe\n",
        "    df = pd.DataFrame.from_dict(data)\n",
        "    # print\n",
        "    time.sleep(1)\n",
        "    print(\"Avg Regret: {:.4f}\".format(df[\"Regret\"].mean()))\n",
        "    print(\"Avg Rel Regret: {:.2f}%\".format(df[\"Relative Regret\"].mean()*100))\n",
        "    print(\"Path Accuracy: {:.2f}%\".format(df[\"Accuracy\"].mean()*100))\n",
        "    print(\"Optimality Ratio: {:.2f}%\".format(df[\"Optimal\"].mean()*100))\n",
        "    return df"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "bd4d46e5",
      "metadata": {},
      "outputs": [],
      "source": [
        "def accuracy(predmodel, optmodel, dataloader):\n",
        "    \"\"\"\n",
        "    A function to evaluate model performance with accuracy.\n",
        "    Written in the style of PyEPO.\n",
        "\n",
        "    Args:\n",
        "        predmodel (nn): a regression neural network for cost prediction\n",
        "        optmodel (optModel): an PyEPO optimization model\n",
        "        dataloader (DataLoader): Torch dataloader from optDataSet\n",
        "\n",
        "    Returns:\n",
        "        float: true regret loss\n",
        "    \"\"\"\n",
        "    # evaluate\n",
        "    predmodel.eval()\n",
        "    loss = 0\n",
        "    optsum = 0\n",
        "    # load data\n",
        "    for data in dataloader:\n",
        "        x, c, w, z = data\n",
        "        batch_loss = 0\n",
        "        # cuda\n",
        "        if next(predmodel.parameters()).is_cuda:\n",
        "            x, c, w, z = x.cuda(), c.cuda(), w.cuda(), z.cuda()\n",
        "        # predict\n",
        "        with torch.no_grad(): # no grad\n",
        "            cp = predmodel(x).to(\"cpu\").detach().numpy()\n",
        "        # solve\n",
        "        for j in range(cp.shape[0]):\n",
        "            # accumulate loss\n",
        "            batch_loss += calAccuracy(optmodel, cp[j], w[j].to(\"cpu\").detach().numpy(),j)\n",
        "        loss += batch_loss/cp.shape[0]\n",
        "    # turn back train mode\n",
        "    predmodel.train()\n",
        "    # normalized\n",
        "    return loss # divide by batch size\n",
        "\n",
        "\n",
        "def calAccuracy(optmodel, pred_cost, true_sol,idx):\n",
        "    \"\"\"\n",
        "    A function to calculate normalized true regret for a batch\n",
        "\n",
        "    Args:\n",
        "        optmodel (optModel): optimization model\n",
        "        pred_cost (torch.tensor): predicted costs\n",
        "        true_sol (torch.tensor): true solution\n",
        "\n",
        "    Returns:\n",
        "        acc: 1 if true_sol matches that computed using pred_cost to within a small tolerance.\n",
        "        \n",
        "    \"\"\"\n",
        "    # opt sol for pred cost\n",
        "    optmodel.setObj(pred_cost)\n",
        "    sol, _ = optmodel.solve()\n",
        "    if np.linalg.norm(sol - true_sol)< 1e-4:\n",
        "        acc = 1\n",
        "    else:\n",
        "        acc = 0\n",
        "    return acc"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "8ca5993d",
      "metadata": {
        "id": "8ca5993d"
      },
      "outputs": [],
      "source": [
        "def plotLearningCurve(loss_log, regret_log, epochs):\n",
        "    # draw loss during training\n",
        "    plt.figure(figsize=(8, 4))\n",
        "    plt.plot(loss_log, color=\"c\")\n",
        "    plt.xticks(fontsize=10)\n",
        "    plt.yticks(fontsize=10)\n",
        "    plt.xlim(-100, len(loss_log)+100)\n",
        "    plt.xlabel(\"Iters\", fontsize=12)\n",
        "    plt.ylabel(\"Loss\", fontsize=12)\n",
        "    plt.title(\"Learning Curve on Training Set\", fontsize=12)\n",
        "    plt.show()\n",
        "    # draw normalized regret on test\n",
        "    plt.figure(figsize=(8, 4))\n",
        "    plt.plot([i*log_step for i in range(len(regret_log))], regret_log, color=\"royalblue\")\n",
        "    plt.xticks(fontsize=10)\n",
        "    plt.yticks(fontsize=10)\n",
        "    plt.xlim(-epochs/50, epochs+epochs/50)\n",
        "    plt.ylim(0, max(regret_log[1:])*1.1)\n",
        "    plt.xlabel(\"Epochs\", fontsize=12)\n",
        "    plt.ylabel(\"Normalized Regret\", fontsize=12)\n",
        "    plt.title(\"Learning Curve on Test Set\", fontsize=12)\n",
        "    plt.show()"
      ]
    },
    {
      "cell_type": "markdown",
      "id": "871b6c66",
      "metadata": {
        "id": "871b6c66"
      },
      "source": [
        "## 6 Training"
      ]
    },
    {
      "cell_type": "markdown",
      "id": "a290a155",
      "metadata": {
        "id": "a290a155"
      },
      "source": [
        "### 6.1 Baseline\n",
        "Baseline model: training with binary cross entropy loss of solutions"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "9609a0a3",
      "metadata": {
        "id": "9609a0a3"
      },
      "outputs": [],
      "source": [
        "# init net\n",
        "nnet = partialResNet(k=12)\n",
        "# cuda\n",
        "if torch.cuda.is_available():\n",
        "    nnet = nnet.cuda()\n",
        "# set optimizer\n",
        "optimizer = torch.optim.Adam(nnet.parameters(), lr=lr)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "e33f9bc6",
      "metadata": {
        "id": "e33f9bc6"
      },
      "outputs": [],
      "source": [
        "# set loss\n",
        "bceloss = nn.BCELoss()"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "027f0153",
      "metadata": {
        "id": "027f0153",
        "outputId": "e35b1327-3d46-4f88-d98e-3602936a25d5"
      },
      "outputs": [],
      "source": [
        "# train\n",
        "loss_log_baseline = []\n",
        "regret_log_baseline = []\n",
        "acc_log_baseline = []\n",
        "epoch_time_log_baseline = []\n",
        "train_time_baseline = 0\n",
        "nnet.train()\n",
        "tbar = tqdm(range(150))\n",
        "for epoch in tbar:\n",
        "    start_time_epoch = time.time()\n",
        "    for x, c, w, z in loader_train:\n",
        "        # cuda\n",
        "        if torch.cuda.is_available():\n",
        "            x, c, w, z = x.cuda(), c.cuda(), w.cuda(), z.cuda()\n",
        "        # forward pass\n",
        "        h = nnet(x)\n",
        "        wp = torch.sigmoid(h)\n",
        "        loss = bceloss(wp, w) # loss\n",
        "        # backward pass\n",
        "        optimizer.zero_grad()\n",
        "        loss.backward()\n",
        "        optimizer.step()\n",
        "        # log\n",
        "        loss_log_baseline.append(loss.item())\n",
        "        tbar.set_description(\"Epoch: {:2}, Loss: {:3.4f}\".format(epoch, loss.item()))\n",
        "    \n",
        "    end_time_epoch = time.time()\n",
        "    epoch_time = end_time_epoch - start_time_epoch\n",
        "    train_time_baseline += epoch_time\n",
        "    epoch_time_log_baseline.append(epoch_time)\n",
        "    # scheduled learning rate\n",
        "    if (epoch == 90) or (epoch == 120):\n",
        "        for g in optimizer.param_groups:\n",
        "            g['lr'] /= 10"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "eef45f54",
      "metadata": {},
      "outputs": [],
      "source": [
        "# Evaluate\n",
        "print(\"Test set:\")\n",
        "df_baseline = evaluate(nnet, optmodel, loader_test)\n",
        "test_acc = accuracy(nnet, optmodel, loader_test)\n",
        "print(f'True test accuracy is {test_acc}')"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "a04beb0c",
      "metadata": {},
      "outputs": [],
      "source": [
        "print(train_time_baseline)"
      ]
    },
    {
      "cell_type": "markdown",
      "id": "dd15a583",
      "metadata": {
        "id": "dd15a583"
      },
      "source": [
        "### 6.4 DBB"
      ]
    },
    {
      "cell_type": "markdown",
      "id": "509eb432",
      "metadata": {
        "id": "509eb432"
      },
      "source": [
        "DBB model: training with differentiable black-box optimizer and MSE of solutions"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "00de53b7",
      "metadata": {
        "id": "00de53b7"
      },
      "outputs": [],
      "source": [
        "# init net\n",
        "nnet = partialResNet(k=12)\n",
        "# cuda\n",
        "if torch.cuda.is_available():\n",
        "    nnet = nnet.cuda()\n",
        "# set optimizer\n",
        "optimizer = torch.optim.Adam(nnet.parameters(), lr=1e-5)\n",
        "# set stopper\n",
        "#stopper = earlyStopper(patience=7)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "6f982a1f",
      "metadata": {
        "id": "6f982a1f",
        "outputId": "9bd35b56-0d7d-47c5-c636-acdbb4fa0b41"
      },
      "outputs": [],
      "source": [
        "# init dbb\n",
        "dbb = pyepo.func.blackboxOpt(optmodel, lambd=10, processes=1)\n",
        "# set loss\n",
        "class hammingLoss(torch.nn.Module):\n",
        "    def forward(self, wp, w):\n",
        "        loss = wp * (1.0 - w) + (1.0 - wp) * w\n",
        "        return loss.mean(dim=0).sum()\n",
        "hmloss = hammingLoss()"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "071cf53e",
      "metadata": {
        "id": "071cf53e",
        "outputId": "0d2e5d23-c89b-4949-e52e-ac3b8b1b8542"
      },
      "outputs": [],
      "source": [
        "# train\n",
        "regret = pyepo.metric.regret(nnet, optmodel, loader_test)\n",
        "loss_log_dbb = []\n",
        "regret_log_dbb = [regret]\n",
        "acc_log_dbb = []\n",
        "epoch_time_log_dbb = []\n",
        "train_time_dbb = 0\n",
        "print(f'regret is {regret}')\n",
        "tbar = tqdm(range(epochs))\n",
        "for epoch in tbar:\n",
        "    nnet.train()\n",
        "    start_time_epoch = time.time()\n",
        "    for x, c, w, z in loader_train:\n",
        "        # cuda\n",
        "        if torch.cuda.is_available():\n",
        "            x, c, w, z = x.cuda(), c.cuda(), w.cuda(), z.cuda()\n",
        "        # forward pass\n",
        "        cp = nnet(x) # predicted cost\n",
        "        wp = dbb(cp) # black-box optimizer\n",
        "        loss = hmloss(wp, w) # loss\n",
        "        # backward pass\n",
        "        optimizer.zero_grad()\n",
        "        loss.backward()\n",
        "        optimizer.step()\n",
        "        # log\n",
        "        loss_log_dbb.append(loss.item())\n",
        "        epoch_time = start_time_epoch - time.time()\n",
        "        train_time_dbb += epoch_time\n",
        "        epoch_time_log_dbb.append(epoch_time)\n",
        "\n",
        "        tbar.set_description(\"Epoch: {:2}, Loss: {:3.4f}\".format(epoch, loss.item()))\n",
        "    # scheduled learning rate\n",
        "    if (epoch == int(epochs*0.6)) or (epoch == int(epochs*0.8)):\n",
        "        for g in optimizer.param_groups:\n",
        "            g['lr'] /= 10\n",
        "    if epoch % log_step == 0:\n",
        "        # log regret\n",
        "        regret = pyepo.metric.regret(nnet, optmodel, loader_test) # regret on test\n",
        "        regret_log_dbb.append(regret)\n",
        "        print(f'regret is {regret}')\n",
        "        # early stop\n",
        "        #regret = pyepo.metric.regret(nnet, optmodel, loader_val) # regret on val\n",
        "        #if stopper.stop(regret):\n",
        "        #    break"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "a6dbba91",
      "metadata": {
        "id": "a6dbba91",
        "outputId": "d9398d54-5d2c-4439-a0f5-7dccd0aa78fd"
      },
      "outputs": [],
      "source": [
        "# plot\n",
        "plotLearningCurve(loss_log_dbb,  regret_log_dbb)\n",
        "# eval\n",
        "print(\"Test set:\")\n",
        "df_dbb = evaluate(nnet, optmodel, loader_test)\n",
        "test_acc = accuracy(nnet, optmodel, loader_test)\n",
        "print(f'True test accuracy is {test_acc}')\n",
        "print(f'Training time is {train_time_dbb}')\n"
      ]
    },
    {
      "cell_type": "markdown",
      "id": "74fd3e7f",
      "metadata": {
        "id": "74fd3e7f"
      },
      "source": [
        "### 6.5 DPO"
      ]
    },
    {
      "cell_type": "markdown",
      "id": "5730bace",
      "metadata": {
        "id": "5730bace"
      },
      "source": [
        "DPO model: training with differentiable perturbed optimizer and MSE of solutions"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "e7688ea5",
      "metadata": {
        "id": "e7688ea5"
      },
      "outputs": [],
      "source": [
        "# init net\n",
        "nnet = partialResNet(k=12)\n",
        "# cuda\n",
        "if torch.cuda.is_available():\n",
        "    nnet = nnet.cuda()\n",
        "# set optimizer\n",
        "optimizer = torch.optim.Adam(nnet.parameters(), lr=lr)\n",
        "# set stopper\n",
        "#stopper = earlyStopper(patience=7)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "2c344cf6",
      "metadata": {
        "id": "2c344cf6",
        "outputId": "291adb81-729d-44b1-c04b-8760b3b68393"
      },
      "outputs": [],
      "source": [
        "# init dpo\n",
        "ptb = pyepo.func.perturbedOpt(optmodel, n_samples=1, sigma=1.0, processes=1)\n",
        "# set loss\n",
        "mseloss = nn.MSELoss()"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "ef6cfe80",
      "metadata": {
        "id": "ef6cfe80",
        "outputId": "7cb86127-d353-4b9e-e31c-1194ffeda9b8"
      },
      "outputs": [],
      "source": [
        "# train\n",
        "loss_log5, regret_log5 = [], [pyepo.metric.regret(nnet, optmodel, loader_test)]\n",
        "tbar = tqdm(range(epochs))\n",
        "for epoch in tbar:\n",
        "    nnet.train()\n",
        "    for x, c, w, z in loader_train:\n",
        "        # cuda\n",
        "        if torch.cuda.is_available():\n",
        "            x, c, w, z = x.cuda(), c.cuda(), w.cuda(), z.cuda()\n",
        "        # forward pass\n",
        "        cp = nnet(x) # predicted cost\n",
        "        we = ptb(cp) # perturbed optimizer\n",
        "        loss = mseloss(we, w) # loss\n",
        "        # backward pass\n",
        "        optimizer.zero_grad()\n",
        "        loss.backward()\n",
        "        optimizer.step()\n",
        "        # log\n",
        "        loss_log5.append(loss.item())\n",
        "        tbar.set_description(\"Epoch: {:2}, Loss: {:3.4f}\".format(epoch, loss.item()))\n",
        "    # scheduled learning rate\n",
        "    if (epoch == int(epochs*0.6)) or (epoch == int(epochs*0.8)):\n",
        "        for g in optimizer.param_groups:\n",
        "            g['lr'] /= 10\n",
        "    if epoch % log_step == 0:\n",
        "        # log regret\n",
        "        regret = pyepo.metric.regret(nnet, optmodel, loader_test) # regret on test\n",
        "        regret_log5.append(regret)\n",
        "        # early stop\n",
        "        #regret = pyepo.metric.regret(nnet, optmodel, loader_val) # regret on val\n",
        "        #if stopper.stop(regret):\n",
        "        #    break"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "c5db27ae",
      "metadata": {
        "id": "c5db27ae",
        "outputId": "7ef1052d-1cf2-49fe-db5f-8bd2bc1e5318"
      },
      "outputs": [],
      "source": [
        "# plot\n",
        "plotLearningCurve(loss_log5, regret_log5)\n",
        "# eval\n",
        "print(\"Test set:\")\n",
        "df5 = evaluate(nnet, optmodel, loader_test)"
      ]
    }
  ],
  "metadata": {
    "colab": {
      "provenance": []
    },
    "kernelspec": {
      "display_name": "fpo-dys-env",
      "language": "python",
      "name": "python3"
    },
    "language_info": {
      "codemirror_mode": {
        "name": "ipython",
        "version": 3
      },
      "file_extension": ".py",
      "mimetype": "text/x-python",
      "name": "python",
      "nbconvert_exporter": "python",
      "pygments_lexer": "ipython3",
      "version": "3.9.6"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 5
}
