from pathlib import Path
import mujoco
import numpy as np
from scipy.spatial.transform import Rotation
import gymnasium as gym

from irl_baselines.environments.walker2d.viewer import MujocoViewer


class Walker2D(gym.Env):
    def __init__(self, horizon=1000, render=False):
        self.horizon = horizon

        xml_path = (Path(__file__).resolve().parent / "data" / "walker2d_v5.xml").as_posix()
        self.model = mujoco.MjModel.from_xml_path(xml_path)
        self.data = mujoco.MjData(self.model)

        self.nr_substeps = 1
        self.nr_intermediate_steps = 1
        self.dt = self.model.opt.timestep * self.nr_substeps * self.nr_intermediate_steps

        self.viewer = None if not render else MujocoViewer(self.model, self.dt)

        action_bounds = self.model.actuator_ctrlrange.copy().astype(np.float32)
        action_low, action_high = action_bounds.T
        self.action_space = gym.spaces.Box(low=action_low, high=action_high, dtype=np.float32)

        self.observation_space = gym.spaces.Box(low=-np.inf, high=np.inf, shape=(17,), dtype=np.float32)

        self.forward_reward_weight = 1.0
        self.healthy_z_range: Tuple[float, float] = (0.8, 2.0)
        self.healthy_angle_range = (-1.0, 1.0)
        self.terminate_when_unhealthy = True
        self.ctrl_cost_weight: float = 1e-3
        self.healthy_reward = 1.0

    def reset(self, seed=None):
        self.episode_step = 0
        self.current_action = np.zeros(self.model.nu)

        initial_qpos = [0.0, 1.25, 0.0] + [0.0] * (self.model.nq - 3)
        qpos = np.array(initial_qpos)
        qvel = np.zeros(self.model.nv)

        self.data.qpos[:] = qpos
        self.data.qvel[:] = qvel
        mujoco.mj_forward(self.model, self.data)

        if self.viewer:
            self.viewer.render(self.data)

        return self.get_observation(), {}


    def step(self, action):
        for _ in range(self.nr_intermediate_steps):
            self.data.ctrl = action
            mujoco.mj_step(self.model, self.data, self.nr_substeps)

        if self.viewer:
            self.viewer.render(self.data)
        
        self.episode_step += 1
        self.current_action = action.copy()

        next_state = self.get_observation()
        reward, r_info = self.get_reward()
        terminated = r_info["is_healthy"] < 0.5
        truncated = self.episode_step >= self.horizon
        info = {**r_info}

        return next_state, reward, terminated, truncated, info


    def get_observation(self):
        torso_height = np.array([self.data.qpos[1]])        # z-position
        torso_pitch = np.array([self.data.qpos[2]])         # pitch
        joint_positions = self.data.qpos[3:]                 # 6 joint angles: thigh, leg, foot x2

        torso_vel_x = np.array([self.data.qvel[0]])        # vx (x-velocity)
        torso_vel_z = np.array([self.data.qvel[1]])        # vz (z-velocity)

        torso_ang_vel = np.array([self.data.qvel[2]])      # pitch angular velocity
        joint_velocities = self.data.qvel[3:]   # 6 joint velocities

        observation = np.nan_to_num(np.concatenate([
            torso_height,
            torso_pitch,
            joint_positions,
            torso_vel_x,
            torso_vel_z,
            torso_ang_vel,
            joint_velocities,
        ]))

        return observation


    def get_reward(self):
        """
        Rewards forward velocity - control cost
        """

        torso_height = self.data.qpos[1]
        torso_pitch = self.data.qpos[2]
        local_lin_vel = self.data.qvel[0] # only in x axis (walker is 2D)

        forward_reward = self.forward_reward_weight * local_lin_vel
        
        min_z, max_z = self.healthy_z_range
        min_angle, max_angle = self.healthy_angle_range
        is_healthy = np.clip(np.nan_to_num(((torso_height > min_z) & (torso_height < max_z) & (torso_pitch > min_angle) & (torso_pitch < max_angle)).astype('float32')), a_min=0.0, a_max=1.0)
        if self.terminate_when_unhealthy:
            healthy_reward = self.healthy_reward
        else:
            healthy_reward = self.healthy_reward * is_healthy

        ctrl_cost = self.ctrl_cost_weight * np.sum(np.square(self.data.ctrl))
        reward = np.nan_to_num(np.clip(forward_reward, a_min=None, a_max=1e4) + healthy_reward - ctrl_cost)

        info = {
            "local_vel_x": local_lin_vel,
            "is_healthy": is_healthy,
            "ctrl_cost": ctrl_cost,
        }

        return reward, info
    

    def close(self):
        if self.viewer:
            self.viewer.close()
