import numpy as np
from torch import Tensor
from ml4co_kit import TSPEvaluator
from ml4co_kit import *
from meta_diffusion.model.decoder.base import MetaDiffDecoder

class TSPDecoder(MetaDiffDecoder):
    def __init__(
        self, 
        decoding_type: str = "greedy", 
        mcts_time_limit: float = 0.0, # disabled
    ) -> None:
        super(TSPDecoder, self).__init__()
        self.decoding_type = decoding_type
        self.mcts_time_limit = mcts_time_limit
        assert mcts_time_limit == 0.0, "this example version does not support MCTS."
    
    def _edge_sparse_decode(
        self, heatmap: Tensor, x: Tensor, edge_index: Tensor, nodes_num: int
    ) -> np.ndarray:
        # tensor -> numpy array
        heatmap = to_numpy(heatmap)
        x = to_numpy(x)
        edge_index = to_numpy(edge_index)

        # heatmap: sparse -> dense
        heatmap = np_sparse_to_dense(
            nodes_num=nodes_num, edge_index=edge_index, edge_attr=heatmap
        )
        heatmap = (heatmap + heatmap.T) / 2
        heatmap = np.clip(heatmap, a_min=1e-14, a_max=1-1e-14)
        
        sol = tsp_greedy_decoder(heatmap)    
        sol = tsp_mcts_local_search(
                init_tours=sol,
                heatmap=heatmap,
                points=x,
                time_limit=self.mcts_time_limit,
                type_2opt=2
        )
        return sol

    def _edge_dense_decode(
        self, heatmap: Tensor, x: Tensor, graph: Tensor, 
    ) -> np.ndarray:
        # tensor -> numpy array
        x = to_numpy(x)
        heatmap = to_numpy(heatmap)
        heatmap = (heatmap + heatmap.T) / 2
        heatmap = np.clip(heatmap, a_min=1e-14, a_max=1-1e-14)
        np.fill_diagonal(heatmap, 0)
        
        sol = tsp_greedy_decoder(heatmap)
        sol = tsp_mcts_local_search(
                init_tours=sol,
                heatmap=heatmap,
                points=x,
                time_limit=self.mcts_time_limit,
                type_2opt=2
        )

        return sol