import numpy as np
from scipy.signal import convolve2d
import re

LARGE_MAZE = \
    "############\\" + \
    "#DOOO#OOOOO#\\" + \
    "#O##O#O#O#O#\\" + \
    "#OOODOO#OOO#\\" + \
    "#O####O###O#\\" + \
    "#OO#D#OOOOO#\\" + \
    "##O#O#O#O###\\" + \
    "#OO#OOO#OGO#\\" + \
    "############"

LARGER_MAZE_BLUE = \
    "###############\\" + \
    "#GOOO#OOOOOODO#\\" + \
    "#O##O#OOO#OO#O#\\" + \
    "#OOOOOOOO#OOOO#\\" + \
    "#OO#####O####O#\\" + \
    "#OOO#DO#OOOOOO#\\" + \
    "##OO#OO#O#OO###\\" + \
    "#OOO#OOOO#ODOO#\\" + \
    "###############"

LARGE_MAZE_BLUE = \
    "############\\" + \
    "#GOOO#OOOOD#\\" + \
    "#O##O#O#O#O#\\" + \
    "#OOOOOO#OOO#\\" + \
    "#O####O###O#\\" + \
    "#OO#D#OOOOO#\\" + \
    "##O#O#O#O###\\" + \
    "#OO#OOO#ODO#\\" + \
    "############"
LARGE_MAZE_RED = \
    "############\\" + \
    "#DOOO#OOOOG#\\" + \
    "#O##O#O#O#O#\\" + \
    "#OOOOOO#OOO#\\" + \
    "#O####O###O#\\" + \
    "#OO#D#OOOOO#\\" + \
    "##O#O#O#O###\\" + \
    "#OO#OOO#ODO#\\" + \
    "############"
LARGE_MAZE_MAGENTA = \
    "############\\" + \
    "#DOOO#OOOOD#\\" + \
    "#O##O#O#O#O#\\" + \
    "#OOOOOO#OOO#\\" + \
    "#O####O###O#\\" + \
    "#OO#G#OOOOO#\\" + \
    "##O#O#O#O###\\" + \
    "#OO#OOO#ODO#\\" + \
    "############"
LARGE_MAZE_YELLOW = \
    "############\\" + \
    "#DOOO#OOOOD#\\" + \
    "#O##O#O#O#O#\\" + \
    "#OOOOOO#OOO#\\" + \
    "#O####O###O#\\" + \
    "#OO#D#OOOOO#\\" + \
    "##O#O#O#O###\\" + \
    "#OO#OOO#OGO#\\" + \
    "############"

LARGE_MAZE_EVAL = \
    "############\\" + \
    "#OO#OOO#OGO#\\" + \
    "##O###O#O#O#\\" + \
    "#OO#O#OOOOO#\\" + \
    "#O##O#OO##O#\\" + \
    "#OOOOOO#OOO#\\" + \
    "#O##O#O#O###\\" + \
    "#OOOO#OOOOO#\\" + \
    "############"

MEDIUM_MAZE = \
    '########\\' + \
    '#OO##OO#\\' + \
    '#OO#OOO#\\' + \
    '##OOO###\\' + \
    '#OO#OOO#\\' + \
    '#O#OO#O#\\' + \
    '#OOO#OG#\\' + \
    "########"

MEDIUM_MAZE_EVAL = \
    '########\\' + \
    '#OOOOOG#\\' + \
    '#O#O##O#\\' + \
    '#OOOO#O#\\' + \
    '###OO###\\' + \
    '#OOOOOO#\\' + \
    '#OO##OO#\\' + \
    "########"

SMALL_MAZE = \
    "######\\" + \
    "#OOOO#\\" + \
    "#O##O#\\" + \
    "#OOOO#\\" + \
    "######"

U_MAZE = \
    "#####\\" + \
    "#GOO#\\" + \
    "###O#\\" + \
    "#OOO#\\" + \
    "#####"

U_MAZE_EVAL = \
    "#####\\" + \
    "#OOG#\\" + \
    "#O###\\" + \
    "#OOO#\\" + \
    "#####"

