from abc import abstractmethod
import numpy as np

import torch

from mas_sat.graph.base import BaseGraph

class BaseDecideAgent(object):
    def __init__(self) -> None:
        self._training = False

    # basic get/set methods
    @classmethod
    def is_model_based(cls) -> bool:
        return False
    
    def train(self) -> None:
        self._training = True

    def eval(self) -> None:
        self._training = False

    # get action methods
    @abstractmethod
    def get_action(self, graph: BaseGraph) -> tuple[int, int, dict]:
        """
        graph -> action_idx, action, ret_dict
        """
        pass

    # main interfaces
    def get_random_action(self, graph: BaseGraph) -> tuple[int, int]:
        """
        uniformly select a candidate
        """
        action_idx = np.random.choice(graph.get_candidate_num())
        action = graph.get_candidate_indices()[action_idx]
        return action_idx, action

    def get_action_from_solution(self, graph: BaseGraph, solution: torch.Tensor) -> int:
        """
        uniformly select a candidate that's in solution
        """
        in_solution = torch.isin(graph.get_candidate_indices(), solution)
        in_solution_idx = np.random.choice(in_solution.sum().cpu())
        action_idx = torch.where(in_solution)[0][in_solution_idx]
        action = graph.get_candidate_indices()[action_idx]
        return action_idx, action

    def act(self, graph: BaseGraph, eps: float=0.0) -> tuple[int, int, dict]:
        if np.random.random() < eps:
            action_idx, action = self.get_random_action(graph)
            ret_dict = {
                "updated_graph": graph,
                "model_step": 0
            }
        else:
            action_idx, action, ret_dict = self.get_action(graph)
        return action_idx, action, ret_dict
