import os
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.ticker as plticker
from deep_sprl.util.maze_env_utils import construct_maze
import deep_sprl.environments
import gym


def function_plot(n):
    # Function
    np.random.seed(0)
    offsets = np.random.uniform(0, 2 * np.pi, size=(1000,))
    multipliers = np.random.uniform(-2 * np.pi, 2 * np.pi, size=(1000, 2))
    weights = np.random.uniform(-1., 1., size=(1000,))

    X, Y = np.meshgrid(np.linspace(0., 1., n), np.linspace(0., 1., n))
    features = np.sin(
        np.einsum("ij,kj->ik", np.stack((X.reshape(-1), Y.reshape(-1)), axis=-1), multipliers) + offsets[None, :])
    Z = np.einsum("k,ik->i", weights, features)

    return np.reshape(Z, (n, n))


def maze():
    # Maze
    maze = np.array(construct_maze(0))
    # We draw a black white image
    maze_image = 255 * np.ones(maze.shape + (3,))
    x, y = np.where(maze == '1')
    maze_image[x, y, :] = 0.
    x, y = np.where(maze == "r")
    maze_image[x, y, 1:] = 0.

    return maze_image


def maze_mujoco_screenshot(path=None):
    env = gym.make("Maze-v1", maze_id=0)
    env.unwrapped.context = np.array([12, 12])
    env.reset()
    img = env.render("rgb_array")

    f = plt.figure(figsize=(1.325, 1.1))
    ax = plt.Axes(f, [0, 0, 1, 1])
    f.add_axes(ax)
    ax.imshow(img)
    ax.set_xticklabels([])
    ax.set_yticklabels([])
    ax.tick_params('both', length=0, width=0, which='major')
    ax.set_rasterized(True)
    if path is None:
        plt.show()
    else:
        plt.savefig(path)


def maze_discretization_figure(path=None):
    f = plt.figure(figsize=(1.325, 1.1))
    ax = plt.Axes(f, [0, 0, 1, 1])
    f.add_axes(ax)

    maze_image = maze()
    ax.imshow(maze_image, origin="lower", extent=[0, 1, 0, 1])
    ax.set_xticklabels([])
    ax.set_yticklabels([])
    ax.tick_params('both', length=0, width=0, which='major')
    ax.set_rasterized(True)
    plt.tight_layout()
    if path is None:
        plt.show()
    else:
        plt.savefig(os.path.join(path, "maze_only.pdf"))

    f = plt.figure(figsize=(1.325, 1.1))
    ax = plt.Axes(f, [0, 0, 1, 1])
    f.add_axes(ax)
    maze_image = maze()
    ax.imshow(function_plot(50), origin="lower", extent=[0, 1, 0, 1], interpolation='bilinear')
    ax.imshow(maze_image, origin="lower", extent=[0, 1, 0, 1], alpha=0.05)
    ax.set_xticklabels([])
    ax.set_yticklabels([])
    ax.tick_params('both', length=0, width=0, which='major')
    ax.set_rasterized(True)
    plt.tight_layout()
    if path is None:
        plt.show()
    else:
        plt.savefig(os.path.join(path, "continuous.pdf"))

    f = plt.figure(figsize=(1.325, 1.1))
    ax = plt.Axes(f, [0, 0, 1, 1])
    f.add_axes(ax)
    maze_image = maze()
    ax.imshow(function_plot(10), origin="lower", extent=[0, 1, 0, 1])
    ax.imshow(maze_image, origin="lower", extent=[0, 1, 0, 1], alpha=0.05)
    # Grids
    loc = plticker.MultipleLocator(base=1 / 10)
    ax.xaxis.set_major_locator(loc)
    ax.yaxis.set_major_locator(loc)
    ax.set_xticklabels([])
    ax.set_yticklabels([])
    ax.tick_params('both', length=0, width=0, which='major')
    ax.grid(which='major', axis='both', linestyle='-', linewidth=1.)
    ax.set_rasterized(True)

    if path is None:
        plt.show()
    else:
        plt.savefig(os.path.join(path, "discretized.pdf"))


if __name__ == "__main__":
    os.makedirs("figures", exist_ok=True)
    maze_mujoco_screenshot("figures/maze_mujoco.pdf")
    maze_discretization_figure(path="figures")
