import gymnasium as gym
import numpy as np
from gymnasium.spaces import Box, Dict


class FetchObservation(gym.ObservationWrapper):
    """Wrapper to flatten Dict observation spaces into a single Box space."""
    
    def __init__(self, env):
        super().__init__(env)

        assert isinstance(self.observation_space, Dict)
        
        if self.env.spec.id.startswith('FetchReach'):
            obs_dim = 9
        elif self.env.spec.id.startswith('FetchPush'):
            obs_dim = 31
        else:
            obs_dim = 11
        
        self.observation_space = Box(
            low=-1,
            high=1,
            shape=(obs_dim,),
            dtype=np.float32
        )
    
    def observation(self, observation):
        assert isinstance(observation, dict)
        env_id = self.env.spec.id
        
        gripper_pos = observation['observation'][..., 0:3]    # gripper x, y, z
        achieved_pos = observation['achieved_goal'][..., 0:3] # achieved x, y, z (object pos for object tasks, gripper pos for reach)
        
        if not env_id.startswith('FetchReach'):
            gripper_right = observation['observation'][..., 9:10] # right gripper finger
            gripper_left = observation['observation'][..., 10:11] # left gripper finger

        if env_id.startswith('FetchReach'):
            gripper_vel = observation['observation'][..., 5:8]  # gripper velocity vx, vy, vz

        desired_goal = observation['desired_goal'][..., 0:3]  # desired goal position x, y, z

        if env_id.startswith('FetchReach'):
            # pos_offset = np.array([1.3419, 0.7491, 0.555])
            pos_offset = np.array([1.3419, 0.7491, 0.455])
        else:
            pos_offset = np.array([1.3419, 0.7491, 0.42])
            object_pos = observation['observation'][..., 3:6]  # object x, y, z
            object_pos = object_pos - pos_offset

        gripper_pos = gripper_pos - pos_offset
        achieved_pos = achieved_pos - pos_offset
        desired_goal = desired_goal - pos_offset

        if env_id.startswith('FetchReach'):
            flat_obs = np.concatenate([achieved_pos, gripper_vel, desired_goal], axis=-1).astype(np.float32)
        elif env_id.startswith('FetchPush'):
            flat_obs = np.concatenate([gripper_pos, object_pos, observation['observation'][..., 6:], achieved_pos, desired_goal], axis=-1).astype(np.float32)
        else:
            flat_obs = np.concatenate([gripper_pos, achieved_pos, gripper_right, gripper_left, desired_goal], axis=-1).astype(np.float32)

        return flat_obs
