import gymnasium as gym
import numpy as np
from gymnasium.spaces import Box, Dict


class ManipObservation(gym.ObservationWrapper):
    """Wrapper to flatten Dict observation spaces into a single Box space."""
    
    def __init__(self, env):
        super().__init__(env)

        assert isinstance(self.observation_space, Dict)
        
        obs_dim = 20
        
        self.observation_space = Box(
            low=-1,
            high=1,
            shape=(obs_dim,),
            dtype=np.float32
        )
    
    def observation(self, observation):
        assert isinstance(observation, dict)
        
        ee_pos = observation['observation'][..., 0:3]  # gripper x, y, z
        ee_quat = observation['observation'][..., 3:7]  # gripper orientation (quaternion)
        ee_velp = observation['observation'][..., 7:10]  # gripper linear velocity vx, vy, vz
        ee_velr = observation['observation'][..., 10:13]  # gripper angular velocity wx, wy, wz

        desired_goal_pos = observation['desired_goal'][..., 0:3]  # desired goal position x, y, z
        desired_goal_quat = observation['desired_goal'][..., 3:7]  # desired goal orientation (quaternion)

        pos_offset = np.array([4.92026489e-01, 1.34014449e-01, 4.37990367e-01])

        ee_pos = ee_pos - pos_offset
        desired_goal_pos = desired_goal_pos - pos_offset

        flat_obs = np.concatenate([ee_pos, ee_quat, ee_velp, ee_velr, desired_goal_pos, desired_goal_quat], axis=-1).astype(np.float32)

        return flat_obs
