import torch
import numpy as np
from ml4co_kit import MISSolver, iterative_execution, SOLVER_TYPE
from meta_diffusion.model import MetaDiffModel


class MetaDiffMISSolver(MISSolver):
    def __init__(self, model: MetaDiffModel, seed: int = 1234):
        super(MetaDiffMISSolver, self).__init__(solver_type=SOLVER_TYPE.ML4MIS)
        self.model = model
        self.model.model.eval()
        self.model.env.mode = "solve"
        torch.manual_seed(seed=seed)
        
    def solve(
        self,
        batch_size: int = 1, 
        show_time: bool = False
    ):
        # solve
        msg = f"Solving solutions using MetaDiffMISSolver"
        samples_num = len(self.graph_data)
        for idx in iterative_execution(range, samples_num // batch_size, msg, show_time):
            # begin index and end index
            begin_idx = idx * batch_size
            end_idx = begin_idx + batch_size
            
            # sparser
            data = self.model.env.data_processor.mis_batch_data_process(
                graph_data=self.graph_data[begin_idx:end_idx]
            )
            
            # gain determined variables
            if self.model.env.sparse:
                with torch.no_grad():
                    vars = self.model.inference_node_sparse_process(*data)
                    solutions = self.model.decoder.sparse_decode(vars, *data)
            else:
                with torch.no_grad():
                    vars = self.model.inference_node_dense_process(*data)
                    solutions = self.model.decoder.dense_decode(vars, *data)

            # best solution
            for _idx in range(batch_size):
                current_solutions = solutions[_idx : (_idx+1)]
                sel_nodes_num_list = [(current_solutions[0]).sum()]
                best_idx = np.argmax(np.array(sel_nodes_num_list))
                self.graph_data[_idx+begin_idx].nodes_label = current_solutions[best_idx]