{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import time\n",
    "import torch\n",
    "\n",
    "from net import Net\n",
    "from aco import ACO\n",
    "from utils import load_test_dataset\n",
    "\n",
    "torch.manual_seed(12345)\n",
    "\n",
    "EPS = 1e-10\n",
    "device = 'cuda:0'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "def infer_instance(model, instance, n_ants, t_aco_diff):\n",
    "    pyg_data, due_time, weights, processing_time = instance\n",
    "    if model:\n",
    "        model.eval()\n",
    "        heu_vec = model(pyg_data)\n",
    "        heu_mat = model.reshape(pyg_data, heu_vec) + EPS\n",
    "        aco = ACO(\n",
    "            due_time=due_time,\n",
    "            weights=weights,\n",
    "            processing_time=processing_time,\n",
    "            n_ants=n_ants,\n",
    "            heuristic=heu_mat,\n",
    "            device=device\n",
    "            )\n",
    "    else:\n",
    "        aco = ACO(\n",
    "            due_time=due_time,\n",
    "            weights=weights,\n",
    "            processing_time=processing_time,\n",
    "            n_ants=n_ants,\n",
    "            device=device\n",
    "            )\n",
    "    results = torch.zeros(size=(len(t_aco_diff),), device=device)\n",
    "    for i, t in enumerate(t_aco_diff):\n",
    "        best_cost = aco.run(t)\n",
    "        results[i] = best_cost\n",
    "    return results\n",
    "\n",
    "@torch.no_grad()\n",
    "def test(dataset, model, n_ants, t_aco):\n",
    "    _t_aco = [0] + t_aco\n",
    "    t_aco_diff = [_t_aco[i+1]-_t_aco[i] for i in range(len(_t_aco)-1)]\n",
    "    sum_results = torch.zeros(size=(len(t_aco_diff),), device=device)\n",
    "    start = time.time()\n",
    "    for instance in dataset:\n",
    "        results = infer_instance(model, instance, n_ants, t_aco_diff)\n",
    "        sum_results += results\n",
    "    end = time.time()\n",
    "    \n",
    "    return sum_results / len(dataset), end-start"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Test on SMTWTP50"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "MetaACO"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "total duration:  223.717276096344\n",
      "T=1, average cost is 0.2536187767982483.\n",
      "T=10, average cost is 0.23605023324489594.\n",
      "T=20, average cost is 0.2332792729139328.\n",
      "T=30, average cost is 0.23270678520202637.\n",
      "T=40, average cost is 0.23149167001247406.\n",
      "T=50, average cost is 0.2314791977405548.\n",
      "T=100, average cost is 0.2310071438550949.\n"
     ]
    }
   ],
   "source": [
    "n_ants = 20\n",
    "n_node = 50\n",
    "t_aco = [1, 10, 20, 30, 40, 50, 100]\n",
    "test_list = load_test_dataset(n_node, device)\n",
    "net = Net().to(device)\n",
    "net.load_state_dict(torch.load(f'../pretrained/smtwtp/smtwtp{n_node}.pt', map_location=device))\n",
    "avg_aco_best, duration = test(test_list, net, n_ants, t_aco)\n",
    "print('total duration: ', duration)\n",
    "for i, t in enumerate(t_aco):\n",
    "    print(\"T={}, average cost is {}.\".format(t, avg_aco_best[i]))    "
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "ACO"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "total duration:  220.85296893119812\n",
      "T=1, average cost is 5.915566921234131.\n",
      "T=10, average cost is 3.170252561569214.\n",
      "T=20, average cost is 2.5141656398773193.\n",
      "T=30, average cost is 1.8278592824935913.\n",
      "T=40, average cost is 1.519057273864746.\n",
      "T=50, average cost is 1.346314787864685.\n",
      "T=100, average cost is 0.8086017966270447.\n"
     ]
    }
   ],
   "source": [
    "n_ants = 20\n",
    "n_node = 50\n",
    "t_aco = [1, 10, 20, 30, 40, 50, 100]\n",
    "test_list = load_test_dataset(n_node, device)\n",
    "avg_aco_best, duration = test(test_list, None, n_ants, t_aco)\n",
    "print('total duration: ', duration)\n",
    "for i, t in enumerate(t_aco):\n",
    "    print(\"T={}, average cost is {}.\".format(t, avg_aco_best[i]))    "
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Test on SMTWTP100"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "MetaACO"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "total duration:  403.97222924232483\n",
      "T=1, average cost is 0.31544962525367737.\n",
      "T=10, average cost is 0.2729398012161255.\n",
      "T=20, average cost is 0.26319679617881775.\n",
      "T=30, average cost is 0.2606569826602936.\n",
      "T=40, average cost is 0.2584685981273651.\n",
      "T=50, average cost is 0.25688955187797546.\n",
      "T=100, average cost is 0.25367045402526855.\n"
     ]
    }
   ],
   "source": [
    "n_ants = 20\n",
    "n_node = 100\n",
    "t_aco = [1, 10, 20, 30, 40, 50, 100]\n",
    "test_list = load_test_dataset(n_node, device)\n",
    "net = Net().to(device)\n",
    "net.load_state_dict(torch.load(f'../pretrained/smtwtp/smtwtp{n_node}.pt', map_location=device))\n",
    "avg_aco_best, duration = test(test_list, net, n_ants, t_aco)\n",
    "print('total duration: ', duration)\n",
    "for i, t in enumerate(t_aco):\n",
    "    print(\"T={}, average cost is {}.\".format(t, avg_aco_best[i]))    "
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "ACO"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "total duration:  384.3187634944916\n",
      "T=1, average cost is 31.881315231323242.\n",
      "T=10, average cost is 20.74480438232422.\n",
      "T=20, average cost is 18.556779861450195.\n",
      "T=30, average cost is 16.647319793701172.\n",
      "T=40, average cost is 14.073697090148926.\n",
      "T=50, average cost is 12.020098686218262.\n",
      "T=100, average cost is 7.338351249694824.\n"
     ]
    }
   ],
   "source": [
    "n_ants = 20\n",
    "n_node = 100\n",
    "t_aco = [1, 10, 20, 30, 40, 50, 100]\n",
    "test_list = load_test_dataset(n_node, device)\n",
    "avg_aco_best, duration = test(test_list, None, n_ants, t_aco)\n",
    "print('total duration: ', duration)\n",
    "for i, t in enumerate(t_aco):\n",
    "    print(\"T={}, average cost is {}.\".format(t, avg_aco_best[i]))    "
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Test on SMTWTP500"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "MetaACO"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "total duration:  1819.7114448547363\n",
      "T=1, average cost is 1.1985440254211426.\n",
      "T=10, average cost is 0.7949917912483215.\n",
      "T=20, average cost is 0.7033312320709229.\n",
      "T=30, average cost is 0.6819817423820496.\n",
      "T=40, average cost is 0.6571812033653259.\n",
      "T=50, average cost is 0.6348550915718079.\n",
      "T=100, average cost is 0.6172884702682495.\n"
     ]
    }
   ],
   "source": [
    "n_ants = 20\n",
    "n_node = 500\n",
    "t_aco = [1, 10, 20, 30, 40, 50, 100]\n",
    "test_list = load_test_dataset(n_node, device)\n",
    "net = Net().to(device)\n",
    "net.load_state_dict(torch.load(f'../pretrained/smtwtp/smtwtp{n_node}.pt', map_location=device))\n",
    "avg_aco_best, duration = test(test_list, net, n_ants, t_aco)\n",
    "print('total duration: ', duration)\n",
    "for i, t in enumerate(t_aco):\n",
    "    print(\"T={}, average cost is {}.\".format(t, avg_aco_best[i]))"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "ACO"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "total duration:  1804.14484000206\n",
      "T=1, average cost is 1069.06787109375.\n",
      "T=10, average cost is 918.884765625.\n",
      "T=20, average cost is 886.0265502929688.\n",
      "T=30, average cost is 867.5519409179688.\n",
      "T=40, average cost is 852.0975341796875.\n",
      "T=50, average cost is 838.2426147460938.\n",
      "T=100, average cost is 649.79931640625.\n"
     ]
    }
   ],
   "source": [
    "n_ants = 20\n",
    "n_node = 500\n",
    "t_aco = [1, 10, 20, 30, 40, 50, 100]\n",
    "test_list = load_test_dataset(n_node, device)\n",
    "avg_aco_best, duration = test(test_list, None, n_ants, t_aco)\n",
    "print('total duration: ', duration)\n",
    "for i, t in enumerate(t_aco):\n",
    "    print(\"T={}, average cost is {}.\".format(t, avg_aco_best[i]))    "
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "base",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.16"
  },
  "orig_nbformat": 4,
  "vscode": {
   "interpreter": {
    "hash": "2f394aca7ca06fed1e6064aef884364492d7cdda3614a461e02e6407fc40ba69"
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