OPEN_YELLOW = \
    "#############\\" + \
    "#OOOOOOOODOO#\\" + \
    "#OOOOODOOOOO#\\" + \
    "#OOOOOOOOOOO#\\" + \
    "#OOOOOOOOOOO#\\" + \
    "#OOOOOOOODOO#\\" + \
    "#OOOOOOOOOOO#\\" + \
    "#OOOOOOOOOOO#\\" + \
    "#OOGOOOOOOOO#\\" + \
    "#OOOOOOOOOOO#\\" + \
    "#OOOOOOOOOOO#\\" + \
    "#OOOOOOOOOOO#\\" + \
    "#############"
OPEN_RED = \
    "#############\\" + \
    "#OOOOOOOODOO#\\" + \
    "#OOOOOGOOOOO#\\" + \
    "#OOOOOOOOOOO#\\" + \
    "#OOOOOOOOOOO#\\" + \
    "#OOOOOOOODOO#\\" + \
    "#OOOOOOOOOOO#\\" + \
    "#OOOOOOOOOOO#\\" + \
    "#OODOOOOOOOO#\\" + \
    "#OOOOOOOOOOO#\\" + \
    "#OOOOOOOOOOO#\\" + \
    "#OOOOOOOOOOO#\\" + \
    "#############"
OPEN_BLUE = \
    "#############\\" + \
    "#OOOOOOOOGOO#\\" + \
    "#OOOOODOOOOO#\\" + \
    "#OOOOOOOOOOO#\\" + \
    "#OOOOOOOOOOO#\\" + \
    "#OOOOOOOODOO#\\" + \
    "#OOOOOOOOOOO#\\" + \
    "#OOOOOOOOOOO#\\" + \
    "#OODOOOOOOOO#\\" + \
    "#OOOOOOOOOOO#\\" + \
    "#OOOOOOOOOOO#\\" + \
    "#OOOOOOOOOOO#\\" + \
    "#############"
OPEN_MAGENTA = \
    "#############\\" + \
    "#OOOOOOOODOO#\\" + \
    "#OOOOODOOOOO#\\" + \
    "#OOOOOOOOOOO#\\" + \
    "#OOOOOOOOOOO#\\" + \
    "#OOOOOOOOGOO#\\" + \
    "#OOOOOOOOOOO#\\" + \
    "#OOOOOOOOOOO#\\" + \
    "#OODOOOOOOOO#\\" + \
    "#OOOOOOOOOOO#\\" + \
    "#OOOOOOOOOOO#\\" + \
    "#OOOOOOOOOOO#\\" + \
    "#############"


HARD_EXP_MAZE = \
    "#####################\\" + \
    "#OODO#OOOO#OOOO#OOOO#\\" + \
    "#OOO###OOO#OOO###OOO#\\" + \
    "#OOOOOOOOO#OOOOOOOOO#\\" + \
    "#OOO###OOO#OOO###OOO#\\" + \
    "#OOOO#OOOO#ODOO#OOOO#\\" + \
    "#OOOO#OOO###OOO#OOOO#\\" + \
    "#OOOO#OOOOOOOOO#OOOO#\\" + \
    "#ODOO#OOO###OOO#OOOO#\\" + \
    "#OOOO#OOOO#OOOO#OOOG#\\" + \
    "#####################"

LARGE_EXP_MAZE_BLUE_V2 = \
    "#####################\\" + \
    "#OOOOOOOOOOOOOOOOOOO#\\" + \
    "#OOO###OOO#OOO###OOO#\\" + \
    "#OOOOOOOOO#OOOOOOOOO#\\" + \
    "#OOOOOOGOO#OOOOODOOO#\\" + \
    "#OOO###OOO#OOO###OOO#\\" + \
    "#OOOO#OOO###OOO#OOOO#\\" + \
    "#OOOO#OOOOOOOOO#OOOO#\\" + \
    "#OOOO#OOOOOOOOO#OOOO#\\" + \
    "#OOOO#OOO###OOO#OOOO#\\" + \
    "##OO###OOO#OOO###OO##\\" + \
    "#OOOO#OOOO#OOOO#OOOO#\\" + \
    "#OOOO#OOO###OOO#OOOO#\\" + \
    "#OOOO#OOOOOOOOO#OOOO#\\" + \
    "#OOOO#OOOOOOOOO#OOOO#\\" + \
    "#OOO###OO###OO###OOO#\\" + \
    "#OOOOOODOO#OODOOOOOO#\\" + \
    "#OOOOOOOOO#OOOOOOOOO#\\" + \
    "#OOO###OOO#OOO###OOO#\\" + \
    "#OOOOOOOOOOOOOOOOOOO#\\" + \
    "#####################"
