# import torch

# import utils.fileio
# import torch_ac
# from utils.reprod_setup import device
# from src.model import ACModel
# import os


# class Agent:
#     """An agent.

#     It is able:
#     - to choose an action given an observation,"""

#     def __init__(self, obs_space, action_space, model=None,
#                  argmax=False, use_memory=False, use_text=False):

#         self.argmax = argmax
#         if isinstance(model, str):
#             obs_space = torch_ac.format.default_preprocess_obss(obs_space)
#             self.acmodel = ACModel(obs_space, action_space)
            
#             self.acmodel.load_state_dict(utils.fileio.load_from_disk(model, device)['model_state'])
#             self.acmodel.to(device)
#         else:
#             self.acmodel = model

#         self.acmodel.eval()

#     def get_actions(self, obss):
#         preprocessed_obss = torch_ac.format.default_preprocess_obss(obss)

#         with torch.no_grad():
#                 dist, _ = self.acmodel(preprocessed_obss.to(device))

#         if self.argmax:
#             actions = dist.probs.max(1, keepdim=True)[1]
#         else:
#             actions = dist.sample()

#         return actions.cpu().numpy()

#     def get_action(self, obs):
#         return self.get_actions([obs])[0]
