import jax
import jax.numpy as jnp

from brax.envs.base import State
from brax.envs.half_cheetah import Halfcheetah

# class HalfCheetahRandom(Halfcheetah):

#     def reset(self, rng: jnp.ndarray) -> State:
#         """Resets the environment to an initial state."""
#         rng, rng1, rng2, rng3 = jax.random.split(rng, 4)

#         low, hi = -self._reset_noise_scale, self._reset_noise_scale
#         qpos = self.sys.init_q + jax.random.uniform(
#             rng1, (self.sys.q_size(),), minval=low, maxval=hi
#         )
#         qpos = qpos.at[0].set(qpos[0] + jax.random.uniform(rng3, minval=0., maxval=4.0))
#         qvel = hi * jax.random.normal(rng2, (self.sys.qd_size(),))

#         pipeline_state = self.pipeline_init(qpos, qvel)

#         obs = self._get_obs(pipeline_state)
#         reward, done, zero = jnp.zeros(3)
#         metrics = {
#             'x_position': zero,
#             'x_velocity': zero,
#             'reward_ctrl': zero,
#             'reward_run': zero,
#         }
#         return State(pipeline_state, obs, reward, done, metrics)


class HalfCheetahRandom(Halfcheetah):

    def reset(self, rng: jnp.ndarray) -> State:
        """Resets the environment to an initial state.

        为了更自然地学会抬起前后腿越过障碍，我们在初始化时：
        - 大部分 episode：把 root x 初始化在靠近障碍的 [2.0, 2.6] 区间，
          并给一个较小的向前初速度（0.3 ~ 1.0 m/s），让 agent 在坑附近有
          多个决策 step 可以调整前后腿高度；
        - 少部分 episode：仍然用原始的 [0.0, 4.0] 随机起步，保持多样性。
        """
        # 比原版多切几个 key 用来做分支和采样
        rng, rng1, rng2, rng3, rng4, rng5, rng6, rng7, rng8 = jax.random.split(rng, 9)

        low, hi = -self._reset_noise_scale, self._reset_noise_scale

        # 1) 原始的随机姿态 + 速度（保持自然性）
        qpos = self.sys.init_q + jax.random.uniform(
            rng1, (self.sys.q_size(),), minval=low, maxval=hi
        )
        qvel = hi * jax.random.normal(rng2, (self.sys.qd_size(),))

        # 2) 以较高概率在障碍附近 reset
        # 比如 80% 的 episode 从 [2.0, 2.6] 开始，20% 从 [0, 4] 开始
        use_near_obstacle = jax.random.bernoulli(rng3, p=0.6)

        # “靠近障碍”的两段区间：
        #   - [2.0, 2.3]：更像前腿将跨到坑边
        #   - [2.3, 2.6]：更像前腿已跨过、后腿将到坑边
        phase_front = jax.random.bernoulli(rng4, p=0.5)

        offset_front = jax.random.uniform(rng5, minval=2.0, maxval=2.3)
        offset_back  = jax.random.uniform(rng6, minval=2.3, maxval=2.6)

        # 远处起步：保持原环境 [0, 4] 的分布
        offset_far = jax.random.uniform(rng7, minval=0.0, maxval=4.0)

        # 给“靠近障碍”的 episode 一个较小的向前初速度（0.3 ~ 1.0 m/s）
        # 注意不要太大，否则一步就跨过坑了，来不及在坑上方学习抬腿
        init_vx = jax.random.uniform(rng8, minval=0.3, maxval=1.0)

        def set_near(carry):
            qpos_, qvel_ = carry

            # 在两个子区间之间做一个简单的“前/后相位”切换
            offset = jax.lax.cond(
                phase_front,
                lambda _: offset_front,
                lambda _: offset_back,
                operand=None,
            )

            # 在原始 init_q + noise 的基础上，把 root x 平移到障碍附近
            qpos_ = qpos_.at[0].set(qpos_[0] + offset)
            # 给一个向前的初速度：更像正在跑向坑
            qvel_ = qvel_.at[0].set(qvel_[0] + init_vx)
            return qpos_, qvel_

        def set_far(carry):
            qpos_, qvel_ = carry
            # 保持原始逻辑：在更大范围 [0, 4] 上随机起步
            qpos_ = qpos_.at[0].set(qpos_[0] + offset_far)
            return qpos_, qvel_

        # 3) 根据 use_near_obstacle 决定这一次是“障碍附近起步”还是“远处起步”
        qpos, qvel = jax.lax.cond(
            use_near_obstacle,
            set_near,
            set_far,
            (qpos, qvel),
        )

        # 4) 后面保持和原始 Halfcheetah.reset 一致
        pipeline_state = self.pipeline_init(qpos, qvel)

        obs = self._get_obs(pipeline_state)
        reward, done, zero = jnp.zeros(3)
        metrics = {
            'x_position': zero,
            'x_velocity': zero,
            'reward_ctrl': zero,
            'reward_run': zero,
        }
        return State(pipeline_state, obs, reward, done, metrics) 

