import torch
from ml4co_kit import ATSPSolver, iterative_execution, SOLVER_TYPE
from meta_diffusion.model import MetaDiffModel


class MetaDiffATSPSolver(ATSPSolver):
    def __init__(self, model: MetaDiffModel, seed: int = 1234):
        super(MetaDiffATSPSolver, self).__init__(solver_type=SOLVER_TYPE.ML4ATSP)
        self.model = model
        self.model.model.eval()
        self.model.env.mode = "solve"
        torch.manual_seed(seed=seed)
        
    def solve(
        self,
        batch_size: int = 1, 
        sampling_num: int = 1,
        show_time: bool = False
    ):
        # solve
        msg = f"Solving solutions using MetaDiffATSPSolver"
        samples_num = len(self.dists)
        solutions_list = list()
        for idx in iterative_execution(range, samples_num // batch_size, msg, show_time):
            # begin index and end index
            begin_idx = idx * batch_size
            end_idx = begin_idx + batch_size
            
            # sparser
            data = self.model.env.data_processor.atsp_batch_data_process(
                dists=self.dists[begin_idx:end_idx], 
                ref_tours=self.ref_tours[begin_idx:end_idx], 
                sampling_num=sampling_num
            )

            # gain determined variables
            if self.model.env.sparse:
                with torch.no_grad():
                    vars = self.model.inference_edge_sparse_process(*data)
                    solutions = self.model.decoder.sparse_decode(vars, *data)
            else:
                with torch.no_grad():
                    vars = self.model.inference_edge_dense_process(*data)
                    solutions = self.model.decoder.dense_decode(vars, *data)
                    
            # solution list
            solutions_list += solutions
        
        # restore solution
        self.from_data(tours=solutions_list, ref=False)
        
        return self.tours