from .simple_replay_pool import SimpleReplayPool


class ExtraPolicyInfoReplayPool(SimpleReplayPool):
    def __init__(self, *args, **kwargs):
        super(ExtraPolicyInfoReplayPool, self).__init__(*args, **kwargs)

        fields = {
            # 'raw_actions': {
            #     'shape': self._action_space.shape,
            #     'dtype': 'float32'
            # },
            'log_pis': {
                'shape': (1, ),
                'dtype': 'float32'
            }
        }

        self.add_fields(fields)
