{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "ef6dd231-5d33-42ce-86b7-60fbf787f76d",
   "metadata": {},
   "outputs": [],
   "source": [
    "import random\n",
    "import os\n",
    "import numpy as np\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.optim as optim\n",
    "from collections import deque\n",
    "from tqdm import trange\n",
    "import matplotlib.pyplot as plt\n",
    "from RTD_Lite_TSP import RTD_Lite, prim_algo\n",
    "\n",
    "import csv\n",
    "import json\n",
    "import pandas as pd\n",
    "\n",
    "import torch\n",
    "import numpy as np\n",
    "import scipy\n",
    "import matplotlib.pyplot as plt\n",
    "from tqdm import trange\n",
    "\n",
    "import numpy as np\n",
    "import tsplib95\n",
    "import scipy.spatial\n",
    "\n",
    "os.environ[\"CUDA_VISIBLE_DEVICES\"] = \"3\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "99bc389f-9ad0-43df-a065-c74868e93392",
   "metadata": {},
   "outputs": [],
   "source": [
    "def create_problem(seed, N):\n",
    "    np.random.seed(seed)\n",
    "    dim = 2\n",
    "    cities = np.random.random((N, dim))\n",
    "    distance_matrix = scipy.spatial.distance.cdist(cities, cities)\n",
    "\n",
    "    return cities, distance_matrix\n",
    "\n",
    "\n",
    "def create_problem_cluster(seed, N, n_clusters):\n",
    "\n",
    "    city_list = []\n",
    "    cluster_radius = 0.05  \n",
    "    margin = 0.1 \n",
    "\n",
    "    centers = []\n",
    "    attempts = 0\n",
    "    max_attempts = 1000\n",
    "\n",
    "    while len(centers) < n_clusters and attempts < max_attempts:\n",
    "        candidate = 0.1 + 0.8 * np.random.rand(2) \n",
    "        too_close = False\n",
    "        for c in centers:\n",
    "            if np.linalg.norm(candidate - c) < (2 * cluster_radius + margin):\n",
    "                too_close = True\n",
    "                break\n",
    "        if not too_close:\n",
    "            centers.append(candidate)\n",
    "        attempts += 1\n",
    "\n",
    "    if len(centers) < n_clusters:\n",
    "        raise ValueError(\"Не удалось разместить кластеры без пересечений\")\n",
    "\n",
    "    points_per_cluster = [N // n_clusters] * n_clusters\n",
    "    points_per_cluster[-1] += N - sum(points_per_cluster)  \n",
    "\n",
    "    for i in range(n_clusters):\n",
    "        center = centers[i]\n",
    "        cov = (cluster_radius ** 2) * np.eye(2)\n",
    "        cnt = 0\n",
    "        cluster_points = []\n",
    "        while cnt < points_per_cluster[i]:\n",
    "            z = np.random.multivariate_normal(center, cov)\n",
    "            if 0 <= z[0] <= 1 and 0 <= z[1] <= 1:\n",
    "                cluster_points.append(z)\n",
    "                cnt += 1\n",
    "        city_list.extend(cluster_points)\n",
    "\n",
    "    cities = np.array(city_list)\n",
    "    distance_matrix = scipy.spatial.distance.cdist(cities, cities)\n",
    "    return cities, distance_matrix\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "802a21ee-3905-4c73-b21d-eda7b06039b3",
   "metadata": {},
   "source": [
    "## Base classes"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1741f71b-c4af-408e-bf16-df4ac3ad31f0",
   "metadata": {},
   "outputs": [],
   "source": [
    "# TSP Environment Class\n",
    "class TSPEnv:\n",
    "    def __init__(self, num_cities, coords=None, start_city = 0, end_city = None, clusters = False, n_clusters=None):\n",
    "        self.num_cities = num_cities\n",
    "        if not clusters:\n",
    "            self.coords = (coords.astype(np.float32) if coords is not None \n",
    "                           else np.random.rand(num_cities, 2).astype(np.float32))\n",
    "            self.dist_matrix = self._compute_dist_matrix().astype(np.float32)\n",
    "        else:\n",
    "            # self.coords, self.dist_matrix = create_problem_cluster(np.random.randint(1), num_cities, n_clusters)\n",
    "            self.coords, self.dist_matrix = create_problem(np.random.randint(1), num_cities)\n",
    "        self.start_city = start_city\n",
    "        if end_city:\n",
    "            self.end_city = end_city\n",
    "        else:\n",
    "            self.end_city = num_cities - 1\n",
    "            \n",
    "        self.reset()\n",
    "\n",
    "    def _compute_dist_matrix(self):\n",
    "        d = np.linalg.norm(self.coords[None, :, :] - self.coords[:, None, :], axis=-1)\n",
    "        return d.astype(np.float32)\n",
    "\n",
    "    def reset(self):\n",
    "        self.current_city = self.start_city\n",
    "        self.visited = [self.start_city, self.end_city]\n",
    "        self.available = set(range(0, self.num_cities))\n",
    "        self.available.remove(self.start_city)\n",
    "        self.available.remove(self.end_city)\n",
    "        return self._get_state()\n",
    "\n",
    "    def _get_state(self):\n",
    "        dist = self.dist_matrix[self.current_city]\n",
    "        mask = np.zeros(self.num_cities, dtype=np.float32)\n",
    "        mask[self.visited] = 1.0\n",
    "        state = np.concatenate([dist, mask]).astype(np.float32)\n",
    "        return torch.tensor(state, dtype=torch.float32)\n",
    "\n",
    "    def step(self, action):\n",
    "        assert action in self.available, \"Invalid action\"\n",
    "        self.visited.append(action)\n",
    "        self.available.remove(action)\n",
    "        prev = self.current_city\n",
    "        self.current_city = action\n",
    "        reward = -self.dist_matrix[prev, action]\n",
    "        done = len(self.visited) == self.num_cities\n",
    "        if done:\n",
    "            reward -= self.dist_matrix[self.current_city, self.end_city]\n",
    "        return self._get_state(), float(reward), done, {}\n",
    "\n",
    "# Deep Q‑Network Model\n",
    "class DQN(nn.Module):\n",
    "    def __init__(self, num_cities, hidden_dim=128):\n",
    "        super().__init__()\n",
    "        input_dim = num_cities * 2\n",
    "        \n",
    "        self.net = nn.Sequential(\n",
    "            nn.Linear(input_dim, hidden_dim),\n",
    "            nn.ReLU(),\n",
    "            nn.LayerNorm(hidden_dim),\n",
    "            nn.Linear(hidden_dim, hidden_dim),\n",
    "            nn.ReLU(),\n",
    "            nn.LayerNorm(hidden_dim),\n",
    "            nn.Linear(hidden_dim, hidden_dim),\n",
    "            nn.ReLU(),\n",
    "            nn.Linear(hidden_dim, num_cities)\n",
    "        )\n",
    "\n",
    "    def forward(self, x):\n",
    "        return self.net(x.to(torch.float32))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "93e88ebb-a579-4492-b2de-c813d101e0a3",
   "metadata": {},
   "source": [
    "## Base DQN"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0abea237-e111-43cb-9aa5-3d9e926a7bd1",
   "metadata": {},
   "outputs": [],
   "source": [
    "def train_dqn(env, episodes=1000, batch_size=64, gamma=0.99,\n",
    "              lr=1e-3, eps_start=1.0, eps_end=0.05, eps_stop_decrease=900, max_buffer = 10000, update_wait = 1):\n",
    "    device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
    "    policy_net = DQN(env.num_cities).to(device)\n",
    "    target_net = DQN(env.num_cities).to(device)\n",
    "    target_net.load_state_dict(policy_net.state_dict())\n",
    "    optimizer = optim.Adam(policy_net.parameters(), lr=lr)\n",
    "    memory = deque(maxlen=max_buffer)\n",
    "    eps = eps_start\n",
    "    eps_decay = round(pow(eps_end/eps_start, 1/eps_stop_decrease),6)\n",
    "    reward_history = []\n",
    "    loss_history = []\n",
    "\n",
    "    best_tour = []\n",
    "    best_reward = -np.inf\n",
    "    print(\"gamma\",gamma, \"lr\", lr, \"eps_start\", eps_start, \"eps_end\", eps_end, \"eps_decay\", eps_decay)\n",
    "\n",
    "    for ep in trange(episodes, desc='Training Base DQ-L'):\n",
    "\n",
    "        state = env.reset().to(device)\n",
    "        total_reward = 0.0\n",
    "        total_loss = 0.0\n",
    "        count_loss = 0\n",
    "        done = False\n",
    "        tour_l = [env.start_city]\n",
    "\n",
    "        \n",
    "        while not done:\n",
    "            if random.random() < eps:\n",
    "                action = random.choice(list(env.available))\n",
    "            else:\n",
    "                with torch.no_grad():\n",
    "                    q = policy_net(state.unsqueeze(0)).cpu().numpy()[0]\n",
    "                q[env.visited] = -np.inf\n",
    "                action = int(np.argmax(q))\n",
    "            tour_l.append(action)\n",
    "            next_state, reward, done, _ = env.step(action)\n",
    "            next_state = next_state.to(device)\n",
    "            memory.append((state, action, reward, next_state, done))\n",
    "            state = next_state\n",
    "            total_reward += reward\n",
    "\n",
    "            if len(memory) >= batch_size:\n",
    "                batch = random.sample(memory, batch_size)\n",
    "                s, a, r, s2, d = zip(*batch)\n",
    "                s = torch.stack(s).to(device)\n",
    "                s2 = torch.stack(s2).to(device)\n",
    "                a = torch.tensor(a, device=device)\n",
    "                r = torch.tensor(r, device=device)\n",
    "                d = torch.tensor(d, device=device, dtype=torch.float32)\n",
    "\n",
    "                q_vals = policy_net(s).gather(1, a.unsqueeze(1)).squeeze()\n",
    "                next_q = target_net(s2)\n",
    "                # We should consider maximum expected reward only for legal actions from s`.\n",
    "                mask = s2[:,env.num_cities:] # right half of state vector is mask if visited sities 1 if visited else 0\n",
    "\n",
    "                next_q = next_q.masked_fill(mask.bool(), float('-inf'))\n",
    "\n",
    "                next_q = next_q.max(1)[0]\n",
    "                next_q = torch.where(next_q == float('-inf'), torch.tensor(0.0, device=next_q.device), next_q)\n",
    "                target = r + gamma * next_q * (1 - d)\n",
    "                # target = r + gamma * next_q\n",
    "\n",
    "\n",
    "                criterion = nn.SmoothL1Loss()\n",
    "                loss = criterion(q_vals, target.detach())\n",
    "                optimizer.zero_grad()\n",
    "                loss.backward()\n",
    "                optimizer.step()\n",
    "\n",
    "                total_loss += loss.item()\n",
    "                count_loss += 1\n",
    "\n",
    "        tour_l.append(env.end_city)\n",
    "\n",
    "        if total_reward>best_reward:\n",
    "            best_reward = total_reward\n",
    "            best_tour = tour_l\n",
    "        \n",
    "        eps = max(eps_end, eps * eps_decay)\n",
    "        if ep % update_wait == 0:\n",
    "            target_net.load_state_dict(policy_net.state_dict())\n",
    "\n",
    "        reward_history.append(-total_reward)\n",
    "        loss_history.append(total_loss / count_loss if count_loss>0 else 0)\n",
    "\n",
    "    return policy_net, reward_history, loss_history, best_tour\n",
    "\n",
    "# 4) Compute tour length\n",
    "def compute_length(coords, tour):\n",
    "    length=0\n",
    "    for i in range(len(tour)-1): length+=np.linalg.norm(coords[tour[i]]-coords[tour[i+1]])\n",
    "    return length\n",
    "\n",
    "# 5) Plot tours\n",
    "def plot_tours(coords, t1, t2):\n",
    "    plt.figure()\n",
    "    plt.scatter(coords[:,0], coords[:,1])\n",
    "    p1=np.array([coords[i] for i in t1+[t1[0]]]); plt.plot(p1[:,0],p1[:,1],'o-')\n",
    "    p2=np.array([coords[i] for i in t2+[t2[0]]]); plt.plot(p2[:,0],p2[:,1],'x--')\n",
    "    plt.title(f\"Lengths: DQN={compute_length(coords,t1):.2f}, Rand={compute_length(coords,t2):.2f}\")\n",
    "    plt.legend(['DQN','Random'])\n",
    "    plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "4dc01ee0-c15f-40cc-a035-aab5785a037b",
   "metadata": {},
   "source": [
    "## RTD DQN"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7ecb9a39-1e6c-463f-bd45-0a32c4102442",
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import torch\n",
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "import networkx as nx\n",
    "from sklearn.manifold import MDS\n",
    "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
    "\n",
    "\n",
    "def apply_positional_encoding(state: torch.Tensor, visit_order: list[int], end_city: int) -> torch.Tensor:\n",
    "\n",
    "    device = state.device\n",
    "    N = state.shape[0] // 2\n",
    "    dist_part = state[:N]\n",
    "\n",
    "    pos_enc = torch.zeros(N, dtype=torch.float32, device=device)\n",
    "    L = len(visit_order)\n",
    "    if L > 0:\n",
    "        for rank, city in enumerate(visit_order):\n",
    "            pos_enc[city] = float(rank + 1)\n",
    "    pos_enc[end_city] = 1\n",
    "\n",
    "    pos_enc = pos_enc/torch.linalg.norm(pos_enc)\n",
    "    return torch.cat([dist_part, pos_enc], dim=0)\n",
    "\n",
    "\n",
    "def masked_dist_matrix(edges, distmatrix):\n",
    "    n = distmatrix.shape[0]\n",
    "    masked = torch.full((n, n), float('inf'), device=device)\n",
    "    masked.fill_diagonal_(0)\n",
    "\n",
    "    for u, v in edges:\n",
    "        masked[u, v] = distmatrix[u, v]\n",
    "        masked[v, u] = distmatrix[v, u]\n",
    "\n",
    "    return masked\n",
    "\n",
    "def find_edge_index(tour, edge):\n",
    "    edge_set = set(edge)\n",
    "    for i, e in enumerate(tour):\n",
    "        if set(e) == edge_set:\n",
    "            return i\n",
    "    return \"not in tour\"\n",
    "\n",
    "\n",
    "def draw_mst_and_tour(positions, mst, tour):\n",
    "    pos_dict = {i: tuple(positions[i]) for i in range(len(positions))}\n",
    "\n",
    "    def draw_graph(ax, edge_list, title):\n",
    "        G = nx.Graph()\n",
    "        G.add_edges_from(edge_list)\n",
    "        nx.draw(\n",
    "            G, pos=pos_dict, with_labels=True, node_color='lightblue',\n",
    "            edge_color='gray', node_size=500, font_size=10, ax=ax\n",
    "        )\n",
    "        ax.set_title(title)\n",
    "\n",
    "    fig, axes = plt.subplots(1, 2, figsize=(12, 6))\n",
    "    draw_graph(axes[0], mst, \"MST\")\n",
    "    draw_graph(axes[1], tour, \"Tour\")\n",
    "\n",
    "    plt.tight_layout()\n",
    "    plt.show()\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "def train_dqn_RTD(env, episodes=1000, batch_size=64, gamma=0.95,\n",
    "              lr=1e-3, eps_start=1.0, eps_end=0.05, eps_stop_decrease=900, beta = 0.1, max_buffer = 10000, pos_encoding = False, update_wait = 10):\n",
    "    device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
    "    policy_net = DQN(env.num_cities).to(device)\n",
    "    target_net = DQN(env.num_cities).to(device)\n",
    "    target_net.load_state_dict(policy_net.state_dict())\n",
    "    optimizer = optim.Adam(policy_net.parameters(), lr=lr)\n",
    "    memory = []\n",
    "    eps = eps_start\n",
    "    eps_decay = round(pow(eps_end/eps_start, 1/eps_stop_decrease),6)\n",
    "    reward_history = []\n",
    "    loss_history = []\n",
    "    \n",
    "    best_tour = []\n",
    "    best_reward = -np.inf\n",
    "    \n",
    "    print(\"gamma\",gamma, \"lr\", lr, \"eps_start\", eps_start, \"eps_end\", eps_end, \"eps_decay\", eps_decay, \"beta\", beta)\n",
    "\n",
    "    dist_matrix =  torch.tensor(env.dist_matrix).to(device)\n",
    "\n",
    "    for ep in trange(episodes, desc='Training RTD DQ-L'):\n",
    "        if ep < update_wait:\n",
    "            zero_shouting = 0\n",
    "        else:\n",
    "            zero_shouting = 1\n",
    "        state = env.reset().to(device)\n",
    "        tour_l = [env.start_city]\n",
    "        if pos_encoding:\n",
    "            state = apply_positional_encoding(state, tour_l, env.end_city)\n",
    "        total_reward = 0.0\n",
    "        total_loss = 0.0\n",
    "        count_loss = 0\n",
    "        done = False\n",
    "        \n",
    "        tour = []\n",
    "        current_city = env.start_city\n",
    "\n",
    "        while not done:\n",
    "            if random.random() < eps:\n",
    "                action = random.choice(list(env.available))\n",
    "            else:\n",
    "                with torch.no_grad():\n",
    "                    q = policy_net(state.unsqueeze(0)).cpu().numpy()[0]\n",
    "                q[env.visited] = -np.inf\n",
    "                action = int(np.argmax(q))\n",
    "\n",
    "            tour_l.append(action)\n",
    "            tour.append([action, current_city])\n",
    "            current_city = action\n",
    "            next_state, reward, done, _ = env.step(action)\n",
    "            total_reward += reward\n",
    "            next_state = next_state.to(device)\n",
    "            if pos_encoding:\n",
    "                next_state = apply_positional_encoding(next_state, tour_l, env.end_city)\n",
    "\n",
    "            memory.append([state, action, reward, next_state, done])\n",
    "            state = next_state\n",
    "            \n",
    "            if len(memory[-max_buffer: ep*env.num_cities]) >= batch_size:\n",
    "                batch = random.sample(memory[-max_buffer: ep*env.num_cities], batch_size) ## add boundaries\n",
    "                s, a, r, s2, d = zip(*batch)\n",
    "                s = torch.stack(s).to(device)\n",
    "                s2 = torch.stack(s2).to(device)\n",
    "                a = torch.tensor(a, device=device)\n",
    "                r = torch.tensor(r, device=device)\n",
    "                d = torch.tensor(d, device=device, dtype=torch.float32)\n",
    "\n",
    "                q_vals = policy_net(s).gather(1, a.unsqueeze(1)).squeeze()\n",
    "                next_q = target_net(s2)\n",
    "                # We should consider maximum expected reward only for legal actions from s`.\n",
    "                mask = s2[:,env.num_cities:] # right half of state vector is mask if visited sities 1 if visited else 0\n",
    "\n",
    "                next_q = next_q.masked_fill(mask.bool(), float('-inf'))\n",
    "\n",
    "                next_q = next_q.max(1)[0]\n",
    "                next_q = torch.where(next_q == float('-inf'), torch.tensor(0.0, device=next_q.device), next_q)\n",
    "                target = r + gamma * next_q * (1 - d)*zero_shouting\n",
    "\n",
    "                \n",
    "                criterion = nn.SmoothL1Loss()\n",
    "                loss = criterion(q_vals, target.detach())\n",
    "                optimizer.zero_grad()\n",
    "                loss.backward()\n",
    "                optimizer.step()\n",
    "\n",
    "                total_loss += loss.item()\n",
    "                count_loss += 1\n",
    "        tour_l.append(env.end_city)\n",
    "        tour.append([env.end_city, current_city])\n",
    "\n",
    "        if total_reward>best_reward:\n",
    "            best_reward = total_reward\n",
    "            best_tour = tour_l\n",
    "        \n",
    "        eps = max(eps_end, eps * eps_decay)\n",
    "        if ep % update_wait == 0:\n",
    "            target_net.load_state_dict(policy_net.state_dict())\n",
    "\n",
    "        endges_w, rmin_edge_idx, path_edges_from_barcodes = RTD_Lite(dist_matrix, masked_dist_matrix(tour, dist_matrix))()\n",
    "        endges_w = endges_w[\"2->1\"]\n",
    "\n",
    "        for i, edge in enumerate(path_edges_from_barcodes):\n",
    "\n",
    "            s = find_edge_index(tour, edge)\n",
    "\n",
    "            if s != \"not in tour\" and s-env.num_cities+2!=0:\n",
    "                memory[s-env.num_cities+2][2] +=  float((endges_w[i][1] - endges_w[i][0]).detach().cpu())*beta # update reward\n",
    "\n",
    "\n",
    "        reward_history.append(-total_reward)\n",
    "        loss_history.append(total_loss / count_loss if count_loss>0 else 0)\n",
    "\n",
    "\n",
    "    return policy_net, reward_history, loss_history, best_tour"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "48aee186-d6ee-44f4-aeba-0f4980903dc7",
   "metadata": {},
   "source": [
    "## Testing"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9800506b-4268-4867-b0ff-fd414479ea89",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "from ortools.constraint_solver import routing_enums_pb2\n",
    "from ortools.constraint_solver import pywrapcp\n",
    "from scipy.spatial import distance_matrix\n",
    "import numpy as np\n",
    "import pickle\n",
    "\n",
    "def make_unique_folder(base_folder_name):\n",
    "    folder_name = base_folder_name.rstrip('/')\n",
    "    k = 1\n",
    "    final_name = folder_name\n",
    "    while os.path.exists(final_name):\n",
    "        final_name = f\"{folder_name}_try_{k}\"\n",
    "        k += 1\n",
    "    os.makedirs(final_name, exist_ok=True)\n",
    "    return final_name+\"/\"\n",
    "\n",
    "\n",
    "def solve_tsp_ortools(coords, start, end):\n",
    "    num_cities = len(coords)\n",
    "    dist = distance_matrix(coords, coords)\n",
    "\n",
    "    manager = pywrapcp.RoutingIndexManager(num_cities, 1, [start], [end])\n",
    "    routing = pywrapcp.RoutingModel(manager)\n",
    "\n",
    "    def distance_callback(from_index, to_index):\n",
    "        from_node = manager.IndexToNode(from_index)\n",
    "        to_node = manager.IndexToNode(to_index)\n",
    "        return int(dist[from_node][to_node] * 1e6)  # Scale to int\n",
    "\n",
    "    transit_callback_index = routing.RegisterTransitCallback(distance_callback)\n",
    "    routing.SetArcCostEvaluatorOfAllVehicles(transit_callback_index)\n",
    "\n",
    "    search_params = pywrapcp.DefaultRoutingSearchParameters()\n",
    "    search_params.first_solution_strategy = (\n",
    "        routing_enums_pb2.FirstSolutionStrategy.PATH_CHEAPEST_ARC\n",
    "    )\n",
    "\n",
    "    solution = routing.SolveWithParameters(search_params)\n",
    "\n",
    "    if solution:\n",
    "        index = routing.Start(0)\n",
    "        route = []\n",
    "        while not routing.IsEnd(index):\n",
    "            route.append(manager.IndexToNode(index))\n",
    "            index = solution.Value(routing.NextVar(index))\n",
    "        route.append(manager.IndexToNode(index))\n",
    "        return route\n",
    "    else:\n",
    "        raise ValueError(\"No solution found by OR-Tools.\")\n",
    "\n",
    "\n",
    "def compare_train_methods(episodes=300, num_cities=20, params_base={}, params_rtd={}, plot=True, save=True, prefix=\"Exp\",coords = None):\n",
    "    if coords is not None:\n",
    "        env_base = TSPEnv(num_cities, coords = coords.copy())\n",
    "    else:\n",
    "        env_base = TSPEnv(num_cities)\n",
    "    env_rtd = TSPEnv(num_cities, coords=env_base.coords.copy())\n",
    "\n",
    "    # Solve with OR-Tools\n",
    "    print(\"Calculating optimal route by solver\")\n",
    "    optimal_tour = solve_tsp_ortools(env_base.coords, env_base.start_city, env_base.end_city)\n",
    "    optimal_length = compute_length(env_base.coords, optimal_tour)\n",
    "\n",
    "    beta = abs(params_rtd[\"beta\"])\n",
    "    eps_stop_decrease = params_rtd[\"eps_stop_decrease\"]\n",
    "    update_wait = params_rtd[\"update_wait\"]\n",
    "    folder_name = f\"pics/{prefix}_cities_{num_cities}_ep_{episodes}_beta_{beta}_stop_eps_{eps_stop_decrease}_update_wait_{update_wait}/\"\n",
    "\n",
    "    if save:\n",
    "        try:\n",
    "            folder_name = make_unique_folder(folder_name)\n",
    "            print(\"Results will be saved at:\", folder_name)\n",
    "        except Exception as e:\n",
    "            print(f\"An error occurred: {e}\")\n",
    "\n",
    "\n",
    "    model_base, rewards_base, losses_base, best_tour_base = train_dqn(env_base, episodes=episodes, **params_base)\n",
    "    model_rtd, rewards_rtd, losses_rtd, best_tour_rtd = train_dqn_RTD(env_rtd, episodes=episodes, **params_rtd)\n",
    "\n",
    "    # _, rewards_random, _, best_tour_random = train_dqn(env_base, episodes=episodes, eps_start=1.0, eps_end=1.0)\n",
    "\n",
    "    # ---------- FIGURE 1: Reward and Loss Curves ----------\n",
    "    plt.figure(figsize=(12, 5))\n",
    "\n",
    "    plt.subplot(1, 2, 1)\n",
    "    plt.plot(rewards_base, label='model_base')\n",
    "    plt.plot(rewards_rtd, label='model_RTD')\n",
    "    # plt.plot(rewards_random, label='random search')\n",
    "    plt.axhline(optimal_length, color='black', linestyle='--', label=f\"Optimal ({optimal_length:.2f})\")\n",
    "    plt.title('Path length per Episode')\n",
    "    plt.xlabel('Episode')\n",
    "    plt.ylabel('Path length')\n",
    "    plt.legend()\n",
    "    plt.grid()\n",
    "\n",
    "    plt.subplot(1, 2, 2)\n",
    "    plt.plot(losses_base, label='model_base')\n",
    "    plt.plot(losses_rtd, label='model_RTD')\n",
    "    plt.title('Loss per Episode')\n",
    "    plt.xlabel('Episode')\n",
    "    plt.ylabel('Loss')\n",
    "    plt.legend()\n",
    "    plt.grid()\n",
    "\n",
    "    if save:\n",
    "        plt.tight_layout()\n",
    "        plt.savefig(folder_name + f'Curves_ep_{episodes}_cities_{num_cities}.png')\n",
    "    if plot:\n",
    "        plt.show()\n",
    "    else:\n",
    "        plt.close()\n",
    "\n",
    "    # ---------- FIGURE 2: Four Tours (random, base, RTD, optimal) ----------\n",
    "    coords = env_base.coords\n",
    "\n",
    "    tours = {\n",
    "        # \"Random best Tour\": (rand, 'x--', f'random ({compute_length(coords, rand):.2f})'),\n",
    "        \"Base Model best Tour\": (best_tour_base, 'o-', f'model_base ({compute_length(coords, best_tour_base):.2f})'),\n",
    "        \"RTD Model best Tour\": (best_tour_rtd, 'o-', f'model_RTD ({compute_length(coords, best_tour_rtd):.2f})'),\n",
    "        \"Optimal Tour (OR-Tools)\": (optimal_tour, 's-', f'optimal ({optimal_length:.2f})')\n",
    "    }\n",
    "    if save:\n",
    "        with open(folder_name+ 'saved_tours.pkl', 'wb') as f:\n",
    "            pickle.dump(tours, f)\n",
    "\n",
    "    plt.figure(figsize=(24, 6))  # 4 horizontal subplots\n",
    "\n",
    "    for i, (title, (tour, style, label)) in enumerate(tours.items(), 1):\n",
    "        plt.subplot(1, 4, i)\n",
    "        plt.scatter(coords[:, 0], coords[:, 1], c='black')\n",
    "        plt.scatter(coords[env_base.start_city, 0], coords[env_base.start_city, 1], c='green', s=100, label=\"start\")\n",
    "        plt.scatter(coords[env_base.end_city, 0], coords[env_base.end_city, 1], c='red', s=100, label=\"end\")\n",
    "\n",
    "        path = np.array([coords[i] for i in tour])\n",
    "        plt.plot(path[:, 0], path[:, 1], style, label=label)\n",
    "        plt.title(title)\n",
    "        plt.legend()\n",
    "        plt.grid()\n",
    "\n",
    "    if save:\n",
    "        plt.tight_layout()\n",
    "        plt.savefig(folder_name + f'Tours_ep_{episodes}_cities_{num_cities}.png')\n",
    "    if plot:\n",
    "        plt.show()\n",
    "    else:\n",
    "        plt.close()\n",
    "\n",
    "    # ---------- SAVE METRICS TO CSV ----------\n",
    "    if save:\n",
    "        # 1. Save rewards and losses\n",
    "        metrics_path = folder_name + \"metrics.csv\"\n",
    "        df = pd.DataFrame({\n",
    "            \"episode\": list(range(episodes)),\n",
    "            \"reward_base\": rewards_base,\n",
    "            \"reward_rtd\": rewards_rtd,\n",
    "            # \"reward_random\": rewards_random,\n",
    "            \"loss_base\": losses_base,\n",
    "            \"loss_rtd\": losses_rtd\n",
    "        })\n",
    "        df.to_csv(metrics_path, index=False)\n",
    "\n",
    "        # 2. Save parameters to JSON\n",
    "        params_all = {\n",
    "            \"episodes\": episodes,\n",
    "            \"num_cities\": num_cities,\n",
    "            \"params_base\": params_base,\n",
    "            \"params_rtd\": params_rtd\n",
    "        }\n",
    "        with open(folder_name + \"params.json\", \"w\") as f:\n",
    "            json.dump(params_all, f, indent=4)\n",
    "\n",
    "        # 3. Save city coordinates\n",
    "        coords_path = folder_name + \"coords.csv\"\n",
    "        coords_df = pd.DataFrame(coords, columns=[\"x\", \"y\"])\n",
    "        coords_df.to_csv(coords_path, index=False)\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "28efd710-49b7-4a1b-8224-7f0d2ab51d38",
   "metadata": {},
   "source": [
    "## Set of experiments with one city for city_num"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2326cd98-a333-4e1f-a4d9-700d9ae9b3b5",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "eps_start = 1.0\n",
    "eps_end = 0.05\n",
    "\n",
    "\n",
    "update_wait = 1\n",
    "batch_size = 64\n",
    "\n",
    "episodes = 600\n",
    "\n",
    "eps_stop_decrease_list = [550, 550, 550, 900, 900, 550, 550, 550, 900, 900, 550, 550, 550, 900, 900]\n",
    "episodes_list = [600, 600, 600, 1000, 1000, 600, 600, 600, 1000, 1000, 600, 600, 600, 1000, 1000]\n",
    "lr_list = [1e-3, 1e-3, 1e-3, 1e-4, 1e-4, 1e-3, 1e-3, 1e-3, 1e-4, 1e-4, 1e-3, 1e-3, 1e-3, 1e-4, 1e-4]\n",
    "num_cities_list = [50, 50, 50, 50, 50, 70, 70, 70, 70, 70, 100, 100, 100, 100, 100]\n",
    "\n",
    "coords_list = [TSPEnv(n).coords for n in num_cities_list]\n",
    "\n",
    "for i, (num_cities, coords, lr, episodes, eps_stop_decrease) in enumerate(zip(num_cities_list, coords_list, lr_list, episodes_list, eps_stop_decrease_list)):\n",
    "    for iteration in range(5):\n",
    "        max_buffer = 10000\n",
    "        params_base = {\"batch_size\":batch_size, \"gamma\":0.99, \n",
    "                  \"lr\":1e-3, \"eps_start\":eps_start,\n",
    "                  \"eps_end\":eps_end, \"eps_stop_decrease\": eps_stop_decrease, \"max_buffer\": max_buffer, \"update_wait\": update_wait} \n",
    "        \n",
    "        params_rtd = {\"batch_size\":batch_size, \"gamma\":0.99, \n",
    "                  \"lr\":1e-3, \"eps_start\":eps_start,\n",
    "                  \"eps_end\":eps_end, \"eps_stop_decrease\": eps_stop_decrease, \"beta\" : -1.5, \"max_buffer\": max_buffer, \"pos_encoding\": True, \"update_wait\": update_wait}\n",
    "        \n",
    "        \n",
    "        compare_train_methods(episodes=episodes, num_cities=num_cities, plot = False, save = True, params_base = params_base, params_rtd = params_rtd, prefix = f\"Diff_params_folder_cities_{num_cities}_try_5_beta_1_5_cluster_10/lr_{lr}_fixed_city_{i}\", coords = coords)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "01de4fd8-fa98-499a-a875-6c1e443c6d2f",
   "metadata": {},
   "source": [
    "## Visualization"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d4b6e281-3398-442a-ba32-3b71b2758199",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import pickle\n",
    "import pandas as pd\n",
    "import re\n",
    "from collections import defaultdict\n",
    "\n",
    "def average_grouped_metrics(grouped_metrics):\n",
    "    averaged_metrics = {}\n",
    "    for city_id, metrics_list in grouped_metrics.items():\n",
    "        keys = metrics_list[0].keys()\n",
    "        averaged_metrics[city_id] = {}\n",
    "        for key in keys:\n",
    "            if \"optimal\" not in key:\n",
    "                # Stack values into a 2D array: (num_experiments, series_length)\n",
    "                stacked = np.stack([m[key].values for m in metrics_list])\n",
    "                # Compute the mean across experiments\n",
    "                averaged = stacked.mean(axis=0)\n",
    "                averaged_metrics[city_id][key] = averaged\n",
    "        averaged_metrics[city_id][\"optimal_length\"] =metrics_list[0][\"optimal_length\"]\n",
    "    return averaged_metrics\n",
    "\n",
    "def extract_lengths_grouped_by_city(base_path=\"pics\", prefix_template=\"fixed_city_{}\", cities=[0, 1, 2 ,3, 4]):\n",
    "    grouped_data = defaultdict(list)\n",
    "    grouped_metrics = defaultdict(list)\n",
    "\n",
    "    def extract_length(label):\n",
    "        match = re.search(r\"\\(([\\d.]+)\\)\", label)\n",
    "        return float(match.group(1)) if match else None\n",
    "\n",
    "    for city_id in cities:\n",
    "        prefix = prefix_template.format(city_id)\n",
    "\n",
    "        for folder in os.listdir(base_path):\n",
    "            if prefix not in folder:\n",
    "                continue\n",
    "            full_path = os.path.join(base_path, folder)\n",
    "            pkl_path = os.path.join(full_path, \"saved_tours.pkl\")\n",
    "            metrics_path =  os.path.join(full_path, \"metrics.csv\")\n",
    "\n",
    "            if not os.path.isfile(pkl_path):\n",
    "                print(f\"Skipped (no pkl): {folder}\")\n",
    "                continue\n",
    "\n",
    "            try:\n",
    "                with open(pkl_path, \"rb\") as f:\n",
    "                    tours = pickle.load(f)\n",
    "\n",
    "                base_len = extract_length(tours[\"Base Model best Tour\"][2])\n",
    "                rtd_len = extract_length(tours[\"RTD Model best Tour\"][2])\n",
    "                opt_len = extract_length(tours[\"Optimal Tour (OR-Tools)\"][2])\n",
    "\n",
    "                grouped_data[city_id].append({\n",
    "                    \"experiment\": folder,\n",
    "                    \"base_length\": base_len,\n",
    "                    \"rtd_length\": rtd_len,\n",
    "                    \"optimal_length\": opt_len\n",
    "                })\n",
    "\n",
    "                metrics_data = pd.read_csv(metrics_path)\n",
    "                grouped_metrics[city_id].append({\n",
    "                    \"reward_base\": metrics_data[\"reward_base\"],\n",
    "                    \"reward_rtd\": metrics_data[\"reward_rtd\"],\n",
    "                    \"loss_base\": metrics_data[\"loss_base\"],\n",
    "                    \"loss_rtd\": metrics_data[\"loss_rtd\"],\n",
    "                    \"optimal_length\": opt_len\n",
    "                })\n",
    "\n",
    "            except Exception as e:\n",
    "                print(f\"Error reading {pkl_path}: {e}\")\n",
    "\n",
    "    # Convert to DataFrame\n",
    "    all_rows = []\n",
    "    for city_id, records in grouped_data.items():\n",
    "        for rec in records:\n",
    "            rec[\"city_id\"] = city_id\n",
    "            all_rows.append(rec)\n",
    "\n",
    "    df = pd.DataFrame(all_rows)\n",
    "\n",
    "    # Optionally aggregate\n",
    "    df_agg = df.groupby(\"city_id\")[[\"base_length\", \"rtd_length\", \"optimal_length\"]].mean().reset_index()\n",
    "    \n",
    "\n",
    "    return df, df_agg, grouped_metrics\n",
    "\n",
    "df_all, df_avg, grouped_metrics = extract_lengths_grouped_by_city(base_path=\"pics/Diff_params_folder_cities_50_try_5_beta_1_5_cluster_10\", cities = np.array([0, 1, 2, 3, 4]))\n",
    "\n",
    "df_all"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ca581345-2df1-4ef7-bd37-a5464afa48a0",
   "metadata": {},
   "outputs": [],
   "source": [
    "import matplotlib\n",
    "matplotlib.rcParams['mathtext.fontset'] = 'custom'\n",
    "matplotlib.rcParams['mathtext.rm'] = 'Bitstream Vera Sans'\n",
    "matplotlib.rcParams['mathtext.it'] = 'Bitstream Vera Sans:italic'\n",
    "matplotlib.rcParams['mathtext.bf'] = 'Bitstream Vera Sans:bold'\n",
    "\n",
    "\n",
    "matplotlib.rcParams['mathtext.fontset'] = 'stix'\n",
    "matplotlib.rcParams['font.family'] = 'STIXGeneral'\n",
    "matplotlib.rcParams.update({'font.size': 15})\n",
    "\n",
    "metrics = average_grouped_metrics(grouped_metrics)\n",
    "for i in range(5):\n",
    "    i = i\n",
    "    plt.figure(figsize=(12, 5))\n",
    "\n",
    "    plt.subplot(1, 2, 1)\n",
    "    plt.plot(metrics[i][\"reward_base\"], label='DQN')\n",
    "    plt.plot(metrics[i][\"reward_rtd\"], label='DQN RTDL')\n",
    "    optimal_length = metrics[i][\"optimal_length\"]\n",
    "    plt.axhline(optimal_length, color='black', linestyle='--', label=f\"Concorde ({optimal_length:.2f})\")\n",
    "    plt.title('Tour length per Episode averaged for 5 seeds')\n",
    "    plt.xlabel('Episode')\n",
    "    plt.ylabel('Tour length')\n",
    "    plt.legend()\n",
    "    plt.grid()\n",
    "\n",
    "    plt.subplot(1, 2, 2)\n",
    "    plt.plot(metrics[i][\"loss_base\"], label='DQN')\n",
    "    plt.plot(metrics[i][\"loss_rtd\"], label='DQN RTDL')\n",
    "    plt.title('Loss per Episode')\n",
    "    plt.xlabel('Episode')\n",
    "    plt.ylabel('Loss')\n",
    "    plt.legend()\n",
    "    plt.grid()"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.11"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
