"""Lightning module for training the TSP model."""

import os
import math
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.data
from pytorch_lightning.utilities import rank_zero_info
from scipy.sparse import coo_matrix

from co_datasets.tsp_graph_dataset import TSPGraphDataset
from pl_meta_model_rddm import COMetaModel
from utils.tsp_utils import TSPEvaluator, batched_two_opt_torch, merge_tours, multip_cal_subheatmaps


class TSPModel(COMetaModel):
  def __init__(self,
               param_args=None):
    super(TSPModel, self).__init__(param_args=param_args, node_feature_only=False)

    self.train_dataset = TSPGraphDataset(
        data_file=os.path.join(self.args.storage_path, self.args.training_split),
        sparse_factor=self.args.sparse_factor,
    )

    self.test_dataset = TSPGraphDataset(
        data_file=os.path.join(self.args.storage_path, self.args.test_split),
        sparse_factor=self.args.sparse_factor,
    )

    self.validation_dataset = TSPGraphDataset(
        data_file=os.path.join(self.args.storage_path, self.args.validation_split),
        sparse_factor=self.args.sparse_factor,
    )

    self.predict_dataset = TSPGraphDataset(
        data_file=os.path.join(self.args.storage_path, self.args.predict_split),
        sparse_factor=self.args.sparse_factor,
    )

    if self.sparse:
      self.all_edge_ranks = []
      self.best_sub_heatmaps = []

  def forward(self, x, adj, t, edge_index):
    return self.model(x, t, adj, edge_index)

  def training_step(self, batch, batch_idx):
    exit(0)

  def gaussian_denoise_step(self, points, xt, x_in, t, device, edge_index=None, target_t=None):
    with torch.no_grad():
      t = t.view(xt.shape[0])
      points = points.float().to(device)
      xt = xt.float().to(device)
      x_res_pred, noise_pred = self.forward(
          points,
          xt,
          t,
          edge_index.long().to(device) if edge_index is not None else None,
      )

      xt = self.gaussian_posterior(target_t.cpu(), t.cpu(), noise_pred.cpu(), x_res_pred.cpu(), xt.cpu())

      return xt.to(device)

  def generate_x_in(self, points, edge_index=None, is_subgraph=False):
    if is_subgraph and self.args.use_multi_processing:
      tour = np.arange(0, self.args.sub_graph_size)
      tour = np.append(tour, 0)
      tour_edges = np.zeros(self.args.sub_graph_size, dtype=np.int64)
      tour_edges[tour[:-1]] = tour[1:]
      tour_edges = torch.from_numpy(tour_edges).to(edge_index.device)
      degree = min(self.args.sparse_factor, self.args.sub_graph_size)
      tour_edges = tour_edges.reshape((-1, 1)).repeat(1, degree).reshape(-1)
      tour_edges = tour_edges.repeat(1, len(points)).squeeze()
      edge_index_1 = edge_index[1]
      x_in = torch.eq(edge_index_1, tour_edges).reshape(-1, 1)
      x_in = x_in.to(edge_index.device)
    elif self.sparse:
      tour = np.arange(0, points.shape[0])
      tour = np.append(tour, 0)
      tour_edges = np.zeros(points.shape[0], dtype=np.int64)
      tour_edges[tour[:-1]] = tour[1:]
      tour_edges = torch.from_numpy(tour_edges).to(points.device)
      if is_subgraph:
        degree = min(self.args.sparse_factor, self.args.sub_graph_size)
        tour_edges = tour_edges.reshape((-1, 1)).repeat(1, degree).reshape(-1)
      else:
        tour_edges = tour_edges.reshape((-1, 1)).repeat(1, self.args.sparse_factor).reshape(-1)
      edge_index_1 = edge_index[1]
      x_in = torch.eq(edge_index_1, tour_edges).reshape(-1, 1)
      x_in = x_in.to(points.device)
    else:
      tour = np.arange(0, points.shape[1])
      tour = np.append(tour, 0)
      x_0 = np.zeros((points.shape[1], points.shape[1]))
      for i in range(tour.shape[0]-1):
        x_0[tour[i], tour[i+1]] = 1
      x_in = x_0
      x_in = torch.from_numpy(x_in).float()
      x_in = x_in.unsqueeze(0)
      x_in = x_in.repeat(points.shape[0], 1, 1)
      x_in = x_in.to(points.device)

    x_in = x_in * 2 - 1
    return x_in

  def test_step(self, batch, batch_idx, split='test'):
    edge_index = None
    np_edge_index = None
    device = batch[-1].device
    if not self.sparse:
      real_batch_idx, points, adj_matrix, gt_tour = batch
      np_points = points.cpu().numpy()[0]
      np_gt_tour = gt_tour.cpu().numpy()[0]
    else:
      real_batch_idx, graph_data, point_indicator, edge_indicator, gt_tour = batch
      route_edge_flags = graph_data.edge_attr
      points = graph_data.x
      edge_index = graph_data.edge_index
      num_edges = edge_index.shape[1]
      batch_size = point_indicator.shape[0]
      adj_matrix = route_edge_flags.reshape((batch_size, num_edges // batch_size))
      points = points.reshape((-1, 2))
      edge_index = edge_index.reshape((2, -1))
      np_points = points.cpu().numpy()
      np_gt_tour = gt_tour.cpu().numpy().reshape(-1)
      np_edge_index = edge_index.cpu().numpy()

    stacked_tours = []
    ns, merge_iterations = 0, 0

    x_in_ori = self.generate_x_in(points, edge_index)
    if self.sparse:
      x_in_ori = x_in_ori.reshape((batch_size, num_edges // batch_size))
    if self.args.parallel_sampling > 1:
      if not self.sparse:
        points = points.repeat(self.args.parallel_sampling, 1, 1)
      else:
        points = points.repeat(self.args.parallel_sampling, 1) # points: 4 * 50, 2
        edge_index = self.duplicate_edge_index(edge_index, np_points.shape[0], device) # 2, 4*500

    for _ in range(self.args.sequential_sampling):
      x_in = x_in_ori.clone()
      epsilon = torch.randn_like(adj_matrix.float())
      if self.args.parallel_sampling > 1:
        if self.sparse:
          epsilon = epsilon.repeat(self.args.parallel_sampling, 1)
          x_in = x_in.repeat(self.args.parallel_sampling, 1)
        else:
          epsilon = epsilon.repeat(self.args.parallel_sampling, 1, 1)
          x_in = x_in.repeat(self.args.parallel_sampling, 1, 1)
      xt = x_in + epsilon

      xt.requires_grad = True

      if self.sparse:
        xt = xt.reshape(-1)
        x_in = x_in.reshape(-1)

      steps = self.args.inference_diffusion_steps
      # Diffusion iterations
      cur_time = torch.ones((xt.shape[0],), device=xt.device)
      step = 1.0 / steps
      for i in range(steps):
        s = torch.full((xt.shape[0], ), step, device=xt.device)
        if i == steps-1:
          s = cur_time
        t1 = cur_time
        t2 = cur_time - s
        xt = self.gaussian_denoise_step(
            points, xt, x_in, t1, device, edge_index, target_t=t2)
        cur_time = cur_time - s

      if self.sparse:
        xt = xt.squeeze()
      adj_mat = xt.cpu().detach().numpy() * 0.5 + 0.5

      if self.args.save_numpy_heatmap:
        if self.sparse:
          # convert to dense adj_mat
          tsp_size = points.shape[0]
          adj_mat_dense = coo_matrix((adj_mat, (edge_index[0].cpu().numpy(), edge_index[1].cpu().numpy())), shape=(tsp_size, tsp_size)).todense()
        self.run_save_numpy_heatmap(adj_mat_dense, np_points, real_batch_idx, split)

      tours, merge_iterations = merge_tours(
          adj_mat, np_points, np_edge_index,
          sparse_graph=self.sparse,
          parallel_sampling=self.args.parallel_sampling,
      )

      solved_tours, ns = batched_two_opt_torch(
          np_points.astype("float64"), np.array(tours).astype('int64'),
          max_iterations=self.args.two_opt_iterations, device=device)
      stacked_tours.append(solved_tours)

    solved_tours = np.concatenate(stacked_tours, axis=0)

    tsp_solver = TSPEvaluator(np_points)
    gt_cost = tsp_solver.evaluate(np_gt_tour)

    total_sampling = self.args.parallel_sampling * self.args.sequential_sampling
    all_solved_costs = [tsp_solver.evaluate(solved_tours[i]) for i in range(total_sampling)]
    best_solved_cost = np.min(all_solved_costs)
    gap = (best_solved_cost - gt_cost) / gt_cost * 100

    metrics = {
        f"{split}/gt_cost": gt_cost,
        f"{split}/2opt_iterations": ns,
        f"{split}/merge_iterations": merge_iterations,
    }
    for k, v in metrics.items():
      self.log(k, v, on_epoch=True, sync_dist=True)
    self.log(f"{split}/solved_cost", best_solved_cost, prog_bar=True, on_epoch=True, sync_dist=True)
    self.log(f"{split}/gap(%)", gap, prog_bar=True, on_epoch=True, sync_dist=True)


    return metrics

  def run_save_numpy_heatmap(self, adj_mat, np_points, real_batch_idx, split):
    exp_save_dir = os.path.join(self.logger.save_dir, self.logger.name, str(self.logger.version))
    heatmap_path = os.path.join(exp_save_dir, 'numpy_heatmap')
    rank_zero_info(f"Saving heatmap to {heatmap_path}")
    os.makedirs(heatmap_path, exist_ok=True)
    real_batch_idx = real_batch_idx.cpu().numpy().reshape(-1)[0]
    np.save(os.path.join(heatmap_path, f"{split}-heatmap-{real_batch_idx}.npy"), adj_mat)
    np.save(os.path.join(heatmap_path, f"{split}-points-{real_batch_idx}.npy"), np_points)

  def validation_step(self, batch, batch_idx):
    return self.test_step(batch, batch_idx, split='val')

  def compute_subgraph_edges(self, edge_index, sub_point_idxes):
    edge_index = edge_index.T.reshape(-1, 2)
    sparse_factor = self.args.sparse_factor
    subgraph_edges = [edge_index[idx*self.args.sparse_factor : (idx+1)*self.args.sparse_factor, : ].T.reshape(2, -1) for idx in sub_point_idxes]
    subgraph_edges = [sub_edges[:, np.isin(sub_edges[1], sub_point_idxes)] for sub_edges in subgraph_edges]
    subgraph_edges = [pointi_edge if len(pointi_edge[0]) >= sparse_factor else np.column_stack((pointi_edge, np.tile(pointi_edge[:, -1:], sparse_factor - len(pointi_edge[0])))) for pointi_edge in subgraph_edges]
    subgraph_edges = np.column_stack(subgraph_edges)
    
    return subgraph_edges
  
  def convert_point_indices(self, ori_arr):
    device = ori_arr.device
    ori_arr = ori_arr.cpu().numpy()
    unique_values, index = np.unique(ori_arr[0], return_inverse=True)
    replaced_first_row = np.arange(len(unique_values))[index]
    value_to_index = {unique_values[i]: i for i in range(len(unique_values))}
    replaced_second_row = np.array([value_to_index[value] for value in ori_arr[1]])
    res_arr = np.vstack((replaced_first_row, replaced_second_row))
    res_arr = torch.from_numpy(res_arr).to(device=device)

    return res_arr

  def Graph_Sampling(self, points, edge_index):
    # sample a single graph
    # input:  info about a large graph TSP-n
    # output: a series of sub-graphs TSP-m from TSP-n
    node_num = points.shape[0]
    np_points = points.cpu().numpy()
    np_edge_index = edge_index.cpu().numpy()
    sub_graph_size = self.args.sub_graph_size
    top_k = sub_graph_size-1
    top_k_expand = np.min((sub_graph_size*2, self.args.sparse_factor))-1
    cluster_center = 0
    distA = np.ones(shape=(node_num, node_num)) * 100
    
    start_indices = np_edge_index[0]
    end_indices = np_edge_index[1]
    distA[start_indices, end_indices] = np.linalg.norm(np_points[start_indices] - np_points[end_indices], axis=1)
    distA[end_indices, start_indices] = distA[start_indices, end_indices]
    
    distB_raw = distA.copy()
    distB = distA + 100.0 * np.eye(N=node_num, M=node_num, dtype=np.float64)

    Omega_w = np.zeros(shape=(node_num, ), dtype=np.int32)
    Omega = np.zeros(shape=(node_num, node_num), dtype=np.int32)
    sub_points = []
    sub_edge_indices = []

    neighbor = np.argpartition(distB, kth=top_k, axis=1)
    neighbor_expand = np.argpartition(distB, kth=top_k_expand, axis=1)

    num_clusters_threshold = math.ceil((node_num / (top_k+1)) * 5)
    all_visited = False
    num_clusters = 0
    min_visited = 0

    while num_clusters < num_clusters_threshold or min_visited < self.args.min_visited:
      if not all_visited:
        cluster_center_neighbor = neighbor[cluster_center, :top_k]
        cluster_center_neighbor = np.insert(cluster_center_neighbor, 0, cluster_center)
        cluster_center_neighbor = np.sort(cluster_center_neighbor)
      else:
        np.random.shuffle(neighbor_expand[cluster_center, :top_k_expand])
        cluster_center_neighbor = neighbor_expand[cluster_center, :top_k]
        cluster_center_neighbor = np.insert(cluster_center_neighbor, 0, cluster_center)
        cluster_center_neighbor = np.sort(cluster_center_neighbor)
      
      Omega_w[cluster_center_neighbor] += 1
      # convert points
      tmp_sub_points = points[cluster_center_neighbor]
      x_min = tmp_sub_points[:, 0].min()
      x_max = tmp_sub_points[:, 0].max()
      y_min = tmp_sub_points[:, 1].min()
      y_max = tmp_sub_points[:, 1].max()
      s = 1 / max((x_max - x_min), (y_max - y_min))
      tmp_sub_points[:, 0] = s * (tmp_sub_points[:, 0] - x_min)
      tmp_sub_points[:, 1] = s * (tmp_sub_points[:, 1] - y_min)
      sub_points.append(tmp_sub_points)


      # calculate edge indices
      sub_edge_index = self.compute_subgraph_edges(np_edge_index, cluster_center_neighbor)

      sub_edge_indices.append(torch.IntTensor(sub_edge_index))

      Omega[sub_edge_index[0], sub_edge_index[1]] += 1
      num_clusters += 1

      if 0 not in Omega_w:
        all_visited = True
      min_visited = Omega_w.min()
      
      cluster_center = np.random.choice(np.where(Omega_w==np.min(Omega_w))[0])

    return sub_points, sub_edge_indices, torch.FloatTensor(Omega)
  
  def Graph_Fusion(self, sub_heatmaps, omegas, all_edge_ranks, node_num, fusion_method):
    # input:  a series of sub-graphs TSP-m sampling from a TSP-n 
    # output: a heatmap of TSP-n
    stacked_heatmaps = torch.stack(sub_heatmaps)
    if fusion_method == "mean":
      heatmap = torch.sparse.sum(stacked_heatmaps, dim=0)
      heatmap = heatmap.to_dense()
      heatmap = heatmap / (omegas + 1e-8)
      heatmap_res = heatmap.numpy()
    elif fusion_method == "max":
      heatmap_max, _ = torch.max(stacked_heatmaps, dim=0)
      heatmap_res = heatmap_max.numpy()
    elif fusion_method == "sum":
      heatmap = torch.sparse.sum(stacked_heatmaps, dim=0)
      heatmap = heatmap.to_dense()
      heatmap = heatmap + heatmap.T
      heatmap_norm = F.normalize(heatmap, dim=1)
      heatmap_res = heatmap_norm.numpy()
      
    return heatmap_res
    
  def predict_step(self, batch, batch_idx, split="predict"):
    # input batch: one large scale TSP (TSP-n)
    # so by default it is sparse
    if not self.sparse:
      raise ValueError(f"Dense tsp not supported!")
    metrics = {}
    edge_index = None
    np_edge_index = None
    device = batch[-1].device

    real_batch_idx, graph_data, point_indicator, edge_indicator, gt_tour = batch
    route_edge_flags = graph_data.edge_attr
    points = graph_data.x
    edge_index = graph_data.edge_index
    num_edges = edge_index.shape[1]
    batch_size = point_indicator.shape[0]
    adj_matrix = route_edge_flags.reshape((batch_size, num_edges // batch_size))
    points = points.reshape((-1, 2))
    edge_index = edge_index.reshape((2, -1))
    np_points = points.cpu().numpy()
    np_gt_tour = gt_tour.cpu().numpy().reshape(-1)
    np_edge_index = edge_index.cpu().numpy()

    sub_points, sub_edge_indices, Omegas = self.Graph_Sampling(points, edge_index)

    if self.args.use_multi_processing:
      if self.args.multip_batchsize:
        total_sub_graphs = len(sub_points)
        start_idx = 0
        end_idx = 0

        while end_idx < total_sub_graphs:
          end_idx = start_idx + self.args.multip_batchsize
          end_idx = end_idx if end_idx < total_sub_graphs else total_sub_graphs
          batched_sub_points = sub_points[start_idx:end_idx]
          batched_edge_indices = sub_edge_indices[start_idx:end_idx]
          start_idx = end_idx

          num_sub_graphs = len(batched_sub_points)
          all_sub_points = torch.vstack(batched_sub_points)
          all_edge_indices = torch.stack(batched_edge_indices)
          all_sub_np_points = all_sub_points.cpu().numpy()
          all_np_edge_indices = all_edge_indices.cpu().numpy()

          all_edge_indices_converted = []
          for sub_edge_index in batched_edge_indices: 
            converted_sub_edge_index = self.convert_point_indices(sub_edge_index)
            all_edge_indices_converted.append(converted_sub_edge_index)

          all_edge_indices_converted = torch.hstack(all_edge_indices_converted).to(device)

          x_in = self.generate_x_in(
                    batched_sub_points, 
                    all_edge_indices_converted, 
                    is_subgraph=True
                  )
          noise = torch.randn_like(x_in.float())
          edge_index_indent = torch.arange(0, num_sub_graphs).to(device)
          edge_index_indent = edge_index_indent * self.args.sub_graph_size
          edge_index_indent = edge_index_indent.reshape(-1, 1).repeat((1, self.args.sparse_factor*batched_sub_points[0].shape[0])).reshape(-1)
          all_edge_indices_converted = all_edge_indices_converted + edge_index_indent
          all_edge_indices_converted = all_edge_indices_converted.reshape((2, -1)).to(device)
          if self.args.parallel_sampling > 1:
            if self.sparse:
              all_sub_points = all_sub_points.repeat(self.args.parallel_sampling, 1)
              all_edge_indices_converted = self.duplicate_edge_index(all_edge_indices_converted, all_sub_np_points.shape[0], device)
              noise = noise.repeat(self.args.parallel_sampling, 1)
              x_in = x_in.repeat(self.args.parallel_sampling, 1)
            else:
              all_sub_points = all_sub_points.repeat(self.args.parallel_sampling, 1, 1)
              noise = noise.repeat(self.args.parallel_sampling, 1, 1)
              x_in = x_in.repeat(self.args.parallel_sampling, 1, 1)
          xt = x_in + noise
          xt.requires_grad = True

          xt = xt.reshape(-1)
          x_in = x_in.reshape(-1)

          steps = self.args.inference_diffusion_steps
          # Diffusion iterations
          cur_time = torch.ones((xt.shape[0],), device=xt.device)
          step = 1.0 / steps
          for i in range(steps):
            s = torch.full((xt.shape[0], ), step, device=xt.device)
            if i == steps-1:
              s = cur_time
            t1 = cur_time
            t2 = cur_time - s
            xt = self.gaussian_denoise_step(
                all_sub_points, xt, x_in, t1, device, all_edge_indices_converted, target_t=t2)
            cur_time = cur_time - s

          xt = xt.squeeze()
          sub_heatmaps = xt.cpu().detach().numpy() * 0.5 + 0.5

          splitted_sub_heatmaps = np.split(
              sub_heatmaps, 
              self.args.parallel_sampling * num_sub_graphs, 
              axis=0
            )

          for sub_i in range(len(batched_sub_points)):
            tmp = np.hstack(splitted_sub_heatmaps[sub_i::num_sub_graphs])
            resi = multip_cal_subheatmaps(self.args, points, tmp, batched_sub_points[sub_i].cpu().numpy(), batched_edge_indices[sub_i])
            self.best_sub_heatmaps.extend(resi[0])
            self.all_edge_ranks.extend(resi[1])
      else:
        # generate best sub-heatmaps
        num_sub_graphs = len(sub_points)
        all_sub_points = torch.vstack(sub_points)
        all_edge_indices = torch.stack(sub_edge_indices)
        all_sub_np_points = all_sub_points.cpu().numpy()
        all_np_edge_indices = all_edge_indices.cpu().numpy()

        all_edge_indices_converted = []
        for sub_edge_index in sub_edge_indices: 
          converted_sub_edge_index = self.convert_point_indices(sub_edge_index)
          all_edge_indices_converted.append(converted_sub_edge_index)

        all_edge_indices_converted = torch.hstack(all_edge_indices_converted).to(device)

        x_in = self.generate_x_in(
                  sub_points, 
                  all_edge_indices_converted, 
                  is_subgraph=True
                )
        noise = torch.randn_like(x_in.float())
        edge_index_indent = torch.arange(0, num_sub_graphs).to(device)
        edge_index_indent = edge_index_indent * self.args.sub_graph_size
        edge_index_indent = edge_index_indent.reshape(-1, 1).repeat((1, self.args.sparse_factor*sub_points[0].shape[0])).reshape(-1)
        all_edge_indices_converted = all_edge_indices_converted + edge_index_indent
        all_edge_indices_converted = all_edge_indices_converted.reshape((2, -1)).to(device)
        if self.args.parallel_sampling > 1:
          if self.sparse:
            all_sub_points = all_sub_points.repeat(self.args.parallel_sampling, 1)
            all_edge_indices_converted = self.duplicate_edge_index(all_edge_indices_converted, all_sub_np_points.shape[0], device)
            noise = noise.repeat(self.args.parallel_sampling, 1)
            x_in = x_in.repeat(self.args.parallel_sampling, 1)
          else:
            all_sub_points = all_sub_points.repeat(self.args.parallel_sampling, 1, 1)
            noise = noise.repeat(self.args.parallel_sampling, 1, 1)
            x_in = x_in.repeat(self.args.parallel_sampling, 1, 1)
        xt = x_in + noise
        xt.requires_grad = True

        xt = xt.reshape(-1)
        x_in = x_in.reshape(-1)

        steps = self.args.inference_diffusion_steps
        # Diffusion iterations
        cur_time = torch.ones((xt.shape[0],), device=xt.device)
        step = 1.0 / steps
        for i in range(steps):
          s = torch.full((xt.shape[0], ), step, device=xt.device)
          if i == steps-1:
            s = cur_time
          t1 = cur_time
          t2 = cur_time - s
          xt = self.gaussian_denoise_step(
              all_sub_points, xt, x_in, t1, device, all_edge_indices_converted, target_t=t2)
          cur_time = cur_time - s

        xt = xt.squeeze()
        sub_heatmaps = xt.cpu().detach().numpy() * 0.5 + 0.5

        splitted_sub_heatmaps = np.split(
            sub_heatmaps, 
            self.args.parallel_sampling * num_sub_graphs, 
            axis=0
          )

        for sub_i in range(len(sub_points)):
          tmp = np.hstack(splitted_sub_heatmaps[sub_i::num_sub_graphs])
          resi = multip_cal_subheatmaps(self.args, points, tmp, sub_points[sub_i].cpu().numpy(), sub_edge_indices[sub_i])
          self.best_sub_heatmaps.extend(resi[0])
    
    num_trials = self.args.num_trials
    all_sub_heatmaps_list = [self.best_sub_heatmaps[k:k+self.args.parallel_sampling] for k in range(0, len(self.best_sub_heatmaps), self.args.parallel_sampling)]

    num_subheatmaps = len(all_sub_heatmaps_list)
    idx_arr = np.random.randint(0, self.args.parallel_sampling, (num_trials, num_subheatmaps))

    tsp_solver = TSPEvaluator(np_points)
    gt_cost = tsp_solver.evaluate(np_gt_tour)

    heatmap_all = []
    best_heatmap = []
    for i in range(num_trials):
      subheatmaps_i  = []
      cnt_i = 0
      for line in all_sub_heatmaps_list:
        subheatmaps_i.append(line[idx_arr[i][cnt_i]])
        cnt_i += 1
      heatmap_i = self.Graph_Fusion(subheatmaps_i, Omegas, [], \
                                      node_num=points.shape[0], fusion_method=self.args.fusion_method)
      heatmap_all.append(heatmap_i)
    mean_heatmap = self.Graph_Fusion(self.best_sub_heatmaps, Omegas, [], \
                                      node_num=points.shape[0], fusion_method="mean")
    sum_heatmap = self.Graph_Fusion(self.best_sub_heatmaps, Omegas, [], \
                                      node_num=points.shape[0], fusion_method="sum")
    heatmap_all.append(mean_heatmap)
    heatmap_all.append(sum_heatmap)

    heatmap_all_con = np.concatenate(heatmap_all, axis=0)

    stacked_tours = []
    solved_tours = []
    tsp_solver = TSPEvaluator(np_points)
    gt_cost = tsp_solver.evaluate(np_gt_tour)
    tours, merge_iterations = merge_tours(
        heatmap_all_con, np_points, np_edge_index,
        sparse_graph=self.sparse,
        parallel_sampling=num_trials+2,
        ls_test=True
    )
    if (point_indicator < 10000).all():
      if self.args.multip_batchsize:
        tour_num = len(tours)
        start_idx = 0
        end_idx = 0
        idx_jump = self.args.multip_batchsize*2
        while end_idx < tour_num:
          end_idx = start_idx + idx_jump
          end_idx = end_idx if end_idx < tour_num else tour_num
          tmp_tour = tours[start_idx:end_idx]
          solved_tours, ns = batched_two_opt_torch(
              np_points.astype("float64"), np.array(tmp_tour).astype('int64'),
              max_iterations=self.args.two_opt_iterations, device=device)
          stacked_tours.append(solved_tours)
          start_idx += idx_jump
      else:
        solved_tours, ns = batched_two_opt_torch(
            np_points.astype("float64"), np.array(tours).astype('int64'),
            max_iterations=self.args.two_opt_iterations, device=device)
        stacked_tours.append(solved_tours)

      solved_tours = np.concatenate(stacked_tours, axis=0)

      total_sampling = num_trials+2
      all_solved_costs = [tsp_solver.evaluate(solved_tours[i]) for i in range(total_sampling)]
      best_heatmap_idx = np.argmin(all_solved_costs)
      best_heatmap = heatmap_all[best_heatmap_idx]

      best_solved_cost = np.min(all_solved_costs)
      gap = (best_solved_cost - gt_cost) / gt_cost * 100
    else:
      solved_tours = tours
      total_sampling = num_trials+2
      all_solved_costs = [tsp_solver.evaluate(solved_tours[i]) for i in range(total_sampling)]
      best_heatmap_idx = np.argmin(all_solved_costs)
      best_heatmap = heatmap_all[best_heatmap_idx]
      tmp_tours = [tours[best_heatmap_idx], tours[-1], tours[-2]]
      solved_tours, ns = batched_two_opt_torch(
          np_points.astype("float64"), np.array(tmp_tours).astype('int64'),
          max_iterations=self.args.two_opt_iterations, device=device)
      total_sampling = 3
      all_solved_costs = [tsp_solver.evaluate(solved_tours[i]) for i in range(total_sampling)]
      if best_heatmap_idx > 0:
        best_heatmap_idx = len(heatmap_all) + (best_heatmap_idx - total_sampling)
        best_heatmap = heatmap_all[best_heatmap_idx]
      best_solved_cost = np.min(all_solved_costs)
      gap = (best_solved_cost - gt_cost) / gt_cost * 100


    self.all_edge_ranks = []
    self.best_sub_heatmaps = []
    
    heatmap = best_heatmap

    # save heatmap
    if self.args.save_numpy_heatmap:
      self.run_save_numpy_heatmap(heatmap, np_points, real_batch_idx, split)


    metrics = {
        f"{split}/gt_cost": gt_cost,
        f"{split}/2opt_iterations": ns,
        f"{split}/merge_iterations": merge_iterations,
        f"{split}/solved_cost": best_solved_cost,
        f"{split}/gap(%)": gap, 
    }
    print(metrics)

    return metrics