LARGE_EXP_MAZE_RED_V2 = \
    "#####################\\" + \
    "#OOOOOOOOOOOOOOOOOOO#\\" + \
    "#OOO###OOO#OOO###OOO#\\" + \
    "#OOOOOOOOO#OOOOOOOOO#\\" + \
    "#OOOOOODOO#OOOOOGOOO#\\" + \
    "#OOO###OOO#OOO###OOO#\\" + \
    "#OOOO#OOO###OOO#OOOO#\\" + \
    "#OOOO#OOOOOOOOO#OOOO#\\" + \
    "#OOOO#OOOOOOOOO#OOOO#\\" + \
    "#OOOO#OOO###OOO#OOOO#\\" + \
    "##OO###OOO#OOO###OO##\\" + \
    "#OOOO#OOOO#OOOO#OOOO#\\" + \
    "#OOOO#OOO###OOO#OOOO#\\" + \
    "#OOOO#OOOOOOOOO#OOOO#\\" + \
    "#OOOO#OOOOOOOOO#OOOO#\\" + \
    "#OOO###OO###OO###OOO#\\" + \
    "#OOOOOODOO#OODOOOOOO#\\" + \
    "#OOOOOOOOO#OOOOOOOOO#\\" + \
    "#OOO###OOO#OOO###OOO#\\" + \
    "#OOOOOOOOOOOOOOOOOOO#\\" + \
    "#####################"
LARGE_EXP_MAZE_YELLOW_V2 = \
    "#####################\\" + \
    "#OOOOOOOOOOOOOOOOOOO#\\" + \
    "#OOO###OOO#OOO###OOO#\\" + \
    "#OOOOOOOOO#OOOOOOOOO#\\" + \
    "#OOOOOODOO#OOOOODOOO#\\" + \
    "#OOO###OOO#OOO###OOO#\\" + \
    "#OOOO#OOO###OOO#OOOO#\\" + \
    "#OOOO#OOOOOOOOO#OOOO#\\" + \
    "#OOOO#OOOOOOOOO#OOOO#\\" + \
    "#OOOO#OOO###OOO#OOOO#\\" + \
    "##OO###OOO#OOO###OO##\\" + \
    "#OOOO#OOOO#OOOO#OOOO#\\" + \
    "#OOOO#OOO###OOO#OOOO#\\" + \
    "#OOOO#OOOOOOOOO#OOOO#\\" + \
    "#OOOO#OOOOOOOOO#OOOO#\\" + \
    "#OOO###OO###OO###OOO#\\" + \
    "#OOOOOODOO#OOGOOOOOO#\\" + \
    "#OOOOOOOOO#OOOOOOOOO#\\" + \
    "#OOO###OOO#OOO###OOO#\\" + \
    "#OOOOOOOOOOOOOOOOOOO#\\" + \
    "#####################"
LARGE_EXP_MAZE_MAGENTA_V2 = \
    "#####################\\" + \
    "#OOOOOOOOOOOOOOOOOOO#\\" + \
    "#OOO###OOO#OOO###OOO#\\" + \
    "#OOOOOOOOO#OOOOOOOOO#\\" + \
    "#OOOOOODOO#OOOOODOOO#\\" + \
    "#OOO###OOO#OOO###OOO#\\" + \
    "#OOOO#OOO###OOO#OOOO#\\" + \
    "#OOOO#OOOOOOOOO#OOOO#\\" + \
    "#OOOO#OOOOOOOOO#OOOO#\\" + \
    "#OOOO#OOO###OOO#OOOO#\\" + \
    "##OO###OOO#OOO###OO##\\" + \
    "#OOOO#OOOO#OOOO#OOOO#\\" + \
    "#OOOO#OOO###OOO#OOOO#\\" + \
    "#OOOO#OOOOOOOOO#OOOO#\\" + \
    "#OOOO#OOOOOOOOO#OOOO#\\" + \
    "#OOO###OO###OO###OOO#\\" + \
    "#OOOOOOGOO#OODOOOOOO#\\" + \
    "#OOOOOOOOO#OOOOOOOOO#\\" + \
    "#OOO###OOO#OOO###OOO#\\" + \
    "#OOOOOOOOOOOOOOOOOOO#\\" + \
    "#####################"
