import numpy as np
import torch

from ._body import Body


class ClipRewards(Body):
    def process_state(self, state, should_stack=False):
        return state.update("reward", self._clip(state.reward))

    def _clip(self, reward):
        if torch.is_tensor(reward):
            return torch.sign(reward)
        return float(np.sign(reward))
