from pathlib import Path
import mujoco
from mujoco import mjx
# from jax.scipy.spatial.transform import Rotation
import jax
import jax.numpy as jnp
from typing import Tuple

from irl_baselines.environments.walker2d_mjx.state import State


class Walker2D:
    def __init__(self, horizon=1000):
        self.horizon = horizon

        xml_path = (Path(__file__).resolve().parent / "data" / "walker2d_v5.xml").as_posix()
        mj_model = mujoco.MjModel.from_xml_path(xml_path)
        mj_model.opt.solver = mujoco.mjtSolver.mjSOL_NEWTON
        self.model = mj_model
        self.data = mujoco.MjData(mj_model)
        self.sys = mjx.put_model(mj_model)

        self.nr_intermediate_steps = 1

        initial_qpos = [0.0, 1.25, 0.0] + [0.0] * (self.sys.nq - 3)

        self.initial_qpos = jnp.array(initial_qpos)
        self.initial_qvel = jnp.zeros(self.sys.nv)

        self.forward_reward_weight = 1.0
        self.healthy_z_range: Tuple[float, float] = (0.8, 2.0)
        self.healthy_angle_range = (-1.0, 1.0)
        self.terminate_when_unhealthy = True
        self.ctrl_cost_weight: float = 1e-3
        self.healthy_reward = 1.0

    def reset(self, key):
        key, subkey = jax.random.split(key)

        data = mjx.put_data(self.model, self.data)
        data = data.replace(qpos=self.initial_qpos, qvel=self.initial_qvel, ctrl=jnp.zeros(self.sys.nu))
        data = mjx.forward(self.sys, data)

        observation = self.get_observation(data)
        reward = 0.0
        terminated = False
        truncated = False
        logging_info = {
            "episode_return": reward,
            "episode_length": 0,
            "local_vel_x": 0.0,
            "is_healthy": 1.0,
            "ctrl_cost": 0.0,
        }
        info = {
            **logging_info,
            "final_observation": jnp.zeros_like(observation),
            "final_info": {**logging_info},
            "done": False,
            "key": subkey
        }

        return State(data, observation, reward, terminated, truncated, info)


    def step(self, state, action):
        data, _ = jax.lax.scan(
            f=lambda data, _: (mjx.step(self.sys, data.replace(ctrl=action)), None),
            init=state.data,
            xs=(),
            length=self.nr_intermediate_steps
        )

        state.info["episode_length"] += 1

        next_observation = self.get_observation(data)
        reward, r_info = self.get_reward(data)
        # terminated = True
        terminated = r_info["is_healthy"] < 0.5
        truncated = state.info["episode_length"] >= self.horizon
        done = terminated | truncated

        state.info.update(r_info)
        state.info["episode_return"] += reward
        state.info["done"] = done

        def when_done(_):
            __, reset_key = jax.random.split(state.info["key"])
            start_state = self.reset(reset_key)
            start_state = start_state.replace(reward=reward, terminated=terminated, truncated=truncated)
            start_state.info.update(r_info)
            start_state.info["done"] = True
            start_state.info["final_observation"] = next_observation
            info_keys_to_remove = ["key", "final_observation", "final_info", "done"]
            start_state.info["final_info"] = {key: state.info[key] for key in state.info if key not in info_keys_to_remove}
            return start_state
        def when_not_done(_):
            return state.replace(data=data, observation=next_observation, reward=reward, terminated=terminated, truncated=truncated)

        state = jax.lax.cond(done, when_done, when_not_done, None)

        return state


    def get_observation(self, data):
        torso_height = jnp.array([data.qpos[1]])        # z-position
        torso_pitch = jnp.array([data.qpos[2]])         # pitch
        joint_positions = data.qpos[3:]                 # 6 joint angles: thigh, leg, foot x2

        torso_vel_x = jnp.array([data.qvel[0]])        # vx (x-velocity)
        torso_vel_z = jnp.array([data.qvel[1]])        # vz (z-velocity)

        torso_ang_vel = jnp.array([data.qvel[2]])      # pitch angular velocity
        joint_velocities = data.qvel[3:]   # 6 joint velocities

        observation = jnp.nan_to_num(jnp.concatenate([
            torso_height,
            torso_pitch,
            joint_positions,
            torso_vel_x,
            torso_vel_z,
            torso_ang_vel,
            joint_velocities,
        ]))

        return observation


    def get_reward(self, data):
        """
        Rewards forward velocity - control cost
        """

        torso_height = data.qpos[1]
        torso_pitch = data.qpos[2]
        local_lin_vel = data.qvel[0] # only in x axis (walker is 2D)

        forward_reward = self.forward_reward_weight * local_lin_vel
        
        min_z, max_z = self.healthy_z_range
        min_angle, max_angle = self.healthy_angle_range
        is_healthy = jnp.clip(jnp.nan_to_num(((torso_height > min_z) & (torso_height < max_z) & (torso_pitch > min_angle) & (torso_pitch < max_angle)).astype('float32')), a_min=0.0, a_max=1.0)
        healthy_reward = jax.lax.cond(
            self.terminate_when_unhealthy,
            lambda _: self.healthy_reward,
            lambda _: self.healthy_reward * is_healthy,
            operand=None
        )

        ctrl_cost = self.ctrl_cost_weight * jnp.sum(jnp.square(data.ctrl))
        reward = jnp.nan_to_num(jnp.clip(forward_reward, a_max=1e4) + healthy_reward - ctrl_cost)

        info = {
            "local_vel_x": local_lin_vel,
            "is_healthy": is_healthy,
            "ctrl_cost": ctrl_cost,
        }

        return reward, info

