import gym
from gym import ObservationWrapper
import tensorflow as tf
import numpy as np

class MyObservationWrapper(ObservationWrapper):

    def __init__(self, env, mode):
        super(MyObservationWrapper, self).__init__(env)
        self.mode = mode
        self.prev_action = np.atleast_1d((self.action_space.high - self.action_space.low) * np.random.random_sample(self.action_space.shape[0]) + self.action_space.low)
        self.n = self.observation_space.shape[0] + self.action_space.shape[0]
        self.low = np.concatenate([self.observation_space.low, self.action_space.low])
        self.high = np.concatenate([self.observation_space.high, self.action_space.high])
        self.observation_space = gym.spaces.Box(self.low, self.high, (self.n,))

    def observation(self, observation):
        assert self.prev_action is not None
        if self.mode == 'original':
            self.prev_action = np.atleast_1d((self.action_space.high - self.action_space.low) * np.random.random_sample(self.action_space.shape[0]) + self.action_space.low)
        if len(observation.shape) != len(self.prev_action.shape):
            if len(observation.shape) == 2:
                observation = observation.reshape(self.observation_space.shape)
            if len(self.prev_action.shape) == 2:
                self.prev_action = self.prev_action.reshape(self.action_space.shape)
        return np.concatenate([observation, self.prev_action])

    def step(self, action):
        observation, reward, done, info = self.env.step(action)
        if self.mode == 'confounded':
            self.prev_action = action
        return self.observation(observation), reward, done, info