HARD_EXP_MAZE_RED_V2 = \
    "#####################\\" + \
    "#OOOO#OOOO#OOOO#OOOO#\\" + \
    "#OOO###OOO#OOO###OOO#\\" + \
    "#OOOOOOOOO#OOOOOOOOO#\\" + \
    "#OOO###DOO#OOO###OOO#\\" + \
    "#OOOO#OOOO#OOOO#OOOO#\\" + \
    "#OOOO#OOO###OOO#OOOO#\\" + \
    "#OOOO#OOOOOOOOO#OOOO#\\" + \
    "#OOOO#OOO###OOO#OOOO#\\" + \
    "#OOOO#OOOO#OOOO#OOOO#\\" + \
    "#################OG##\\" + \
    "#OOOO#OOOO#OOOO#OOOO#\\" + \
    "#OOOO#OOO###OOO#OOOO#\\" + \
    "#ODOO#OOOOOOOOO#OOOO#\\" + \
    "#OOOO#OOO###OOO#OOOO#\\" + \
    "#OOOO#OOOO#OOOO#OOOO#\\" + \
    "#OOO###OOO#OOD###OOO#\\" + \
    "#OOOOOOOOO#OOOOOOOOO#\\" + \
    "#OOO###OOO#OOO###OOO#\\" + \
    "#OOOO#OOOO#OOOO#OOOO#\\" + \
    "#####################"
HARD_EXP_MAZE_BLUE_V2 = \
    "#####################\\" + \
    "#OOOO#OOOO#OOOO#OOOO#\\" + \
    "#OOO###OOO#OOO###OOO#\\" + \
    "#OOOOOOOOO#OOOOOOOOO#\\" + \
    "#OOO###GOO#OOO###OOO#\\" + \
    "#OOOO#OOOO#OOOO#OOOO#\\" + \
    "#OOOO#OOO###OOO#OOOO#\\" + \
    "#OOOO#OOOOOOOOO#OOOO#\\" + \
    "#OOOO#OOO###OOO#OOOO#\\" + \
    "#OOOO#OOOO#OOOO#OOOO#\\" + \
    "#################OD##\\" + \
    "#OOOO#OOOO#OOOO#OOOO#\\" + \
    "#OOOO#OOO###OOO#OOOO#\\" + \
    "#ODOO#OOOOOOOOO#OOOO#\\" + \
    "#OOOO#OOO###OOO#OOOO#\\" + \
    "#OOOO#OOOO#OOOO#OOOO#\\" + \
    "#OOO###OOO#OOD###OOO#\\" + \
    "#OOOOOOOOO#OOOOOOOOO#\\" + \
    "#OOO###OOO#OOO###OOO#\\" + \
    "#OOOO#OOOO#OOOO#OOOO#\\" + \
    "#####################"
HARD_EXP_MAZE_YELLOW_V2 = \
    "#####################\\" + \
    "#OOOO#OOOO#OOOO#OOOO#\\" + \
    "#OOO###OOO#OOO###OOO#\\" + \
    "#OOOOOOOOO#OOOOOOOOO#\\" + \
    "#OOO###DOO#OOO###OOO#\\" + \
    "#OOOO#OOOO#OOOO#OOOO#\\" + \
    "#OOOO#OOO###OOO#OOOO#\\" + \
    "#OOOO#OOOOOOOOO#OOOO#\\" + \
    "#OOOO#OOO###OOO#OOOO#\\" + \
    "#OOOO#OOOO#OOOO#OOOO#\\" + \
    "#################OD##\\" + \
    "#OOOO#OOOO#OOOO#OOOO#\\" + \
    "#OOOO#OOO###OOO#OOOO#\\" + \
    "#ODOO#OOOOOOOOO#OOOO#\\" + \
    "#OOOO#OOO###OOO#OOOO#\\" + \
    "#OOOO#OOOO#OOOO#OOOO#\\" + \
    "#OOO###OOO#OOG###OOO#\\" + \
    "#OOOOOOOOO#OOOOOOOOO#\\" + \
    "#OOO###OOO#OOO###OOO#\\" + \
    "#OOOO#OOOO#OOOO#OOOO#\\" + \
    "#####################"
