"""
 Copyright 2021 - 2024 brax authors.
 Copyright 2024 Anonymous Authors.

 Licensed under the Apache License, Version 2.0 (the "License"); you may not
 use this file except in compliance with the License. You may obtain a copy of
 the License at

 http://www.apache.org/licenses/LICENSE-2.0

 Unless required by applicable law or agreed to in writing, software
 distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
 WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
 License for the specific language governing permissions and limitations under
 the License.
"""


from brax.v1 import jumpy as jp
from brax.v1 import math
from brax.v1.envs import env


def ablated_step(self, state: env.State, action: jp.ndarray) -> env.State:
    """Run one timestep of the environment's dynamics."""
    qp, _ = self.sys.step(state.qp, action)


    _ctrl_cost_weight = 1e-1
    x_velocity = (qp.pos[0, 0] - state.qp.pos[0, 0]) / self.sys.config.dt
    #forward_reward = self._forward_reward_weight * x_velocity   # old reward
    forward_reward = jp.where(qp.pos[0, 0] >= 1.3, 2., 0.)   # new reward

    min_z, max_z = self._healthy_z_range
    min_angle, max_angle = self._healthy_angle_range
    ang_y = math.quat_to_euler(qp.rot[0])[1]
    is_healthy = jp.where(qp.pos[0, 2] < min_z, x=0.0, y=1.0)  # pytype: disable=wrong-arg-types  # jax-ndarray
    is_healthy = jp.where(qp.pos[0, 2] > max_z, x=0.0, y=is_healthy)  # pytype: disable=wrong-arg-types  # jax-ndarray
    is_healthy = jp.where(ang_y > max_angle, x=0.0, y=is_healthy)  # pytype: disable=wrong-arg-types  # jax-ndarray
    is_healthy = jp.where(ang_y < min_angle, x=0.0, y=is_healthy)  # pytype: disable=wrong-arg-types  # jax-ndarray
    if self._terminate_when_unhealthy:
      healthy_reward = self._healthy_reward
    else:
      healthy_reward = self._healthy_reward * is_healthy

    ctrl_cost = _ctrl_cost_weight * jp.sum(jp.square(action))

    obs = self._get_obs(qp)
    reward = forward_reward - ctrl_cost  # No healthy reward
    done = 1.0 - is_healthy if self._terminate_when_unhealthy else 0.0
    state.metrics.update(
        reward_forward=forward_reward,
        reward_ctrl=-ctrl_cost,
        reward_healthy=healthy_reward,
        x_position=qp.pos[0, 0],
        x_velocity=x_velocity)

    return state.replace(qp=qp, obs=obs, reward=reward, done=done)
