import jax
import jax.numpy as jnp
from functools import partial
from gymnax.environments import spaces
from brax.envs.wrappers.training import EpisodeWrapper, AutoResetWrapper
from flax import struct
from brax.envs.base import State

from .hopper_random import HopperRandom


@struct.dataclass
class EnvState:
    state: State
    g: float
    h: float
    # z: float   # 新增：扩展状态 z，-1 或 1


@struct.dataclass
class EnvParams:
    torque_limit: float = 0.2
    max_torque: float = 1.0


class HopperAvoidCeiling:
    def __init__(self, backend="positional"):
        env = HopperRandom(backend=backend,
                           exclude_current_positions_from_observation=False,
                           terminate_when_unhealthy=False)
        env = EpisodeWrapper(env, episode_length=1000, action_repeat=2)
        env = AutoResetWrapper(env)
        self._env = env
        self.action_size = env.action_size
        self.observation_size = (env.observation_size,)
        self.default_params = EnvParams()

    @partial(jax.jit, static_argnums=(0,))
    def reset(self, key, params=None):
        state = self._env.reset(key)
        head_pos, _, _, _, _, _ = self.calculate_position(state.obs)
        h_value = jnp.where(self.is_avoid(head_pos), 300.0, -300.0)
        g_value = self.calculate_g(head_pos)
        unsafe0 = h_value >= 0.0
        # z_value = jnp.where(unsafe0, 1.0, -1.0)
        # env_state = EnvState(state, g_value, h_value, z_value)
        env_state = EnvState(state, g_value, h_value)
        return state.obs, env_state
        # obs = jnp.concatenate([state.obs, jnp.array([z_value])])
        # return obs, env_state

    @partial(jax.jit, static_argnums=(0,))
    def step(self, key, state, action, params=None):
        # u = jnp.tanh(action)

        u = jnp.tanh(action)

        # 每个维度独立的相对误差：[-0.1, 0.1]
        key, noise_key = jax.random.split(key)
        eps = jax.random.uniform(noise_key, shape=u.shape, minval=-0.1, maxval=0.1)
        u = u * (1.0 + eps)

        # 保证动作仍在 [-1, 1]（很重要，避免噪声把动作顶出界）
        u = jnp.clip(u, -1.0, 1.0)

        reach_limit_0 = jnp.fabs(u[0] * state.state.obs[-3] / 2.) > params.torque_limit
        energy_consumption_0 = jnp.where(reach_limit_0, (jnp.fabs(u[0] * state.state.obs[-3] / 2.) ** 2) * 0.6, 0.)
        reach_limit_1 = jnp.fabs(u[1] * state.state.obs[-2] / 2.) > params.torque_limit
        energy_consumption_1 = jnp.where(reach_limit_1, (jnp.fabs(u[1] * state.state.obs[-2] / 2.) ** 2) * 0.6, 0.)
        reach_limit_2 = jnp.fabs(u[2] * state.state.obs[-1] / 2.) > params.torque_limit
        energy_consumption_2 = jnp.where(reach_limit_2, (jnp.fabs(u[2] * state.state.obs[-1] / 2.) ** 2) * 0.6, 0.)
        energy_consumption = energy_consumption_0 + energy_consumption_1 + energy_consumption_2
        next_state = self._env.step(state.state, u)
        head_pos, _, _, _, _, _ = self.calculate_position(next_state.obs)
        h_value = jnp.where(self.is_avoid(head_pos), 300.0, -300.0)
        g_value = self.calculate_g(head_pos)

        # head_pos, _, _, _, _, _ = self.calculate_position(next_state.obs)

        # unsafe_now = self.is_avoid(head_pos)
        # z_prev = state.z
        # # 一旦进入 unsafe（unsafe_now 为 True），z 置 1，否则保持原值
        # z_new = jnp.where(unsafe_now, 1.0, z_prev)

        # # 只要 z = -1，就令 h = 300，否则 h = -300
        # h_value = jnp.where(z_new == 1.0, 300.0, -300.0)

        # g_value = self.calculate_g(head_pos)

        head_pos, jaw_pos, thg_pos, leg_pos, foot_front_pos, foot_back_pos = self.calculate_position(state.state.obs)
        pos_dict = {"head_pos": head_pos, "jaw_pos": jaw_pos, "thg_pos": thg_pos, "leg_pos": leg_pos,
                    "foot_front_pos": foot_front_pos, "foot_back_pos": foot_back_pos}
        next_state_new = EnvState(next_state, g_value, h_value)
        # next_state_new = EnvState(next_state, g_value, h_value, z_new)

        return next_state.obs, next_state_new, energy_consumption, next_state.done > 0.5, pos_dict
        # next_state_new = EnvState(next_state, g_value, h_value, z_new)

        obs = jnp.concatenate([next_state.obs, jnp.array([z_new])])
        return obs, next_state_new, energy_consumption, next_state.done > 0.5, pos_dict

    @partial(jax.jit, static_argnums=(0,))
    def calculate_position(self, obs):
        head_pos = jnp.array([obs[0] + 0.2 * jnp.sin(obs[2]),
                              obs[1] + 0.2 * jnp.cos(obs[2])])
        jaw_pos = jnp.array([obs[0] - 0.2 * jnp.sin(obs[2]),
                             obs[1] - 0.2 * jnp.cos(obs[2])])
        thg_pos = jnp.array([jaw_pos[0] - 0.45 * jnp.sin(obs[2] - obs[3]),
                             jaw_pos[1] - 0.45 * jnp.cos(obs[2] - obs[3])])
        leg_pos = jnp.array([thg_pos[0] - 0.5 * jnp.sin(obs[2] - obs[3] - obs[4]),
                             thg_pos[1] - 0.5 * jnp.cos(obs[2] - obs[3] - obs[4])])
        foot_back_pos = jnp.array([leg_pos[0] - 0.13 * jnp.cos(obs[2] - obs[3] - obs[4] - obs[5]),
                                    leg_pos[1] + 0.13 * jnp.sin(obs[2] - obs[3] - obs[4] - obs[5])])
        foot_front_pos = jnp.array([leg_pos[0] + 0.26 * jnp.cos(obs[2] - obs[3] - obs[4] - obs[5]),
                                   leg_pos[1] - 0.26 * jnp.sin(obs[2] - obs[3] - obs[4] - obs[5])])
        return head_pos, jaw_pos, thg_pos, leg_pos, foot_front_pos, foot_back_pos

    @partial(jax.jit, static_argnums=(0,))
    def calculate_g(self, head_pos):
        reach = jnp.sqrt((head_pos[0] - 2.0) ** 2 + (head_pos[1] - 1.4) ** 2) - 0.1
        has_reached_goal = jnp.sqrt((head_pos[0] - 2.0) ** 2 + (head_pos[1] - 1.4) ** 2) < 0.1
        value = jnp.where(has_reached_goal, -3.0, reach)
        return value * 100.0

    @partial(jax.jit, static_argnums=(0,))
    def is_avoid(self, head_pos):
        avoid_1 = (head_pos[1] >= 1.3) & (head_pos[0] >= 0.95) & (head_pos[0] <= 1.05)
        return avoid_1

    def observation_space(self, params):
        return spaces.Box(
            low=-jnp.inf,
            high=jnp.inf,
            shape=(self._env.observation_size,),
        )

    def action_space(self, params):
        return spaces.Box(
            low=-1.0,
            high=1.0,
            shape=(self._env.action_size,),
        )

