import numpy as np
from collections import deque
from gym import ObservationWrapper


class StackObservations(ObservationWrapper):
    def __init__(self, env):
        super(StackObservations, self).__init__(env)
        self.num_frames = 1
        self.frame_queue = deque(maxlen=4)
        for _ in range(self.num_frames - 1):
            self.frame_queue.append(np.zeros((3, 18, 24), dtype=np.float32))

    def observation(self, observation):

        if len(observation.shape) < 3:
            observation = np.expand_dims(observation, 2)
        observation = np.transpose(observation, (2, 0, 1)) / 255
        self.frame_queue.append(observation)
        return np.array(self.frame_queue).reshape((-1, 18, 24))


class NormaliseObservations(ObservationWrapper):
    def __init__(self, env):
        super(NormaliseObservations, self).__init__(env)
        self.num_frames = 4

    def observation(self, observation):
        observation = observation / 255
        return observation


class TransposeObservations(ObservationWrapper):
    def __init__(self, env):
        super(TransposeObservations, self).__init__(env)

    def observation(self, observation):

        if len(observation.shape) < 3:
            observation = np.expand_dims(observation, 2)
        observation = np.transpose(observation, (2, 0, 1))
        return observation
