{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import time\n",
    "import torch\n",
    "from torch.distributions import Categorical, kl\n",
    "from d2l.torch import Animator\n",
    "\n",
    "from net import Net\n",
    "from aco import ACO\n",
    "from utils import gen_pyg_data, load_test_dataset\n",
    "\n",
    "torch.manual_seed(12345)\n",
    "\n",
    "EPS = 1e-10\n",
    "device = 'cuda:0'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "def infer_instance(model, instance, n_ants, t_aco_diff):\n",
    "    distances, adj_mat, prec_cons = instance\n",
    "    if model:\n",
    "        model.eval()\n",
    "        pyg_data = gen_pyg_data(distances, adj_mat, device)\n",
    "        heu_vec = model(pyg_data)\n",
    "        heu_mat = model.reshape(pyg_data, heu_vec) + EPS\n",
    "        aco = ACO(\n",
    "            distances=distances,\n",
    "            prec_cons=prec_cons,\n",
    "            n_ants=n_ants,\n",
    "            heuristic=heu_mat,\n",
    "            device=device\n",
    "            )\n",
    "    else:\n",
    "        aco = ACO(\n",
    "            distances=distances,\n",
    "            prec_cons=prec_cons,\n",
    "            n_ants=n_ants,\n",
    "            device=device\n",
    "            )\n",
    "    results = torch.zeros(size=(len(t_aco_diff),), device=device)\n",
    "    for i, t in enumerate(t_aco_diff):\n",
    "        best_cost = aco.run(t)\n",
    "        results[i] = best_cost\n",
    "    return results\n",
    "\n",
    "@torch.no_grad()\n",
    "def test(dataset, model, n_ants, t_aco):\n",
    "    _t_aco = [0] + t_aco\n",
    "    t_aco_diff = [_t_aco[i+1]-_t_aco[i] for i in range(len(_t_aco)-1)]\n",
    "    sum_results = torch.zeros(size=(len(t_aco_diff),), device=device)\n",
    "    start = time.time()\n",
    "    for instance in dataset:\n",
    "        results = infer_instance(model, instance, n_ants, t_aco_diff)\n",
    "        sum_results += results\n",
    "    end = time.time()\n",
    "    \n",
    "    return sum_results / len(dataset), end-start"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Test on SOP20"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "MetaACO"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "total duration:  148.10551619529724\n",
      "T=1, average cost is 13.649788856506348.\n",
      "T=10, average cost is 13.322684288024902.\n",
      "T=20, average cost is 13.289615631103516.\n",
      "T=30, average cost is 13.270099639892578.\n",
      "T=40, average cost is 13.267578125.\n",
      "T=50, average cost is 13.252917289733887.\n",
      "T=100, average cost is 13.241530418395996.\n"
     ]
    }
   ],
   "source": [
    "n_ants = 20\n",
    "n_node = 20\n",
    "t_aco = [1, 10, 20, 30, 40, 50, 100]\n",
    "test_list = load_test_dataset(n_node, device)\n",
    "net = Net().to(device)\n",
    "net.load_state_dict(torch.load(f'../pretrained/sop/sop{n_node}.pt', map_location=device))\n",
    "avg_aco_best, duration = test(test_list, net, n_ants, t_aco)\n",
    "print('total duration: ', duration)\n",
    "for i, t in enumerate(t_aco):\n",
    "    print(\"T={}, average cost is {}.\".format(t, avg_aco_best[i]))    "
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "ACO"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "total duration:  146.17993807792664\n",
      "T=1, average cost is 15.694759368896484.\n",
      "T=10, average cost is 14.777204513549805.\n",
      "T=20, average cost is 14.582863807678223.\n",
      "T=30, average cost is 14.48759651184082.\n",
      "T=40, average cost is 14.444717407226562.\n",
      "T=50, average cost is 14.414101600646973.\n",
      "T=100, average cost is 14.376885414123535.\n"
     ]
    }
   ],
   "source": [
    "n_ants = 20\n",
    "n_node = 20\n",
    "t_aco = [1, 10, 20, 30, 40, 50, 100]\n",
    "test_list = load_test_dataset(n_node, device)\n",
    "avg_aco_best, duration = test(test_list, None, n_ants, t_aco)\n",
    "print('total duration: ', duration)\n",
    "for i, t in enumerate(t_aco):\n",
    "    print(\"T={}, average cost is {}.\".format(t, avg_aco_best[i]))    "
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "base",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.16"
  },
  "orig_nbformat": 4,
  "vscode": {
   "interpreter": {
    "hash": "2f394aca7ca06fed1e6064aef884364492d7cdda3614a461e02e6407fc40ba69"
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
