from pathlib import Path
import mujoco
from mujoco import mjx
import jax
import jax.numpy as jnp

from irl_baselines.environments.point_maze_mjx.state import State


class PointMaze:
    def __init__(self, horizon: int = 100, reward_style: str = "dense", success_radius: float = 0.1):
        self.horizon = horizon
        self.reward_style = reward_style  # "dense" or "sparse"
        self.success_radius = success_radius

        xml_path = (Path(__file__).resolve().parent / "data" / "point_maze.xml").as_posix()
        mj_model = mujoco.MjModel.from_xml_path(xml_path)
        mj_model.opt.solver = mujoco.mjtSolver.mjSOL_NEWTON

        self.model = mj_model
        self.data = mujoco.MjData(mj_model)
        self.sys = mjx.put_model(mj_model)

        self.init_qpos = jnp.zeros(2)
        self.init_qvel = jnp.zeros(2)

        self.particle_body_id = self.model.body("particle").id
        self.target_body_id = self.model.body("target").id

    def reset(self, key):
        key_sample, key_next = jax.random.split(key)
        data = mjx.put_data(self.model, self.data)
        data = data.replace(qpos=self.init_qpos, qvel=self.init_qvel, ctrl=jnp.zeros(self.sys.nu))
        data = mjx.forward(self.sys, data)

        target_pos = data.xpos[self.target_body_id][:2]
        observation = self.get_observation(data, target_pos)

        reward = 0.0
        terminated = False
        truncated = False

        logging_info = {
            "episode_return": 0.0,
            "episode_length": 0,
            "target_x": target_pos[0],
            "target_y": target_pos[1],
            "is_success": 0.0,
            "reward_dist": 0.0,
            "reward_ctrl": 0.0,
        }

        info = {
            **logging_info,
            "final_observation": jnp.zeros_like(observation),
            "final_info": {**logging_info},
            "done": False,
            "key": key_next,
            "target_pos": target_pos,
        }

        return State(data, observation, reward, terminated, truncated, info)

    def step(self, state, action):
        data = state.data.replace(ctrl=action)
        data = mjx.step(self.sys, data)

        state.info["episode_length"] += 1

        # target_pos = state.info["target_pos"]
        target_pos = data.xpos[self.target_body_id][:2]
        next_observation = self.get_observation(data, target_pos)

        reward, r_info = self.get_reward(data, target_pos, action)

        terminated = r_info["is_success"] > 0.5
        truncated = state.info["episode_length"] >= self.horizon
        done = terminated | truncated

        state.info.update(r_info)
        state.info["episode_return"] += reward
        state.info["done"] = done

        def when_done(_):
            _, reset_key = jax.random.split(state.info["key"])
            start_state = self.reset(reset_key)
            start_state = start_state.replace(reward=reward, terminated=terminated, truncated=truncated)
            start_state.info.update(r_info)
            start_state.info["done"] = True
            start_state.info["final_observation"] = next_observation

            info_keys_to_remove = ["key", "final_observation", "final_info", "done", "target_pos"]
            start_state.info["final_info"] = {k: state.info[k] for k in state.info if k not in info_keys_to_remove}

            return start_state

        def when_not_done(_):
            return state.replace(
                data=data,
                observation=next_observation,
                reward=reward,
                terminated=terminated,
                truncated=truncated,
            )

        state = jax.lax.cond(done, when_done, when_not_done, operand=None)
        return state

    def get_observation(self, data, target_pos):
        particle_pos = data.xpos[self.particle_body_id][:2]
        return jnp.concatenate([particle_pos, target_pos])

    def get_reward(self, data, target_pos, action):
        particle_pos = data.xpos[self.particle_body_id][:2]
        diff = particle_pos - target_pos
        dist = jnp.linalg.norm(diff)

        reward_dist = -dist
        reward_ctrl = -jnp.sum(jnp.square(action))
        is_success = (dist <= self.success_radius).astype(jnp.float32)

        def dense_reward_fn(_):
            return reward_dist + 0.001 * reward_ctrl

        def sparse_reward_fn(_):
            return jnp.where(is_success > 0.5, 1.0, 0.0)

        reward = jax.lax.cond(self.reward_style == "sparse", sparse_reward_fn, dense_reward_fn, operand=None)

        info = {"reward_dist": reward_dist, "reward_ctrl": reward_ctrl, "is_success": is_success}
        return reward, info
