import os
import numpy as np
import torch

from gym_minigrid.wrappers import *

class OneHotDynamicObjectsWrapper(gym.core.ObservationWrapper):
    """
    Wrapper to get a one-hot encoding of dynamic enemies environment
    """

    def __init__(self, env, tile_size=8):
        super().__init__(env)

        # Number of bits per cell (only 4 objects + 4 directions the agent can face)
        num_bits = 7

        self.observation_space.spaces["image"] = spaces.Box(
            low=0,
            high=255,
            shape=(self.env.width, self.env.height, num_bits),
            dtype='uint8'
        )
        # Map of object type to bits for this wrapper (only agent, wall, enemy and goal)
        self.OBJECT_TO_BITS = {
            2:  0,  # wall
            6:  1,  # ball (enemy)
            8:  2,  # goal
            10: 3,  # agent
        }
        self.BITS_TO_OBJECTS = dict(zip(self.OBJECT_TO_BITS.values(), self.OBJECT_TO_BITS.keys()))
        # agent can be in 4 directions
        self.BITS_TO_OBJECTS[4] = 10
        self.BITS_TO_OBJECTS[5] = 10
        self.BITS_TO_OBJECTS[6] = 10

        self.BITS_TO_COLOR = {
            0: 5,  # wall is grey
            1: 2,  # ball is blue
            2: 1,  # goal is green
            3: 0,  # agent is different color depending on its direction
            4: 1,  # agent is different color depending on its direction
            5: 2,  # agent is different color depending on its direction
            6: 3,  # agent is different color depending on its direction
        }

    def observation(self, obs):
        img = obs['image']
        out = np.zeros(self.observation_space.spaces['image'].shape, dtype='uint8')

        for i in range(img.shape[0]):
            for j in range(img.shape[1]):
                type = img[i, j, 0]
                color = img[i, j, 1]
                if type == 1:
                    pass  # nothing happens for empty
                elif type == 10:  # in the case of the agent, direction matters
                    out[i, j, self.OBJECT_TO_BITS[type]+color] = 1
                else:
                    out[i, j, self.OBJECT_TO_BITS[type]] = 1

        return {
            'mission': obs['mission'],
            'image': out
        }

    def render(self, img):
        indices = np.argmax(img, axis=-1)
        val = np.max(img, axis=-1)
        obs = np.zeros((self.env.width, self.env.height, 3))
        for i in range(self.env.width):
            for j in range(self.env.height):
                if val[i, j] > 0.5:  # only if confidence > 50% on argmax
                    obs[i,j,0] = self.BITS_TO_OBJECTS[indices[i, j]]
                    obs[i,j,1] = self.BITS_TO_COLOR[indices[i, j]]
                else:
                    obs[i,j,0] = 1  # otherwise empty
        return self.env.get_obs_render(obs)

if __name__ == "__main__":
    env = gym.make('MiniGrid-Dynamic-Obstacles-16x16-v0')
    env = OneHotDynamicObjectsWrapper(env)
    # env.seed(123)
    env = ImgObsWrapper(env) # Get rid of the 'mission' field
    done = True
    ep = -1
    total_steps = 0
    # savedir = '../trainingdata-minigrid-v0/'
    savedir = '../testingdata-minigrid-v0/'
    # while total_steps < 40000:
    while ep < 10:
        if done:
            obs = env.reset()
            obs = obs.astype('float')/obs.max()
            ep += 1
            step = 0
            actions = []
            rewards = []
            os.makedirs(savedir + str(ep))
            torch.save(torch.from_numpy(obs.transpose(2,0,1)),
                       open(savedir + '{}/{}.pt'.format(ep, step), 'wb'))
        actions.append(np.random.randint(3))
        obs, r, done, _ = env.step(actions[-1])
        obs = obs.astype('float')/obs.max()
        rewards.append(r)
        step += 1
        torch.save(torch.from_numpy(obs.transpose(2,0,1)),
                   open(savedir + '{}/{}.pt'.format(ep, step), 'wb'))
        torch.save(torch.LongTensor(actions),
                   open(savedir + '{}/actions.pt'.format(ep), 'wb'))
        torch.save(torch.FloatTensor(rewards),
                   open(savedir + '{}/rewards.pt'.format(ep), 'wb'))
        total_steps += 1
        if total_steps % 1000 == 0:
            print("{} steps and {} eps written!".format(total_steps, ep))
