from pathlib import Path
import mujoco
from mujoco import mjx
from jax.scipy.spatial.transform import Rotation
import jax
import jax.numpy as jnp

from irl_baselines.environments.ant_mjx.state import State

class Ant:
    def __init__(self, horizon=1000):
        self.horizon = horizon
        
        xml_path = (Path(__file__).resolve().parent / "data" / "ant.xml").as_posix()
        mj_model = mujoco.MjModel.from_xml_path(xml_path)
        mj_model.opt.solver = mujoco.mjtSolver.mjSOL_NEWTON
        self.model = mj_model
        self.data = mujoco.MjData(mj_model)
        self.sys = mjx.put_model(mj_model)

        self.nr_intermediate_steps = 1

        initial_height = 0.75
        initial_rotation_quaternion = [1.0, 0.0, 0.0, 0.0]  # mujoco quaternion format: [w, x, y, z]
        initial_joint_angles = [0.0, 0.0] * 4
        self.initial_qpos = jnp.array([0.0, 0.0, initial_height, *initial_rotation_quaternion, *initial_joint_angles])
        self.initial_qvel = jnp.zeros(self.sys.nv)

        self.forward_reward_weight = 1.0
        self.healthy_z_range: Tuple[float, float] = (0.2, 1.0)
        self.terminate_when_unhealthy = True
        self.ctrl_cost_weight: float = 5e-4
        self.healthy_reward = 1.0
        self.contact_force_range: Tuple[float, float] = (-1.0, 1.0)
        self.contact_cost_weight: float = 5e-4


    def reset(self, key):
        key, subkey = jax.random.split(key)

        data = mjx.put_data(self.model, self.data)
        data = data.replace(qpos=self.initial_qpos, qvel=self.initial_qvel, ctrl=jnp.zeros(self.sys.nu))
        data = mjx.forward(self.sys, data)

        observation = self.get_observation(data)
        reward = 0.0
        terminated = False
        truncated = False
        logging_info = {
            "episode_return": reward,
            "episode_length": 0,
            "global_vel_x": 0.0,
            "local_vel_x": 0.0,
            "is_healthy": 1.0,
            "ctrl_cost": 0.0,
            "contact_cost": 0.0,
        }
        info = {
            **logging_info,
            "final_observation": jnp.zeros_like(observation),
            "final_info": {**logging_info},
            "done": False,
            "key": subkey
        }

        return State(data, observation, reward, terminated, truncated, info)


    def step(self, state, action):
        data, _ = jax.lax.scan(
            f=lambda data, _: (mjx.step(self.sys, data.replace(ctrl=action)), None),
            init=state.data,
            xs=(),
            length=self.nr_intermediate_steps
        )

        state.info["episode_length"] += 1

        next_observation = self.get_observation(data)
        reward, r_info = self.get_reward(data)
        terminated = r_info["is_healthy"] < 0.5
        truncated = state.info["episode_length"] >= self.horizon
        done = terminated | truncated

        state.info.update(r_info)
        state.info["episode_return"] += reward
        state.info["done"] = done

        def when_done(_):
            __, reset_key = jax.random.split(state.info["key"])
            start_state = self.reset(reset_key)
            start_state = start_state.replace(reward=reward, terminated=terminated, truncated=truncated)
            start_state.info.update(r_info)
            start_state.info["done"] = True
            start_state.info["final_observation"] = next_observation
            info_keys_to_remove = ["key", "final_observation", "final_info", "done"]
            start_state.info["final_info"] = {key: state.info[key] for key in state.info if key not in info_keys_to_remove}
            return start_state
        def when_not_done(_):
            return state.replace(data=data, observation=next_observation, reward=reward, terminated=terminated, truncated=truncated)
        state = jax.lax.cond(done, when_done, when_not_done, None)

        return state


    def get_observation(self, data):
        position = data.qpos[2:] # exclude x and y coordinates of the torso
        velocity = data.qvel[:]
        raw_contact_forces = data.cfrc_ext
        min_value, max_value = self.contact_force_range
        contact_forces = jnp.clip(raw_contact_forces, min_value, max_value)
        contact_force = contact_forces[1:].flatten()

        observation = jnp.nan_to_num(jnp.concatenate([
            position,
            velocity,
            contact_force
        ]))

        return observation


    def get_reward(self, data):
        """
        Rewards forward motion - control cost
        """
        torso_height = data.qpos[2]
        base_orientation = [data.qpos[4], data.qpos[5], data.qpos[6], data.qpos[3]]
        inverted_rotation = Rotation.from_quat(base_orientation).inv()
        current_global_linear_velocity = data.qvel[:3]
        current_local_linear_velocity = inverted_rotation.apply(current_global_linear_velocity)[0]
        forward_reward = self.forward_reward_weight * current_global_linear_velocity[0]

        min_z, max_z = self.healthy_z_range
        is_healthy = jnp.clip(jnp.nan_to_num(((torso_height > min_z) & (torso_height < max_z)).astype('float32')), a_min=0.0, a_max=1.0)
        healthy_reward = jax.lax.cond(
            self.terminate_when_unhealthy,
            lambda _: self.healthy_reward,
            lambda _: self.healthy_reward * is_healthy,
            operand=None
        )

        ctrl_cost = self.ctrl_cost_weight * jnp.sum(jnp.square(data.ctrl))

        raw_contact_forces = data.cfrc_ext
        min_value, max_value = self.contact_force_range
        contact_forces = jnp.clip(raw_contact_forces, min_value, max_value)
        contact_force = contact_forces[1:].flatten()
        contact_cost = self.contact_cost_weight * jnp.sum(jnp.square(contact_force))

        reward = jnp.nan_to_num(jnp.clip(forward_reward, a_max=1e4) + healthy_reward - ctrl_cost - contact_cost)

        info = {
            "global_vel_x": current_global_linear_velocity[0],
            "local_vel_x": current_local_linear_velocity,
            "is_healthy": is_healthy,
            "ctrl_cost": ctrl_cost,
            "contact_cost": contact_cost,
        }

        return reward, info
    