import numpy as np
from gym import core, spaces


BLOCK_NORMAL, BLOCK_WALL, BLOCK_HALLWAY, BLOCK_AGENT, BLOCK_GOAL = 0, 1, 2, 3, 4
RGB_COLORS = {
    'red': np.array([240, 52, 52]),
    'black': np.array([0, 0, 0]),
    'green': np.array([77, 181, 33]),
    'blue': np.array([29, 111, 219]),
    'purple': np.array([112, 39, 195]),
    'yellow': np.array([217, 213, 104]),
    'grey': np.array([192, 195, 196]),
    'light_grey': np.array([230, 230, 230]),
    'white': np.array([255, 255, 255])
}

four_room_map = [
    [1, 1, 1, 1, 1, 1, 1, 1, 1],
    [1, 0, 0, 0, 1, 0, 4, 0, 1],
    [1, 0, 0, 0, 1, 0, 0, 0, 1],
    [1, 0, 0, 0, 1, 0, 0, 0, 1],
    [1, 1, 2, 1, 1, 0, 0, 0, 1],
    [1, 0, 0, 0, 1, 1, 2, 1, 1],
    [1, 0, 0, 0, 2, 0, 0, 0, 1],
    [1, 0, 0, 0, 1, 0, 0, 0, 1],
    [1, 1, 1, 1, 1, 1, 1, 1, 1]
]


class FourRoomGridWorld:
    def __init__(self, stochasticity_fraction=0.0):
        self._grid = np.transpose(np.flip(np.array(four_room_map, dtype=np.uint8), axis=0)[1:-1, 1:-1])
        self._max_row, self._max_col = self._grid.shape
        self._normal_tiles = np.where(self._grid == BLOCK_NORMAL)
        self._hallways_tiles = np.where(self._grid == BLOCK_HALLWAY)
        self._goal_tile = np.where(self._grid == BLOCK_GOAL)
        self._walls_tiles = np.where(self._grid == BLOCK_WALL)
        self.num_states = self._grid.size

        self._state = None
        self.ACTION_UP, self.ACTION_DOWN, self.ACTION_RIGHT, self.ACTION_LEFT = 0, 1, 2, 3
        self.num_actions = 4
        self.action_space = spaces.Discrete(self.num_actions)
        self._stochasticity_fraction = stochasticity_fraction
        self.hallways = {
            0: (3, 1),
            1: (1, 3),
            2: (5, 2)
        }
        self._window, self._info = None, None

    def reset(self):
        self._state = (0, 0)
        return self.get_rep(*self._state)

    def step(self, action):
        x, y = self._state
        is_stochastic_selected = False
        if self._stochasticity_fraction >= np.random.uniform():
            action_probability = [1 / (self.num_actions - 1) if i != action else 0 for i in range(self.num_actions)]
            action = np.random.choice(self.num_actions, 1, p=action_probability)[0]
            is_stochastic_selected = True
        x_p, y_p = self._next(action, *self._state)
        is_done = self._grid[x_p, y_p] == BLOCK_GOAL
        reward = -1
        self._state = (x_p, y_p)
        return self.get_rep(*self._state), reward, is_done, {
            'x': x, 'y': y,
            'x_p': x_p, 'y_p': y_p,
            'is_stochastic_selected': is_stochastic_selected,
            'selected_action': action}

    def get_xy(self, state):
        return (state % self._max_row), (state // self._max_col)

    def get_state_index(self, x, y):
        return y * self._max_col + x

    def get_rep(self, x, y):
        index = self.get_state_index(x, y)
        rep = np.zeros(self._grid.shape[0]**2)
        rep[index] = 1
        return rep

        # distractor = (np.random.randint(self._max_row, size=8))
        # return np.concatenate(((x, y), distractor))

    def _next(self, action, x, y):

        def move(current_x, current_y, next_x, next_y):
            if next_y < 0 or next_x < 0:
                return current_x, current_y
            if next_y >= self._max_col or next_x >= self._max_row:
                return current_x, current_y
            if self._grid[next_x, next_y] == BLOCK_WALL:
                return current_x, current_y
            return next_x, next_y

        switcher = {
            self.ACTION_DOWN: lambda pox_x, pos_y: move(pox_x, pos_y, pox_x, pos_y - 1),
            self.ACTION_RIGHT: lambda pox_x, pos_y: move(pox_x, pos_y, pox_x + 1, pos_y),
            self.ACTION_UP: lambda pox_x, pos_y: move(pox_x, y, pox_x, pos_y + 1),
            self.ACTION_LEFT: lambda pox_x, pos_y: move(pox_x, pos_y, pox_x - 1, pos_y),
        }
        move_func = switcher.get(action)
        return move_func(x, y)

