
class State(object):
    # every node has aa state with the current replay schedule and selected actions

    def __init__(self, n_tasks, task, action_space, actions=[], replay_schedule=[]):
        if task > n_tasks:
            raise ValueError('task cannot be greater than number of tasks!')
        self.n_tasks = n_tasks
        self.task = task
        self.action_space = action_space
        self.replay_schedule = replay_schedule
        self.actions = actions

    def move(self, action, action_index=None):
        next_rs = self.replay_schedule.copy()
        next_rs.append(action)
        actions = self.actions.copy()
        if action_index is not None:
            actions.append(action_index)
        return State(self.n_tasks, 
                    self.task+1, 
                    self.action_space, 
                    actions=actions, 
                    replay_schedule=next_rs)

    def get_legal_actions(self):
        actions = self.action_space.get_actions_at_task(self.task)
        return actions

    def get_random_action(self):
        random_action = self.action_space.generate_random_action(self.task)
        return random_action

    def is_terminal_state(self):
        if self.task < self.n_tasks: 
            return False
        return True 

    def get_replay_schedule(self):
        return self.replay_schedule.copy(), self.actions.copy()
