import torch

from utils.snippets import convert_args_to_tensor

class ArgMaxPolicy(object):

    def __init__(self, critic, device='cpu'):
        self.critic = critic
        self.device = device

    @convert_args_to_tensor()
    def get_action(self, obs):

        # Make use of self.action by passing these input observations into self.critic
        with torch.no_grad():
            if len(obs.shape) > 1:
                observation = obs.to(self.device)
            else:
                observation = obs[None].to(self.device)
            # Define what action this policy should return
            q_t_values = self.critic.q_t_values(observation)
            return torch.argmax(q_t_values, dim=1).cpu().numpy(), q_t_values.cpu().numpy()

    @convert_args_to_tensor()
    def get_action_values(self, obs):

        # Make use of self.action by passing these input observations into self.critic
         
        with torch.no_grad():
            if len(obs.shape) > 1:
                observation = obs.to(self.device)
            else:
                observation = obs[None].to(self.device)
            # Define what action this policy should return
            return self.critic.q_t_values(observation).cpu().numpy()
