import numpy as np
import matplotlib.pyplot as plt
from envfiles.funcs.utils import *
from envfiles.funcs.create_lava_env import initialize_img


def create_centerSquare(map_size, lava_size):
    env_name = "centerSquare" + str(lava_size) + "x" + str(lava_size)
    width = map_size
    height = map_size
    grid = np.zeros([height, width])
    lava = np.zeros([height, width])

    t = (width - lava_size) / 2
    for i in range(height):
        for j in range(width):
            if t-1 < i < map_size-t and t-1 < j < map_size-t:
                lava[i, j] = 1

    img = initialize_img(grid, lava)

    file_name = generate_dir(env_name)
    np.save(file_name + "_grid.npy", grid)
    np.save(file_name + "_lava.npy", lava)
    np.save(file_name + "_img.npy", img)


def create_appleDoor():
    env_name = "appleDoor"
    width = 10
    height = 5
    grid = np.zeros([height, width])
    lava = np.zeros([height, width])

    grid[:2, 3] = 1
    grid[-2:, 3] = 1
    grid[-4:, 7] = 1

    img = initialize_img(grid, lava)

    file_name = generate_dir(env_name)
    np.save(file_name + "_grid.npy", grid)
    np.save(file_name + "_lava.npy", lava)
    np.save(file_name + "_img.npy", img)


if __name__ == '__main__':
    create_centerSquare(map_size=10, lava_size=6)
    # create_appleDoor()