{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "a95d775b-ad73-4083-bcf8-7a51fb948957",
   "metadata": {},
   "source": [
    "### Playbook\n",
    "#### This file is an interactive notebook version of the script found at forge/train.py"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "fcac3385-cc57-4934-be87-c8dddbb3cc5e",
   "metadata": {},
   "source": [
    "##### Imports"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "509b2fd0-8012-47ed-9e35-d9b437b74e31",
   "metadata": {},
   "outputs": [],
   "source": [
    "import sys\n",
    "sys.path.append('../')\n",
    "\n",
    "from data.processor import DataProcessor\n",
    "from data.collector import Collector\n",
    "from forge import Forge"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "727d0535-6c0f-4bbc-b6ae-d4a5e06a781b",
   "metadata": {},
   "source": [
    "##### Data Generation"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "cb67b157-cc53-426b-b27f-30e6e25369b4",
   "metadata": {},
   "outputs": [],
   "source": [
    "dp = DataProcessor()\n",
    "dp.get_instance_dict(input_path = './data/train_instances/CA-easy',\n",
    "                     output_file = './data/intermediate_files/mips_to_dgl.pkl',\n",
    "                     perturb = [0.05, 0.01],\n",
    "                     return_dict = False)\n",
    "\n",
    "train_list = dp.get_train_list(['./data/intermediate_files/mips_to_dgl.pkl'])\n",
    "\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e46dc1ef-7cd1-4f1a-b64f-33b271d0aab5",
   "metadata": {},
   "source": [
    "##### Unsupervised Pre-Training"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f036889b-f302-46b9-b400-d0efea47b103",
   "metadata": {},
   "outputs": [],
   "source": [
    "model = Forge()\n",
    "model.train_unsupervised(model_save_path = './models/unsupervised_model.pkl',\n",
    "                         train_list = train_list[:15],\n",
    "                         epochs = 1,\n",
    "                         steps_per_instance = 10,\n",
    "                         lr = 1e-4,\n",
    "                         log_path = './data/log/unsupervised_train_log.pkl')\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "fa6489e7-de74-4bce-8ce5-df75bbcc851e",
   "metadata": {},
   "source": [
    "##### Triplet Data Collection \n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "93d92053-eef4-46bd-805c-7c79c7c413bf",
   "metadata": {},
   "outputs": [],
   "source": [
    "dc = Collector()\n",
    "dc.get_triplets(input_path = './data/train_instances/CA-easy/',\n",
    "                gnn_model_path = './models/unsupervised_model.pkl',\n",
    "                output_file = './data/intermediate_files/mips_to_triplet.pkl',\n",
    "                return_dict = False)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "bb1aee6f-8ccc-4b5b-bd90-815ce976bdff",
   "metadata": {},
   "source": [
    "##### Supervised Fine Tuning - Warm Start Prediction "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2e258949-1b49-47c8-b746-718025c9865f",
   "metadata": {},
   "outputs": [],
   "source": [
    "model = Forge(prob_head = True)\n",
    "model.train_triplets(pretrained_path = './models/unsupervised_model.pkl',\n",
    "                     model_save_path = './models/warm_start_model.pkl',\n",
    "                     mips_to_triplet = None,\n",
    "                     mips_to_triplet_path = './data/intermediate_files/mips_to_triplet.pkl',\n",
    "                     epochs = 1,\n",
    "                     steps_per_instance = 1,\n",
    "                     lr = 1e-5,\n",
    "                     batch_size = 1024)\n",
    "\n",
    "\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "3fe3b66c-bb76-4456-a509-e95200b4cb1b",
   "metadata": {},
   "source": [
    "##### Supervised Fine Tuning - Pseudo-Cut Generation via Integrality Gap Prediction  "
   ]
  },
  {
   "cell_type": "markdown",
   "id": "d5e45359-1511-488a-87ab-c5e71cf84f8d",
   "metadata": {},
   "source": [
    "##### Integrality Gap Data Collection"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4a550cb6-e686-4317-a214-ffc1168abb38",
   "metadata": {},
   "outputs": [],
   "source": [
    "dc = Collector()\n",
    "dc.get_cut_ratios(input_path = './data/train_instances/CA-easy/',\n",
    "                  output_file = './data/intermediate_files/mips_to_gaps.pkl',\n",
    "                  return_dict = False)\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e137315d-c656-4455-8338-e2758515e679",
   "metadata": {},
   "source": [
    "##### Train Integrality Gap Prediction"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "60011b7b-5be6-4543-8c4d-39d97b610eec",
   "metadata": {},
   "outputs": [],
   "source": [
    "model = Forge(prob_head = True,\n",
    "              cut_head = True)\n",
    "model.train_lp_gaps(pretrained_path = './models/warm_start_model.pkl',\n",
    "                    model_save_path = './models/lp_gap_model.pkl',\n",
    "                    mips_to_gaps = None,\n",
    "                    mips_to_gaps_path = './data/intermediate_files/mips_to_gaps.pkl',\n",
    "                    epochs = 1,\n",
    "                    steps_per_instance = 1,\n",
    "                    lr = 1e-4)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "1172fadd-dfd2-4924-808b-679367ac4e42",
   "metadata": {},
   "source": [
    "##### Inference"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f82fd794-48fa-44ac-a775-d65dab8c7728",
   "metadata": {},
   "outputs": [],
   "source": [
    "model = Forge(prob_head = False, cut_head = False)\n",
    "model.load_model('./models/unsupervised_model.pkl', model_type = 'none')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c972c9e6-ad41-439c-8033-83078b187c77",
   "metadata": {},
   "outputs": [],
   "source": [
    "mip_vec, emb_of_cons, emb_of_vars = model.mip_to_vector(mip_instance = './data/test/SC-easy/instance_1002.lp',\n",
    "                    gnn_model_path = None)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "996322c4-1623-46d0-b106-e0adf4edfc09",
   "metadata": {},
   "outputs": [],
   "source": [
    "model.mip_to_lp_cut(mip_instance_path = './data/test/SC-easy/instance_1002.lp',\n",
    "                    gnn_model_path = None,\n",
    "                    prob_type = 'SC')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f6920802-ffa4-46d9-b9f5-07fa64310d6a",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "hint_ones, hint_zeros, hint_pri_ones, hint_pri_zeros = model.mip_to_hint(mip_instance_path = './data/test/SC-easy/instance_1002.lp',\n",
    "                                                                          gnn_model_path = None,\n",
    "                                                                          prob_type = 'SC')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5a8042f6-e4bf-4245-bba6-56ceec12f0f2",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.0"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
