"""Locomotion Reward."""
from erl_lib.envs.reward import StateActionReward


class LocomotionReward(StateActionReward):
    r"""A locomotion reward model is used for locomotion robots."""

    def __init__(
        self,
        ctrl_cost_weight=1.0,
        action_scale=1.0,
        sparse=False,
        forward_reward_weight=1.0,
        healthy_reward=0.0,
        health_checker=None,
        **_,
    ):
        super().__init__(
            ctrl_cost_weight=ctrl_cost_weight, sparse=sparse, action_scale=action_scale
        )
        self.forward_reward_weight = forward_reward_weight
        self.healthy_reward = healthy_reward
        self.health_checker = health_checker

    def state_reward(self, state, next_state, log=False):
        """Get reward that corresponds to the states."""
        forward_reward = next_state[..., :1]
        reward = self.forward_reward_weight * forward_reward
        if log:
            forward_mean = forward_reward.mean().item()
            self._info.update()

        if self.health_checker:
            reward += self.healthy_reward * self.health_checker.check(next_state)
        return reward
