import gymnasium
import gymnasium.spaces as spaces
import numpy as np
import math


class FlatternEnv(gymnasium.ObservationWrapper):
    def __init__(self, env):
        super().__init__(env)  # Call the parent class initializer
        obs_shape = sum([env.observation_space[k].shape[0] for k in env.observation_space])
        self.observation_space = spaces.Box(-math.inf, math.inf, shape=(obs_shape,))
    
    def observation(self, observation):
        new_obs = np.concatenate([observation[k] for k in observation], axis=0)
        return new_obs

def flattern_observation(obs):
    new_obs = np.concatenate([obs[k] for k in obs], axis=-1)
    return new_obs
