from pathlib import Path
import mujoco
import numpy as np
from scipy.spatial.transform import Rotation
import gymnasium as gym

from irl_baselines.environments.ant.viewer import MujocoViewer


class Ant(gym.Env):
    def __init__(self, horizon=1000, render=False):
        self.horizon = horizon

        xml_path = (Path(__file__).resolve().parent / "data" / "ant.xml").as_posix()
        self.model = mujoco.MjModel.from_xml_path(xml_path)
        self.data = mujoco.MjData(self.model)

        self.nr_substeps = 1
        self.nr_intermediate_steps = 1
        self.dt = self.model.opt.timestep * self.nr_substeps * self.nr_intermediate_steps

        self.viewer = None if not render else MujocoViewer(self.model, self.dt)

        action_bounds = self.model.actuator_ctrlrange.copy().astype(np.float32)
        action_low, action_high = action_bounds.T
        self.action_space = gym.spaces.Box(low=action_low, high=action_high, dtype=np.float32)

        self.observation_space = gym.spaces.Box(low=-np.inf, high=np.inf, shape=(105,), dtype=np.float32)

        self.forward_reward_weight = 1.0
        self.healthy_z_range: Tuple[float, float] = (0.2, 1.0)
        self.terminate_when_unhealthy = True
        self.ctrl_cost_weight: float = 5e-4
        self.healthy_reward = 1.0
        self.contact_force_range: Tuple[float, float] = (-1.0, 1.0)
        self.contact_cost_weight: float = 5e-4

    def reset(self, seed=None):
        self.episode_step = 0
        self.current_action = np.zeros(self.model.nu)
        
        qpos = np.zeros(self.model.nq)
        qpos[2] = 0.75  # z position
        qpos[3:7] = [1.0, 0.0, 0.0, 0.0]  # mujoco quaternion format: [w, x, y, z]

        qvel = np.zeros(self.model.nv)

        self.data.qpos[:] = qpos
        self.data.qvel[:] = qvel
        mujoco.mj_forward(self.model, self.data)

        if self.viewer:
            self.viewer.render(self.data)

        return self.get_observation(), {}


    def step(self, action):
        for _ in range(self.nr_intermediate_steps):
            self.data.ctrl = action
            mujoco.mj_step(self.model, self.data, self.nr_substeps)

        if self.viewer:
            self.viewer.render(self.data)
        
        self.episode_step += 1
        self.current_action = action.copy()

        next_state = self.get_observation()
        reward, r_info = self.get_reward()
        # terminated = self.data.qpos[2] < 0.35
        terminated = r_info["is_healthy"] < 0.5
        truncated = self.episode_step >= self.horizon
        info = {**r_info}

        return next_state, reward, terminated, truncated, info


    def get_observation(self):
        position = self.data.qpos[2:] # exclude x and y coordinates of the torso
        velocity = self.data.qvel[:]
        raw_contact_forces = self.data.cfrc_ext
        min_value, max_value = self.contact_force_range
        contact_forces = np.clip(raw_contact_forces, min_value, max_value)
        contact_force = contact_forces[1:].flatten()

        observation = np.nan_to_num(np.concatenate([
            position,
            velocity,
            contact_force
        ]))
        
        return observation



    def get_reward(self):
        """
        Rewards forward motion
        """
        torso_height = self.data.qpos[2]
        base_orientation = [self.data.qpos[4], self.data.qpos[5], self.data.qpos[6], self.data.qpos[3]]
        inverted_rotation = Rotation.from_quat(base_orientation).inv()
        current_global_linear_velocity = self.data.qvel[:3]
        current_local_linear_velocity = inverted_rotation.apply(current_global_linear_velocity)[0]
        forward_reward = self.forward_reward_weight * current_global_linear_velocity[0]

        min_z, max_z = self.healthy_z_range
        is_healthy = np.clip(np.nan_to_num(((torso_height > min_z) & (torso_height < max_z)).astype('float32')), a_min=0.0, a_max=1.0)
        
        if self.terminate_when_unhealthy:
            healthy_reward = self.healthy_reward
        else:
            healthy_reward = self.healthy_reward * is_healthy

        ctrl_cost = self.ctrl_cost_weight * np.sum(np.square(self.data.ctrl))

        raw_contact_forces = self.data.cfrc_ext
        min_value, max_value = self.contact_force_range
        contact_forces = np.clip(raw_contact_forces, min_value, max_value)
        contact_force = contact_forces[1:].flatten()
        contact_cost = self.contact_cost_weight * np.sum(np.square(contact_force))

        reward = np.nan_to_num(np.clip(forward_reward, a_min=None, a_max=1e4) + healthy_reward - ctrl_cost - contact_cost)

        info = {
            "global_vel_x": current_global_linear_velocity[0],
            "local_vel_x": current_local_linear_velocity,
            "is_healthy": is_healthy,
            "ctrl_cost": ctrl_cost,
            "contact_cost": contact_cost,
        }

        return reward, info

    def close(self):
        if self.viewer:
            self.viewer.close()
