import torch

from utils.snippets import convert_args_to_tensor

class ArgMaxPolicy(object):

    def __init__(self, critic, device='cpu'):
        self.critic = critic
        self.device = device

    @convert_args_to_tensor()
    def get_action(self, obs):

        # Make use of self.action by passing these input observations into self.critic
         
        with torch.no_grad():
            if len(obs.shape) > 1:
                observation = obs.to(self.device)
            else:
                observation = obs[None].to(self.device)
            # Define what action this policy should return
            return torch.argmax(self.critic.q_t_values(observation), dim=1).cpu().numpy()

    @convert_args_to_tensor()
    def get_action_values(self, obs):

        # Make use of self.action by passing these input observations into self.critic
         
        with torch.no_grad():
            if len(obs.shape) > 1:
                observation = obs.to(self.device)
            else:
                observation = obs[None].to(self.device)
            # Define what action this policy should return
            return self.critic.q_t_values(observation).cpu().numpy()
