import numpy as np
from torch import Tensor
from ml4co_kit import *
from meta_diffusion.model.decoder.base import MetaDiffDecoder


class MCutDecoder(MetaDiffDecoder):
    def __init__(
        self, 
        # basic
        decoding_type: str = "greedy", 
        local_search_type: str = None,
        # rlsa (disabled)
        rlsa_tau: float = 1.5,
        rlsa_d: int = 20,        
        rlsa_k: int = 200,
        rlsa_t: int = 500,
        rlsa_alpha: float = 0.3,
        rlsa_beta: float = None,
        rlsa_device: str = "cuda",
        rlsa_seed: int = 1234
    ) -> None:
        super(MCutDecoder, self).__init__()
        
        # basic
        self.decoding_type = decoding_type
        self.local_search_type = local_search_type
        assert local_search_type is None, "This example version does not support post processing."

        # rlsa
        self.rlsa_tau = rlsa_tau
        self.rlsa_d = rlsa_d
        self.rlsa_k = rlsa_k
        self.rlsa_t = rlsa_t
        self.rlsa_alpha = rlsa_alpha
        self.rlsa_device = rlsa_device
        self.rlsa_seed = rlsa_seed
        
    def _node_sparse_decode(
        self, heatmap: Tensor, graph: Tensor, edge_index: Tensor
    ) -> np.ndarray:
        # tensor -> numpy array
        heatmap = to_numpy(heatmap)
        
        # decoding   
        if self.decoding_type == "greedy":
            sol: np.ndarray = (heatmap > 0.5)
            sol = sol.astype(np.int32)
        else:
            raise NotImplementedError()

        # local search
        if self.local_search_type == "rlsa":
            sol = mcut_rlsa_local_search(
                init_sol=sol,
                graph=to_numpy(graph),
                edge_index=to_numpy(edge_index),
                rlsa_tau=self.rlsa_tau,
                rlsa_d=self.rlsa_d,
                rlsa_k=self.rlsa_k,
                rlsa_t=self.rlsa_t,
                # rlsa_alpha=self.rlsa_alpha,
                rlsa_device=self.rlsa_device,
                seed=self.rlsa_seed
            )
        elif self.local_search_type is not None:
            raise NotImplementedError()
        
        return sol.astype(np.int32)