# class HopperAvoidCeilingDeterministic:
#     def __init__(self, backend="positional"):
#         env = HopperDeterministic(backend=backend,
#                            exclude_current_positions_from_observation=False,
#                            terminate_when_unhealthy=False)
#         env = EpisodeWrapper(env, episode_length=1000, action_repeat=2)
#         env = AutoResetWrapper(env)
#         self._env = env
#         self.action_size = env.action_size
#         self.observation_size = (env.observation_size,)
#         self.default_params = EnvParams()

#     @partial(jax.jit, static_argnums=(0,))
#     def reset(self, key, params=None):
#         state = self._env.reset(key)
#         # head_pos, _, _, _, _, _ = self.calculate_position(state.obs)
#         # h_value = jnp.where(self.is_avoid(head_pos), 300.0, -300.0)
#         # g_value = self.calculate_g(head_pos)
#         # env_state = EnvState(state, g_value, h_value)
#         # return state.obs, env_state
#         head_pos, _, _, _, _, _ = self.calculate_position(state.obs)
#         unsafe0 = self.is_avoid(head_pos)
#         z_value = jnp.where(unsafe0, 1.0, -1.0)
#         # 只要 z = -1，就令 h = 300，否则 h = -300
#         h_value = jnp.where(z_value == 1.0, 300.0, -300.0)

#         g_value = self.calculate_g(head_pos)
#         env_state = EnvState(state, g_value, h_value, z_value)

#         obs = jnp.concatenate([state.obs, jnp.array([z_value])])
#         return obs, env_state

