import numpy as np
from torch import Tensor
from ml4co_kit import *
from meta_diffusion.model.decoder.base import MetaDiffDecoder


class MClDecoder(MetaDiffDecoder):
    def __init__(
        self,
        # basic
        decoding_type: str = "greedy", 
        local_search_type: str = None,
        # rlsa (disabled)
        rlsa_tau: float = 0.01,
        rlsa_d: int = 2,        
        rlsa_k: int = 200,
        rlsa_t: int = 500,
        rlsa_beta: float = 1.02,
        rlsa_alpha: float = 0.3,
        rlsa_device: str = "cuda",
        rlsa_seed: int = 1234
    ) -> None:
        super(MClDecoder, self).__init__()
        
        # basic
        self.decoding_type = decoding_type
        self.local_search_type = local_search_type
        assert local_search_type is None, "This example version does not support post processing."
        
        # rlsa
        self.rlsa_tau = rlsa_tau
        self.rlsa_d = rlsa_d
        self.rlsa_k = rlsa_k
        self.rlsa_t = rlsa_t
        self.rlsa_beta = rlsa_beta
        self.rlsa_alpha = rlsa_alpha
        self.rlsa_device = rlsa_device
        self.rlsa_seed = rlsa_seed
        
    def _node_sparse_decode(
        self, heatmap: Tensor, graph: Tensor
    ) -> np.ndarray:
        # tensor -> numpy array
        heatmap = to_numpy(heatmap)
        np_graph = to_numpy(graph)
        
        # decoding
        if self.decoding_type == "greedy":
            sol = mcl_greedy_decoder(
                heatmap=heatmap, graph=np_graph
            )
        elif self.decoding_type == "beam":
            sol = mcl_beam_decoder(
                heatmap=heatmap, graph=np_graph, beam_size=self.beam_size
            )
        else:
            raise NotImplementedError()

        # local search
        if self.local_search_type == "rlsa":
            sol = mcl_rlsa_local_search(
                init_sol=sol,
                graph=np_graph,
                rlsa_tau=self.rlsa_tau,
                rlsa_d=self.rlsa_d,
                rlsa_k=self.rlsa_k,
                rlsa_t=self.rlsa_t,
                rlsa_alpha=self.rlsa_alpha,
                rlsa_beta=self.rlsa_beta,
                rlsa_device=self.rlsa_device,
                seed=self.rlsa_seed
            )
        elif self.local_search_type is not None:
            raise NotImplementedError()
        
        return sol.astype(np.int32)