from rlf.il import TrajDataset


class GoalTrajDataset(TrajDataset):
    def _setup(self, trajs):
        if 'ep_found_goal' in trajs:
            self.found_goal = trajs['ep_found_goal'].float()
        else:
            self.found_goal = None

    def should_terminate_traj(self, j, obs, next_obs, done, actions):
        if self.found_goal is None:
            return done[j]
        else:
            return self.found_goal[j] == 1.0