#     @partial(jax.jit, static_argnums=(0,))
#     def step(self, key, state, action, params=None):
#         u = jnp.tanh(action)
#         reach_limit_0 = jnp.fabs(u[0] * state.state.obs[-3] / 2.) > params.torque_limit
#         energy_consumption_0 = jnp.where(reach_limit_0, (jnp.fabs(u[0] * state.state.obs[-3] / 2.) ** 2) * 0.3, 0.)
#         reach_limit_1 = jnp.fabs(u[1] * state.state.obs[-2] / 2.) > params.torque_limit
#         energy_consumption_1 = jnp.where(reach_limit_1, (jnp.fabs(u[1] * state.state.obs[-2] / 2.) ** 2) * 0.3, 0.)
#         reach_limit_2 = jnp.fabs(u[2] * state.state.obs[-1] / 2.) > params.torque_limit
#         energy_consumption_2 = jnp.where(reach_limit_2, (jnp.fabs(u[2] * state.state.obs[-1] / 2.) ** 2) * 0.3, 0.)
#         energy_consumption = energy_consumption_0 + energy_consumption_1 + energy_consumption_2
#         next_state = self._env.step(state.state, u)
#         head_pos, _, _, _, _, _ = self.calculate_position(next_state.obs)
#         h_value = jnp.where(self.is_avoid(head_pos), 300.0, -300.0)
#         g_value = self.calculate_g(head_pos)
#         head_pos, jaw_pos, thg_pos, leg_pos, foot_front_pos, foot_back_pos = self.calculate_position(state.state.obs)
#         pos_dict = {"head_pos": head_pos, "jaw_pos": jaw_pos, "thg_pos": thg_pos, "leg_pos": leg_pos,
#                     "foot_front_pos": foot_front_pos, "foot_back_pos": foot_back_pos}
#         next_state_new = EnvState(next_state, g_value, h_value)

#         return next_state.obs, next_state_new, energy_consumption, next_state.done > 0.5, pos_dict

#     @partial(jax.jit, static_argnums=(0,))
#     def calculate_position(self, obs):
#         head_pos = jnp.array([obs[0] + 0.2 * jnp.sin(obs[2]),
#                               obs[1] + 0.2 * jnp.cos(obs[2])])
#         jaw_pos = jnp.array([obs[0] - 0.2 * jnp.sin(obs[2]),
#                              obs[1] - 0.2 * jnp.cos(obs[2])])
#         thg_pos = jnp.array([jaw_pos[0] - 0.45 * jnp.sin(obs[2] - obs[3]),
#                              jaw_pos[1] - 0.45 * jnp.cos(obs[2] - obs[3])])
#         leg_pos = jnp.array([thg_pos[0] - 0.5 * jnp.sin(obs[2] - obs[3] - obs[4]),
#                              thg_pos[1] - 0.5 * jnp.cos(obs[2] - obs[3] - obs[4])])
#         foot_back_pos = jnp.array([leg_pos[0] - 0.13 * jnp.cos(obs[2] - obs[3] - obs[4] - obs[5]),
#                                     leg_pos[1] + 0.13 * jnp.sin(obs[2] - obs[3] - obs[4] - obs[5])])
#         foot_front_pos = jnp.array([leg_pos[0] + 0.26 * jnp.cos(obs[2] - obs[3] - obs[4] - obs[5]),
#                                    leg_pos[1] - 0.26 * jnp.sin(obs[2] - obs[3] - obs[4] - obs[5])])
#         return head_pos, jaw_pos, thg_pos, leg_pos, foot_front_pos, foot_back_pos

#     @partial(jax.jit, static_argnums=(0,))
#     def calculate_g(self, head_pos):
#         reach = jnp.sqrt((head_pos[0] - 2.0) ** 2 + (head_pos[1] - 1.4) ** 2) - 0.1
#         has_reached_goal = jnp.sqrt((head_pos[0] - 2.0) ** 2 + (head_pos[1] - 1.4) ** 2) < 0.1
#         value = jnp.where(has_reached_goal, -2.5, reach)
#         return value * 100.0

#     @partial(jax.jit, static_argnums=(0,))
#     def is_avoid(self, head_pos):
#         avoid_1 = (head_pos[1] >= 1.3) & (head_pos[0] >= 0.95) & (head_pos[0] <= 1.05)
#         return avoid_1
        
#     @partial(jax.jit, static_argnums=(0,))
#     def cross_product(self, array_1, array_2):
#         return array_1[0] * array_2[1] - array_1[1] * array_2[0]

#     def observation_space(self, params):
#         return spaces.Box(
#             low=-jnp.inf,
#             high=jnp.inf,
#             shape=(self._env.observation_size,),
#         )

#     def action_space(self, params):
#         return spaces.Box(
#             low=-1.0,
#             high=1.0,
#             shape=(self._env.action_size,),
#         )

    
  