HARD_EXP_MAZE_MAGENTA_V2 = \
    "#####################\\" + \
    "#OOOO#OOOO#OOOO#OOOO#\\" + \
    "#OOO###OOO#OOO###OOO#\\" + \
    "#OOOOOOOOO#OOOOOOOOO#\\" + \
    "#OOO###DOO#OOO###OOO#\\" + \
    "#OOOO#OOOO#OOOO#OOOO#\\" + \
    "#OOOO#OOO###OOO#OOOO#\\" + \
    "#OOOO#OOOOOOOOO#OOOO#\\" + \
    "#OOOO#OOO###OOO#OOOO#\\" + \
    "#OOOO#OOOO#OOOO#OOOO#\\" + \
    "#################OD##\\" + \
    "#OOOO#OOOO#OOOO#OOOO#\\" + \
    "#OOOO#OOO###OOO#OOOO#\\" + \
    "#OGOO#OOOOOOOOO#OOOO#\\" + \
    "#OOOO#OOO###OOO#OOOO#\\" + \
    "#OOOO#OOOO#OOOO#OOOO#\\" + \
    "#OOO###OOO#OOD###OOO#\\" + \
    "#OOOOOOOOO#OOOOOOOOO#\\" + \
    "#OOO###OOO#OOO###OOO#\\" + \
    "#OOOO#OOOO#OOOO#OOOO#\\" + \
    "#####################"

def compute_sampling_probs(maze_layout, filter, temp):
    probs = convolve2d(maze_layout, filter, 'valid')
    return np.exp(-temp * probs) / np.sum(np.exp(-temp * probs))


def sample_2d(probs, rng):
    flat_probs = probs.flatten()
    sample = rng.choice(np.arange(flat_probs.shape[0]), p=flat_probs)
    sampled_2d = np.zeros_like(flat_probs)
    sampled_2d[sample] = 1
    idxs = np.where(sampled_2d.reshape(probs.shape))
    return idxs[0][0], idxs[1][0]


def place_wall(maze_layout, rng, min_len_frac, max_len_frac, temp):
    """Samples wall such that overlap with other walls is minimized (overlap is determined by temperature).
       Also adds one door per wall."""
    size = maze_layout.shape[0]
    sample_vert_hor = 0 if rng.random() < 0.5 else 1
    sample_len = int(max((max_len_frac - min_len_frac) * size * rng.random() + min_len_frac * size, 3))
    sample_door_offset = rng.choice(np.arange(1, sample_len - 1))

    if sample_vert_hor == 0:
        filter = np.ones((sample_len, 5)) / (5 * sample_len)
        probs = compute_sampling_probs(maze_layout, filter, temp)
        middle_idxs = sample_2d(probs, rng)
        sample_pos1 = middle_idxs[0]
        sample_pos2 = middle_idxs[1] + 2

        maze_layout[sample_pos1: sample_pos1 + sample_len, sample_pos2] = 1
        maze_layout[sample_pos1 + sample_door_offset, sample_pos2] = 0
        maze_layout[sample_pos1 + sample_door_offset - 1, sample_pos2 + 1] = 1
        maze_layout[sample_pos1 + sample_door_offset - 1, sample_pos2 - 1] = 1
        maze_layout[sample_pos1 + sample_door_offset + 1, sample_pos2 + 1] = 1
        maze_layout[sample_pos1 + sample_door_offset + 1, sample_pos2 - 1] = 1
    else:
        filter = np.ones((5, sample_len)) / (5 * sample_len)
        probs = compute_sampling_probs(maze_layout, filter, temp)
        middle_idxs = sample_2d(probs, rng)
        sample_pos1 = middle_idxs[1]
        sample_pos2 = middle_idxs[0] + 2

        maze_layout[sample_pos2, sample_pos1: sample_pos1 + sample_len] = 1
        maze_layout[sample_pos2, sample_pos1 + sample_door_offset] = 0
        maze_layout[sample_pos2 + 1, sample_pos1 + sample_door_offset - 1] = 1
        maze_layout[sample_pos2 - 1, sample_pos1 + sample_door_offset - 1] = 1
        maze_layout[sample_pos2 + 1, sample_pos1 + sample_door_offset + 1] = 1
        maze_layout[sample_pos2 - 1, sample_pos1 + sample_door_offset + 1] = 1
    return maze_layout


