import numpy as np
from torch import Tensor
from ml4co_kit import *
from meta_diffusion.model.decoder.base import MetaDiffDecoder


class ATSPDecoder(MetaDiffDecoder):
    def __init__(
        self, 
        decoding_type: str = "greedy", 
    ) -> None:
        super().__init__()
        self.decoding_type = decoding_type
    
    def _edge_dense_decode(
        self, heatmap: Tensor, x: Tensor, graph: Tensor, 
    ) -> np.ndarray:
        # tensor -> numpy array
        device = heatmap.device
        dists = to_numpy(graph)
        heatmap = to_numpy(heatmap)
        np.fill_diagonal(heatmap, 0)
        
        # decoding
        if self.decoding_type == "greedy":
            sol = atsp_greedy_decoder(-heatmap)
        else:
            raise NotImplementedError()
    
        sol = atsp_2opt_local_search(
            init_tours=sol, dists=dists
        )
        return sol