from typing import Any, Callable, Literal, SupportsFloat, Optional
import numpy as np
import numpy.typing as npt

import mujoco
from metaworld.envs.mujoco.sawyer_xyz import v2
from metaworld.envs.mujoco.sawyer_xyz.sawyer_xyz_env import RenderMode


class SawyerReachEnv(v2.SawyerReachEnvV2):
    def __init__(
        self,
        render_mode: Optional[RenderMode] = None,
        camera_name: Optional[str] = None,
        camera_id: Optional[int] = None,
        base_penalty: Optional[float] = 0.9,
        max_steps: Optional[int] = 100
    ) -> None:
        super().__init__(
            render_mode=render_mode,
            camera_name=camera_name,
            camera_id=camera_id
        )
        self.max_steps = max_steps
        self.base_penalty = base_penalty
        
    def step(
        self, action: npt.NDArray[np.float32]
    ) -> tuple[npt.NDArray[np.float64], SupportsFloat, bool, bool, dict[str, Any]]:
        """Step the environment.

        Args:
            action: The action to take. Must be a 4 element array of floats.

        Returns:
            The (next_obs, reward, terminated, truncated, info) tuple.
        """
        assert len(action) == 4, f"Actions should be size 4, got {len(action)}"
        self.set_xyz_action(action[:3])
        if self.curr_path_length >= self.max_path_length:
            raise ValueError("You must reset the env manually once truncate==True")
        self.do_simulation([action[-1], -action[-1]], n_frames=self.frame_skip)
        self.curr_path_length += 1

        # Running the simulator can sometimes mess up site positions, so
        # re-position them here to make sure they're accurate
        for site in self._target_site_config:
            self._set_pos_site(*site)

        if self._did_see_sim_exception:
            assert self._last_stable_obs is not None
            return (
                self._last_stable_obs,  # observation just before going unstable
                0.0,  # reward (penalize for causing instability)
                False,
                True,  # [MODIFIED] set exception to termindated=True
                {  # info
                    "success": False,
                    "near_object": 0.0,
                    "grasp_success": False,
                    "grasp_reward": 0.0,
                    "in_place_reward": 0.0,
                    "obj_to_target": 0.0,
                    "unscaled_reward": 0.0,
                },
            )
        mujoco.mj_forward(self.model, self.data)
        self._last_stable_obs = self._get_obs()

        self._last_stable_obs = np.clip(
            self._last_stable_obs,
            a_max=self.sawyer_observation_space.high,
            a_min=self.sawyer_observation_space.low,
            dtype=np.float64,
        )
        assert isinstance(self._last_stable_obs, np.ndarray)
        reward, info = self.evaluate_state(self._last_stable_obs, action)

        truncate = self.curr_path_length == self.max_path_length
        success = bool(info['success'])
        terminate = success
        
        # alternative reward fn: 
        # 1 if within range and 0 otherwise 
        # if success:
        #     reward = 1 - self.base_penalty * self.curr_path_length / self.max_steps
        # else:
        #     reward = 0
        
        return (
            np.array(self._last_stable_obs, dtype=np.float64),
            reward,
            terminate,
            truncate,
            info,
        )