import numpy as np
import gym
from gym import spaces
from gym.utils import seeding

# Unit test environment for CNNs.
# Looks like this (RGB observations):
#
#  ---------------------------
# |                           |
# |         ******            |
# |         ******            |
# |       **      **          |
# |       **      **          |
# |               **          |
# |               **          |
# |           ****            |
# |           ****            |
# |       ****                |
# |       ****                |
# |       **********          |
# |       **********          |
# |                           |
#  ---------------------------
#
# Agent should hit action 2 to gain reward. Catches off-by-one errors in your agent.
#
# To see how it works, run:
#
# python examples/agents/keyboard_agent.py MemorizeDigits-v0

FIELD_W = 32
FIELD_H = 24

bogus_mnist = [
    [" **** ", "*    *", "*    *", "*    *", "*    *", " **** "],
    ["  **  ", " * *  ", "   *  ", "   *  ", "   *  ", "  *** "],
    [" **** ", "*    *", "     *", "  *** ", "**    ", "******"],
    [" **** ", "*    *", "   ** ", "     *", "*    *", " **** "],
    [" *  * ", " *  * ", " *  * ", " **** ", "    * ", "    * "],
    [" **** ", " *    ", " **** ", "    * ", "    * ", " **** "],
    ["  *** ", " *    ", " **** ", " *  * ", " *  * ", " **** "],
    [" **** ", "    * ", "   *  ", "   *  ", "  *   ", "  *   "],
    [" **** ", "*    *", " **** ", "*    *", "*    *", " **** "],
    [" **** ", "*    *", "*    *", " *****", "     *", " **** "],
]

color_black = np.array((0, 0, 0)).astype("float32")
color_white = np.array((255, 255, 255)).astype("float32")


class MemorizeDigits(gym.Env):
    metadata = {
        "render.modes": ["human", "rgb_array"],
        "video.frames_per_second": 60,
        "video.res_w": FIELD_W,
        "video.res_h": FIELD_H,
    }

    use_random_colors = False

    def __init__(self):
        self.seed()
        self.viewer = None
        self.observation_space = spaces.Box(
            0, 255, (FIELD_H, FIELD_W, 3), dtype=np.uint8
        )
        self.action_space = spaces.Discrete(10)
        self.bogus_mnist = np.zeros((10, 6, 6), dtype=np.uint8)
        for digit in range(10):
            for y in range(6):
                self.bogus_mnist[digit, y, :] = [
                    ord(char) for char in bogus_mnist[digit][y]
                ]
        self.reset()

    def seed(self, seed=None):
        self.np_random, seed = seeding.np_random(seed)
        return [seed]

    def random_color(self):
        return np.array(
            [
                self.np_random.randint(low=0, high=255),
                self.np_random.randint(low=0, high=255),
                self.np_random.randint(low=0, high=255),
            ]
        ).astype("uint8")

    def reset(self):
        self.digit_x = self.np_random.randint(low=FIELD_W // 5, high=FIELD_W // 5 * 4)
        self.digit_y = self.np_random.randint(low=FIELD_H // 5, high=FIELD_H // 5 * 4)
        self.color_bg = self.random_color() if self.use_random_colors else color_black
        self.step_n = 0
        while 1:
            self.color_digit = (
                self.random_color() if self.use_random_colors else color_white
            )
            if np.linalg.norm(self.color_digit - self.color_bg) < 50:
                continue
            break
        self.digit = -1
        return self.step(0)[0]

    def step(self, action):
        reward = -1
        done = False
        self.step_n += 1
        if self.digit == -1:
            pass
        else:
            if self.digit == action:
                reward = +1
            done = self.step_n > 20 and 0 == self.np_random.randint(low=0, high=5)
        self.digit = self.np_random.randint(low=0, high=10)
        obs = np.zeros((FIELD_H, FIELD_W, 3), dtype=np.uint8)
        obs[:, :, :] = self.color_bg
        digit_img = np.zeros((6, 6, 3), dtype=np.uint8)
        digit_img[:] = self.color_bg
        xxx = self.bogus_mnist[self.digit] == 42
        digit_img[xxx] = self.color_digit
        obs[
            self.digit_y - 3 : self.digit_y + 3, self.digit_x - 3 : self.digit_x + 3
        ] = digit_img
        self.last_obs = obs
        return obs, reward, done, {}

    def render(self, mode="human"):
        if mode == "rgb_array":
            return self.last_obs

        elif mode == "human":
            from gym.envs.classic_control import rendering

            if self.viewer is None:
                self.viewer = rendering.SimpleImageViewer()
            self.viewer.imshow(self.last_obs)
            return self.viewer.isopen

        else:
            assert 0, "Render mode '%s' is not supported" % mode

    def close(self):
        if self.viewer is not None:
            self.viewer.close()
            self.viewer = None