def sample_layout(seed=None,
                  size=20,
                  max_len_frac=0.5,
                  min_len_frac=0.3,
                  coverage_frac=0.25,
                  temp=20):
    """
    Generates maze layout with randomly placed walls.
    :param seed: if not None, makes maze layout reproducible
    :param size: number of cells per side in maze
    :param max_len_frac: maximum length of walls, as fraction of total maze side length
    :param min_len_frac: minimum length of walls, as fraction of total maze side length
    :param coverage_frac: fraction of cells that is covered with walls in randomly generated layout
    :param temp: controls overlap of walls in maze, the higher the temp the less the overlap of walls
    :return: layout matrix (where 1 indicates wall, 0 indicates free space)
    """
    rng = np.random.default_rng(seed=seed)
    maze_layout = np.zeros((size, size))

    while np.mean(maze_layout) < coverage_frac:
        maze_layout = place_wall(maze_layout, rng, min_len_frac, max_len_frac, temp)

    return maze_layout


def layout2str(layout, seed):
    """Transfers a layout matrix to string format that is used by MazeEnv class."""
    h, w = layout.shape
    padded_layout = np.ones((h + 2, w + 2))
    padded_layout[1:-1, 1:-1] = layout
    output_str = ""
    for row in padded_layout:
        for cell in row:
            output_str += "O" if cell == 0 else "#"
        output_str += "\\"
    output_str = output_str[:-1]  # remove last line break

    # add goal at random position
    rng = np.random.default_rng(seed=None)
    length = len(output_str) - 1
    output_str = find_empty_position(length, output_str, rng, "G")

    # add 3 distractors at random position
    num_dis = 3
    for _ in range(num_dis):
        output_str = find_empty_position(length, output_str, rng, "D")

    return output_str


def find_empty_position(length, output_str, rng, new_char):
    index = 0
    while output_str[index] != "O":
        index = rng.choice(np.arange(1, length))
    output_str = output_str[:index] + new_char + output_str[index + 1:]
    return output_str


def rand_layout(seed=None, **kwargs):
    """Generates random layout with specified params (see 'sample_layout' function)."""
    rand_layout = sample_layout(seed, **kwargs)
    layout_str = layout2str(rand_layout, seed)
    return layout_str


if __name__ == "__main__":
    print(rand_layout(24))

    ######################\
    # GOOOOO#OOOOOO#OOOOOO#\
    # OOOOOO#OOOOOO#OOOOOO#\
    # OOOOOO#OOOOOO#OOOOOO#\
    # OOOOOO##O#OOO##O#OOO#\
    # OOO#####O##O###O#####\
    # OOOOO###O#OO###O#OOO#\
    # OOOOOOO#OOOOOOOOOOOO#\
    # OOO#O###OOO####OOOOO#\
    # OOO#OO##OOO#OOOOOOOO#\
    # OOO#OOO#OO###OO###OO#\
    # OO###OO#OOOOOOOOOOOO#\
    # OOOOOO###O###OO###OO#\
    # OO###OOOOOO#OO#O#OOO#\
    # OOO#OO###OO#OO#O#OOO#\
    # OOOOOOOOOOO#OO#O#OOO#\
    # OOOOOOO#O#OOO###OOOO#\
    # OOOO####O###OOOOOOOO#\
    # OOOOOOO#O#OOO###OOOO#\
    # OOOOOOOOOOOOOO#OOOOO#\
    # OOOOOOOOOOOOOOOOOOOO#\
    ######################
