{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 23,
   "metadata": {},
   "outputs": [],
   "source": [
    "import time\n",
    "import torch\n",
    "\n",
    "from rcpsp_inst import load_dataset\n",
    "from net import Net\n",
    "from aco import ACO_RCPSP\n",
    "\n",
    "lr = 1e-3\n",
    "EPS = 1e-10\n",
    "ALPHA = 0.05\n",
    "T=20\n",
    "device = 'cpu'\n",
    "acoparam = dict(elitist = True, min_max = True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "metadata": {},
   "outputs": [],
   "source": [
    "@torch.no_grad()\n",
    "def infer_instance(model, rcpsp, n_ants, t_aco_diff):\n",
    "    if model:\n",
    "        model.eval()\n",
    "        pyg_data = rcpsp.to_pyg_data()\n",
    "        phe_vec, heu_vec = model(pyg_data, require_phe=True, require_heu=True)\n",
    "        heu_mat = model.reshape(pyg_data, heu_vec) + EPS\n",
    "    \n",
    "        aco = ACO_RCPSP(\n",
    "            rcpsp,\n",
    "            n_ants=n_ants,\n",
    "            pheromone=None,\n",
    "            heuristic=heu_mat,\n",
    "            device=device,\n",
    "            **acoparam\n",
    "        )\n",
    "    \n",
    "    else:\n",
    "        aco = ACO_RCPSP(\n",
    "            rcpsp,\n",
    "            n_ants=n_ants,\n",
    "            device=device,\n",
    "            **acoparam\n",
    "        )\n",
    "        \n",
    "    results = torch.zeros(size=(len(t_aco_diff),), device=device)\n",
    "    for i, t in enumerate(t_aco_diff):\n",
    "        best_cost = aco.run(t).cost\n",
    "        results[i] = best_cost\n",
    "    return results\n",
    "        \n",
    "    \n",
    "@torch.no_grad()\n",
    "def test(dataset, model, n_ants, t_aco):\n",
    "    _t_aco = [0] + t_aco\n",
    "    t_aco_diff = [_t_aco[i+1]-_t_aco[i] for i in range(len(_t_aco)-1)]\n",
    "    sum_results = torch.zeros(size=(len(t_aco_diff),), device=device)\n",
    "    start = time.time()\n",
    "    for instance in dataset:\n",
    "        results = infer_instance(model, instance, n_ants, t_aco_diff)\n",
    "        sum_results += results\n",
    "    end = time.time()\n",
    "    \n",
    "    return sum_results / len(dataset), end-start"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Test on j30"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "metadata": {},
   "outputs": [],
   "source": [
    "n_ants = 20\n",
    "n_node = 30\n",
    "t_aco = [1, 10, 20, 30, 40, 50, 100]\n",
    "_, test_list = load_dataset(f\"../data/rcpsp/j{n_node}rcp\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "total duration:  132.94564294815063\n",
      "T=1, average cost is 61.06999969482422.\n",
      "T=10, average cost is 59.02000045776367.\n",
      "T=20, average cost is 58.599998474121094.\n",
      "T=30, average cost is 58.5.\n",
      "T=40, average cost is 58.15999984741211.\n",
      "T=50, average cost is 58.02000045776367.\n",
      "T=100, average cost is 57.75.\n"
     ]
    }
   ],
   "source": [
    "# MetaACO\n",
    "net_tsp = Net().to(device)\n",
    "net_tsp.load_state_dict(torch.load(f'../pretrained/rcpsp/rcpsp{n_node}-5.pt', map_location=device))\n",
    "avg_aco_best, duration = test(test_list, net_tsp, n_ants, t_aco)\n",
    "print('total duration: ', duration)\n",
    "for i, t in enumerate(t_aco):\n",
    "    print(\"T={}, average cost is {}.\".format(t, avg_aco_best[i]))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 27,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "total duration:  132.5256588459015\n",
      "T=1, average cost is 63.880001068115234.\n",
      "T=10, average cost is 59.959999084472656.\n",
      "T=20, average cost is 59.599998474121094.\n",
      "T=30, average cost is 59.25.\n",
      "T=40, average cost is 59.060001373291016.\n",
      "T=50, average cost is 58.970001220703125.\n",
      "T=100, average cost is 58.77000045776367.\n"
     ]
    }
   ],
   "source": [
    "# ACO\n",
    "avg_aco_best, duration = test(test_list, None, n_ants, t_aco)\n",
    "print('total duration: ', duration)\n",
    "for i, t in enumerate(t_aco):\n",
    "    print(\"T={}, average cost is {}.\".format(t, avg_aco_best[i]))"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Test on j60"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 28,
   "metadata": {},
   "outputs": [],
   "source": [
    "n_ants = 20\n",
    "n_node = 60\n",
    "t_aco = [1, 10, 20, 30, 40, 50, 100]\n",
    "_, test_list = load_dataset(f\"../data/rcpsp/j{n_node}rcp\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 31,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "total duration:  261.567898273468\n",
      "T=1, average cost is 95.30000305175781.\n",
      "T=10, average cost is 90.62999725341797.\n",
      "T=20, average cost is 89.33999633789062.\n",
      "T=30, average cost is 88.69000244140625.\n",
      "T=40, average cost is 88.41000366210938.\n",
      "T=50, average cost is 88.0999984741211.\n",
      "T=100, average cost is 87.16999816894531.\n"
     ]
    }
   ],
   "source": [
    "# MetaACO\n",
    "net_tsp = Net().to(device)\n",
    "net_tsp.load_state_dict(torch.load(f'../pretrained/rcpsp/rcpsp{n_node}-5.pt', map_location=device))\n",
    "avg_aco_best, duration = test(test_list, net_tsp, n_ants, t_aco)\n",
    "print('total duration: ', duration)\n",
    "for i, t in enumerate(t_aco):\n",
    "    print(\"T={}, average cost is {}.\".format(t, avg_aco_best[i]))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 32,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "total duration:  261.57174706459045\n",
      "T=1, average cost is 99.33999633789062.\n",
      "T=10, average cost is 93.80000305175781.\n",
      "T=20, average cost is 92.80999755859375.\n",
      "T=30, average cost is 92.16999816894531.\n",
      "T=40, average cost is 91.86000061035156.\n",
      "T=50, average cost is 91.66000366210938.\n",
      "T=100, average cost is 90.77999877929688.\n"
     ]
    }
   ],
   "source": [
    "# ACO\n",
    "avg_aco_best, duration = test(test_list, None, n_ants, t_aco)\n",
    "print('total duration: ', duration)\n",
    "for i, t in enumerate(t_aco):\n",
    "    print(\"T={}, average cost is {}.\".format(t, avg_aco_best[i]))"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Test on j120"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 33,
   "metadata": {},
   "outputs": [],
   "source": [
    "n_ants = 20\n",
    "n_node = 120\n",
    "t_aco = [1, 10, 20, 30, 40, 50, 100]\n",
    "_, test_list = load_dataset(f\"../data/rcpsp/j{n_node}rcp\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 34,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "total duration:  530.4030299186707\n",
      "T=1, average cost is 184.35000610351562.\n",
      "T=10, average cost is 177.24000549316406.\n",
      "T=20, average cost is 175.5.\n",
      "T=30, average cost is 174.55999755859375.\n",
      "T=40, average cost is 174.0500030517578.\n",
      "T=50, average cost is 173.5800018310547.\n",
      "T=100, average cost is 172.0800018310547.\n"
     ]
    }
   ],
   "source": [
    "# MetaACO\n",
    "net_tsp = Net().to(device)\n",
    "net_tsp.load_state_dict(torch.load(f'../pretrained/rcpsp/rcpsp{n_node}-5.pt', map_location=device))\n",
    "avg_aco_best, duration = test(test_list, net_tsp, n_ants, t_aco)\n",
    "print('total duration: ', duration)\n",
    "for i, t in enumerate(t_aco):\n",
    "    print(\"T={}, average cost is {}.\".format(t, avg_aco_best[i]))    "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 35,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "total duration:  530.2516784667969\n",
      "T=1, average cost is 193.77999877929688.\n",
      "T=10, average cost is 185.6199951171875.\n",
      "T=20, average cost is 183.3800048828125.\n",
      "T=30, average cost is 182.52999877929688.\n",
      "T=40, average cost is 181.75.\n",
      "T=50, average cost is 181.2899932861328.\n",
      "T=100, average cost is 179.64999389648438.\n"
     ]
    }
   ],
   "source": [
    "# ACO\n",
    "avg_aco_best, duration = test(test_list, None, n_ants, t_aco)\n",
    "print('total duration: ', duration)\n",
    "for i, t in enumerate(t_aco):\n",
    "    print(\"T={}, average cost is {}.\".format(t, avg_aco_best[i]))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "torch",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.13"
  },
  "orig_nbformat": 4,
  "vscode": {
   "interpreter": {
    "hash": "a93e3d9460341b3566123144586be69108c80018542c7977bec35f4a26a80b82"
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
