{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Branching with Imitation Learning and a GNN\n",
    "\n",
    "In this tutorial we will reproduce a simplified version of the paper of Gasse et al. (2019) on learning to branch with Ecole with `pytorch` and `pytorch geometric`. We collect strong branching examples on randomly generated maximum set covering instances, then train a graph neural network with bipartite state encodings to imitate the expert by classification. Finally, we will evaluate the quality of the policy.\n",
    "\n",
    "The biggest difference with Gasse et al. (2019) is that only n=1,000 training examples of expert decisions are collected for training, to keep the time needed to run the tutorial reasonable. As a consequence, the resulting policy is undertrained and is not competitive with SCIP's default branching rule.\n",
    "\n",
    "Users that are interested in reproducing competitive performance should use a larger sample size, such as the n=100,000 samples used for training in the paper. In this case, we strongly recommend to parallelize data collection, as in the original Gasse et al. (2019) code.\n",
    "\n",
    "### Requirements\n",
    "This tutorial requires the following libraries. The version numbers used when writing this tutorial are given in parentheses.\n",
    "- `python` (3.8.2)\n",
    "- `numpy` (1.19.4)\n",
    "- `pytorch` (1.7.0)\n",
    "- `pytorch-geometric` (1.6.2)\n",
    "- `ecole` (0.6.0)\n",
    "\n",
    "The tutorial was designed with the provided version numbers."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 1. Data collection\n",
    "\n",
    "Our first step will be to run explore-then-strong-branch on randomly generated maximum set covering instances, and save the branching decisions to build a dataset. We will also record the state of the branch-and-bound process as a bipartite graph, which is already implemented in Ecole with the same features as Gasse et al. (2019)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "ename": "ModuleNotFoundError",
     "evalue": "No module named 'ecole'",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mModuleNotFoundError\u001b[0m                       Traceback (most recent call last)",
      "Cell \u001b[0;32mIn[2], line 4\u001b[0m\n\u001b[1;32m      2\u001b[0m \u001b[38;5;28;01mimport\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;21;01mpickle\u001b[39;00m\n\u001b[1;32m      3\u001b[0m \u001b[38;5;28;01mimport\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;21;01mnumpy\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;28;01mas\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;21;01mnp\u001b[39;00m\n\u001b[0;32m----> 4\u001b[0m \u001b[38;5;28;01mimport\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;21;01mecole\u001b[39;00m\n\u001b[1;32m      5\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;21;01mpathlib\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;28;01mimport\u001b[39;00m Path\n\u001b[1;32m      7\u001b[0m MAX_SAMPLES \u001b[38;5;241m=\u001b[39m \u001b[38;5;241m1000\u001b[39m\n",
      "\u001b[0;31mModuleNotFoundError\u001b[0m: No module named 'ecole'"
     ]
    }
   ],
   "source": [
    "import gzip\n",
    "import pickle\n",
    "import numpy as np\n",
    "import ecole\n",
    "from pathlib import Path\n",
    "\n",
    "MAX_SAMPLES = 1000"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We will use the Ecole-provided set cover instance generator."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "ename": "NameError",
     "evalue": "name 'ecole' is not defined",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mNameError\u001b[0m                                 Traceback (most recent call last)",
      "Cell \u001b[0;32mIn[4], line 1\u001b[0m\n\u001b[0;32m----> 1\u001b[0m instances \u001b[38;5;241m=\u001b[39m \u001b[43mecole\u001b[49m\u001b[38;5;241m.\u001b[39minstance\u001b[38;5;241m.\u001b[39mSetCoverGenerator(n_rows\u001b[38;5;241m=\u001b[39m\u001b[38;5;241m500\u001b[39m, n_cols\u001b[38;5;241m=\u001b[39m\u001b[38;5;241m1000\u001b[39m, density\u001b[38;5;241m=\u001b[39m\u001b[38;5;241m0.05\u001b[39m)\n",
      "\u001b[0;31mNameError\u001b[0m: name 'ecole' is not defined"
     ]
    }
   ],
   "source": [
    "instances = ecole.instance.SetCoverGenerator(n_rows=500, n_cols=1000, density=0.05)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The explore-then-strong-branch scheme described in the paper is not implemented by default in Ecole. In this scheme, to diversify the states in which we collect examples of strong branching behavior, we mostly follow a weak but cheap expert (pseudocost branching) and only occasionally call the strong expert (strong branching). This also ensures that samples are closer to being independent and identically distributed.\n",
    "\n",
    "This can be realized in Ecole by creating a custom observation function, which will randomly compute and return the pseudocost scores (cheap) or the strong branching scores (expensive). It also showcases extensibility in Ecole by showing how easily a custom observation function can be created and used, directly in Python."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "class ExploreThenStrongBranch:\n",
    "    \"\"\"\n",
    "    This custom observation function class will randomly return either strong branching scores (expensive expert) \n",
    "    or pseudocost scores (weak expert for exploration) when called at every node.\n",
    "    \"\"\"\n",
    "    def __init__(self, expert_probability):\n",
    "        self.expert_probability = expert_probability\n",
    "        self.pseudocosts_function = ecole.observation.Pseudocosts()\n",
    "        self.strong_branching_function = ecole.observation.StrongBranchingScores()\n",
    "    \n",
    "    def before_reset(self, model):\n",
    "        \"\"\"\n",
    "        This function will be called at initialization of the environment (before dynamics are reset).\n",
    "        \"\"\"\n",
    "        self.pseudocosts_function.before_reset(model)\n",
    "        self.strong_branching_function.before_reset(model)\n",
    "    \n",
    "    def extract(self, model, done):\n",
    "        \"\"\"\n",
    "        Should we return strong branching or pseudocost scores at time node?\n",
    "        \"\"\"\n",
    "        probabilities = [1-self.expert_probability, self.expert_probability]\n",
    "        expert_chosen = bool(np.random.choice(np.arange(2), p=probabilities))\n",
    "        if expert_chosen:\n",
    "            return (self.strong_branching_function.extract(model, done), True)\n",
    "        else:\n",
    "            return (self.pseudocosts_function.extract(model, done), False)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We can now create the environment with the correct parameters (no restarts, 1h time limit, 5% expert sampling probability).\n",
    "\n",
    "Besides the (pseudocost or strong branching) scores, our environment will return the node bipartite graph representation of \n",
    "branch-and-bound states used in Gasse et al. (2019), using the `ecole.observation.NodeBipartite` observation function.\n",
    "On one side of that bipartite graph, nodes represent the variables of the problem, with a vector encoding features of \n",
    "that variable. On the other side of the bipartite graph, nodes represent the constraints of the problem, similarly with \n",
    "a vector encoding features of that constraint. An edge links a variable and a constraint node if the variable participates \n",
    "in that constraint, that is, its coefficient is nonzero in that constraint. The constraint coefficient is attached as an\n",
    "attribute of the edge."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "# We can pass custom SCIP parameters easily\n",
    "scip_parameters = {'separating/maxrounds': 0, 'presolving/maxrestarts': 0, 'limits/time': 3600}\n",
    "\n",
    "# Note how we can tuple observation functions to return complex state information\n",
    "env = ecole.environment.Branching(observation_function=(ExploreThenStrongBranch(expert_probability=0.05), \n",
    "                                                        ecole.observation.NodeBipartite()), \n",
    "                                  scip_params=scip_parameters)\n",
    "\n",
    "# This will seed the environment for reproducibility\n",
    "env.seed(0)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Now we loop over the instances, following the strong branching expert 5% of the time and saving its decision, until enough samples are collected."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Episode 1, 2 samples collected so far\n",
      "Episode 2, 2 samples collected so far\n",
      "Episode 3, 2 samples collected so far\n",
      "Episode 4, 7 samples collected so far\n",
      "Episode 5, 7 samples collected so far\n",
      "Episode 6, 7 samples collected so far\n",
      "Episode 7, 7 samples collected so far\n",
      "Episode 8, 8 samples collected so far\n",
      "Episode 9, 8 samples collected so far\n",
      "Episode 10, 8 samples collected so far\n",
      "Episode 11, 8 samples collected so far\n",
      "Episode 12, 9 samples collected so far\n",
      "Episode 13, 9 samples collected so far\n",
      "Episode 14, 10 samples collected so far\n",
      "Episode 15, 13 samples collected so far\n",
      "Episode 16, 13 samples collected so far\n",
      "Episode 17, 16 samples collected so far\n",
      "Episode 18, 24 samples collected so far\n",
      "Episode 19, 24 samples collected so far\n",
      "Episode 20, 57 samples collected so far\n",
      "Episode 21, 59 samples collected so far\n",
      "Episode 22, 59 samples collected so far\n",
      "Episode 23, 60 samples collected so far\n",
      "Episode 24, 60 samples collected so far\n",
      "Episode 25, 60 samples collected so far\n",
      "Episode 26, 60 samples collected so far\n",
      "Episode 27, 63 samples collected so far\n",
      "Episode 28, 63 samples collected so far\n",
      "Episode 29, 64 samples collected so far\n",
      "Episode 30, 70 samples collected so far\n",
      "Episode 31, 70 samples collected so far\n",
      "Episode 32, 70 samples collected so far\n",
      "Episode 33, 76 samples collected so far\n",
      "Episode 34, 79 samples collected so far\n",
      "Episode 35, 93 samples collected so far\n",
      "Episode 36, 93 samples collected so far\n",
      "Episode 37, 94 samples collected so far\n",
      "Episode 38, 104 samples collected so far\n",
      "Episode 39, 109 samples collected so far\n",
      "Episode 40, 109 samples collected so far\n",
      "Episode 41, 111 samples collected so far\n",
      "Episode 42, 122 samples collected so far\n",
      "Episode 43, 122 samples collected so far\n",
      "Episode 44, 123 samples collected so far\n",
      "Episode 45, 125 samples collected so far\n",
      "Episode 46, 125 samples collected so far\n",
      "Episode 47, 125 samples collected so far\n",
      "Episode 48, 143 samples collected so far\n",
      "Episode 49, 145 samples collected so far\n",
      "Episode 50, 152 samples collected so far\n",
      "Episode 51, 152 samples collected so far\n",
      "Episode 52, 154 samples collected so far\n",
      "Episode 53, 155 samples collected so far\n",
      "Episode 54, 155 samples collected so far\n",
      "Episode 55, 155 samples collected so far\n",
      "Episode 56, 155 samples collected so far\n",
      "Episode 57, 160 samples collected so far\n",
      "Episode 58, 160 samples collected so far\n",
      "Episode 59, 162 samples collected so far\n",
      "Episode 60, 164 samples collected so far\n",
      "Episode 61, 167 samples collected so far\n",
      "Episode 62, 167 samples collected so far\n",
      "Episode 63, 168 samples collected so far\n",
      "Episode 64, 168 samples collected so far\n",
      "Episode 65, 169 samples collected so far\n",
      "Episode 66, 171 samples collected so far\n",
      "Episode 67, 171 samples collected so far\n",
      "Episode 68, 171 samples collected so far\n",
      "Episode 69, 174 samples collected so far\n",
      "Episode 70, 174 samples collected so far\n",
      "Episode 71, 175 samples collected so far\n",
      "Episode 72, 178 samples collected so far\n",
      "Episode 73, 179 samples collected so far\n",
      "Episode 74, 180 samples collected so far\n",
      "Episode 75, 204 samples collected so far\n",
      "Episode 76, 204 samples collected so far\n",
      "Episode 77, 204 samples collected so far\n",
      "Episode 78, 205 samples collected so far\n",
      "Episode 79, 208 samples collected so far\n",
      "Episode 80, 220 samples collected so far\n",
      "Episode 81, 220 samples collected so far\n",
      "Episode 82, 220 samples collected so far\n",
      "Episode 83, 220 samples collected so far\n",
      "Episode 84, 221 samples collected so far\n",
      "Episode 85, 221 samples collected so far\n",
      "Episode 86, 221 samples collected so far\n",
      "Episode 87, 223 samples collected so far\n",
      "Episode 88, 223 samples collected so far\n",
      "Episode 89, 223 samples collected so far\n",
      "Episode 90, 224 samples collected so far\n",
      "Episode 91, 225 samples collected so far\n",
      "Episode 92, 226 samples collected so far\n",
      "Episode 93, 226 samples collected so far\n",
      "Episode 94, 252 samples collected so far\n",
      "Episode 95, 252 samples collected so far\n",
      "Episode 96, 254 samples collected so far\n",
      "Episode 97, 254 samples collected so far\n",
      "Episode 98, 255 samples collected so far\n",
      "Episode 99, 255 samples collected so far\n",
      "Episode 100, 255 samples collected so far\n",
      "Episode 101, 257 samples collected so far\n",
      "Episode 102, 259 samples collected so far\n",
      "Episode 103, 259 samples collected so far\n",
      "Episode 104, 260 samples collected so far\n",
      "Episode 105, 262 samples collected so far\n",
      "Episode 106, 267 samples collected so far\n",
      "Episode 107, 267 samples collected so far\n",
      "Episode 108, 267 samples collected so far\n",
      "Episode 109, 270 samples collected so far\n",
      "Episode 110, 277 samples collected so far\n",
      "Episode 111, 278 samples collected so far\n",
      "Episode 112, 279 samples collected so far\n",
      "Episode 113, 279 samples collected so far\n",
      "Episode 114, 282 samples collected so far\n",
      "Episode 115, 282 samples collected so far\n",
      "Episode 116, 286 samples collected so far\n",
      "Episode 117, 290 samples collected so far\n",
      "Episode 118, 298 samples collected so far\n",
      "Episode 119, 301 samples collected so far\n",
      "Episode 120, 301 samples collected so far\n",
      "Episode 121, 303 samples collected so far\n",
      "Episode 122, 303 samples collected so far\n",
      "Episode 123, 303 samples collected so far\n",
      "Episode 124, 305 samples collected so far\n",
      "Episode 125, 305 samples collected so far\n",
      "Episode 126, 307 samples collected so far\n",
      "Episode 127, 312 samples collected so far\n",
      "Episode 128, 323 samples collected so far\n",
      "Episode 129, 324 samples collected so far\n",
      "Episode 130, 325 samples collected so far\n",
      "Episode 131, 325 samples collected so far\n",
      "Episode 132, 330 samples collected so far\n",
      "Episode 133, 330 samples collected so far\n",
      "Episode 134, 330 samples collected so far\n",
      "Episode 135, 330 samples collected so far\n",
      "Episode 136, 334 samples collected so far\n",
      "Episode 137, 335 samples collected so far\n",
      "Episode 138, 335 samples collected so far\n",
      "Episode 139, 340 samples collected so far\n",
      "Episode 140, 341 samples collected so far\n",
      "Episode 141, 342 samples collected so far\n",
      "Episode 142, 349 samples collected so far\n",
      "Episode 143, 351 samples collected so far\n",
      "Episode 144, 365 samples collected so far\n",
      "Episode 145, 365 samples collected so far\n",
      "Episode 146, 377 samples collected so far\n",
      "Episode 147, 377 samples collected so far\n",
      "Episode 148, 379 samples collected so far\n",
      "Episode 149, 379 samples collected so far\n",
      "Episode 150, 381 samples collected so far\n",
      "Episode 151, 386 samples collected so far\n",
      "Episode 152, 386 samples collected so far\n",
      "Episode 153, 389 samples collected so far\n",
      "Episode 154, 390 samples collected so far\n",
      "Episode 155, 390 samples collected so far\n",
      "Episode 156, 390 samples collected so far\n",
      "Episode 157, 392 samples collected so far\n",
      "Episode 158, 393 samples collected so far\n",
      "Episode 159, 395 samples collected so far\n",
      "Episode 160, 401 samples collected so far\n",
      "Episode 161, 401 samples collected so far\n",
      "Episode 162, 401 samples collected so far\n",
      "Episode 163, 403 samples collected so far\n",
      "Episode 164, 403 samples collected so far\n",
      "Episode 165, 403 samples collected so far\n",
      "Episode 166, 403 samples collected so far\n",
      "Episode 167, 403 samples collected so far\n",
      "Episode 168, 404 samples collected so far\n",
      "Episode 169, 404 samples collected so far\n",
      "Episode 170, 406 samples collected so far\n",
      "Episode 171, 406 samples collected so far\n",
      "Episode 172, 409 samples collected so far\n",
      "Episode 173, 412 samples collected so far\n",
      "Episode 174, 417 samples collected so far\n",
      "Episode 175, 422 samples collected so far\n",
      "Episode 176, 422 samples collected so far\n",
      "Episode 177, 422 samples collected so far\n",
      "Episode 178, 422 samples collected so far\n",
      "Episode 179, 422 samples collected so far\n",
      "Episode 180, 423 samples collected so far\n",
      "Episode 181, 424 samples collected so far\n",
      "Episode 182, 427 samples collected so far\n",
      "Episode 183, 427 samples collected so far\n",
      "Episode 184, 428 samples collected so far\n",
      "Episode 185, 428 samples collected so far\n",
      "Episode 186, 428 samples collected so far\n",
      "Episode 187, 429 samples collected so far\n",
      "Episode 188, 430 samples collected so far\n",
      "Episode 189, 431 samples collected so far\n",
      "Episode 190, 434 samples collected so far\n",
      "Episode 191, 435 samples collected so far\n",
      "Episode 192, 437 samples collected so far\n",
      "Episode 193, 437 samples collected so far\n",
      "Episode 194, 437 samples collected so far\n",
      "Episode 195, 437 samples collected so far\n",
      "Episode 196, 438 samples collected so far\n",
      "Episode 197, 440 samples collected so far\n",
      "Episode 198, 441 samples collected so far\n",
      "Episode 199, 459 samples collected so far\n",
      "Episode 200, 459 samples collected so far\n",
      "Episode 201, 459 samples collected so far\n",
      "Episode 202, 462 samples collected so far\n",
      "Episode 203, 463 samples collected so far\n",
      "Episode 204, 466 samples collected so far\n",
      "Episode 205, 472 samples collected so far\n",
      "Episode 206, 472 samples collected so far\n",
      "Episode 207, 475 samples collected so far\n",
      "Episode 208, 476 samples collected so far\n",
      "Episode 209, 477 samples collected so far\n",
      "Episode 210, 477 samples collected so far\n",
      "Episode 211, 480 samples collected so far\n",
      "Episode 212, 491 samples collected so far\n",
      "Episode 213, 494 samples collected so far\n",
      "Episode 214, 494 samples collected so far\n",
      "Episode 215, 494 samples collected so far\n",
      "Episode 216, 502 samples collected so far\n",
      "Episode 217, 507 samples collected so far\n",
      "Episode 218, 509 samples collected so far\n",
      "Episode 219, 509 samples collected so far\n",
      "Episode 220, 510 samples collected so far\n",
      "Episode 221, 510 samples collected so far\n",
      "Episode 222, 511 samples collected so far\n",
      "Episode 223, 512 samples collected so far\n",
      "Episode 224, 512 samples collected so far\n",
      "Episode 225, 512 samples collected so far\n",
      "Episode 226, 512 samples collected so far\n",
      "Episode 227, 512 samples collected so far\n",
      "Episode 228, 514 samples collected so far\n",
      "Episode 229, 514 samples collected so far\n",
      "Episode 230, 517 samples collected so far\n",
      "Episode 231, 517 samples collected so far\n",
      "Episode 232, 517 samples collected so far\n",
      "Episode 233, 528 samples collected so far\n",
      "Episode 234, 537 samples collected so far\n",
      "Episode 235, 537 samples collected so far\n",
      "Episode 236, 537 samples collected so far\n",
      "Episode 237, 537 samples collected so far\n",
      "Episode 238, 537 samples collected so far\n",
      "Episode 239, 537 samples collected so far\n",
      "Episode 240, 537 samples collected so far\n",
      "Episode 241, 537 samples collected so far\n",
      "Episode 242, 538 samples collected so far\n",
      "Episode 243, 539 samples collected so far\n",
      "Episode 244, 545 samples collected so far\n",
      "Episode 245, 545 samples collected so far\n",
      "Episode 246, 546 samples collected so far\n",
      "Episode 247, 546 samples collected so far\n",
      "Episode 248, 546 samples collected so far\n",
      "Episode 249, 547 samples collected so far\n",
      "Episode 250, 551 samples collected so far\n",
      "Episode 251, 554 samples collected so far\n",
      "Episode 252, 555 samples collected so far\n",
      "Episode 253, 555 samples collected so far\n",
      "Episode 254, 556 samples collected so far\n",
      "Episode 255, 556 samples collected so far\n",
      "Episode 256, 562 samples collected so far\n",
      "Episode 257, 562 samples collected so far\n",
      "Episode 258, 563 samples collected so far\n",
      "Episode 259, 569 samples collected so far\n",
      "Episode 260, 574 samples collected so far\n",
      "Episode 261, 575 samples collected so far\n",
      "Episode 262, 578 samples collected so far\n",
      "Episode 263, 583 samples collected so far\n",
      "Episode 264, 584 samples collected so far\n",
      "Episode 265, 588 samples collected so far\n",
      "Episode 266, 590 samples collected so far\n",
      "Episode 267, 590 samples collected so far\n",
      "Episode 268, 590 samples collected so far\n",
      "Episode 269, 592 samples collected so far\n",
      "Episode 270, 594 samples collected so far\n",
      "Episode 271, 597 samples collected so far\n",
      "Episode 272, 599 samples collected so far\n",
      "Episode 273, 600 samples collected so far\n",
      "Episode 274, 603 samples collected so far\n",
      "Episode 275, 606 samples collected so far\n",
      "Episode 276, 607 samples collected so far\n",
      "Episode 277, 608 samples collected so far\n",
      "Episode 278, 619 samples collected so far\n",
      "Episode 279, 619 samples collected so far\n",
      "Episode 280, 621 samples collected so far\n",
      "Episode 281, 623 samples collected so far\n",
      "Episode 282, 624 samples collected so far\n",
      "Episode 283, 624 samples collected so far\n",
      "Episode 284, 624 samples collected so far\n",
      "Episode 285, 627 samples collected so far\n",
      "Episode 286, 628 samples collected so far\n",
      "Episode 287, 630 samples collected so far\n",
      "Episode 288, 630 samples collected so far\n",
      "Episode 289, 632 samples collected so far\n",
      "Episode 290, 633 samples collected so far\n",
      "Episode 291, 633 samples collected so far\n",
      "Episode 292, 633 samples collected so far\n",
      "Episode 293, 633 samples collected so far\n",
      "Episode 294, 633 samples collected so far\n",
      "Episode 295, 634 samples collected so far\n",
      "Episode 296, 638 samples collected so far\n",
      "Episode 297, 639 samples collected so far\n",
      "Episode 298, 639 samples collected so far\n",
      "Episode 299, 640 samples collected so far\n",
      "Episode 300, 642 samples collected so far\n",
      "Episode 301, 642 samples collected so far\n",
      "Episode 302, 644 samples collected so far\n",
      "Episode 303, 647 samples collected so far\n",
      "Episode 304, 650 samples collected so far\n",
      "Episode 305, 650 samples collected so far\n",
      "Episode 306, 651 samples collected so far\n",
      "Episode 307, 652 samples collected so far\n",
      "Episode 308, 657 samples collected so far\n",
      "Episode 309, 657 samples collected so far\n",
      "Episode 310, 660 samples collected so far\n",
      "Episode 311, 660 samples collected so far\n",
      "Episode 312, 664 samples collected so far\n",
      "Episode 313, 664 samples collected so far\n",
      "Episode 314, 666 samples collected so far\n",
      "Episode 315, 666 samples collected so far\n",
      "Episode 316, 666 samples collected so far\n",
      "Episode 317, 673 samples collected so far\n",
      "Episode 318, 676 samples collected so far\n",
      "Episode 319, 677 samples collected so far\n",
      "Episode 320, 678 samples collected so far\n",
      "Episode 321, 679 samples collected so far\n",
      "Episode 322, 680 samples collected so far\n",
      "Episode 323, 680 samples collected so far\n",
      "Episode 324, 686 samples collected so far\n",
      "Episode 325, 687 samples collected so far\n",
      "Episode 326, 687 samples collected so far\n",
      "Episode 327, 688 samples collected so far\n",
      "Episode 328, 688 samples collected so far\n",
      "Episode 329, 688 samples collected so far\n",
      "Episode 330, 688 samples collected so far\n",
      "Episode 331, 689 samples collected so far\n",
      "Episode 332, 692 samples collected so far\n",
      "Episode 333, 692 samples collected so far\n",
      "Episode 334, 694 samples collected so far\n",
      "Episode 335, 694 samples collected so far\n",
      "Episode 336, 694 samples collected so far\n",
      "Episode 337, 694 samples collected so far\n",
      "Episode 338, 696 samples collected so far\n",
      "Episode 339, 696 samples collected so far\n",
      "Episode 340, 699 samples collected so far\n",
      "Episode 341, 700 samples collected so far\n",
      "Episode 342, 700 samples collected so far\n",
      "Episode 343, 701 samples collected so far\n",
      "Episode 344, 704 samples collected so far\n",
      "Episode 345, 705 samples collected so far\n",
      "Episode 346, 706 samples collected so far\n",
      "Episode 347, 706 samples collected so far\n",
      "Episode 348, 713 samples collected so far\n",
      "Episode 349, 713 samples collected so far\n",
      "Episode 350, 715 samples collected so far\n",
      "Episode 351, 718 samples collected so far\n",
      "Episode 352, 718 samples collected so far\n",
      "Episode 353, 731 samples collected so far\n",
      "Episode 354, 732 samples collected so far\n",
      "Episode 355, 733 samples collected so far\n",
      "Episode 356, 737 samples collected so far\n",
      "Episode 357, 738 samples collected so far\n",
      "Episode 358, 738 samples collected so far\n",
      "Episode 359, 743 samples collected so far\n",
      "Episode 360, 743 samples collected so far\n",
      "Episode 361, 743 samples collected so far\n",
      "Episode 362, 746 samples collected so far\n",
      "Episode 363, 746 samples collected so far\n",
      "Episode 364, 746 samples collected so far\n",
      "Episode 365, 746 samples collected so far\n",
      "Episode 366, 746 samples collected so far\n",
      "Episode 367, 746 samples collected so far\n",
      "Episode 368, 747 samples collected so far\n",
      "Episode 369, 750 samples collected so far\n",
      "Episode 370, 750 samples collected so far\n",
      "Episode 371, 753 samples collected so far\n",
      "Episode 372, 756 samples collected so far\n",
      "Episode 373, 759 samples collected so far\n",
      "Episode 374, 759 samples collected so far\n",
      "Episode 375, 759 samples collected so far\n",
      "Episode 376, 759 samples collected so far\n",
      "Episode 377, 759 samples collected so far\n",
      "Episode 378, 759 samples collected so far\n",
      "Episode 379, 760 samples collected so far\n",
      "Episode 380, 760 samples collected so far\n",
      "Episode 381, 761 samples collected so far\n",
      "Episode 382, 762 samples collected so far\n",
      "Episode 383, 765 samples collected so far\n",
      "Episode 384, 765 samples collected so far\n",
      "Episode 385, 766 samples collected so far\n",
      "Episode 386, 766 samples collected so far\n",
      "Episode 387, 773 samples collected so far\n",
      "Episode 388, 773 samples collected so far\n",
      "Episode 389, 774 samples collected so far\n",
      "Episode 390, 775 samples collected so far\n",
      "Episode 391, 778 samples collected so far\n",
      "Episode 392, 779 samples collected so far\n",
      "Episode 393, 779 samples collected so far\n",
      "Episode 394, 782 samples collected so far\n",
      "Episode 395, 782 samples collected so far\n",
      "Episode 396, 782 samples collected so far\n",
      "Episode 397, 802 samples collected so far\n",
      "Episode 398, 804 samples collected so far\n",
      "Episode 399, 804 samples collected so far\n",
      "Episode 400, 805 samples collected so far\n",
      "Episode 401, 808 samples collected so far\n",
      "Episode 402, 809 samples collected so far\n",
      "Episode 403, 813 samples collected so far\n",
      "Episode 404, 813 samples collected so far\n",
      "Episode 405, 816 samples collected so far\n",
      "Episode 406, 817 samples collected so far\n",
      "Episode 407, 817 samples collected so far\n",
      "Episode 408, 818 samples collected so far\n",
      "Episode 409, 829 samples collected so far\n",
      "Episode 410, 829 samples collected so far\n",
      "Episode 411, 829 samples collected so far\n",
      "Episode 412, 829 samples collected so far\n",
      "Episode 413, 831 samples collected so far\n",
      "Episode 414, 832 samples collected so far\n",
      "Episode 415, 835 samples collected so far\n",
      "Episode 416, 842 samples collected so far\n",
      "Episode 417, 844 samples collected so far\n",
      "Episode 418, 844 samples collected so far\n",
      "Episode 419, 846 samples collected so far\n",
      "Episode 420, 848 samples collected so far\n",
      "Episode 421, 848 samples collected so far\n",
      "Episode 422, 853 samples collected so far\n",
      "Episode 423, 854 samples collected so far\n",
      "Episode 424, 855 samples collected so far\n",
      "Episode 425, 857 samples collected so far\n",
      "Episode 426, 857 samples collected so far\n",
      "Episode 427, 858 samples collected so far\n",
      "Episode 428, 873 samples collected so far\n",
      "Episode 429, 876 samples collected so far\n",
      "Episode 430, 879 samples collected so far\n",
      "Episode 431, 879 samples collected so far\n",
      "Episode 432, 879 samples collected so far\n",
      "Episode 433, 879 samples collected so far\n",
      "Episode 434, 880 samples collected so far\n",
      "Episode 435, 881 samples collected so far\n",
      "Episode 436, 893 samples collected so far\n",
      "Episode 437, 893 samples collected so far\n",
      "Episode 438, 893 samples collected so far\n",
      "Episode 439, 897 samples collected so far\n",
      "Episode 440, 899 samples collected so far\n",
      "Episode 441, 900 samples collected so far\n",
      "Episode 442, 901 samples collected so far\n",
      "Episode 443, 903 samples collected so far\n",
      "Episode 444, 905 samples collected so far\n",
      "Episode 445, 910 samples collected so far\n",
      "Episode 446, 915 samples collected so far\n",
      "Episode 447, 915 samples collected so far\n",
      "Episode 448, 916 samples collected so far\n",
      "Episode 449, 916 samples collected so far\n",
      "Episode 450, 916 samples collected so far\n",
      "Episode 451, 916 samples collected so far\n",
      "Episode 452, 916 samples collected so far\n",
      "Episode 453, 919 samples collected so far\n",
      "Episode 454, 919 samples collected so far\n",
      "Episode 455, 920 samples collected so far\n",
      "Episode 456, 930 samples collected so far\n",
      "Episode 457, 930 samples collected so far\n",
      "Episode 458, 934 samples collected so far\n",
      "Episode 459, 935 samples collected so far\n",
      "Episode 460, 938 samples collected so far\n",
      "Episode 461, 938 samples collected so far\n",
      "Episode 462, 942 samples collected so far\n",
      "Episode 463, 942 samples collected so far\n",
      "Episode 464, 947 samples collected so far\n",
      "Episode 465, 955 samples collected so far\n",
      "Episode 466, 955 samples collected so far\n",
      "Episode 467, 955 samples collected so far\n",
      "Episode 468, 956 samples collected so far\n",
      "Episode 469, 956 samples collected so far\n",
      "Episode 470, 958 samples collected so far\n",
      "Episode 471, 962 samples collected so far\n",
      "Episode 472, 965 samples collected so far\n",
      "Episode 473, 965 samples collected so far\n",
      "Episode 474, 965 samples collected so far\n",
      "Episode 475, 966 samples collected so far\n",
      "Episode 476, 966 samples collected so far\n",
      "Episode 477, 968 samples collected so far\n",
      "Episode 478, 969 samples collected so far\n",
      "Episode 479, 969 samples collected so far\n",
      "Episode 480, 970 samples collected so far\n",
      "Episode 481, 973 samples collected so far\n",
      "Episode 482, 973 samples collected so far\n",
      "Episode 483, 989 samples collected so far\n",
      "Episode 484, 990 samples collected so far\n",
      "Episode 485, 992 samples collected so far\n",
      "Episode 486, 992 samples collected so far\n",
      "Episode 487, 993 samples collected so far\n",
      "Episode 488, 1000 samples collected so far\n"
     ]
    }
   ],
   "source": [
    "episode_counter, sample_counter = 0, 0\n",
    "Path('samples/').mkdir(exist_ok=True)\n",
    "\n",
    "# We will solve problems (run episodes) until we have saved enough samples\n",
    "max_samples_reached = False\n",
    "while not max_samples_reached:\n",
    "    episode_counter += 1\n",
    "    \n",
    "    observation, action_set, _, done, _ = env.reset(next(instances))\n",
    "    while not done:\n",
    "        (scores, scores_are_expert), node_observation = observation\n",
    "        action = action_set[scores[action_set].argmax()]\n",
    "\n",
    "        # Only save samples if they are coming from the expert (strong branching)\n",
    "        if scores_are_expert and not max_samples_reached:\n",
    "            sample_counter += 1\n",
    "            data = [node_observation, action, action_set, scores]\n",
    "            filename = f'samples/sample_{sample_counter}.pkl'\n",
    "\n",
    "            with gzip.open(filename, 'wb') as f:\n",
    "                pickle.dump(data, f)\n",
    "            \n",
    "            # If we collected enough samples, we finish the current episode but stop saving samples\n",
    "            if sample_counter == MAX_SAMPLES:\n",
    "                max_samples_reached = True\n",
    "\n",
    "        observation, action_set, _, done, _ = env.step(action)\n",
    "    \n",
    "    print(f\"Episode {episode_counter}, {sample_counter} samples collected so far\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 2. Train a GNN\n",
    "\n",
    "Our next step is to train a GNN classifier on these collected samples to predict similar choices to strong branching."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import torch.nn.functional as F\n",
    "import torch_geometric\n",
    "\n",
    "LEARNING_RATE = 0.001\n",
    "NB_EPOCHS = 50\n",
    "PATIENCE = 10\n",
    "EARLY_STOPPING = 20\n",
    "DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We will first define pytorch geometric data classes to handle the bipartite graph data."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "class BipartiteNodeData(torch_geometric.data.Data):\n",
    "    \"\"\"\n",
    "    This class encode a node bipartite graph observation as returned by the `ecole.observation.NodeBipartite` \n",
    "    observation function in a format understood by the pytorch geometric data handlers.\n",
    "    \"\"\"\n",
    "    def __init__(self, constraint_features, edge_indices, edge_features, variable_features,\n",
    "                 candidates, candidate_choice, candidate_scores):\n",
    "        super().__init__()\n",
    "        self.constraint_features = torch.FloatTensor(constraint_features)\n",
    "        self.edge_index = torch.LongTensor(edge_indices.astype(np.int64))\n",
    "        self.edge_attr = torch.FloatTensor(edge_features).unsqueeze(1)\n",
    "        self.variable_features = torch.FloatTensor(variable_features)\n",
    "        self.candidates = candidates\n",
    "        self.nb_candidates = len(candidates)\n",
    "        self.candidate_choices = candidate_choice\n",
    "        self.candidate_scores = candidate_scores\n",
    "\n",
    "    def __inc__(self, key, value):\n",
    "        \"\"\"\n",
    "        We overload the pytorch geometric method that tells how to increment indices when concatenating graphs \n",
    "        for those entries (edge index, candidates) for which this is not obvious.\n",
    "        \"\"\"\n",
    "        if key == 'edge_index':\n",
    "            return torch.tensor([[self.constraint_features.size(0)], [self.variable_features.size(0)]])\n",
    "        elif key == 'candidates':\n",
    "            return self.variable_features.size(0)\n",
    "        else:\n",
    "            return super().__inc__(key, value)\n",
    "\n",
    "\n",
    "class GraphDataset(torch_geometric.data.Dataset):\n",
    "    \"\"\"\n",
    "    This class encodes a collection of graphs, as well as a method to load such graphs from the disk.\n",
    "    It can be used in turn by the data loaders provided by pytorch geometric.\n",
    "    \"\"\"\n",
    "    def __init__(self, sample_files):\n",
    "        super().__init__(root=None, transform=None, pre_transform=None)\n",
    "        self.sample_files = sample_files\n",
    "\n",
    "    def len(self):\n",
    "        return len(self.sample_files)\n",
    "\n",
    "    def get(self, index):\n",
    "        \"\"\"\n",
    "        This method loads a node bipartite graph observation as saved on the disk during data collection.\n",
    "        \"\"\"\n",
    "        with gzip.open(self.sample_files[index], 'rb') as f:\n",
    "            sample = pickle.load(f)\n",
    "\n",
    "        sample_observation, sample_action, sample_action_set, sample_scores = sample\n",
    "        \n",
    "        # We note on which variables we were allowed to branch, the scores as well as the choice \n",
    "        # taken by strong branching (relative to the candidates)\n",
    "        candidates = torch.LongTensor(np.array(sample_action_set, dtype=np.int32))\n",
    "        candidate_scores = torch.FloatTensor([sample_scores[j] for j in candidates])\n",
    "        candidate_choice = torch.where(candidates == sample_action)[0][0]\n",
    "\n",
    "        graph = BipartiteNodeData(sample_observation.row_features, sample_observation.edge_features.indices, \n",
    "                                  sample_observation.edge_features.values, sample_observation.column_features,\n",
    "                                  candidates, candidate_choice, candidate_scores)\n",
    "        \n",
    "        # We must tell pytorch geometric how many nodes there are, for indexing purposes\n",
    "        graph.num_nodes = sample_observation.row_features.shape[0]+sample_observation.column_features.shape[0]\n",
    "        \n",
    "        return graph"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We can then prepare the data loaders."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "sample_files = [str(path) for path in Path('samples/').glob('sample_*.pkl')]\n",
    "train_files = sample_files[:int(0.8*len(sample_files))]\n",
    "valid_files = sample_files[int(0.8*len(sample_files)):]\n",
    "\n",
    "train_data = GraphDataset(train_files)\n",
    "train_loader = torch_geometric.data.DataLoader(train_data, batch_size=32, shuffle=True)\n",
    "valid_data = GraphDataset(valid_files)\n",
    "valid_loader = torch_geometric.data.DataLoader(valid_data, batch_size=128, shuffle=False)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Next, we will define our graph neural network architecture."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "class GNNPolicy(torch.nn.Module):\n",
    "    def __init__(self):\n",
    "        super().__init__()\n",
    "        emb_size = 64\n",
    "        cons_nfeats = 5\n",
    "        edge_nfeats = 1\n",
    "        var_nfeats = 19\n",
    "\n",
    "        # CONSTRAINT EMBEDDING\n",
    "        self.cons_embedding = torch.nn.Sequential(\n",
    "            torch.nn.LayerNorm(cons_nfeats),\n",
    "            torch.nn.Linear(cons_nfeats, emb_size),\n",
    "            torch.nn.ReLU(),\n",
    "            torch.nn.Linear(emb_size, emb_size),\n",
    "            torch.nn.ReLU(),\n",
    "        )\n",
    "\n",
    "        # EDGE EMBEDDING\n",
    "        self.edge_embedding = torch.nn.Sequential(\n",
    "            torch.nn.LayerNorm(edge_nfeats),\n",
    "        )\n",
    "\n",
    "        # VARIABLE EMBEDDING\n",
    "        self.var_embedding = torch.nn.Sequential(\n",
    "            torch.nn.LayerNorm(var_nfeats),\n",
    "            torch.nn.Linear(var_nfeats, emb_size),\n",
    "            torch.nn.ReLU(),\n",
    "            torch.nn.Linear(emb_size, emb_size),\n",
    "            torch.nn.ReLU(),\n",
    "        )\n",
    "\n",
    "        self.conv_v_to_c = BipartiteGraphConvolution()\n",
    "        self.conv_c_to_v = BipartiteGraphConvolution()\n",
    "\n",
    "        self.output_module = torch.nn.Sequential(\n",
    "            torch.nn.Linear(emb_size, emb_size),\n",
    "            torch.nn.ReLU(),\n",
    "            torch.nn.Linear(emb_size, 1, bias=False),\n",
    "        )\n",
    "\n",
    "    def forward(self, constraint_features, edge_indices, edge_features, variable_features):\n",
    "        reversed_edge_indices = torch.stack([edge_indices[1], edge_indices[0]], dim=0)\n",
    "        \n",
    "        # First step: linear embedding layers to a common dimension (64)\n",
    "        constraint_features = self.cons_embedding(constraint_features)\n",
    "        edge_features = self.edge_embedding(edge_features)\n",
    "        variable_features = self.var_embedding(variable_features)\n",
    "\n",
    "        # Two half convolutions\n",
    "        constraint_features = self.conv_v_to_c(variable_features, reversed_edge_indices, edge_features, constraint_features)\n",
    "        variable_features = self.conv_c_to_v(constraint_features, edge_indices, edge_features, variable_features)\n",
    "\n",
    "        # A final MLP on the variable features\n",
    "        output = self.output_module(variable_features).squeeze(-1)\n",
    "        return output\n",
    "    \n",
    "\n",
    "class BipartiteGraphConvolution(torch_geometric.nn.MessagePassing):\n",
    "    \"\"\"\n",
    "    The bipartite graph convolution is already provided by pytorch geometric and we merely need \n",
    "    to provide the exact form of the messages being passed.\n",
    "    \"\"\"\n",
    "    def __init__(self):\n",
    "        super().__init__('add')\n",
    "        emb_size = 64\n",
    "        \n",
    "        self.feature_module_left = torch.nn.Sequential(\n",
    "            torch.nn.Linear(emb_size, emb_size)\n",
    "        )\n",
    "        self.feature_module_edge = torch.nn.Sequential(\n",
    "            torch.nn.Linear(1, emb_size, bias=False)\n",
    "        )\n",
    "        self.feature_module_right = torch.nn.Sequential(\n",
    "            torch.nn.Linear(emb_size, emb_size, bias=False)\n",
    "        )\n",
    "        self.feature_module_final = torch.nn.Sequential(\n",
    "            torch.nn.LayerNorm(emb_size),\n",
    "            torch.nn.ReLU(),\n",
    "            torch.nn.Linear(emb_size, emb_size)\n",
    "        )\n",
    "        \n",
    "        self.post_conv_module = torch.nn.Sequential(\n",
    "            torch.nn.LayerNorm(emb_size)\n",
    "        )\n",
    "\n",
    "        # output_layers\n",
    "        self.output_module = torch.nn.Sequential(\n",
    "            torch.nn.Linear(2*emb_size, emb_size),\n",
    "            torch.nn.ReLU(),\n",
    "            torch.nn.Linear(emb_size, emb_size),\n",
    "        )\n",
    "\n",
    "    def forward(self, left_features, edge_indices, edge_features, right_features):\n",
    "        \"\"\"\n",
    "        This method sends the messages, computed in the message method.\n",
    "        \"\"\"\n",
    "        output = self.propagate(edge_indices, size=(left_features.shape[0], right_features.shape[0]), \n",
    "                                node_features=(left_features, right_features), edge_features=edge_features)\n",
    "        return self.output_module(torch.cat([self.post_conv_module(output), right_features], dim=-1))\n",
    "\n",
    "    def message(self, node_features_i, node_features_j, edge_features):\n",
    "        output = self.feature_module_final(self.feature_module_left(node_features_i) \n",
    "                                           + self.feature_module_edge(edge_features) \n",
    "                                           + self.feature_module_right(node_features_j))\n",
    "        return output\n",
    "    \n",
    "\n",
    "policy = GNNPolicy().to(DEVICE)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "With this model we can predict a probability distribution over actions as follows."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor([0.0159, 0.0159, 0.0159, 0.0158, 0.0158, 0.0159, 0.0159, 0.0159, 0.0159,\n",
      "        0.0159, 0.0159, 0.0159, 0.0158, 0.0158, 0.0159, 0.0158, 0.0159, 0.0159,\n",
      "        0.0159, 0.0159, 0.0159, 0.0159, 0.0159, 0.0159, 0.0159, 0.0159, 0.0159,\n",
      "        0.0158, 0.0159, 0.0159, 0.0159, 0.0158, 0.0159, 0.0159, 0.0159, 0.0159,\n",
      "        0.0158, 0.0158, 0.0159, 0.0159, 0.0158, 0.0159, 0.0159, 0.0158, 0.0159,\n",
      "        0.0159, 0.0159, 0.0159, 0.0159, 0.0158, 0.0158, 0.0159, 0.0159, 0.0159,\n",
      "        0.0159, 0.0159, 0.0159, 0.0159, 0.0159, 0.0159, 0.0158, 0.0159, 0.0159],\n",
      "       grad_fn=<SoftmaxBackward>)\n"
     ]
    }
   ],
   "source": [
    "observation = train_data[0].to(DEVICE)\n",
    "\n",
    "logits = policy(observation.constraint_features, observation.edge_index, observation.edge_attr, observation.variable_features)\n",
    "action_distribution = F.softmax(logits[observation.candidates], dim=-1)\n",
    "\n",
    "print(action_distribution)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "As can be seen, with randomly initialized weights, the initial distributions tend to be close to uniform.\n",
    "Next, we will define two helper functions: one to train or evaluate the model on a whole epoch and compute metrics for monitoring, and one for padding tensors when doing predictions on a batch of graphs of potentially different number of variables."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [],
   "source": [
    "def process(policy, data_loader, optimizer=None):\n",
    "    \"\"\"\n",
    "    This function will process a whole epoch of training or validation, depending on whether an optimizer is provided.\n",
    "    \"\"\"\n",
    "    mean_loss = 0\n",
    "    mean_acc = 0\n",
    "\n",
    "    n_samples_processed = 0\n",
    "    with torch.set_grad_enabled(optimizer is not None):\n",
    "        for batch in data_loader:\n",
    "            batch = batch.to(DEVICE)\n",
    "            # Compute the logits (i.e. pre-softmax activations) according to the policy on the concatenated graphs\n",
    "            logits = policy(batch.constraint_features, batch.edge_index, batch.edge_attr, batch.variable_features)\n",
    "            # Index the results by the candidates, and split and pad them\n",
    "            logits = pad_tensor(logits[batch.candidates], batch.nb_candidates)\n",
    "            # Compute the usual cross-entropy classification loss\n",
    "            loss = F.cross_entropy(logits, batch.candidate_choices)\n",
    "\n",
    "            if optimizer is not None:\n",
    "                optimizer.zero_grad()\n",
    "                loss.backward()\n",
    "                optimizer.step()\n",
    "\n",
    "            true_scores = pad_tensor(batch.candidate_scores, batch.nb_candidates)\n",
    "            true_bestscore = true_scores.max(dim=-1, keepdims=True).values\n",
    "            \n",
    "            predicted_bestindex = logits.max(dim=-1, keepdims=True).indices\n",
    "            accuracy = (true_scores.gather(-1, predicted_bestindex) == true_bestscore).float().mean().item()\n",
    "\n",
    "            mean_loss += loss.item() * batch.num_graphs\n",
    "            mean_acc += accuracy * batch.num_graphs\n",
    "            n_samples_processed += batch.num_graphs\n",
    "\n",
    "    mean_loss /= n_samples_processed\n",
    "    mean_acc /= n_samples_processed\n",
    "    return mean_loss, mean_acc\n",
    "\n",
    "\n",
    "def pad_tensor(input_, pad_sizes, pad_value=-1e8):\n",
    "    \"\"\"\n",
    "    This utility function splits a tensor and pads each split to make them all the same size, then stacks them.\n",
    "    \"\"\"\n",
    "    max_pad_size = pad_sizes.max()\n",
    "    output = input_.split(pad_sizes.cpu().numpy().tolist())\n",
    "    output = torch.stack([F.pad(slice_, (0, max_pad_size-slice_.size(0)), 'constant', pad_value)\n",
    "                          for slice_ in output], dim=0)\n",
    "    return output"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "After this, we can actually create the model and train it."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 1\n",
      "Train loss: 3.963, accuracy 0.367\n",
      "Valid loss: 3.555, accuracy 0.425\n",
      "Epoch 2\n",
      "Train loss: 3.429, accuracy 0.485\n",
      "Valid loss: 3.412, accuracy 0.440\n",
      "Epoch 3\n",
      "Train loss: 3.373, accuracy 0.490\n",
      "Valid loss: 3.439, accuracy 0.445\n",
      "Epoch 4\n",
      "Train loss: 3.392, accuracy 0.492\n",
      "Valid loss: 3.376, accuracy 0.460\n",
      "Epoch 5\n",
      "Train loss: 3.370, accuracy 0.490\n",
      "Valid loss: 3.435, accuracy 0.425\n",
      "Epoch 6\n",
      "Train loss: 3.351, accuracy 0.496\n",
      "Valid loss: 3.375, accuracy 0.425\n",
      "Epoch 7\n",
      "Train loss: 3.332, accuracy 0.487\n",
      "Valid loss: 3.446, accuracy 0.445\n",
      "Epoch 8\n",
      "Train loss: 3.319, accuracy 0.509\n",
      "Valid loss: 3.419, accuracy 0.430\n",
      "Epoch 9\n",
      "Train loss: 3.336, accuracy 0.501\n",
      "Valid loss: 3.427, accuracy 0.485\n",
      "Epoch 10\n",
      "Train loss: 3.336, accuracy 0.499\n",
      "Valid loss: 3.469, accuracy 0.425\n",
      "Epoch 11\n",
      "Train loss: 3.359, accuracy 0.494\n",
      "Valid loss: 3.381, accuracy 0.430\n",
      "Epoch 12\n",
      "Train loss: 3.332, accuracy 0.501\n",
      "Valid loss: 3.453, accuracy 0.455\n",
      "Epoch 13\n",
      "Train loss: 3.315, accuracy 0.500\n",
      "Valid loss: 3.400, accuracy 0.485\n",
      "Epoch 14\n",
      "Train loss: 3.314, accuracy 0.509\n",
      "Valid loss: 3.375, accuracy 0.465\n",
      "Epoch 15\n",
      "Train loss: 3.331, accuracy 0.497\n",
      "Valid loss: 3.372, accuracy 0.435\n",
      "Epoch 16\n",
      "Train loss: 3.312, accuracy 0.515\n",
      "Valid loss: 3.382, accuracy 0.485\n",
      "Epoch 17\n",
      "Train loss: 3.318, accuracy 0.500\n",
      "Valid loss: 3.409, accuracy 0.490\n",
      "Epoch 18\n",
      "Train loss: 3.320, accuracy 0.500\n",
      "Valid loss: 3.425, accuracy 0.475\n",
      "Epoch 19\n",
      "Train loss: 3.310, accuracy 0.506\n",
      "Valid loss: 3.384, accuracy 0.425\n",
      "Epoch 20\n",
      "Train loss: 3.314, accuracy 0.494\n",
      "Valid loss: 3.425, accuracy 0.440\n",
      "Epoch 21\n",
      "Train loss: 3.282, accuracy 0.509\n",
      "Valid loss: 3.392, accuracy 0.440\n",
      "Epoch 22\n",
      "Train loss: 3.341, accuracy 0.499\n",
      "Valid loss: 3.440, accuracy 0.455\n",
      "Epoch 23\n",
      "Train loss: 3.316, accuracy 0.500\n",
      "Valid loss: 3.389, accuracy 0.470\n",
      "Epoch 24\n",
      "Train loss: 3.287, accuracy 0.515\n",
      "Valid loss: 3.353, accuracy 0.450\n",
      "Epoch 25\n",
      "Train loss: 3.269, accuracy 0.510\n",
      "Valid loss: 3.412, accuracy 0.445\n",
      "Epoch 26\n",
      "Train loss: 3.218, accuracy 0.522\n",
      "Valid loss: 3.299, accuracy 0.490\n",
      "Epoch 27\n",
      "Train loss: 3.284, accuracy 0.497\n",
      "Valid loss: 3.352, accuracy 0.470\n",
      "Epoch 28\n",
      "Train loss: 3.210, accuracy 0.527\n",
      "Valid loss: 3.247, accuracy 0.480\n",
      "Epoch 29\n",
      "Train loss: 3.205, accuracy 0.519\n",
      "Valid loss: 3.326, accuracy 0.470\n",
      "Epoch 30\n",
      "Train loss: 3.200, accuracy 0.512\n",
      "Valid loss: 3.267, accuracy 0.440\n",
      "Epoch 31\n",
      "Train loss: 3.144, accuracy 0.532\n",
      "Valid loss: 3.244, accuracy 0.485\n",
      "Epoch 32\n",
      "Train loss: 3.097, accuracy 0.531\n",
      "Valid loss: 3.241, accuracy 0.465\n",
      "Epoch 33\n",
      "Train loss: 3.103, accuracy 0.512\n",
      "Valid loss: 3.261, accuracy 0.495\n",
      "Epoch 34\n",
      "Train loss: 3.088, accuracy 0.532\n",
      "Valid loss: 3.244, accuracy 0.485\n",
      "Epoch 35\n",
      "Train loss: 3.067, accuracy 0.522\n",
      "Valid loss: 3.268, accuracy 0.420\n",
      "Epoch 36\n",
      "Train loss: 3.034, accuracy 0.521\n",
      "Valid loss: 3.277, accuracy 0.515\n",
      "Epoch 37\n",
      "Train loss: 3.047, accuracy 0.515\n",
      "Valid loss: 3.321, accuracy 0.495\n",
      "Epoch 38\n",
      "Train loss: 3.035, accuracy 0.529\n",
      "Valid loss: 3.241, accuracy 0.455\n",
      "Epoch 39\n",
      "Train loss: 3.047, accuracy 0.519\n",
      "Valid loss: 3.244, accuracy 0.415\n",
      "Epoch 40\n",
      "Train loss: 3.012, accuracy 0.530\n",
      "Valid loss: 3.329, accuracy 0.485\n",
      "Epoch 41\n",
      "Train loss: 3.016, accuracy 0.531\n",
      "Valid loss: 3.221, accuracy 0.500\n",
      "Epoch 42\n",
      "Train loss: 2.974, accuracy 0.519\n",
      "Valid loss: 3.230, accuracy 0.445\n",
      "Epoch 43\n",
      "Train loss: 2.975, accuracy 0.542\n",
      "Valid loss: 3.261, accuracy 0.480\n",
      "Epoch 44\n",
      "Train loss: 3.036, accuracy 0.510\n",
      "Valid loss: 3.308, accuracy 0.495\n",
      "Epoch 45\n",
      "Train loss: 2.979, accuracy 0.526\n",
      "Valid loss: 3.187, accuracy 0.460\n",
      "Epoch 46\n",
      "Train loss: 2.966, accuracy 0.512\n",
      "Valid loss: 3.173, accuracy 0.415\n",
      "Epoch 47\n",
      "Train loss: 2.951, accuracy 0.525\n",
      "Valid loss: 3.143, accuracy 0.465\n",
      "Epoch 48\n",
      "Train loss: 2.956, accuracy 0.525\n",
      "Valid loss: 3.141, accuracy 0.495\n",
      "Epoch 49\n",
      "Train loss: 2.941, accuracy 0.509\n",
      "Valid loss: 3.134, accuracy 0.470\n",
      "Epoch 50\n",
      "Train loss: 2.948, accuracy 0.544\n",
      "Valid loss: 3.429, accuracy 0.355\n"
     ]
    }
   ],
   "source": [
    "optimizer = torch.optim.Adam(policy.parameters(), lr=LEARNING_RATE)\n",
    "for epoch in range(NB_EPOCHS):\n",
    "    print(f\"Epoch {epoch+1}\")\n",
    "    \n",
    "    train_loss, train_acc = process(policy, train_loader, optimizer)\n",
    "    print(f\"Train loss: {train_loss:0.3f}, accuracy {train_acc:0.3f}\" )\n",
    "\n",
    "    valid_loss, valid_acc = process(policy, valid_loader, None)\n",
    "    print(f\"Valid loss: {valid_loss:0.3f}, accuracy {valid_acc:0.3f}\" )\n",
    "\n",
    "torch.save(policy.state_dict(), 'trained_params.pkl')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 3 Evaluation\n",
    "\n",
    "Finally, we can evaluate the performance of the model. We first define appropriate environments. For benchmarking purposes, we include a trivial environment that merely runs SCIP."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [],
   "source": [
    "scip_parameters = {'separating/maxrounds': 0, 'presolving/maxrestarts': 0, 'limits/time': 3600}\n",
    "env = ecole.environment.Branching(observation_function=ecole.observation.NodeBipartite(), \n",
    "                                  information_function={\"nb_nodes\": ecole.reward.NNodes(), \n",
    "                                                        \"time\": ecole.reward.SolvingTime()}, \n",
    "                                  scip_params=scip_parameters)\n",
    "default_env = ecole.environment.Configuring(observation_function=None,\n",
    "                                            information_function={\"nb_nodes\": ecole.reward.NNodes(), \n",
    "                                                                  \"time\": ecole.reward.SolvingTime()}, \n",
    "                                            scip_params=scip_parameters)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Then we can simply follow the environments, taking steps appropriately according to the GNN policy."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Instance   0 | SCIP nb nodes       3  | SCIP time     2.42 \n",
      "             | GNN  nb nodes      47  | GNN  time     3.23 \n",
      "             | Gain         -1466.67% | Gain        -33.69%\n",
      "Instance   1 | SCIP nb nodes      13  | SCIP time     4.20 \n",
      "             | GNN  nb nodes     145  | GNN  time     7.29 \n",
      "             | Gain         -1015.38% | Gain        -73.39%\n",
      "Instance   2 | SCIP nb nodes       1  | SCIP time     0.37 \n",
      "             | GNN  nb nodes       1  | GNN  time     0.32 \n",
      "             | Gain             0.00% | Gain         15.18%\n",
      "Instance   3 | SCIP nb nodes      17  | SCIP time     2.41 \n",
      "             | GNN  nb nodes     123  | GNN  time     4.58 \n",
      "             | Gain          -623.53% | Gain        -89.74%\n",
      "Instance   4 | SCIP nb nodes      11  | SCIP time     2.06 \n",
      "             | GNN  nb nodes      91  | GNN  time     3.33 \n",
      "             | Gain          -727.27% | Gain        -61.42%\n",
      "Instance   5 | SCIP nb nodes       1  | SCIP time     1.16 \n",
      "             | GNN  nb nodes       9  | GNN  time     1.59 \n",
      "             | Gain          -800.00% | Gain        -36.98%\n",
      "Instance   6 | SCIP nb nodes      15  | SCIP time     4.05 \n",
      "             | GNN  nb nodes     151  | GNN  time     7.58 \n",
      "             | Gain          -906.67% | Gain        -87.07%\n",
      "Instance   7 | SCIP nb nodes      13  | SCIP time     2.91 \n",
      "             | GNN  nb nodes     111  | GNN  time     4.66 \n",
      "             | Gain          -753.85% | Gain        -60.40%\n",
      "Instance   8 | SCIP nb nodes       1  | SCIP time     2.09 \n",
      "             | GNN  nb nodes      43  | GNN  time     2.50 \n",
      "             | Gain         -4200.00% | Gain        -19.57%\n",
      "Instance   9 | SCIP nb nodes       1  | SCIP time     0.91 \n",
      "             | GNN  nb nodes       1  | GNN  time     0.87 \n",
      "             | Gain             0.00% | Gain          4.71%\n",
      "Instance  10 | SCIP nb nodes      33  | SCIP time     5.52 \n",
      "             | GNN  nb nodes     263  | GNN  time     8.78 \n",
      "             | Gain          -696.97% | Gain        -59.05%\n",
      "Instance  11 | SCIP nb nodes       1  | SCIP time     0.70 \n",
      "             | GNN  nb nodes       5  | GNN  time     0.78 \n",
      "             | Gain          -400.00% | Gain        -11.78%\n",
      "Instance  12 | SCIP nb nodes       9  | SCIP time     2.53 \n",
      "             | GNN  nb nodes      92  | GNN  time     6.07 \n",
      "             | Gain          -922.22% | Gain       -140.34%\n",
      "Instance  13 | SCIP nb nodes       1  | SCIP time     0.95 \n",
      "             | GNN  nb nodes      11  | GNN  time     0.98 \n",
      "             | Gain         -1000.00% | Gain         -3.75%\n",
      "Instance  14 | SCIP nb nodes       1  | SCIP time     1.19 \n",
      "             | GNN  nb nodes       1  | GNN  time     1.17 \n",
      "             | Gain             0.00% | Gain          1.93%\n",
      "Instance  15 | SCIP nb nodes       1  | SCIP time     1.46 \n",
      "             | GNN  nb nodes       9  | GNN  time     1.56 \n",
      "             | Gain          -800.00% | Gain         -6.69%\n",
      "Instance  16 | SCIP nb nodes       3  | SCIP time     2.30 \n",
      "             | GNN  nb nodes      30  | GNN  time     2.54 \n",
      "             | Gain          -900.00% | Gain        -10.65%\n",
      "Instance  17 | SCIP nb nodes       1  | SCIP time     2.53 \n",
      "             | GNN  nb nodes      21  | GNN  time     2.08 \n",
      "             | Gain         -2000.00% | Gain         17.88%\n",
      "Instance  18 | SCIP nb nodes      10  | SCIP time     2.11 \n",
      "             | GNN  nb nodes      61  | GNN  time     3.16 \n",
      "             | Gain          -510.00% | Gain        -49.80%\n",
      "Instance  19 | SCIP nb nodes      65  | SCIP time     4.95 \n",
      "             | GNN  nb nodes     479  | GNN  time    13.93 \n",
      "             | Gain          -636.92% | Gain       -181.42%\n"
     ]
    }
   ],
   "source": [
    "instances = ecole.instance.SetCoverGenerator(n_rows=500, n_cols=1000, density=0.05)\n",
    "for instance_count, instance in zip(range(20), instances):\n",
    "    # Run the GNN brancher\n",
    "    nb_nodes, time = 0, 0\n",
    "    observation, action_set, _, done, info = env.reset(instance)\n",
    "    nb_nodes += info['nb_nodes']\n",
    "    time += info['time']\n",
    "    while not done:\n",
    "        with torch.no_grad():\n",
    "            observation = (torch.from_numpy(observation.row_features.astype(np.float32)).to(DEVICE),\n",
    "                           torch.from_numpy(observation.edge_features.indices.astype(np.int64)).to(DEVICE), \n",
    "                           torch.from_numpy(observation.edge_features.values.astype(np.float32)).view(-1, 1).to(DEVICE),\n",
    "                           torch.from_numpy(observation.column_features.astype(np.float32)).to(DEVICE))\n",
    "            logits = policy(*observation)\n",
    "            action = action_set[logits[action_set.astype(np.int64)].argmax()]\n",
    "            observation, action_set, _, done, info = env.step(action)\n",
    "        nb_nodes += info['nb_nodes']\n",
    "        time += info['time']\n",
    "\n",
    "    # Run SCIP's default brancher\n",
    "    default_env.reset(instance)\n",
    "    _, _, _, _, default_info = default_env.step({})\n",
    "    \n",
    "    print(f\"Instance {instance_count: >3} | SCIP nb nodes    {int(default_info['nb_nodes']): >4d}  | SCIP time   {default_info['time']: >6.2f} \")\n",
    "    print(f\"             | GNN  nb nodes    {int(nb_nodes): >4d}  | GNN  time   {time: >6.2f} \")\n",
    "    print(f\"             | Gain         {100*(1-nb_nodes/default_info['nb_nodes']): >8.2f}% | Gain      {100*(1-time/default_info['time']): >8.2f}%\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We can also evaluate on instances larger and harder than those trained on, say with 600 rather than 500 constraints.\n",
    "In addition, we showcase that the cumulative number of nodes and time required to solve an instance can also be computed directly using the `.cumsum()` method."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Instance   0 | SCIP nb nodes       7  | SCIP time     3.10 \n",
      "             | GNN  nb nodes      79  | GNN  time     4.27 \n",
      "             | Gain         -1028.57% | Gain        -37.85%\n",
      "Instance   1 | SCIP nb nodes       7  | SCIP time     2.72 \n",
      "             | GNN  nb nodes      59  | GNN  time     3.26 \n",
      "             | Gain          -742.86% | Gain        -19.87%\n",
      "Instance   2 | SCIP nb nodes       3  | SCIP time     2.75 \n",
      "             | GNN  nb nodes      17  | GNN  time     2.35 \n",
      "             | Gain          -466.67% | Gain         14.46%\n",
      "Instance   3 | SCIP nb nodes       9  | SCIP time     2.67 \n",
      "             | GNN  nb nodes      89  | GNN  time     4.16 \n",
      "             | Gain          -888.89% | Gain        -55.91%\n",
      "Instance   4 | SCIP nb nodes      19  | SCIP time     3.64 \n",
      "             | GNN  nb nodes     115  | GNN  time     5.58 \n",
      "             | Gain          -505.26% | Gain        -53.45%\n",
      "Instance   5 | SCIP nb nodes      23  | SCIP time     3.39 \n",
      "             | GNN  nb nodes      83  | GNN  time     5.21 \n",
      "             | Gain          -260.87% | Gain        -53.53%\n",
      "Instance   6 | SCIP nb nodes      21  | SCIP time     2.80 \n",
      "             | GNN  nb nodes     158  | GNN  time     4.83 \n",
      "             | Gain          -652.38% | Gain        -72.61%\n",
      "Instance   7 | SCIP nb nodes       1  | SCIP time     1.39 \n",
      "             | GNN  nb nodes       7  | GNN  time     1.60 \n",
      "             | Gain          -600.00% | Gain        -15.16%\n",
      "Instance   8 | SCIP nb nodes       1  | SCIP time     0.72 \n",
      "             | GNN  nb nodes       1  | GNN  time     0.75 \n",
      "             | Gain             0.00% | Gain         -5.16%\n",
      "Instance   9 | SCIP nb nodes      17  | SCIP time     3.94 \n",
      "             | GNN  nb nodes     113  | GNN  time     3.73 \n",
      "             | Gain          -564.71% | Gain          5.21%\n",
      "Instance  10 | SCIP nb nodes      61  | SCIP time     7.68 \n",
      "             | GNN  nb nodes     493  | GNN  time    18.76 \n",
      "             | Gain          -708.20% | Gain       -144.21%\n",
      "Instance  11 | SCIP nb nodes       1  | SCIP time     1.53 \n",
      "             | GNN  nb nodes      15  | GNN  time     1.77 \n",
      "             | Gain         -1400.00% | Gain        -15.42%\n",
      "Instance  12 | SCIP nb nodes      11  | SCIP time     5.13 \n",
      "             | GNN  nb nodes     153  | GNN  time     8.24 \n",
      "             | Gain         -1290.91% | Gain        -60.47%\n",
      "Instance  13 | SCIP nb nodes      77  | SCIP time     4.20 \n",
      "             | GNN  nb nodes     397  | GNN  time    13.28 \n",
      "             | Gain          -415.58% | Gain       -216.23%\n",
      "Instance  14 | SCIP nb nodes     283  | SCIP time     9.85 \n",
      "             | GNN  nb nodes    2510  | GNN  time    86.71 \n",
      "             | Gain          -786.93% | Gain       -780.63%\n",
      "Instance  15 | SCIP nb nodes      13  | SCIP time     3.90 \n",
      "             | GNN  nb nodes     139  | GNN  time     6.27 \n",
      "             | Gain          -969.23% | Gain        -60.63%\n",
      "Instance  16 | SCIP nb nodes       3  | SCIP time     3.50 \n",
      "             | GNN  nb nodes      61  | GNN  time     4.32 \n",
      "             | Gain         -1933.33% | Gain        -23.65%\n",
      "Instance  17 | SCIP nb nodes      21  | SCIP time     3.46 \n",
      "             | GNN  nb nodes     215  | GNN  time     9.02 \n",
      "             | Gain          -923.81% | Gain       -160.96%\n",
      "Instance  18 | SCIP nb nodes     247  | SCIP time     7.16 \n",
      "             | GNN  nb nodes     523  | GNN  time    20.03 \n",
      "             | Gain          -111.74% | Gain       -179.56%\n",
      "Instance  19 | SCIP nb nodes      21  | SCIP time     3.76 \n",
      "             | GNN  nb nodes     205  | GNN  time     7.50 \n",
      "             | Gain          -876.19% | Gain        -99.36%\n"
     ]
    }
   ],
   "source": [
    "instances = ecole.instance.SetCoverGenerator(n_rows=600, n_cols=1000, density=0.05)\n",
    "scip_parameters = {'separating/maxrounds': 0, 'presolving/maxrestarts': 0, 'limits/time': 3600}\n",
    "env = ecole.environment.Branching(observation_function=ecole.observation.NodeBipartite(), \n",
    "                                  information_function={\"nb_nodes\": ecole.reward.NNodes().cumsum(), \n",
    "                                                        \"time\": ecole.reward.SolvingTime().cumsum()}, \n",
    "                                  scip_params=scip_parameters)\n",
    "default_env = ecole.environment.Configuring(observation_function=None,\n",
    "                                            information_function={\"nb_nodes\": ecole.reward.NNodes().cumsum(), \n",
    "                                                                  \"time\": ecole.reward.SolvingTime().cumsum()}, \n",
    "                                            scip_params=scip_parameters)\n",
    "\n",
    "for instance_count, instance in zip(range(20), instances):\n",
    "    # Run the GNN brancher\n",
    "    observation, action_set, _, done, info = env.reset(instance)\n",
    "    while not done:\n",
    "        with torch.no_grad():\n",
    "            observation = (torch.from_numpy(observation.row_features.astype(np.float32)).to(DEVICE),\n",
    "                           torch.from_numpy(observation.edge_features.indices.astype(np.int64)).to(DEVICE), \n",
    "                           torch.from_numpy(observation.edge_features.values.astype(np.float32)).view(-1, 1).to(DEVICE),\n",
    "                           torch.from_numpy(observation.column_features.astype(np.float32)).to(DEVICE))\n",
    "            logits = policy(*observation)\n",
    "            action = action_set[logits[action_set.astype(np.int64)].argmax()]\n",
    "            observation, action_set, _, done, info = env.step(action)\n",
    "    nb_nodes = info['nb_nodes']\n",
    "    time = info['time']\n",
    "\n",
    "    # Run SCIP's default brancher\n",
    "    default_env.reset(instance)\n",
    "    _, _, _, _, default_info = default_env.step({})\n",
    "\n",
    "    print(f\"Instance {instance_count: >3} | SCIP nb nodes    {int(default_info['nb_nodes']): >4d}  | SCIP time   {default_info['time']: >6.2f} \")\n",
    "    print(f\"             | GNN  nb nodes    {int(nb_nodes): >4d}  | GNN  time   {time: >6.2f} \")\n",
    "    print(f\"             | Gain         {100*(1-nb_nodes/default_info['nb_nodes']): >8.2f}% | Gain      {100*(1-time/default_info['time']): >8.2f}%\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### References\n",
    "\n",
    "Gasse, M., Chételat, D., Ferroni, N., Charlin, L. and Lodi, A. (2019). Exact combinatorial optimization with graph convolutional neural networks. In Advances in Neural Information Processing Systems (pp. 15580-15592)."
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.18"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
