{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "47986890",
   "metadata": {},
   "source": [
    "## Library"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "6e29e80c",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2024-02-02T03:32:27.675741Z",
     "start_time": "2024-02-02T03:32:18.098770Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\u001b[33mWARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager. It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv\u001b[0m\n",
      "Requirement already satisfied: certifi>=2017.4.17 in /opt/conda/lib/python3.8/site-packages (from requests>=2.19.0->dgl==1.0.1+cu117) (2021.5.30)\n",
      "\u001b[33mWARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager. It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv\u001b[0m\n",
      "Requirement already satisfied: networkx==3.1 in /opt/conda/lib/python3.8/site-packages (3.1)\n"
     ]
    }
   ],
   "source": [
    "! pip install dgl==1.0.1+cu117 -f https://data.dgl.ai/wheels/cu117/repo.html | tail -n 1\n",
    "! pip install networkx==3.1 | tail -n 1"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "09810767",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2024-02-02T03:32:31.469108Z",
     "start_time": "2024-02-02T03:32:27.680407Z"
    }
   },
   "outputs": [],
   "source": [
    "import random\n",
    "import os\n",
    "import copy\n",
    "from collections import OrderedDict, defaultdict\n",
    "from itertools import chain, islice, combinations\n",
    "from time import time\n",
    "from tqdm import tqdm\n",
    "\n",
    "import dgl\n",
    "from dgl.nn.pytorch import GraphConv\n",
    "from dgl.nn.pytorch import SAGEConv\n",
    "\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.nn.functional as F\n",
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "from mpl_toolkits.mplot3d import Axes3D\n",
    "import pandas as pd\n",
    "import seaborn as sns\n",
    "import sklearn \n",
    "from sklearn.preprocessing import StandardScaler\n",
    "from sklearn.decomposition import PCA\n",
    "from networkx.algorithms.approximation.clique import maximum_independent_set as mis\n",
    "import networkx as nx\n",
    "\n",
    "from main import utils\n",
    "from main import gnn\n",
    "from main import instance\n",
    "\n",
    "device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "212a8735",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2024-02-02T03:32:31.482121Z",
     "start_time": "2024-02-02T03:32:31.474548Z"
    }
   },
   "outputs": [],
   "source": [
    "# fix seed\n",
    "SEED=0\n",
    "utils.fix_seed(SEED)\n",
    "torch_type=torch.float32"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "78a34a51",
   "metadata": {},
   "source": [
    "## CRA-PI-GNN"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "ac94ed4f",
   "metadata": {},
   "source": [
    "### Step1:  Load Problem and Set Hyperparameters"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "7b870913",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2024-02-02T03:32:43.735374Z",
     "start_time": "2024-02-02T03:32:31.484280Z"
    }
   },
   "outputs": [],
   "source": [
    "# Graph parameters\n",
    "N, d, p, graph_type = 5000, 20, None, \"reg\"\n",
    "nx_graph = nx.random_regular_graph(d=d, n=N, seed=SEED)\n",
    "dgl_graph = dgl.from_networkx(nx_graph).to(device)\n",
    "Q_mat = utils.qubo_dict_to_torch(nx_graph, instance.gen_q_dict_mis_sym(nx_graph,penalty=2)).to(device)\n",
    "\n",
    "# GNN Architecture\n",
    "in_feats = int(dgl_graph.number_of_nodes()**(0.5))\n",
    "hidden_size=int(in_feats)\n",
    "num_class=1\n",
    "dropout=0.0\n",
    "model=gnn.GCN_dev(in_feats, \n",
    "              hidden_size, \n",
    "              num_class, \n",
    "              dropout, \n",
    "              device).to(device)\n",
    "embedding= nn.Embedding(dgl_graph.number_of_nodes(), \n",
    "                        in_feats\n",
    "                       ).type(torch_type).to(device)\n",
    "\n",
    "# Learning Parameters\n",
    "num_epoch=int(1e+5)\n",
    "lr=1e-4 \n",
    "weight_decay=1e-2\n",
    "tol=1e-4\n",
    "patience=1000\n",
    "vari_param=0\n",
    "init_reg_param=-20\n",
    "annealing_rate=1e-3\n",
    "check_interval=1000\n",
    "curve_rate=2                     "
   ]
  },
  {
   "cell_type": "markdown",
   "id": "00b6d2f7",
   "metadata": {},
   "source": [
    "### Step 2: Define Loss Function"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "8d96569b",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2024-02-02T03:32:43.753330Z",
     "start_time": "2024-02-02T03:32:43.746975Z"
    }
   },
   "outputs": [],
   "source": [
    "def loss(probs, reg_param, curve_rate=2):    \n",
    "    probs_ = torch.unsqueeze(probs, 1)\n",
    "    # cost function\n",
    "    cost = (probs_.T @ Q_mat @ probs_).squeeze()\n",
    "    # annealed term\n",
    "    reg_term = torch.sum(1-(2*probs_-1)**curve_rate)\n",
    "    return cost+reg_param*reg_term , cost, reg_term"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "cc1d6ad5",
   "metadata": {},
   "source": [
    "### Set 3: Run PI-GNN solver"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "0628962e",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2024-02-02T03:35:18.514366Z",
     "start_time": "2024-02-02T03:32:43.758041Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "【START】\n",
      "【TRAIN EPOCH 0】LOSS 38718.742 COST 38718.742 REG 4937.979 PARAM 0.000\n",
      "【TRAIN EPOCH 1000】LOSS 284.547 COST 284.547 REG 733.183 PARAM 0.000\n",
      "【TRAIN EPOCH 2000】LOSS 42.714 COST 42.714 REG 291.718 PARAM 0.000\n",
      "【TRAIN EPOCH 3000】LOSS 14.270 COST 14.270 REG 169.905 PARAM 0.000\n",
      "【TRAIN EPOCH 4000】LOSS 6.171 COST 6.171 REG 112.166 PARAM 0.000\n",
      "【TRAIN EPOCH 5000】LOSS 3.005 COST 3.005 REG 78.453 PARAM 0.000\n",
      "【TRAIN EPOCH 6000】LOSS 1.558 COST 1.558 REG 56.588 PARAM 0.000\n",
      "【TRAIN EPOCH 7000】LOSS 0.838 COST 0.838 REG 41.554 PARAM 0.000\n",
      "【TRAIN EPOCH 8000】LOSS 0.461 COST 0.461 REG 30.850 PARAM 0.000\n",
      "【TRAIN EPOCH 9000】LOSS 0.257 COST 0.257 REG 23.068 PARAM 0.000\n",
      "【TRAIN EPOCH 10000】LOSS 0.145 COST 0.145 REG 17.332 PARAM 0.000\n",
      "【TRAIN EPOCH 11000】LOSS 0.082 COST 0.082 REG 13.067 PARAM 0.000\n",
      "【TRAIN EPOCH 12000】LOSS 0.047 COST 0.047 REG 9.876 PARAM 0.000\n",
      "【TRAIN EPOCH 13000】LOSS 0.027 COST 0.027 REG 7.480 PARAM 0.000\n",
      "【TRAIN EPOCH 14000】LOSS 0.015 COST 0.015 REG 5.674 PARAM 0.000\n",
      "【TRAIN EPOCH 15000】LOSS 0.009 COST 0.009 REG 4.309 PARAM 0.000\n",
      "【TRAIN EPOCH 16000】LOSS 0.005 COST 0.005 REG 3.277 PARAM 0.000\n",
      "【TRAIN EPOCH 17000】LOSS 0.003 COST 0.003 REG 2.494 PARAM 0.000\n",
      "【TRAIN EPOCH 18000】LOSS 0.002 COST 0.002 REG 1.900 PARAM 0.000\n",
      "【TRAIN EPOCH 19000】LOSS 0.001 COST 0.001 REG 1.449 PARAM 0.000\n",
      "【TRAIN EPOCH 20000】LOSS 0.001 COST 0.001 REG 1.105 PARAM 0.000\n",
      "【TRAIN EPOCH 21000】LOSS 0.000 COST 0.000 REG 0.844 PARAM 0.000\n",
      "【TRAIN EPOCH 22000】LOSS 0.000 COST 0.000 REG 0.645 PARAM 0.000\n",
      "【TRAIN EPOCH 23000】LOSS 0.000 COST 0.000 REG 0.493 PARAM 0.000\n",
      "【TRAIN EPOCH 24000】LOSS 0.000 COST 0.000 REG 0.377 PARAM 0.000\n",
      "【TRAIN EPOCH 25000】LOSS 0.000 COST 0.000 REG 0.289 PARAM 0.000\n",
      "Early Stopping 25297\n"
     ]
    }
   ],
   "source": [
    "model, bit_string_PI, cost, reg_term, runtime = gnn.fit_model(model,\n",
    "                                                         dgl_graph, \n",
    "                                                         embedding,\n",
    "                                                         loss,\n",
    "                                                         num_epoch=num_epoch,\n",
    "                                                         lr=lr, \n",
    "                                                         weight_decay=weight_decay,\n",
    "                                                         tol=tol, \n",
    "                                                         patience=patience, \n",
    "                                                         device=device,\n",
    "                                                         annealing=False,\n",
    "                                                         init_reg_param=0,\n",
    "                                                         annealing_rate=0,\n",
    "                                                         check_interval=check_interval,\n",
    "                                                         curve_rate=curve_rate\n",
    "                                                        )"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e3c4317d",
   "metadata": {},
   "source": [
    "### Step 3: Run CRA-PI-GNN solver"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "977dfc07",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2024-02-02T03:37:56.390867Z",
     "start_time": "2024-02-02T03:35:18.516424Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "【START】\n",
      "【TRAIN EPOCH 0】LOSS -5.333 COST 0.000 REG 0.267 PARAM -20.000\n",
      "【TRAIN EPOCH 1000】LOSS -62621.617 COST 21374.662 REG 4420.857 PARAM -19.000\n",
      "【TRAIN EPOCH 2000】LOSS -58284.879 COST 20586.074 REG 4381.720 PARAM -18.000\n",
      "【TRAIN EPOCH 3000】LOSS -53962.191 COST 19744.613 REG 4335.694 PARAM -17.000\n",
      "【TRAIN EPOCH 4000】LOSS -49676.172 COST 18860.611 REG 4283.549 PARAM -16.000\n",
      "【TRAIN EPOCH 5000】LOSS -45436.805 COST 17932.307 REG 4224.607 PARAM -15.000\n",
      "【TRAIN EPOCH 6000】LOSS -41253.844 COST 16956.037 REG 4157.849 PARAM -14.000\n",
      "【TRAIN EPOCH 7000】LOSS -37137.938 COST 15928.104 REG 4082.003 PARAM -13.000\n",
      "【TRAIN EPOCH 8000】LOSS -33100.734 COST 14844.926 REG 3995.472 PARAM -12.000\n",
      "【TRAIN EPOCH 9000】LOSS -29155.059 COST 13703.106 REG 3896.197 PARAM -11.000\n",
      "【TRAIN EPOCH 10000】LOSS -25315.529 COST 12499.744 REG 3781.527 PARAM -10.000\n",
      "【TRAIN EPOCH 11000】LOSS -21599.461 COST 11233.095 REG 3648.062 PARAM -9.000\n",
      "【TRAIN EPOCH 12000】LOSS -18027.859 COST 9903.521 REG 3491.423 PARAM -8.000\n",
      "【TRAIN EPOCH 13000】LOSS -14626.691 COST 8514.641 REG 3305.905 PARAM -7.000\n",
      "【TRAIN EPOCH 14000】LOSS -11428.475 COST 7075.174 REG 3083.941 PARAM -6.000\n",
      "【TRAIN EPOCH 15000】LOSS -8474.523 COST 5602.078 REG 2815.320 PARAM -5.000\n",
      "【TRAIN EPOCH 16000】LOSS -5818.154 COST 4125.034 REG 2485.797 PARAM -4.000\n",
      "【TRAIN EPOCH 17000】LOSS -3529.412 COST 2698.770 REG 2076.061 PARAM -3.000\n",
      "【TRAIN EPOCH 18000】LOSS -2195.724 COST 674.407 REG 1435.065 PARAM -2.000\n",
      "【TRAIN EPOCH 19000】LOSS -1148.299 COST -133.128 REG 1015.171 PARAM -1.000\n",
      "【TRAIN EPOCH 20000】LOSS -839.166 COST -839.166 REG 19.068 PARAM 0.000\n",
      "【TRAIN EPOCH 21000】LOSS -850.262 COST -851.718 REG 1.456 PARAM 1.000\n",
      "【TRAIN EPOCH 22000】LOSS -852.000 COST -852.677 REG 0.338 PARAM 2.000\n",
      "【TRAIN EPOCH 23000】LOSS -852.556 COST -852.889 REG 0.111 PARAM 3.000\n",
      "【TRAIN EPOCH 24000】LOSS -852.784 COST -852.955 REG 0.043 PARAM 4.000\n",
      "【TRAIN EPOCH 25000】LOSS -852.889 COST -852.980 REG 0.018 PARAM 5.000\n",
      "【TRAIN EPOCH 26000】LOSS -852.941 COST -852.991 REG 0.008 PARAM 6.000\n",
      "【TRAIN EPOCH 27000】LOSS -852.968 COST -852.996 REG 0.004 PARAM 7.000\n",
      "【TRAIN EPOCH 28000】LOSS -852.982 COST -852.998 REG 0.002 PARAM 8.000\n",
      "【TRAIN EPOCH 29000】LOSS -852.990 COST -852.999 REG 0.001 PARAM 9.000\n",
      "Early Stopping 29897\n"
     ]
    }
   ],
   "source": [
    "model, bit_string_CRA, cost, reg_term, runtime = gnn.fit_model(model,\n",
    "                                                         dgl_graph, \n",
    "                                                         embedding,\n",
    "                                                         loss,\n",
    "                                                         num_epoch=num_epoch,\n",
    "                                                         lr=lr, \n",
    "                                                         weight_decay=weight_decay,\n",
    "                                                         tol=tol, \n",
    "                                                         patience=patience, \n",
    "                                                         device=device,\n",
    "                                                         annealing=True,\n",
    "                                                         init_reg_param=init_reg_param,\n",
    "                                                         annealing_rate=annealing_rate,\n",
    "                                                         check_interval=check_interval,\n",
    "                                                         curve_rate=curve_rate\n",
    "                                                        )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "14c0ba82",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2024-02-02T03:37:57.760492Z",
     "start_time": "2024-02-02T03:37:56.394601Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Independent set size: (CRA) 853, (PI) 0\n"
     ]
    }
   ],
   "source": [
    "size_mis_CRA, _, number_violation = utils.postprocess_gnn_mis(bit_string_CRA, nx_graph)\n",
    "size_mis_PI, _, number_violation = utils.postprocess_gnn_mis(bit_string_PI, nx_graph)\n",
    "print(f\"Independent set size: (CRA) {size_mis_CRA.item()}, (PI) {size_mis_PI.item()}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3ed0964e",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "009de8e2",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.10"
  },
  "toc": {
   "base_numbering": 1,
   "nav_menu": {},
   "number_sections": false,
   "sideBar": true,
   "skip_h1_title": false,
   "title_cell": "Table of Contents",
   "title_sidebar": "Contents",
   "toc_cell": false,
   "toc_position": {},
   "toc_section_display": true,
   "toc_window_display": false
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
