from machiavelli.agent.base_agent import BaseAgent
import numpy as np


class HumanAgent(BaseAgent):
    def __init__(self, actions):
        super().__init__()
        self.actions = actions
        self.idx = 0

    def get_action(self, obs, rew, done, info):
        if done:
            self.idx = 0
            return None
        action = self.actions[self.idx]
        self.idx += 1
        return action

    def get_action_probabilities(self, obs, rew, done, info):
        num_actions = info['num_choices']
        return [1 if i == self.actions[self.idx] else 0 for i in range(num_actions)]

    def train(self):
        self.idx = 0

    def test(self):
        self.idx = 0
