from pathlib import Path
import mujoco
import numpy as np
import gymnasium as gym

from irl_baselines.environments.point_maze.viewer import MujocoViewer


class PointMaze(gym.Env):
    def __init__(self, horizon: int = 100, reward_style: str = "dense",
                 success_radius: float = 0.1, render: bool = False):
        self.horizon = horizon
        self.reward_style = reward_style  # "dense" or "sparse"
        self.success_radius = success_radius

        xml_path = (Path(__file__).resolve().parent / "data" / "point_maze.xml").as_posix()
        mj_model = mujoco.MjModel.from_xml_path(xml_path)
        mj_model.opt.solver = mujoco.mjtSolver.mjSOL_NEWTON
        self.model = mj_model
        self.data = mujoco.MjData(self.model)

        self.nr_substeps = 1
        self.nr_intermediate_steps = 1
        self.dt = self.model.opt.timestep * self.nr_substeps * self.nr_intermediate_steps
        self.viewer = None if not render else MujocoViewer(self.model, self.dt)

        self.init_qpos = np.zeros(2, dtype=np.float64)
        self.init_qvel = np.zeros(2, dtype=np.float64)

        action_bounds = self.model.actuator_ctrlrange.copy().astype(np.float32)
        action_low, action_high = action_bounds.T
        self.action_space = gym.spaces.Box(low=action_low, high=action_high, dtype=np.float32)

        # [particle_x, particle_y, target_x, target_y]
        self.observation_space = gym.spaces.Box(low=-np.inf, high=np.inf, shape=(4,), dtype=np.float32)

        self.particle_body_id = self.model.body("particle").id
        self.target_body_id = self.model.body("target").id

        self.episode_step = 0
        self.current_action = np.zeros(self.model.nu, dtype=np.float32)
        self.target_pos = np.zeros(2, dtype=np.float64)

    def reset(self, seed=None, options=None):
        super().reset(seed=seed)
        self.episode_step = 0
        self.current_action[:] = 0.0

        qpos = np.zeros(self.model.nq, dtype=np.float64)
        qvel = np.zeros(self.model.nv, dtype=np.float64)
        qpos[:2] = self.init_qpos
        qvel[:2] = self.init_qvel

        self.data.qpos[:] = qpos
        self.data.qvel[:] = qvel
        self.data.ctrl[:] = 0.0
        mujoco.mj_forward(self.model, self.data)

        # Fixed target pose from XML via the target body world position
        self.target_pos = self.data.xpos[self.target_body_id, :2].copy()

        if self.viewer:
            self.viewer.render(self.data)

        obs = self.get_observation().astype(np.float32)
        info = {
            "episode_step": self.episode_step,
            "target_x": float(self.target_pos[0]),
            "target_y": float(self.target_pos[1]),
        }
        return obs, info

    def step(self, action):
        action = np.asarray(action, dtype=np.float32)
        if self.action_space is not None:
            action = np.clip(action, self.action_space.low, self.action_space.high)

        for _ in range(self.nr_intermediate_steps):
            self.data.ctrl[:] = action
            mujoco.mj_step(self.model, self.data, self.nr_substeps)

        if self.viewer:
            self.viewer.render(self.data)

        self.episode_step += 1
        self.current_action = action.copy()

        # In case the target body had any dynamics (usually it doesn't), keep this in sync
        self.target_pos = self.data.xpos[self.target_body_id, :2].copy()

        obs = self.get_observation().astype(np.float32)
        reward, r_info = self.get_reward(action)

        terminated = bool(r_info["is_success"] > 0.5)
        truncated = bool(self.episode_step >= self.horizon)

        info = {
            **r_info,
            "episode_step": self.episode_step,
            "target_x": float(self.target_pos[0]),
            "target_y": float(self.target_pos[1]),
        }

        return obs, float(reward), terminated, truncated, info

    def get_observation(self):
        particle_pos = self.data.xpos[self.particle_body_id, :2]
        obs = np.concatenate([particle_pos, self.target_pos], axis=0)
        obs = np.nan_to_num(obs)
        return obs

    def get_reward(self, action):
        particle_pos = self.data.xpos[self.particle_body_id, :2]
        diff = particle_pos - self.target_pos
        dist = float(np.linalg.norm(diff))

        reward_dist = -dist
        reward_ctrl = -float(np.sum(np.square(action)))
        is_success = float(dist <= self.success_radius)

        if self.reward_style == "sparse":
            reward = 1.0 if is_success > 0.5 else 0.0
        else:
            reward = reward_dist + 0.001 * reward_ctrl

        info = {
            "reward_dist": reward_dist,
            "reward_ctrl": reward_ctrl,
            "is_success": is_success,
            "distance_to_goal": dist,
        }
        return reward, info

    def close(self):
        if self.viewer:
            self.viewer.close()
            self.viewer = None
