{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "ename": "SyntaxError",
     "evalue": "invalid syntax. Perhaps you forgot a comma? (functions.py, line 146)",
     "output_type": "error",
     "traceback": [
      "Traceback \u001b[0;36m(most recent call last)\u001b[0m:\n",
      "\u001b[0m  File \u001b[1;32m~/anaconda3/envs/tsp/lib/python3.11/site-packages/IPython/core/interactiveshell.py:3577\u001b[0m in \u001b[1;35mrun_code\u001b[0m\n    exec(code_obj, self.user_global_ns, self.user_ns)\u001b[0m\n",
      "\u001b[0m  Cell \u001b[1;32mIn[1], line 20\u001b[0m\n    from problems.tsp.problem_tsp import TSP\u001b[0m\n",
      "\u001b[0m  File \u001b[1;32m~/learning-tsp/problems/__init__.py:1\u001b[0m\n    from problems.tsp.problem_tsp import TSP, TSPSL\u001b[0m\n",
      "\u001b[0m  File \u001b[1;32m~/learning-tsp/problems/tsp/problem_tsp.py:9\u001b[0m\n    from problems.tsp.state_tsp import StateTSP\u001b[0m\n",
      "\u001b[0m  File \u001b[1;32m~/learning-tsp/problems/tsp/state_tsp.py:3\u001b[0m\n    from utils.boolmask import mask_long2bool, mask_long_scatter\u001b[0m\n",
      "\u001b[0;36m  File \u001b[0;32m~/learning-tsp/utils/__init__.py:1\u001b[0;36m\n\u001b[0;31m    from .functions import *\u001b[0;36m\n",
      "\u001b[0;36m  File \u001b[0;32m~/learning-tsp/utils/functions.py:146\u001b[0;36m\u001b[0m\n\u001b[0;31m    extra_logging=extra_logging\u001b[0m\n\u001b[0m                  ^\u001b[0m\n\u001b[0;31mSyntaxError\u001b[0m\u001b[0;31m:\u001b[0m invalid syntax. Perhaps you forgot a comma?\n"
     ]
    }
   ],
   "source": [
    "import os\n",
    "import json\n",
    "import argparse\n",
    "import pprint as pp\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "import pickle\n",
    "from tqdm import tqdm\n",
    "from datetime import timedelta\n",
    "\n",
    "import networkx as nx\n",
    "from scipy.spatial.distance import pdist, squareform\n",
    "import torch\n",
    "from torch.utils.data import DataLoader\n",
    "\n",
    "import matplotlib\n",
    "import matplotlib.pyplot as plt\n",
    "import seaborn as sns\n",
    "\n",
    "from problems.tsp.problem_tsp import TSP\n",
    "from utils import load_model, move_to\n",
    "from train import set_decode_type\n",
    "\n",
    "from torch.utils.tensorboard import SummaryWriter\n",
    "\n",
    "import warnings\n",
    "warnings.filterwarnings(\"ignore\", category=UserWarning)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "os.environ[\"CUDA_DEVICE_ORDER\"] = \"PCI_BUS_ID\"\n",
    "os.environ[\"CUDA_VISIBLE_DEVICES\"] = str(0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class opts:\n",
    "    dataset_path = \"data/tsp/tsp10-200_concorde.txt\"\n",
    "    batch_size = 16\n",
    "    accumulation_steps = 80\n",
    "    num_samples = 25600 # 1280 samples per TSP size \n",
    "    \n",
    "    neighbors = 0.20\n",
    "    knn_strat = 'percentage'\n",
    "    \n",
    "    # model = \"outputs/tsp_20-50/rl-ar-var-20pnn-gnn-max-ln_20200313T125908\"\n",
    "    model = \"pretrained/tsp_20-50/rl-ar-var-20pnn-gnn-max_20200313T002243\"\n",
    "    \n",
    "#     model = \"outputs/tspsl_20-50/sl-ar-var-20pnn-gnn-sum_20200310T094801\"\n",
    "#     model = \"outputs/tspsl_20-50/sl-ar-var-20pnn-gnn-max_20200308T172931\"\n",
    "#     model = \"outputs/tspsl_20-50/sl-ar-var-20pnn-gnn-mean_20200310T094833\"\n",
    "#     model = \"outputs/tspsl_20-50/sl-ar-var-full-mlp_20200306T182155\"\n",
    "    \n",
    "#     model = \"outputs/tspsl_20-50/sl-ar-var-20pnn-gnn-max-bntrack_20200310T095509\"\n",
    "#     model = \"outputs/tspsl_20-50/sl-ar-var-20pnn-gnn-max-ln_20200310T095955\"\n",
    "\n",
    " #   model = \"outputs/tsp_20-50/rl-ar-var-20pnn-gnn-max-gaggr-sum_20200411T145725\"\n",
    "    \n",
    "    use_cuda = torch.cuda.is_available()\n",
    "    device = torch.device(\"cuda:0\" if use_cuda else \"cpu\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "ename": "KeyError",
     "evalue": "'set_start'",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mKeyError\u001b[0m                                  Traceback (most recent call last)",
      "Cell \u001b[0;32mIn[6], line 1\u001b[0m\n\u001b[0;32m----> 1\u001b[0m model, model_args \u001b[38;5;241m=\u001b[39m \u001b[43mload_model\u001b[49m\u001b[43m(\u001b[49m\u001b[43mopts\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mmodel\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mextra_logging\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43;01mTrue\u001b[39;49;00m\u001b[43m)\u001b[49m\n\u001b[1;32m      2\u001b[0m model\u001b[38;5;241m.\u001b[39mto(opts\u001b[38;5;241m.\u001b[39mdevice)\n\u001b[1;32m      3\u001b[0m set_decode_type(model, \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mgreedy\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n",
      "File \u001b[0;32m~/learning-tsp/utils/functions.py:147\u001b[0m, in \u001b[0;36mload_model\u001b[0;34m(path, epoch, extra_logging)\u001b[0m\n\u001b[1;32m    122\u001b[0m encoder_class \u001b[38;5;241m=\u001b[39m {\n\u001b[1;32m    123\u001b[0m     \u001b[38;5;124m'\u001b[39m\u001b[38;5;124mgnn\u001b[39m\u001b[38;5;124m'\u001b[39m: GNNEncoder,\n\u001b[1;32m    124\u001b[0m     \u001b[38;5;124m'\u001b[39m\u001b[38;5;124mgat\u001b[39m\u001b[38;5;124m'\u001b[39m: GraphAttentionEncoder,\n\u001b[1;32m    125\u001b[0m     \u001b[38;5;124m'\u001b[39m\u001b[38;5;124mmlp\u001b[39m\u001b[38;5;124m'\u001b[39m: MLPEncoder\n\u001b[1;32m    126\u001b[0m }\u001b[38;5;241m.\u001b[39mget(args\u001b[38;5;241m.\u001b[39mget(\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mencoder\u001b[39m\u001b[38;5;124m'\u001b[39m, \u001b[38;5;124m'\u001b[39m\u001b[38;5;124mgnn\u001b[39m\u001b[38;5;124m'\u001b[39m), \u001b[38;5;28;01mNone\u001b[39;00m)\n\u001b[1;32m    127\u001b[0m \u001b[38;5;28;01massert\u001b[39;00m encoder_class \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m, \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mUnknown encoder: \u001b[39m\u001b[38;5;132;01m{}\u001b[39;00m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;241m.\u001b[39mformat(encoder_class)\n\u001b[1;32m    128\u001b[0m model \u001b[38;5;241m=\u001b[39m model_class(\n\u001b[1;32m    129\u001b[0m     problem\u001b[38;5;241m=\u001b[39mproblem,\n\u001b[1;32m    130\u001b[0m     embedding_dim\u001b[38;5;241m=\u001b[39margs[\u001b[38;5;124m'\u001b[39m\u001b[38;5;124membedding_dim\u001b[39m\u001b[38;5;124m'\u001b[39m],\n\u001b[1;32m    131\u001b[0m     encoder_class\u001b[38;5;241m=\u001b[39mencoder_class,\n\u001b[1;32m    132\u001b[0m     n_encode_layers\u001b[38;5;241m=\u001b[39margs[\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mn_encode_layers\u001b[39m\u001b[38;5;124m'\u001b[39m],\n\u001b[1;32m    133\u001b[0m     aggregation\u001b[38;5;241m=\u001b[39margs[\u001b[38;5;124m'\u001b[39m\u001b[38;5;124maggregation\u001b[39m\u001b[38;5;124m'\u001b[39m],\n\u001b[1;32m    134\u001b[0m     aggregation_graph\u001b[38;5;241m=\u001b[39margs[\u001b[38;5;124m'\u001b[39m\u001b[38;5;124maggregation_graph\u001b[39m\u001b[38;5;124m'\u001b[39m],\n\u001b[1;32m    135\u001b[0m     normalization\u001b[38;5;241m=\u001b[39margs[\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mnormalization\u001b[39m\u001b[38;5;124m'\u001b[39m],\n\u001b[1;32m    136\u001b[0m     learn_norm\u001b[38;5;241m=\u001b[39margs[\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mlearn_norm\u001b[39m\u001b[38;5;124m'\u001b[39m],\n\u001b[1;32m    137\u001b[0m     track_norm\u001b[38;5;241m=\u001b[39margs[\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mtrack_norm\u001b[39m\u001b[38;5;124m'\u001b[39m],\n\u001b[1;32m    138\u001b[0m     gated\u001b[38;5;241m=\u001b[39margs[\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mgated\u001b[39m\u001b[38;5;124m'\u001b[39m],\n\u001b[1;32m    139\u001b[0m     n_heads\u001b[38;5;241m=\u001b[39margs[\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mn_heads\u001b[39m\u001b[38;5;124m'\u001b[39m],\n\u001b[1;32m    140\u001b[0m     tanh_clipping\u001b[38;5;241m=\u001b[39margs[\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mtanh_clipping\u001b[39m\u001b[38;5;124m'\u001b[39m],\n\u001b[1;32m    141\u001b[0m     mask_inner\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mTrue\u001b[39;00m,\n\u001b[1;32m    142\u001b[0m     mask_logits\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mTrue\u001b[39;00m,\n\u001b[1;32m    143\u001b[0m     mask_graph\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mFalse\u001b[39;00m,\n\u001b[1;32m    144\u001b[0m     checkpoint_encoder\u001b[38;5;241m=\u001b[39margs[\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mcheckpoint_encoder\u001b[39m\u001b[38;5;124m'\u001b[39m],\n\u001b[1;32m    145\u001b[0m     shrink_size\u001b[38;5;241m=\u001b[39margs[\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mshrink_size\u001b[39m\u001b[38;5;124m'\u001b[39m],\n\u001b[1;32m    146\u001b[0m     extra_logging\u001b[38;5;241m=\u001b[39mextra_logging,\n\u001b[0;32m--> 147\u001b[0m     set_start\u001b[38;5;241m=\u001b[39m\u001b[43margs\u001b[49m\u001b[43m[\u001b[49m\u001b[38;5;124;43m'\u001b[39;49m\u001b[38;5;124;43mset_start\u001b[39;49m\u001b[38;5;124;43m'\u001b[39;49m\u001b[43m]\u001b[49m\n\u001b[1;32m    148\u001b[0m )    \n\u001b[1;32m    150\u001b[0m \u001b[38;5;66;03m# Overwrite model parameters by parameters to load\u001b[39;00m\n\u001b[1;32m    151\u001b[0m load_data \u001b[38;5;241m=\u001b[39m torch_load_cpu(model_filename)\n",
      "\u001b[0;31mKeyError\u001b[0m: 'set_start'"
     ]
    }
   ],
   "source": [
    "model, model_args = load_model(opts.model, extra_logging=True)\n",
    "model.to(opts.device)\n",
    "set_decode_type(model, \"greedy\")\n",
    "model.eval()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "tb_logger = SummaryWriter(os.path.join(\n",
    "    model_args[\"log_dir\"], \"{}_{}-{}\".format(model_args[\"problem\"], model_args[\"min_size\"], model_args[\"max_size\"]), model_args[\"run_name\"]))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "  0%|          | 0/25600 [00:00<?, ?it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Loading from data/tsp/tsp10-200_concorde.txt...\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|##########| 25600/25600 [00:04<00:00, 5215.21it/s]\n"
     ]
    }
   ],
   "source": [
    "dataset = TSP.make_dataset(\n",
    "    filename=opts.dataset_path, batch_size=opts.batch_size, num_samples=opts.num_samples, \n",
    "    neighbors=opts.neighbors, knn_strat=opts.knn_strat, supervised=True\n",
    ")\n",
    "dataloader = DataLoader(dataset, batch_size=opts.batch_size, shuffle=False, num_workers=0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|##########| 1600/1600 [15:55<00:00,  1.67it/s]  \n"
     ]
    }
   ],
   "source": [
    "with torch.no_grad():\n",
    "    graph_embs = None  # Store all graph embeddings for TB-projector \n",
    "    graph_embs_meta = None\n",
    "    \n",
    "    node_embs = None  # Reset after logging\n",
    "    log_p = None\n",
    "    log_p_selected = None\n",
    "    \n",
    "    for bat_idx, bat in enumerate(tqdm(dataloader, ascii=True)):\n",
    "        x = move_to(bat['nodes'], opts.device)\n",
    "        graph = move_to(bat['graph'], opts.device)\n",
    "        cost, ll, pi = model(x, graph, return_pi=True)\n",
    "        \n",
    "        if node_embs is None:\n",
    "            node_embs = model.embeddings_batch.cpu().numpy()\n",
    "        else:\n",
    "            # Append to node embeddings\n",
    "            node_embs = np.concatenate((node_embs, model.embeddings_batch.cpu().numpy()), axis=0)\n",
    "        \n",
    "        if log_p is None:\n",
    "            log_p = model.log_p_batch.cpu().numpy()\n",
    "        else:\n",
    "            # Append to log probabilities\n",
    "            log_p = np.concatenate((log_p, model.log_p_batch.cpu().numpy()), axis=0)\n",
    "        \n",
    "        if log_p_selected is None:\n",
    "            log_p_selected = model.log_p_sel_batch.cpu().numpy()\n",
    "        else:\n",
    "            # Append to log probabilities\n",
    "            log_p_selected = np.concatenate((log_p_selected, model.log_p_sel_batch.cpu().numpy()), axis=0)\n",
    "        \n",
    "        if (bat_idx+1) % opts.accumulation_steps == 0:\n",
    "            if graph_embs is None:\n",
    "                graph_embs = node_embs.mean(1)\n",
    "                graph_embs_meta = [f\"TSP{10* ((bat_idx+1)//opts.accumulation_steps)}\"]*len(node_embs)\n",
    "            else:\n",
    "                graph_embs = np.concatenate((graph_embs, node_embs.mean(1)), axis=0)\n",
    "                graph_embs_meta += [f\"TSP{10* ((bat_idx+1)//opts.accumulation_steps)}\"]*len(node_embs)\n",
    "            \n",
    "            # Log prediction probabilities (for all action and selected actions)\n",
    "            tb_logger.add_histogram('probs', np.exp(log_p.flatten()), 10* ((bat_idx+1)//opts.accumulation_steps))\n",
    "            tb_logger.add_histogram('probs_selected', np.exp(log_p_selected.flatten()), 10* ((bat_idx+1)//opts.accumulation_steps))\n",
    "            \n",
    "            # Log histograms of raw values\n",
    "            tb_logger.add_histogram('emb_values', node_embs.flatten(), 10* ((bat_idx+1)//opts.accumulation_steps))\n",
    "            tb_logger.add_histogram('graph_emb_values', graph_embs.flatten(), 10* ((bat_idx+1)//opts.accumulation_steps))\n",
    "            \n",
    "            # Log histograms of L2 norms\n",
    "            tb_logger.add_histogram('emb_2norm', np.linalg.norm(node_embs, axis=-1).flatten(), 10* ((bat_idx+1)//opts.accumulation_steps))\n",
    "            tb_logger.add_histogram('graph_emb_2norm', np.linalg.norm(graph_embs, axis=-1).flatten(), 10* ((bat_idx+1)//opts.accumulation_steps))\n",
    "            \n",
    "            # Log histogram of distances between node embeddings (within each graph)\n",
    "            node_embs_dists = []\n",
    "            for node_emb in node_embs:\n",
    "                # compute pdist for node embeddings within each graph\n",
    "                node_embs_dists.append(pdist(node_emb, metric='euclidean'))\n",
    "            tb_logger.add_histogram('emb_dist', np.array(node_embs_dists).flatten(), 10* ((bat_idx+1)//opts.accumulation_steps))\n",
    "            \n",
    "            # Log histogram of distances between graph embeddings\n",
    "            graph_embs_dists = pdist(graph_embs, metric='euclidean')\n",
    "            tb_logger.add_histogram('graph_emb_dist', np.array(graph_embs_dists).flatten(), 10* ((bat_idx+1)//opts.accumulation_steps))\n",
    "            \n",
    "            node_embs = None\n",
    "            log_p = None\n",
    "            log_p_selected = None\n",
    "    \n",
    "    # Log graph embeddings to projector\n",
    "    tb_logger.add_embedding(graph_embs, metadata=graph_embs_meta, tag='graph_emb')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.7"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
