from pathlib import Path
import mujoco
from mujoco import mjx
import jax
import jax.numpy as jnp
from typing import Tuple

from irl_baselines.environments.half_cheetah_mjx.state import State


class HalfCheetah:
    def __init__(self, horizon=1000):
        self.horizon = horizon

        xml_path = (Path(__file__).resolve().parent / "data" / "half_cheetah.xml").as_posix()
        mj_model = mujoco.MjModel.from_xml_path(xml_path)
        mj_model.opt.solver = mujoco.mjtSolver.mjSOL_NEWTON
        self.model = mj_model
        self.data = mujoco.MjData(mj_model)
        self.sys = mjx.put_model(mj_model)

        self.nr_intermediate_steps = 1

        initial_qpos = [0.0] * (self.sys.nq)

        self.initial_qpos = jnp.array(initial_qpos)
        self.initial_qvel = jnp.zeros(self.sys.nv)

        self.forward_reward_weight = 1.0
        self.ctrl_cost_weight: float = 0.1

    def reset(self, key):
        key, subkey = jax.random.split(key)

        data = mjx.put_data(self.model, self.data)
        data = data.replace(qpos=self.initial_qpos, qvel=self.initial_qvel, ctrl=jnp.zeros(self.sys.nu))
        data = mjx.forward(self.sys, data)

        observation = self.get_observation(data)
        reward = 0.0
        terminated = False
        truncated = False
        logging_info = {
            "episode_return": reward,
            "episode_length": 0,
            "local_vel_x": 0.0,
            "ctrl_cost": 0.0,
        }
        info = {
            **logging_info,
            "final_observation": jnp.zeros_like(observation),
            "final_info": {**logging_info},
            "done": False,
            "key": subkey
        }

        return State(data, observation, reward, terminated, truncated, info)


    def step(self, state, action):
        data, _ = jax.lax.scan(
            f=lambda data, _: (mjx.step(self.sys, data.replace(ctrl=action)), None),
            init=state.data,
            xs=(),
            length=self.nr_intermediate_steps
        )

        state.info["episode_length"] += 1

        next_observation = self.get_observation(data)
        reward, r_info = self.get_reward(data)
        terminated = False
        truncated = state.info["episode_length"] >= self.horizon
        done = terminated | truncated

        state.info.update(r_info)
        state.info["episode_return"] += reward
        state.info["done"] = done

        def when_done(_):
            __, reset_key = jax.random.split(state.info["key"])
            start_state = self.reset(reset_key)
            start_state = start_state.replace(reward=reward, terminated=terminated, truncated=truncated)
            start_state.info.update(r_info)
            start_state.info["done"] = True
            start_state.info["final_observation"] = next_observation
            info_keys_to_remove = ["key", "final_observation", "final_info", "done"]
            start_state.info["final_info"] = {key: state.info[key] for key in state.info if key not in info_keys_to_remove}
            return start_state
        def when_not_done(_):
            return state.replace(data=data, observation=next_observation, reward=reward, terminated=terminated, truncated=truncated)

        state = jax.lax.cond(done, when_done, when_not_done, None)

        return state


    def get_observation(self, data):
        position = data.qpos[1:] # exclude x coordinate of front tip
        velocity = data.qvel[:]

        observation = jnp.nan_to_num(jnp.concatenate([
            position,
            velocity,
        ]))

        return observation


    def get_reward(self, data):
        """
        Rewards forward velocity - control cost
        """

        local_lin_vel = data.qvel[0] # only in x axis (half cheetah is 2D)
        forward_reward = self.forward_reward_weight * local_lin_vel
        ctrl_cost = self.ctrl_cost_weight * jnp.sum(jnp.square(data.ctrl))
        reward = jnp.nan_to_num(jnp.clip(forward_reward, a_max=1e4) - ctrl_cost)

        info = {
            "local_vel_x": local_lin_vel,
            "ctrl_cost": ctrl_cost,
        }

        return reward, info

