import numpy as np

import torch
from torch import Tensor
from torch.distributions import Categorical

from mas_sat.agent.decide.base import BaseDecideAgent
from mas_sat.graph.base import BaseGraph

class ModelDecideAgent(BaseDecideAgent):
    """
    A decide agent that relies on a model
    """
    def __init__(self, model, args) -> None:
        super().__init__()
        self._model = model
        if args.learner in ["reinforce", "multi"]:
            self._stochastic = True
        else:
            self._stochastic = False

    # basic get/set methods
    @classmethod
    def is_model_based(cls) -> bool:
        return True

    # get action methods
    def get_action_from_heurisic(self, graph: BaseGraph, heuristic: Tensor) -> tuple[int, dict]:
        if self._training and self._stochastic:
            dist = Categorical(heuristic.softmax(-1))
            action_idx = dist.sample()
            ret_dict = {"log_prob": dist.log_prob(action_idx)}
        else:
            action_idx = torch.argmax(heuristic)
            ret_dict = {}
        action = graph.get_candidate_indices()[action_idx]
        return action_idx.item(), action.item(), ret_dict
    
    def get_action(self, graph: BaseGraph) -> tuple[int, int, dict]:
        ret_dict = self._model(graph)
        if ret_dict["solution"] is not None:
            action_idx, action = self.get_action_from_solution(graph, ret_dict["solution"])
            return action_idx, action, ret_dict
        action_idx, action, ret_dict_ = self.get_action_from_heurisic(graph, ret_dict["heuristic"])
        ret_dict.update(ret_dict_)
        return action_idx, action, ret_dict
