{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "440aef8e-c4ed-4af4-95d7-e01673a6ee61",
   "metadata": {},
   "outputs": [],
   "source": [
    "from generate_instances_lp import generate_setcover\n",
    "import numpy as np"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "id": "87ad9cea-9575-49f1-b7e8-84dd683e1f4e",
   "metadata": {},
   "outputs": [],
   "source": [
    "def surrogate_gen():\n",
    "    nrows = 500\n",
    "    ncols = 500\n",
    "    density = 0.05\n",
    "    nnzrs = int(nrows * ncols * density)\n",
    "    A, b, c = generate_setcover(nrows, ncols, nnzrs, rng)\n",
    "    return A, b, c"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "id": "d9f90ea3-c355-44b8-96ef-ed67e86875af",
   "metadata": {},
   "outputs": [],
   "source": [
    "rng = np.random.RandomState(1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "id": "62e55aba-3f61-49ee-abe3-cb389d92b383",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "A, b, c = surrogate_gen()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 27,
   "id": "625b77b8-7328-4c93-b96c-c370fa84bcc5",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(500, 500)"
      ]
     },
     "execution_count": 27,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "A.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7bf87c16-8cc0-4350-bcc2-cb0fa33b2ab3",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "4ef07b22-5c78-4afa-9fb1-3f7c3611223e",
   "metadata": {},
   "outputs": [],
   "source": [
    "from pyscipopt import Model, Branchrule, SCIP_RESULT"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "c66a91cb-1970-406e-950d-0cd46f8e52c7",
   "metadata": {},
   "outputs": [],
   "source": [
    "import pdb"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "f894742d-721a-4aad-aa68-70f6d811d657",
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "\n",
    "class MostInfBranchRule(Branchrule):\n",
    "\n",
    "    def __init__(self, scip):\n",
    "        self.scip = scip\n",
    "\n",
    "    def branchexeclp(self, allowaddcons):\n",
    "\n",
    "        # Get the branching candidates. Only consider the number of priority candidates (they are sorted to be first)\n",
    "        # The implicit integer candidates in general shouldn't be branched on. Unless specified by the user\n",
    "        # npriocands and ncands are the same (npriocands are variables that have been designated as priorities)\n",
    "        branch_cands, branch_cand_sols, branch_cand_fracs, ncands, npriocands, nimplcands = self.scip.getLPBranchCands()\n",
    "\n",
    "        # Find the variable that is most fractional\n",
    "        best_cand_idx = 0\n",
    "        best_dist = np.inf\n",
    "        for i in range(npriocands):\n",
    "            if abs(branch_cand_fracs[i] - 0.5) <= best_dist:\n",
    "                best_dist = abs(branch_cand_fracs[i] - 0.5)\n",
    "                best_cand_idx = i\n",
    "\n",
    "        # Branch on the variable with the largest score\n",
    "        down_child, eq_child, up_child = self.model.branchVarVal(branch_cands[best_cand_idx], branch_cand_sols[best_cand_idx])\n",
    "\n",
    "        pdb.set_trace()\n",
    "\n",
    "        return {\"result\": SCIP_RESULT.BRANCHED}\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 117,
   "id": "34ec571f-1440-4e5f-90d2-ba78145dfd8b",
   "metadata": {},
   "outputs": [],
   "source": [
    "class StrongBranchingRule(Branchrule):\n",
    "\n",
    "    def __init__(self, scip):\n",
    "        self.scip = scip\n",
    "\n",
    "    def branchexeclp(self, allowaddcons):\n",
    "\n",
    "        branch_cands, branch_cand_sols, branch_cand_fracs, ncands, npriocands, nimplcands = self.scip.getLPBranchCands()\n",
    "\n",
    "        # Initialise scores for each variable\n",
    "        scores = [-self.scip.infinity() for _ in range(npriocands)]\n",
    "        down_bounds = [None for _ in range(npriocands)]\n",
    "        up_bounds = [None for _ in range(npriocands)]\n",
    "\n",
    "        # Initialise placeholder values\n",
    "        num_nodes = self.scip.getNNodes()\n",
    "        lpobjval = self.scip.getLPObjVal()\n",
    "\n",
    "        lperror = False\n",
    "        best_cand_idx = 0\n",
    "\n",
    "\n",
    "        cur_node = self.scip.getCurrentNode()\n",
    "        # conss = cur_node.getAddedConss()\n",
    "\n",
    "        # Start strong branching and iterate over the branching candidates\n",
    "        self.scip.startStrongbranch()\n",
    "        for i in range(npriocands):\n",
    "\n",
    "            # Check the case that the variable has already been strong branched on at this node.\n",
    "            # This case occurs when events happen in the node that should be handled immediately.\n",
    "            # When processing the node again (because the event did not remove it), there's no need to duplicate work.\n",
    "            # if self.scip.getVarStrongbranchNode(branch_cands[i]) == num_nodes:\n",
    "            #     pdb.set_trace()\n",
    "            #     down, up, downvalid, upvalid, _, lastlpobjval = self.scip.getVarStrongbranchLast(branch_cands[i])\n",
    "            #     if downvalid:\n",
    "            #         down_bounds[i] = down\n",
    "            #     if upvalid:\n",
    "            #         up_bounds[i] = up\n",
    "            #     downgain = max([down - lastlpobjval, 0])\n",
    "            #     upgain = max([up - lastlpobjval, 0])\n",
    "            #     scores[i] = self.scip.getBranchScoreMultiple(branch_cands[i], [downgain, upgain])\n",
    "            #     continue\n",
    "\n",
    "            # variables = model.getVars()\n",
    "            # objective_coeffs = [var.getObj() for var in variables]\n",
    "            # assert np.all(np.array(objective_coeffs) == c)\n",
    "\n",
    "\n",
    "            # Strong branch!\n",
    "            down, up, downvalid, upvalid, downinf, upinf, downconflict, upconflict, lperror = self.scip.getVarStrongbranch(\n",
    "                branch_cands[i], 200, idempotent=False)\n",
    "            # down (float) – The dual bound of the LP after branching down on the variable\n",
    "            # up (float) – The dual bound of the LP after branchign up on the variable\n",
    "\n",
    "            # In the case of an LP error handle appropriately (for this example we just break the loop)\n",
    "            if lperror:\n",
    "                break\n",
    "\n",
    "            # In the case of both infeasible sub-problems cutoff the node\n",
    "            if downinf and upinf:\n",
    "                return {\"result\": SCIP_RESULT.CUTOFF}\n",
    "\n",
    "            # Calculate the gains for each up and down node that strong branching explored\n",
    "            if not downinf and downvalid:\n",
    "                down_bounds[i] = down\n",
    "                downgain = max([down - lpobjval, 0])\n",
    "            else:\n",
    "                downgain = 0\n",
    "            if not upinf and upvalid:\n",
    "                up_bounds[i] = up\n",
    "                upgain = max([up - lpobjval, 0])\n",
    "            else:\n",
    "                upgain = 0\n",
    "\n",
    "            # Update the pseudo-costs\n",
    "            lpsol = branch_cands[i].getLPSol()  # solution of this variable\n",
    "\n",
    "            if not downinf and downvalid:\n",
    "                self.scip.updateVarPseudocost(branch_cands[i], -self.scip.frac(lpsol), downgain, 1)\n",
    "            if not upinf and upvalid:\n",
    "                self.scip.updateVarPseudocost(branch_cands[i], 1 - self.scip.frac(lpsol), upgain, 1)\n",
    "\n",
    "            scores[i] = self.scip.getBranchScoreMultiple(branch_cands[i], [downgain, upgain])\n",
    "            if scores[i] > scores[best_cand_idx]:\n",
    "                best_cand_idx = i\n",
    "\n",
    "        # End strong branching\n",
    "        self.scip.endStrongbranch()\n",
    "\n",
    "\n",
    "        # In the case of an LP error\n",
    "        if lperror:\n",
    "            return {\"result\": SCIP_RESULT.DIDNOTRUN}\n",
    "\n",
    "        # Branch on the variable with the largest score\n",
    "        down_child, eq_child, up_child = self.model.branchVarVal(\n",
    "            branch_cands[best_cand_idx], branch_cands[best_cand_idx].getLPSol())\n",
    "\n",
    "        # Update the bounds of the down node and up node. Some cols might not exist due to pricing\n",
    "        if self.scip.allColsInLP():\n",
    "            if down_child is not None and down_bounds[best_cand_idx] is not None:\n",
    "                self.scip.updateNodeLowerbound(down_child, down_bounds[best_cand_idx])\n",
    "            if up_child is not None and up_bounds[best_cand_idx] is not None:\n",
    "                self.scip.updateNodeLowerbound(up_child, up_bounds[best_cand_idx])\n",
    "\n",
    "        return {\"result\": SCIP_RESULT.BRANCHED}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "66c7bc2e-3972-441c-96ff-bb1d3d5fb4b7",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 118,
   "id": "93704303-268a-440f-a3e1-5eda0d789af9",
   "metadata": {},
   "outputs": [],
   "source": [
    "# variables = model.getVars()\n",
    "# constraints = model.getConss()\n",
    "\n",
    "# n_vars = len(variables)\n",
    "# n_cons = len(constraints)\n",
    "\n",
    "# # Initialize A and b\n",
    "# A_ = np.zeros((n_cons, n_vars))\n",
    "# b_ = np.zeros(n_cons)\n",
    "\n",
    "# for i, cons in enumerate(constraints):\n",
    "#     var_coeff_dict = model.getValsLinear(cons)\n",
    "#     for var, coef in var_coeff_dict.items():\n",
    "#         j = int(var.split('_')[-1])\n",
    "#         A_[i, j] = coef\n",
    "    \n",
    "#     # Get the right-hand side (RHS) value\n",
    "#     b_[i] = model.getRhs(cons)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "52e42c91-799d-40dd-b67f-072f90c701a3",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f622b129-707c-43d3-bdba-ef8dac278e40",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 119,
   "id": "c6deb58a-e557-48a2-a2b2-6093d5c44c2a",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "presolving:\n",
      "   (0.0s) symmetry computation started: requiring (bin +, int +, cont +), (fixed: bin -, int -, cont -)\n",
      "[]\n",
      "   (0.0s) no symmetry present (symcode time: 0.00)\n",
      "presolving (0 rounds: 0 fast, 0 medium, 0 exhaustive):\n",
      " 0 deleted vars, 0 deleted constraints, 0 added constraints, 0 tightened bounds, 0 added holes, 0 changed sides, 0 changed coefficients\n",
      " 0 implications, 0 cliques\n",
      "presolved problem has 500 variables (0 bin, 500 int, 0 impl, 0 cont) and 500 constraints\n",
      "    500 constraints of type <linear>\n",
      "Presolving Time: 0.00\n",
      "\n",
      " time | node  | left  |LP iter|LP it/n|mem/heur|mdpt |vars |cons |rows |cuts |sepa|confs|strbr|  dualbound   | primalbound  |  gap   | compl. \n",
      "  0.0s|     1 |     0 |   790 |     - |  6587k |   0 | 500 | 500 | 500 |   0 |  0 |   0 |   0 | 3.929222e+00 |      --      |    Inf | unknown\n",
      "  0.1s|     1 |     0 |   818 |     - |    23M |   0 | 500 | 500 | 501 |   1 |  1 |   0 |   0 | 3.945969e+00 |      --      |    Inf | unknown\n",
      "  0.1s|     1 |     0 |   879 |     - |    37M |   0 | 500 | 500 | 502 |   2 |  2 |   0 |   0 | 3.956411e+00 |      --      |    Inf | unknown\n",
      "  0.1s|     1 |     0 |   921 |     - |    53M |   0 | 500 | 500 | 503 |   3 |  3 |   0 |   0 | 3.960150e+00 |      --      |    Inf | unknown\n",
      "  0.1s|     1 |     0 |   956 |     - |    61M |   0 | 500 | 500 | 504 |   4 |  4 |   0 |   0 | 3.961228e+00 |      --      |    Inf | unknown\n",
      "  0.2s|     1 |     0 |  1016 |     - |    70M |   0 | 500 | 500 | 505 |   5 |  5 |   0 |   0 | 3.964499e+00 |      --      |    Inf | unknown\n",
      "  0.2s|     1 |     0 |  1044 |     - |    86M |   0 | 500 | 500 | 506 |   6 |  6 |   0 |   0 | 3.965780e+00 |      --      |    Inf | unknown\n",
      "  0.2s|     1 |     0 |  1113 |     - |   101M |   0 | 500 | 500 | 507 |   7 |  7 |   0 |   0 | 3.968737e+00 |      --      |    Inf | unknown\n",
      "  0.3s|     1 |     0 |  1159 |     - |   107M |   0 | 500 | 500 | 508 |   8 |  8 |   0 |   0 | 3.971322e+00 |      --      |    Inf | unknown\n",
      "  0.3s|     1 |     0 |  1198 |     - |   121M |   0 | 500 | 500 | 509 |   9 |  9 |   0 |   0 | 3.972788e+00 |      --      |    Inf | unknown\n",
      "  0.3s|     1 |     0 |  1247 |     - |   142M |   0 | 500 | 500 | 510 |  10 | 10 |   0 |   0 | 3.974048e+00 |      --      |    Inf | unknown\n",
      "  0.3s|     1 |     0 |  1259 |     - |   142M |   0 | 500 | 500 | 511 |  11 | 11 |   0 |   0 | 3.974158e+00 |      --      |    Inf | unknown\n",
      "  0.3s|     1 |     0 |  1276 |     - |   142M |   0 | 500 | 500 | 512 |  12 | 12 |   0 |   0 | 3.974333e+00 |      --      |    Inf | unknown\n",
      "  0.3s|     1 |     0 |  1292 |     - |   143M |   0 | 500 | 500 | 513 |  13 | 13 |   0 |   0 | 3.974422e+00 |      --      |    Inf | unknown\n",
      "  0.4s|     1 |     0 |  1316 |     - |   143M |   0 | 500 | 500 | 514 |  14 | 14 |   0 |   0 | 3.974738e+00 |      --      |    Inf | unknown\n",
      " time | node  | left  |LP iter|LP it/n|mem/heur|mdpt |vars |cons |rows |cuts |sepa|confs|strbr|  dualbound   | primalbound  |  gap   | compl. \n",
      "  0.4s|     1 |     0 |  1337 |     - |   143M |   0 | 500 | 500 | 515 |  15 | 15 |   0 |   0 | 3.974826e+00 |      --      |    Inf | unknown\n",
      "  0.4s|     1 |     0 |  1355 |     - |   143M |   0 | 500 | 500 | 516 |  16 | 16 |   0 |   0 | 3.975250e+00 |      --      |    Inf | unknown\n",
      "  0.4s|     1 |     0 |  1392 |     - |   143M |   0 | 500 | 500 | 517 |  17 | 17 |   0 |   0 | 3.975543e+00 |      --      |    Inf | unknown\n",
      "  0.4s|     1 |     0 |  1405 |     - |   143M |   0 | 500 | 500 | 518 |  18 | 18 |   0 |   0 | 3.975625e+00 |      --      |    Inf | unknown\n",
      "  0.4s|     1 |     0 |  1413 |     - |   143M |   0 | 500 | 500 | 519 |  19 | 19 |   0 |   0 | 3.975647e+00 |      --      |    Inf | unknown\n",
      "  0.4s|     1 |     0 |  1433 |     - |   143M |   0 | 500 | 500 | 520 |  20 | 20 |   0 |   0 | 3.975749e+00 |      --      |    Inf | unknown\n",
      "  0.4s|     1 |     0 |  1438 |     - |   143M |   0 | 500 | 500 | 521 |  21 | 21 |   0 |   0 | 3.975766e+00 |      --      |    Inf | unknown\n",
      "  0.4s|     1 |     0 |  1442 |     - |   143M |   0 | 500 | 500 | 522 |  22 | 22 |   0 |   0 | 3.975779e+00 |      --      |    Inf | unknown\n",
      "  0.4s|     1 |     0 |  1448 |     - |   143M |   0 | 500 | 500 | 522 |  23 | 23 |   0 |   0 | 3.975793e+00 |      --      |    Inf | unknown\n",
      "  0.4s|     1 |     0 |  1451 |     - |   143M |   0 | 500 | 500 | 523 |  24 | 24 |   0 |   0 | 3.975795e+00 |      --      |    Inf | unknown\n",
      "  0.5s|     1 |     0 |  1457 |     - |   143M |   0 | 500 | 500 | 524 |  25 | 25 |   0 |   0 | 3.975803e+00 |      --      |    Inf | unknown\n",
      "  0.5s|     1 |     0 |  1458 |     - |   143M |   0 | 500 | 500 | 525 |  26 | 26 |   0 |   0 | 3.975803e+00 |      --      |    Inf | unknown\n",
      "  1.5s|     1 |     2 |  1458 |     - |   143M |   0 | 500 | 500 | 525 |  26 | 26 |   0 | 107 | 3.975803e+00 |      --      |    Inf | unknown\n",
      "[]\n",
      "[]\n",
      "[]\n",
      "[]\n",
      "[]\n",
      "[]\n",
      "[]\n",
      "[]\n",
      "[]\n",
      "[]\n",
      "[]\n",
      "[]\n",
      "[]\n",
      "[]\n",
      "[]\n",
      "* 4.6s|    16 |    11 |  2888 |  95.3 |    LP  |  15 | 500 | 500 | 525 |  33 |  1 |   0 |1011 | 3.985991e+00 | 4.899391e+00 |  22.92%| unknown\n",
      "[]\n",
      "[]\n",
      "[]\n",
      "[]\n",
      "[]\n",
      "[]\n",
      "[]\n",
      "* 7.6s|    25 |    12 |  3838 |  99.2 |    LP  |  15 | 500 | 500 | 521 |  36 |  4 |   0 |1639 | 3.985991e+00 | 4.623130e+00 |  15.98%|   1.39%\n",
      "[]\n",
      "[]\n",
      "[]\n",
      "[]\n",
      "[]\n",
      "[]\n",
      "[]\n",
      " time | node  | left  |LP iter|LP it/n|mem/heur|mdpt |vars |cons |rows |cuts |sepa|confs|strbr|  dualbound   | primalbound  |  gap   | compl. \n",
      "* 9.5s|    33 |    14 |  4502 |  95.1 |    LP  |  15 | 500 | 526 | 518 |  37 |  1 |  26 |2116 | 3.985991e+00 | 4.583952e+00 |  15.00%|   3.54%\n",
      "[]\n",
      "[]\n",
      "[]\n",
      "*10.2s|    38 |    11 |  4883 |  92.6 |    LP  |  15 | 500 | 537 | 518 |  37 |  1 |  37 |2306 | 3.985991e+00 | 4.429726e+00 |  11.13%|   6.16%\n",
      "[]\n",
      "[]\n",
      "*10.3s|    39 |    10 |  4926 |  91.3 |    LP  |  15 | 500 | 537 | 522 |  41 |  3 |  37 |2306 | 3.985991e+00 | 4.427964e+00 |  11.09%|   6.20%\n",
      "[]\n",
      "[]\n",
      "[]\n",
      "[]\n",
      "[]\n",
      "[]\n",
      "[]\n",
      "[]\n",
      "[]\n",
      "[]\n",
      "[]\n",
      "[]\n",
      "[]\n",
      "[]\n",
      "[]\n",
      "[]\n",
      "[]\n",
      "[]\n",
      "[]\n",
      "[]\n",
      "[]\n",
      "[]*18.5s|    71 |     6 |  8366 |  98.7 |    LP  |  15 | 500 | 740 | 519 |  49 |  1 | 240 |3943 | 4.113430e+00 | 4.263190e+00 |   3.64%|  30.15%\n",
      "\n",
      "[]\n",
      "[]\n",
      "[]\n",
      "[]\n",
      "[]\n",
      "[]\n",
      "[]\n",
      "[]\n",
      "\n",
      "SCIP Status        : problem is solved [optimal solution found]\n",
      "Solving Time (sec) : 21.16\n",
      "Solving Nodes      : 87\n",
      "Primal Bound       : +4.26319019743823e+00 (6 solutions)\n",
      "Dual Bound         : +4.26319019743823e+00\n",
      "Gap                : 0.00 %\n"
     ]
    }
   ],
   "source": [
    "from pyscipopt import Model, quicksum\n",
    "from pyscipopt import SCIP_PARAMSETTING\n",
    "\n",
    "# Example input (replace with your own data)\n",
    "# Initialize the SCIP model\n",
    "model = Model(\"Integer_Programming_Example\")\n",
    "model.setPresolve(SCIP_PARAMSETTING.OFF)\n",
    "model.setHeuristics(SCIP_PARAMSETTING.OFF)\n",
    "\n",
    "# Number of variables\n",
    "num_vars = len(c)\n",
    "\n",
    "# Add integer variables to the model\n",
    "x = [model.addVar(vtype=\"I\", lb=0, name=f\"x_{i}\") for i in range(num_vars)]\n",
    "\n",
    "# Set the objective function: min c^T x\n",
    "model.setObjective(quicksum(c[i] * x[i] for i in range(num_vars)), \"minimize\")\n",
    "\n",
    "# Add constraints: Ax <= b\n",
    "for row_idx in range(A.shape[0]):\n",
    "    model.addCons(quicksum(A[row_idx, j] * x[j] for j in range(num_vars)) <= b[row_idx])\n",
    "\n",
    "model.includeBranchrule(StrongBranchingRule(model), \"NNBranch\", \"Branching using NN predictions\", \n",
    "                        priority=1000000, maxdepth=-1,\n",
    "                        maxbounddist=1.)\n",
    "\n",
    "# Optimize the model\n",
    "model.optimize()\n",
    "\n",
    "# Output results\n",
    "# print(\"\\nOptimal Solution:\")\n",
    "# for i, var in enumerate(x):\n",
    "#     print(f\"x_{i} = {model.getVal(var)}\")\n",
    "\n",
    "# print(f\"Optimal Objective Value: {model.getObjVal()}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7eb0ba62-844e-4731-834f-7fb443d1f025",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.11"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
