from pathlib import Path
import mujoco
import numpy as np
from scipy.spatial.transform import Rotation
import gymnasium as gym

from irl_baselines.environments.humanoid.viewer import MujocoViewer


class Humanoid(gym.Env):
    def __init__(self, horizon=1000, render=False):
        self.horizon = horizon

        xml_path = (Path(__file__).resolve().parent / "data" / "humanoid.xml").as_posix()
        self.model = mujoco.MjModel.from_xml_path(xml_path)
        self.data = mujoco.MjData(self.model)

        self.nr_substeps = 1
        self.nr_intermediate_steps = 1
        self.dt = self.model.opt.timestep * self.nr_substeps * self.nr_intermediate_steps

        self.viewer = None if not render else MujocoViewer(self.model, self.dt)

        action_bounds = self.model.actuator_ctrlrange.copy().astype(np.float32)
        action_low, action_high = action_bounds.T
        self.action_space = gym.spaces.Box(low=action_low, high=action_high, dtype=np.float32)

        self.observation_space = gym.spaces.Box(low=-np.inf, high=np.inf, shape=(348,), dtype=np.float32)

        self.forward_reward_weight = 1.25
        self.healthy_z_range: Tuple[float, float] = (1.0, 2.0)
        self.terminate_when_unhealthy = True
        self.ctrl_cost_weight: float = 0.1
        self.healthy_reward = 5.0
        self.contact_cost_weight: float = 5e-7
        self.contact_cost_range: Tuple[float, float] = (-np.inf, 10.0)

    def reset(self, seed=None):
        self.episode_step = 0
        self.current_action = np.zeros(self.model.nu)
        
        initial_qpos = [0.0, 0.0, 1.4, 1.0] + [0.0] * (self.model.nq - 4)
        qpos = np.array(initial_qpos) # + np.random.uniform(low=-1e-3, high = 1e-3, size=self.model.nq)
        qvel = np.zeros(self.model.nv)

        self.data.qpos[:] = qpos
        self.data.qvel[:] = qvel
        mujoco.mj_forward(self.model, self.data)

        if self.viewer:
            self.viewer.render(self.data)

        return self.get_observation(), {}


    def step(self, action):
        for _ in range(self.nr_intermediate_steps):
            self.data.ctrl = action
            mujoco.mj_step(self.model, self.data, self.nr_substeps)

        if self.viewer:
            self.viewer.render(self.data)
        
        self.episode_step += 1
        self.current_action = action.copy()

        next_state = self.get_observation()
        reward, r_info = self.get_reward()
        terminated = r_info["is_healthy"] < 0.5
        # terminated = False
        truncated = self.episode_step >= self.horizon
        info = {**r_info}

        return next_state, reward, terminated, truncated, info


    def get_observation(self):
        position = self.data.qpos[2:] # exclude x and y coordinates of the torso
        velocity = self.data.qvel[:]
        com_inertia = self.data.cinert[1:].flatten()
        com_velocity = self.data.cvel[1:].flatten()
        actuator_forces = self.data.qfrc_actuator[6:].flatten()
        external_contact_forces = self.data.cfrc_ext[1:].flatten()

        observation = np.nan_to_num(np.concatenate([
            position,
            velocity,
            com_inertia,
            com_velocity,
            actuator_forces,
            external_contact_forces,
        ]))

        return observation


    def get_reward(self):
        """
        Rewards forward motion
        """
        torso_height = self.data.qpos[2]
        base_orientation = [self.data.qpos[4], self.data.qpos[5], self.data.qpos[6], self.data.qpos[3]]
        inverted_rotation = Rotation.from_quat(base_orientation).inv()
        current_global_linear_velocity = self.data.qvel[:3]
        current_local_linear_velocity = inverted_rotation.apply(current_global_linear_velocity)[0]
        # forward_reward = self.forward_reward_weight * current_local_linear_velocity
        forward_reward = self.forward_reward_weight * current_global_linear_velocity[0]

        min_z, max_z = self.healthy_z_range
        is_healthy = np.clip(np.nan_to_num(((torso_height > min_z) & (torso_height < max_z)).astype('float32')), a_min=0.0, a_max=1.0)
        if self.terminate_when_unhealthy:
            healthy_reward = self.healthy_reward
        else:
            healthy_reward = self.healthy_reward * is_healthy

        ctrl_cost = self.ctrl_cost_weight * np.sum(np.square(self.data.ctrl))

        contact_forces = self.data.cfrc_ext
        contact_cost = self.contact_cost_weight * np.sum(np.square(contact_forces))
        min_cost, max_cost = self.contact_cost_range
        contact_cost = np.clip(contact_cost, a_min=min_cost, a_max=max_cost)

        reward = np.nan_to_num(np.clip(forward_reward, a_min=None, a_max=1e4) + healthy_reward - ctrl_cost - contact_cost)

        info = {
            "global_vel_x": current_global_linear_velocity[0],
            "local_vel_x": current_local_linear_velocity,
            "is_healthy": is_healthy,
            "ctrl_cost": ctrl_cost,
            "contact_cost": contact_cost,
        }

        return reward, info
    

    def close(self):
        if self.viewer:
            self.viewer.close()
