#!/usr/bin/env python3

"""
Code derived from Maxime et al: https://github.com/maximecb/gym-minigrid
"""

import sys
sys.path.insert(1, '../')

import argparse
import gymnasium as gym
import matplotlib.pyplot as plt

from envs.make_env import *

# plt.rcParams.update({
#     "text.usetex": True,
#     "font.family": "sans-serif",
#     "font.sans-serif": "Helvetica",
#     "text.latex.preamble":r'\usepackage{pifont,marvosym,scalerel}'
# })

# For Mujoco envs, to render rgb: export LD_PRELOAD=""
# For Mujoco envs, to render render regular window: export LD_PRELOAD=/usr/lib/x86_64-linux-gnu/libGLEW.so

def redraw(img):
    if not args.agent_view:
        img = env.render()
    if args.agent_view or args.render_mode=="rgb_array":
        plt.imshow(img)   
        fig.canvas.blit(ax.bbox)
        plt.pause(0.0000000001)

        # env_fig = env.gen_fig()
        # env_fig.savefig("./images/tmaze{}.png".format(args.maze_length), bbox_inches='tight')
        # env_fig.savefig("./images/pdf/tmaze{}.pdf".format(args.maze_length), bbox_inches='tight')

def reset():      
    obs, _ = env.reset(seed=args.seed)  
    redraw(obs)

def step(action):
    obs, reward, done, truncated, info = env.step(action) 
    print('reward = {} | done = {} | truncated = {}'.format(reward, done, truncated))
    print(info)
    redraw(obs)

def key_handler(event):
    print('pressed', event.key)
    # minigrid_actions = {"left":"left", "right": "right", "up":"forward", "enter":"pickup", "down":"drop", " ":"toggle"}
    minigrid_actions = {"left":0, "right":1, "up":2, "enter":3, "down":4, " ":5}

    if event.key == 'escape': 
        return plt.close()
    elif event.key == 'backspace': 
        return reset()
    elif hasattr(env.unwrapped,"actions") and type(env.unwrapped.actions)==dict and event.key in env.unwrapped.actions: 
        action = env.unwrapped.actions[event.key]
    elif "MiniGrid" in args.env and event.key in minigrid_actions: 
        action = minigrid_actions[event.key]
    else: 
        action = env.action_space.sample()
        print("random action")

    step(action)
    


parser = argparse.ArgumentParser()
parser.add_argument('--env', default="tmaze-v0", help="Environment name")
parser.add_argument('--algo', default="PPO", help="RL learning algorithm: DQN or PPO")
parser.add_argument('--mask_type', default="fully_obs", help="fully_obs, no_stack, framestack, masked, ca_masked, ca_all_masked, demir")
parser.add_argument('--cube_cam', default="orthographic", help="full, face, orthographic")
parser.add_argument('--scramble_steps', type=int, default=5, help="Scramble steps for cube env")
parser.add_argument('--maze_length', type=int, default=1, help="Maze length for tmaze")
parser.add_argument("--random_length", help="", action='store_true', default=False)
parser.add_argument('--active', action='store_true', default=False, help="Active tmaze mode")
parser.add_argument('--continual', action='store_true', default=False, help="Continual setting")
parser.add_argument('--visible_goal_steps', type=int, default=2, help="Number of steps where the environment goal is visible in GCRL tasks")
parser.add_argument('--max_episode_steps', type=int, default=50, help="Max number of steps per episode")
parser.add_argument('--num_stack', type=int, default=1, help="Memory length (sequence length)")
parser.add_argument('--maxiter', type=int, default=1e6, help="Max training timesteps")
parser.add_argument('--features_dim', type=int, default=256, help="Input dim of policy layer")
parser.add_argument('--hidden_size', type=int, default=128, help="Hidden dim of memory architecture layer")
parser.add_argument('--run', type=int, default=None, help="Random seed / run id")
parser.add_argument('--nenvs', type=int, default=1, help="Number of envs/processes")
parser.add_argument('--path', default="./data/", help="Save path for logs and models")
parser.add_argument('--device', default="cpu", help="Device for Pytorch")
parser.add_argument('--arch', choices=['cnn', 'mlp', 'transformer', 'lstm'], default='mlp', help="Policy architecture")
parser.add_argument('--render_mode', default='human', help="Render mode")
parser.add_argument("--agent_view", help="", action='store_true', default=False)
parser.add_argument("--seed", type=int, help="random seed to generate the environment with", default=None)
args = parser.parse_args()

if __name__ == "__main__":
    # gym.logger.set_level(gym.logger.ERROR) 
    args.agent_view 

    fig, ax = plt.subplots(figsize=(1,1), facecolor='w', edgecolor='k')
    fig.canvas.mpl_connect('key_press_event', key_handler)
    plt.xticks([]); plt.yticks([]); plt.axis("off"); plt.axis("off")
    fig.tight_layout()
    
    env, name = make_env(args,args.seed)
    print("Observation space: ", env.observation_space)
    print("action_space: ", env.action_space)

    plt.get_current_fig_manager().set_window_title(args.env)

    reset(); plt.show()
