from pathlib import Path
import mujoco
import numpy as np
from scipy.spatial.transform import Rotation
import gymnasium as gym

from irl_baselines.environments.half_cheetah.viewer import MujocoViewer


class HalfCheetah(gym.Env):
    def __init__(self, horizon=1000, render=False):
        self.horizon = horizon

        xml_path = (Path(__file__).resolve().parent / "data" / "half_cheetah.xml").as_posix()
        self.model = mujoco.MjModel.from_xml_path(xml_path)
        self.data = mujoco.MjData(self.model)

        self.nr_substeps = 1
        self.nr_intermediate_steps = 1
        self.dt = self.model.opt.timestep * self.nr_substeps * self.nr_intermediate_steps

        self.viewer = None if not render else MujocoViewer(self.model, self.dt)

        action_bounds = self.model.actuator_ctrlrange.copy().astype(np.float32)
        action_low, action_high = action_bounds.T
        self.action_space = gym.spaces.Box(low=action_low, high=action_high, dtype=np.float32)

        self.observation_space = gym.spaces.Box(low=-np.inf, high=np.inf, shape=(17,), dtype=np.float32)

        self.forward_reward_weight = 1.0
        self.ctrl_cost_weight: float = 0.1

    def reset(self, seed=None):
        self.episode_step = 0
        self.current_action = np.zeros(self.model.nu)

        initial_qpos = [0.0] * (self.model.nq)
        qpos = np.array(initial_qpos)
        qvel = np.zeros(self.model.nv)

        self.data.qpos[:] = qpos
        self.data.qvel[:] = qvel
        mujoco.mj_forward(self.model, self.data)

        if self.viewer:
            self.viewer.render(self.data)

        return self.get_observation(), {}


    def step(self, action):
        for _ in range(self.nr_intermediate_steps):
            self.data.ctrl = action
            mujoco.mj_step(self.model, self.data, self.nr_substeps)

        if self.viewer:
            self.viewer.render(self.data)
        
        self.episode_step += 1
        self.current_action = action.copy()

        next_state = self.get_observation()
        reward, r_info = self.get_reward()
        terminated = False
        truncated = self.episode_step >= self.horizon
        info = {**r_info}

        return next_state, reward, terminated, truncated, info


    def get_observation(self):
        position = self.data.qpos[1:] # exclude x coordinate of front tip
        velocity = self.data.qvel[:]

        observation = np.nan_to_num(np.concatenate([
            position,
            velocity,
        ]))

        return observation


    def get_reward(self):
        """
        Rewards forward velocity - control cost
        """

        local_lin_vel = self.data.qvel[0] # only in x axis (half cheetah is 2D)
        forward_reward = self.forward_reward_weight * local_lin_vel
        ctrl_cost = self.ctrl_cost_weight * np.sum(np.square(self.data.ctrl))
        reward = np.nan_to_num(np.clip(forward_reward, a_min=None, a_max=1e4) - ctrl_cost)

        info = {
            "local_vel_x": local_lin_vel,
            "ctrl_cost": ctrl_cost,
        }

        return reward, info
    

    def close(self):
        if self.viewer:
            self.viewer.close()
