import os

import gym
import numpy as np
import torch
from gym.spaces.box import Box


class TransposeObs(gym.ObservationWrapper):
    def __init__(self, env=None):
        """
        Transpose observation space (base class)
        """
        super(TransposeObs, self).__init__(env)


class TransposeImage(TransposeObs):
    def __init__(self, env, op, transpose_keys):
        """
        Transpose observation space for images
        """
        super(TransposeImage, self).__init__(env)
        assert len(op) == 3
        self.op = op

        obs_shape = self.observation_space.shape
        self.observation_space = Box(
            self.observation_space.low[0, 0, 0],
            self.observation_space.high[0, 0, 0], [
                obs_shape[self.op[0]], obs_shape[self.op[1]],
                obs_shape[self.op[2]]
            ],
            dtype=self.observation_space.dtype)

        self.transpose_keys = transpose_keys


    def observation(self, ob):
        for k in self.transpose_keys:
            if k is None:
                ob = ob.transpose(self.op[0], self.op[1], self.op[2])
            else:
                ob[k] = ob[k].transpose(self.op[0], self.op[1], self.op[2])
        return ob

