{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import time\n",
    "import torch\n",
    "from torch.distributions import Categorical, kl\n",
    "from d2l.torch import Animator\n",
    "\n",
    "from net import Net\n",
    "from aco import ACO\n",
    "from utils import gen_pyg_data, load_test_dataset\n",
    "\n",
    "torch.manual_seed(12345)\n",
    "\n",
    "EPS = 1e-10\n",
    "device = 'cuda:0'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "@torch.no_grad()\n",
    "def infer_instance(model, pyg_data, distances, n_ants, t_aco_diff, k_sparse=None):\n",
    "    if model:\n",
    "        model.eval()\n",
    "        heu_vec = model(pyg_data)\n",
    "        heu_mat = model.reshape(pyg_data, heu_vec) + EPS\n",
    "    \n",
    "        aco = ACO(\n",
    "        n_ants=n_ants,\n",
    "        heuristic=heu_mat,\n",
    "        distances=distances,\n",
    "        device=device\n",
    "        )\n",
    "    \n",
    "    else:\n",
    "        aco = ACO(\n",
    "        n_ants=n_ants,\n",
    "        distances=distances,\n",
    "        device=device\n",
    "        )\n",
    "        if k_sparse:\n",
    "            aco.sparsify(k_sparse)\n",
    "        \n",
    "    results = torch.zeros(size=(len(t_aco_diff),), device=device)\n",
    "    for i, t in enumerate(t_aco_diff):\n",
    "        best_cost = aco.run(t)\n",
    "        results[i] = best_cost\n",
    "    return results\n",
    "        \n",
    "    \n",
    "@torch.no_grad()\n",
    "def test(dataset, model, n_ants, t_aco, k_sparse=None):\n",
    "    _t_aco = [0] + t_aco\n",
    "    t_aco_diff = [_t_aco[i+1]-_t_aco[i] for i in range(len(_t_aco)-1)]\n",
    "    sum_results = torch.zeros(size=(len(t_aco_diff),), device=device)\n",
    "    start = time.time()\n",
    "    for pyg_data, distances in dataset:\n",
    "        results = infer_instance(model, pyg_data, distances, n_ants, t_aco_diff, k_sparse)\n",
    "        sum_results += results\n",
    "    end = time.time()\n",
    "    \n",
    "    return sum_results / len(dataset), end-start"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Test on TSP20"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "MetaACO"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "total duration:  153.85154390335083\n",
      "T=1, average cost is 3.9320340156555176.\n",
      "T=10, average cost is 3.8233907222747803.\n",
      "T=20, average cost is 3.816669225692749.\n",
      "T=30, average cost is 3.8122851848602295.\n",
      "T=40, average cost is 3.8114993572235107.\n",
      "T=50, average cost is 3.8111512660980225.\n",
      "T=100, average cost is 3.809321165084839.\n"
     ]
    }
   ],
   "source": [
    "n_ants = 20\n",
    "n_node = 20\n",
    "k_sparse = 10\n",
    "t_aco = [1, 10, 20, 30, 40, 50, 100]\n",
    "test_list = load_test_dataset(n_node, k_sparse, device)\n",
    "net_tsp = Net().to(device)\n",
    "net_tsp.load_state_dict(torch.load(f'../pretrained/tsp/tsp{n_node}.pt', map_location=device))\n",
    "avg_aco_best, duration = test(test_list, net_tsp, n_ants, t_aco, k_sparse)\n",
    "print('total duration: ', duration)\n",
    "for i, t in enumerate(t_aco):\n",
    "    print(\"T={}, average cost is {}.\".format(t, avg_aco_best[i]))    "
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "ACO"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "total duration:  146.5802402496338\n",
      "T=1, average cost is 5.254907608032227.\n",
      "T=10, average cost is 4.2066216468811035.\n",
      "T=20, average cost is 4.0416178703308105.\n",
      "T=30, average cost is 3.9487416744232178.\n",
      "T=40, average cost is 3.911342144012451.\n",
      "T=50, average cost is 3.8893744945526123.\n",
      "T=100, average cost is 3.8348400592803955.\n"
     ]
    }
   ],
   "source": [
    "n_ants = 20\n",
    "n_node = 20\n",
    "k_sparse = 10\n",
    "t_aco = [1, 10, 20, 30, 40, 50, 100]\n",
    "test_list = load_test_dataset(n_node, k_sparse, device)\n",
    "avg_aco_best, duration = test(test_list, None, n_ants, t_aco, k_sparse)\n",
    "print('total duration: ', duration)\n",
    "for i, t in enumerate(t_aco):\n",
    "    print(\"T={}, average cost is {}.\".format(t, avg_aco_best[i]))    "
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "base",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.16"
  },
  "orig_nbformat": 4,
  "vscode": {
   "interpreter": {
    "hash": "2f394aca7ca06fed1e6064aef884364492d7cdda3614a461e02e6407fc40ba69"
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
