{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "18f25ae1-6354-4037-9fe7-39f0cf537af4",
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import os\n",
    "from tqdm import tqdm\n",
    "from scipy.spatial.distance import cdist"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "c430ab15-be3a-4d60-a1c6-8b6824029828",
   "metadata": {},
   "source": [
    "Load problems"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "0ecd2b1d-d7e7-4a9b-8fc1-07569bb36f69",
   "metadata": {},
   "outputs": [],
   "source": [
    "def read_tsp_file(N):\n",
    "    file_path = f\"./default_mcts/tsp{N}_test_concorde.txt\"\n",
    "    with open(file_path, 'r') as file:\n",
    "        lines = file.readlines()\n",
    "    data = []\n",
    "    true_tours = []\n",
    "    for line in lines:\n",
    "        parts = line.strip().split(\" output \")\n",
    "        coords_flat = np.array(parts[0].split(), dtype=np.float32)\n",
    "        tour = np.array(parts[1].split(), dtype=np.int32)[:-1] - 1\n",
    "        data.append(coords_flat[:2*N])\n",
    "        true_tours.append(tour)\n",
    "    data = np.array(data).reshape(-1, N, 2)\n",
    "    true_tours = np.array(true_tours)\n",
    "    return data, true_tours"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "4f358287-7751-4aec-98bf-df638b1a4608",
   "metadata": {},
   "source": [
    "Load heatmaps"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "d52cdb25-baac-493b-a020-26174fa4f6b8",
   "metadata": {},
   "outputs": [],
   "source": [
    "def load_one_heatmap(path):\n",
    "    with open(path, 'r') as f:\n",
    "        for i, line in enumerate(f):\n",
    "            if i == 0:\n",
    "                N = int(line)\n",
    "                h_map = np.zeros((N, N))\n",
    "            else:\n",
    "                h_map[i-1] = list(map(float, line.split()))\n",
    "    return h_map"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "25d9f45b-d978-4df7-a38d-e68144cb4caf",
   "metadata": {},
   "outputs": [],
   "source": [
    "def load_heatmaps(base_dir):\n",
    "    N = int(base_dir.split('/')[-1].split('tsp')[-1])\n",
    "    valid_fns = [fn for fn in os.listdir(base_dir) if 'heatmaptsp' in fn]\n",
    "    n = len(valid_fns)\n",
    "    print (f\"Loading {n} heatmaps of size {N}x{N} from {base_dir}\")\n",
    "    heatmaps = np.zeros((n, N, N))\n",
    "    for fn in tqdm(valid_fns):\n",
    "        i = int(fn.split('_')[-1].split('.txt')[0])\n",
    "        path = os.path.join(base_dir, fn)\n",
    "        heatmaps[i] = load_one_heatmap(path)\n",
    "    return heatmaps"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e419313e-a538-470b-b1ac-ab102422725d",
   "metadata": {},
   "source": [
    "We use the procedure for greedy decoding from https://arxiv.org/pdf/2302.08224, https://arxiv.org/pdf/2206.09012, https://github.com/AlexGraikos/diffusion_priors/blob/97554a9171e268e3fefcd32ecd6839024cea0d3e/tsp/inference.py#L200."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "acd659d4-aa4b-44d3-9270-9a94b356b309",
   "metadata": {},
   "outputs": [],
   "source": [
    "def greedy_decoding(heat_map, dist_mat):\n",
    "    adj_mat = heat_map + heat_map.T\n",
    "    components = np.zeros((adj_mat.shape[0], 2)).astype(int)\n",
    "    components[:] = np.arange(adj_mat.shape[0])[..., None]\n",
    "    real_adj_mat = np.zeros_like(adj_mat)\n",
    "    for edge in tqdm((-adj_mat/dist_mat).flatten().argsort()):\n",
    "        a, b = edge // adj_mat.shape[0], edge % adj_mat.shape[0]\n",
    "        if not (a in components and b in components): continue\n",
    "        ca = np.nonzero((components==a).sum(1))[0][0]\n",
    "        cb = np.nonzero((components==b).sum(1))[0][0]\n",
    "        if ca == cb: continue\n",
    "        cca = sorted(components[ca], key=lambda x:x==a)\n",
    "        ccb = sorted(components[cb], key=lambda x:x==b)\n",
    "        newc = np.array([[cca[0], ccb[0]]])\n",
    "        m, M = min(ca,cb), max(ca,cb)\n",
    "        real_adj_mat[a, b] = 1\n",
    "        components = np.concatenate([components[:m], components[m+1:M], components[M+1:], newc], 0)\n",
    "        if len(components) == 1: break\n",
    "    real_adj_mat[components[0, 1], components[0, 0]] = 1\n",
    "    real_adj_mat += real_adj_mat.T\n",
    "\n",
    "    tour = [0]\n",
    "    while len(tour) < adj_mat.shape[0] + 1:\n",
    "        n = np.nonzero(real_adj_mat[tour[-1]])[0]\n",
    "        if len(tour) > 1:\n",
    "            n = n[n!=tour[-2]]\n",
    "        tour.append(n.max())\n",
    "\n",
    "    return tour"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "8a5a05ba-471b-41b3-8d2a-06d2becb0bcd",
   "metadata": {},
   "outputs": [],
   "source": [
    "def save_greedy_tours(problems, tours, model_name, N):\n",
    "    model_dir = f\"tours/{model_name}\"\n",
    "    os.makedirs(model_dir, exist_ok=True)\n",
    "    filename = os.path.join(model_dir, f\"greedy_tsp{N}.txt\")\n",
    "    with open(filename, 'w') as f:\n",
    "        for i, (p, t) in enumerate(zip(problems, tours)):\n",
    "            for x, y in p:\n",
    "                f.write(f\"{x} {y} \")\n",
    "            f.write(f\"output \")\n",
    "            for j in t:\n",
    "                f.write(f\"{j} \")\n",
    "            f.write(f\"\\n\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "d9f4d12a-f23d-4344-a2a0-d13ad96edeed",
   "metadata": {},
   "outputs": [],
   "source": [
    "def generate_greedy_tours(model_name, N):\n",
    "    # load problems\n",
    "    data, _ = read_tsp_file(N=N)\n",
    "    # load model heatmaps\n",
    "    h_maps_dir = f'./all_heatmap/{model_name}/heatmap/tsp{N}'\n",
    "    h_maps = load_heatmaps(h_maps_dir)\n",
    "    # find greedy decoding tours\n",
    "    tours = []\n",
    "    for i, hm in enumerate(h_maps):\n",
    "        # print (f'Heat Map {i}')\n",
    "        problem = data[i]\n",
    "        dm = cdist(problem, problem)\n",
    "        tour = greedy_decoding(hm, dm)\n",
    "        tour = tour[:-1]\n",
    "        tours.append(tour)\n",
    "    save_greedy_tours(data, tours, model_name, N)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "5e2c2404-69b8-4e0c-af97-64ef0b25ff96",
   "metadata": {},
   "source": [
    "difusco"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "1bc29ed9-0f81-4fa4-9dbb-678c49394cce",
   "metadata": {},
   "outputs": [],
   "source": [
    "generate_greedy_tours('difusco', 500)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "0a50b931-d047-4c27-bf64-80777757c07e",
   "metadata": {},
   "outputs": [],
   "source": [
    "generate_greedy_tours('difusco', 1000)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "a5ea7d78-568c-47a1-a99e-49cacce8707c",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "generate_greedy_tours('difusco', 10000)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "b0e1d98f-da37-4768-8c49-60864aaea9cb",
   "metadata": {},
   "source": [
    "attgcn"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "9fa610ea-2f46-4096-b2b0-6f8e6a7e4993",
   "metadata": {},
   "outputs": [],
   "source": [
    "generate_greedy_tours('attgcn', 500)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "ab322160-0ec1-4770-b835-640eb121fab8",
   "metadata": {},
   "outputs": [],
   "source": [
    "generate_greedy_tours('attgcn', 1000)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "73c05922-280e-4a04-918c-2fdc6031ccca",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "generate_greedy_tours('attgcn', 10000)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "012c74f4-66b4-4dbe-a9f8-d7fc9df9dd7a",
   "metadata": {},
   "source": [
    "dimes"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "fa5376e8-40a3-44eb-b589-aa9ca05137eb",
   "metadata": {},
   "outputs": [],
   "source": [
    "generate_greedy_tours('dimes', 500)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "4a129f9a-4bb1-4128-9896-a49d751f580e",
   "metadata": {},
   "outputs": [],
   "source": [
    "generate_greedy_tours('dimes', 1000)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "9a7405de-9727-48f3-a2da-f5e62de43d42",
   "metadata": {},
   "outputs": [],
   "source": [
    "generate_greedy_tours('dimes', 10000)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "39ea2d40-6787-4e7a-a407-5f3a1befd82e",
   "metadata": {},
   "source": [
    "utsp"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "5bc4a304-7229-4c94-bc68-29118ac47e75",
   "metadata": {},
   "outputs": [],
   "source": [
    "generate_greedy_tours('utsp', 500)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "894f654f-f520-4fa2-aef5-0e684c29675c",
   "metadata": {},
   "outputs": [],
   "source": [
    "generate_greedy_tours('utsp', 1000)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "89adb866-7bc4-404d-9000-0cf80ea93a31",
   "metadata": {},
   "source": [
    "softdist"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "346bd0a7-2820-46e4-84b6-7069c1279900",
   "metadata": {},
   "outputs": [],
   "source": [
    "generate_greedy_tours('softdist', 500)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "id": "b09a3b16-7858-41fd-bd46-4a9a6b9a04a0",
   "metadata": {},
   "outputs": [],
   "source": [
    "generate_greedy_tours('softdist', 1000)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "id": "d860ca17-13cb-4b6e-8a86-9f8fb6bb6fe7",
   "metadata": {},
   "outputs": [],
   "source": [
    "generate_greedy_tours('softdist', 10000)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "683ce2e3-c14f-4290-9a42-e00a2a0004b7",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.13.7"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
