# Python imports.
import sys
import os
import logging
import json

import numpy as np
from simple_rl.agents import RandomAgent, QLearningAgent
# from simple_rl.tasks import GymMDP
from TestEnv import GymMDP, TestEnv, TestEnvEasy, TestEnvMedium, TestEnvLessBall, TestEnvMediumLava
from simple_rl.run_experiments import run_agents_on_mdp
from simple_rl.utils.chart_utils import make_plots
import minigrid
from minigrid.wrappers import FullyObsWrapper
from agents.FullyObsObjWrapperClass import FullyObsObjWrapper, SkillsAndObjWrapper
import rlang
from PIL import Image
from utils import hash_state
from dynamic_grounding import get_knowledge_from_file, get_primitives_for, SmartStateFeaturizer, get_stable_knowledge
from llm_rlang import stage_1, stage_2
from param_sweep import perform_param_sweep

from agents import QLearningAgentSimple, DynaQLearningAgent, RLangDynaQLearningAgent
import matplotlib.pyplot as plt

plt.rcParams["axes.grid"] = False

def generate_trial_data(env_name, seeds, path, **kwargs):
    env = GymMDP(env_name=env_name, wrapper=FullyObsObjWrapper, seed=seeds[0], render=True, render_mode='rgb_array', **kwargs)

    for s in seeds:
        env.reset(seed=s)
        state = env.get_init_state()
        image_obs = env.env.get_frame(highlight=False)
        image = Image.fromarray(image_obs)
        image.save(f"{path}/{env_name}-seed-{s}.png")

        primitives = get_primitives_for(env.get_init_state(), env=env)
        open(f"{path}/{env_name}-seed-{s}-primitives.txt", "w").write(str(primitives))
        open(f"{path}/{env_name}-seed-{s}-advice.txt", "w").write("")
    


def generate_all_trial_data():
    generate_trial_data("MiniGrid-LavaCrossingS9N1-v0", list(range(1, 11)), "trial_data/LavaCrossing")
    generate_trial_data("MiniGrid-LavaCrossingS11N5-v0", list(range(1, 11)), "trial_data/LavaCrossing")
    generate_trial_data("MiniGrid-LockedRoom-v0", list(range(1, 11)), "trial_data/LockedRoom")
    generate_trial_data("MiniGrid-MultiRoom-N4-S5-v0", list(range(1, 11)), "trial_data/MultiRoom")
    generate_trial_data("MiniGrid-DoorKey-8x8-v0", list(range(1, 11)), "trial_data/DoorKey")
    generate_trial_data("MiniGrid-MemoryS11-v0", list(range(1, 11)), "trial_data/Memory")

def generate_all_trial_data_zy():
    trial_path = "trial_data/nl2rlang"
    os.makedirs(trial_path, exist_ok=True)
    generate_trial_data("nl2rlang", list(range(1, 5)), trial_path, customized_env=TestEnv())

def nl2rlang(env_name, seed, advices, **kwargs):
    env = GymMDP(env_name=env_name, wrapper=SkillsAndObjWrapper, seed=seed, render=False, render_mode='rgb_array', **kwargs)
    # primitives = get_primitives_for(state=env.get_init_state()['image'], env=env)
    primitives = get_primitives_for(env.get_init_state(), env=env)
    # print(primitives)
    # breakpoint()
    # if any(["lava" in p for p in primitives]): primitives.append("lava")
    # primitives = [p for p in primitives if not p.startswith("lava_")]
    
    rlang_advices = []
    for advice in advices:
        grounding_selection = stage_1(advice)
        # print(grounding_selection)
        rlang_advice = stage_2(advice, primitives, grounding_selection)
        rlang_advices.append(rlang_advice)
        if kwargs['update_advice_dict'] == True:
            with open(advice_dict_fpath, "r") as f: advice_dict = json.load(f)
            nl2rlang_inst = {"advice": advice, "grounding": grounding_selection, "rlang": rlang_advice}
            if env_name in advice_dict.keys():
                advice_dict[env_name].append(nl2rlang_inst)
            else:
                advice_dict[env_name] = [nl2rlang_inst]
            with open(advice_dict_fpath, "w") as f: json.dump(advice_dict, f, sort_keys=True, indent=4 * ' ')
    return rlang_advices


def run_experiment_for_env(env_name="MiniGrid-Empty-5x5-v0", env_nickname="Empty", seed=None, render=False, verbose=False, action_idxs=list(range(3)), agent_types=None, *args, **kwargs):
    gym_mdp = GymMDP(env_name=env_name, render=render, wrapper=SkillsAndObjWrapper, seed=seed, render_mode='human' if render else 'rgb_array', *args, **kwargs)#, agent_start_pos=(2,3))
    # actions = gym_mdp.get_actions()[action_idxs]
    actions = action_idxs
    if env_nickname in ("EasyMaze", "HardMaze", "MidMaze", "DoorKey", "MultiRoom", "LockedRoom", "LavaCrossing", "MidMazeLava"):
        gym_mdp.reset()
        state = gym_mdp.get_init_state()
        knowledge = get_stable_knowledge()
        statefeaturizer = SmartStateFeaturizer(knowledge=knowledge)
        gym_mdp.env.skill_dict, skill_names = statefeaturizer.generate_skill_dict(state)
        skill_idxs = list(skill_names.keys())
        actions += skill_idxs
        print(actions)

    # Setup agents and run.
    rand_agent = RandomAgent(actions)
    if render:
        gym_mdp.render = True
    # rlang_agents, q_params, exp_params = get_rlang_agents_for_env(gym_mdp, env_nickname, actions)
    if agent_types:
        rlang_agents, q_params, exp_params = get_rlang_agents_for_env_fast(gym_mdp, env_nickname, actions, agent_types=agent_types)
    else:
        rlang_agents, q_params, exp_params = get_rlang_agents_for_env(gym_mdp, env_nickname, actions)
    # q_agent = QLearningAgentSimple(actions, state_unwrapper=unwrap_state, state_hash_func=hash_state,
    #                                 name="Q-Learning", **q_params)
    dyna_agent = DynaQLearningAgent(actions, state_hash_func=hash_state, name="Dyna-Q", **q_params)
    # *rlang_agents, rand_agent
    run_agents_on_mdp([*rlang_agents],
                        gym_mdp, **exp_params, open_plot=False, verbose=verbose,
                        cumulative_plot=True, track_disc_reward=False, seed_per_inst=False, seeds=list(range(1, 51)))
    

def get_rlang_agents_for_env(gym_mdp, env_name, actions):
    if env_name == "Empty":
        get_knowledge = lambda state, filename="rlang_advice/empty.rlang": get_knowledge_from_file(state, filename, gym_mdp)

        params = {"gamma":0.99, "epsilon":0.1, "alpha":0.05, "anneal":True}
        
        rlang_dyna_agent = RLangDynaQLearningAgent(
            name="RLang-Dyna-Q-plan",
            actions=actions, get_knowledge=get_knowledge,
            state_hash_func=hash_state,
            use_policy=False, use_plan=True, use_effects=False, **params)
        
        return ([rlang_dyna_agent],
                {"gamma":0.99, "epsilon":0.1, "alpha":0.05, "anneal":True},
                {"instances":10, "episodes":15, "steps":100})
    
    elif env_name == "DoorKey":
        get_knowledge = lambda state, filename="rlang_advice/doorkey.rlang": get_knowledge_from_file(state, filename, gym_mdp)

        params = {"gamma":0.99, "epsilon":0.05, "alpha":0.5, "anneal":False}
        
        rlang_dyna_agent_plan = RLangDynaQLearningAgent(
            name="RLang-Dyna-Q-plan",
            actions=actions, get_knowledge=get_knowledge,
            state_hash_func=hash_state,
            use_policy=False, use_plan=True, use_effects=False, **params)
        rlang_dyna_agent_policy = RLangDynaQLearningAgent(
            name="RLang-Dyna-Q-policy",
            actions=actions, get_knowledge=get_knowledge,
            state_hash_func=hash_state,
            use_policy=True, use_plan=False, use_effects=False, **params)
        rlang_dyna_agent_effect = RLangDynaQLearningAgent(
            name="RLang-Dyna-Q-effect",
            actions=actions, get_knowledge=get_knowledge,
            state_hash_func=hash_state,
            use_policy=False, use_plan=False, use_effects=True, **params)
        
        return ([rlang_dyna_agent_plan, rlang_dyna_agent_policy, rlang_dyna_agent_effect][1:2],
                {"gamma":0.99, "epsilon":0.1, "alpha":0.1, "anneal":False},
                {"instances":5, "episodes":50, "steps":200})
    
    elif env_name == "LavaCrossing":
        get_knowledge = lambda state, filename="rlang_advice/lavacrossing.rlang": get_knowledge_from_file(state, filename, gym_mdp)

        params = {"gamma":0.99, "epsilon":0.01, "alpha":0.1, "anneal":False, "default_q":0.001}

        rlang_dyna_agent_plan = RLangDynaQLearningAgent(
            name="RLang-Dyna-Q-plan",
            actions=actions, get_knowledge=get_knowledge,
            state_hash_func=hash_state,
            use_policy=False, use_plan=True, use_effects=False, policy_epsilon=1.0, **params)
        rlang_dyna_agent_effect = RLangDynaQLearningAgent(
            name="RLang-Dyna-Q-effect",
            actions=actions, get_knowledge=get_knowledge,
            state_hash_func=hash_state,
            use_policy=False, use_plan=False, use_effects=True, **params)
        rlang_dyna_agent_combined = RLangDynaQLearningAgent(
            name="RLang-Dyna-Q-combined",
            actions=actions, get_knowledge=get_knowledge,
            state_hash_func=hash_state,
            use_policy=False, use_plan=True, use_effects=True, policy_epsilon=1.0, **params)
        return ([rlang_dyna_agent_effect],
                {"gamma":0.99, "epsilon":0.01, "alpha":0.1, "default_q":0.001},
                {"instances":10, "episodes":25, "steps":500, "reset_at_terminal":False})
    
    elif env_name == "GoToObject":
        get_knowledge = lambda state, filename="rlang_advice/gotoobject.rlang": get_knowledge_from_file(state, filename, gym_mdp)

        params = {"gamma":0.99, "epsilon":0.01, "alpha":0.1, "anneal":False}

        rlang_dyna_agent_plan = RLangDynaQLearningAgent(
            name="RLang-Dyna-Q-plan",
            actions=actions, get_knowledge=get_knowledge,
            state_hash_func=hash_state,
            use_policy=False, use_plan=True, use_effects=False, **params)
        return ([rlang_dyna_agent_plan],
                {"gamma":0.99, "epsilon":0.01, "alpha":0.1},
                {"instances":5, "episodes":30, "steps":300})
    
    elif env_name == "MultiRoom":
        get_knowledge = lambda state, filename="rlang_advice/multiroom.rlang": get_knowledge_from_file(state, filename, gym_mdp)

        params = {"gamma":0.99, "epsilon":0.01, "alpha":0.1, "anneal":False, "default_q":0.001}

        rlang_dyna_agent_plan = RLangDynaQLearningAgent(
            name="RLang-Dyna-Q-plan",
            actions=actions, get_knowledge=get_knowledge, num_hallucinations=8,
            state_hash_func=hash_state,
            use_policy=False, use_plan=True, use_effects=False, policy_epsilon=1.0, **params)
        lang_dyna_agent_policy = RLangDynaQLearningAgent(
            name="RLang-Dyna-Q-policy",
            actions=actions, get_knowledge=get_knowledge,
            state_hash_func=hash_state,
            use_policy=True, use_plan=False, use_effects=False, policy_epsilon=0.99, **params)
        return ([rlang_dyna_agent_plan],
                {"gamma":0.99, "epsilon":0.01, "alpha":1.0, "default_q":0.001, "num_hallucinations":8},
                {"instances":10, "episodes":50, "steps":500})
    
    elif env_name == "LockedRoom":
        get_knowledge = lambda state, filename="rlang_advice/lockedroom.rlang": get_knowledge_from_file(state, filename, gym_mdp)

        params = {"gamma":0.99, "epsilon":0.01, "alpha":0.1, "anneal":False, "default_q":0.001}

        rlang_dyna_agent_plan = RLangDynaQLearningAgent(
            name="RLang-Dyna-Q-plan",
            actions=actions, get_knowledge=get_knowledge,
            state_hash_func=hash_state,
            use_policy=False, use_plan=True, use_effects=False, policy_epsilon=1.0, **params)
        rlang_dyna_agent_effect = RLangDynaQLearningAgent(
            name="RLang-Dyna-Q-effect",
            actions=actions, get_knowledge=get_knowledge,
            state_hash_func=hash_state,
            use_policy=False, use_plan=False, use_effects=True, policy_epsilon=1.0, **params)
        rlang_dyna_agent_combined = RLangDynaQLearningAgent(
            name="RLang-Dyna-Q-combined",
            actions=actions, get_knowledge=get_knowledge,
            state_hash_func=hash_state,
            use_policy=False, use_plan=True, use_effects=True, policy_epsilon=1.0, **params)
        return ([rlang_dyna_agent_effect],
                {"gamma":0.99, "epsilon":0.01, "alpha":1.0, "default_q":0.001},
                {"instances":10, "episodes":50, "steps":500})
    
    elif env_name == "Memory": # obj in start_room is randomized, this is a hard one
        get_knowledge = lambda state, filename="rlang_advice/memory.rlang": get_knowledge_from_file(state, filename, gym_mdp)
        params = {"gamma":0.99, "epsilon":0.01, "alpha":0.1, "anneal":False, "default_q":0.001}

        rlang_dyna_agent_effect = RLangDynaQLearningAgent(
            name="RLang-Dyna-Q-effect",
            actions=actions, get_knowledge=get_knowledge,
            state_hash_func=hash_state,
            use_policy=False, use_plan=False, use_effects=True, policy_epsilon=1.0, **params)
        rlang_dyna_agent_policy = RLangDynaQLearningAgent(
            name="RLang-Dyna-Q-policy",
            actions=actions, get_knowledge=get_knowledge,
            state_hash_func=hash_state,
            use_policy=True, use_plan=False, use_effects=False, policy_epsilon=1.0, **params)
        rlang_dyna_agent_plan = RLangDynaQLearningAgent(
            name="RLang-Dyna-Q-plan",
            actions=actions, get_knowledge=get_knowledge,
            state_hash_func=hash_state,
            use_policy=False, use_plan=True, use_effects=False, policy_epsilon=1.0, **params)
        return ([rlang_dyna_agent_plan],
                {"gamma":0.99, "epsilon":0.01, "alpha":1.0, "default_q":0.001},
                {"instances":10, "episodes":50, "steps":500})

    elif env_name == 'HardMaze':
        params = {"gamma":0.99, "epsilon":0.01, "alpha":0.1, "anneal":False, "default_q":0.0001}
        return (None, params, {"instances":1, "episodes":100, "steps":1000})
    
    elif env_name == 'MidMaze':
        get_knowledge = lambda state, filename="rlang_advice/midmaze.rlang": get_knowledge_from_file(state, filename, gym_mdp)
        params = {"gamma":0.99, "epsilon":0.1, "alpha":0.01, "anneal":False, "default_q":0.0}

        rlang_dyna_agent_plan = RLangDynaQLearningAgent(
            name="RLang-Dyna-Q-plan",
            actions=actions, get_knowledge=get_knowledge,
            state_hash_func=hash_state,
            use_policy=False, use_plan=True, use_effects=False, policy_epsilon=1.0, **params)
        rlang_dyna_agent_policy = RLangDynaQLearningAgent(
            name="RLang-Dyna-Q-policy",
            actions=actions, get_knowledge=get_knowledge,
            state_hash_func=hash_state,
            use_policy=True, use_plan=False, use_effects=False, policy_epsilon=1.0, **params)
        
        return ([rlang_dyna_agent_plan, rlang_dyna_agent_policy], params, {"instances":5, "episodes":40, "steps":1000})
    
    elif env_name == 'MidMazeLava':
        get_knowledge = lambda state, filename="rlang_advice/midmazelava.rlang": get_knowledge_from_file(state, filename, gym_mdp)
        params = {"gamma":0.99, "epsilon":0.1, "alpha":0.01, "anneal":False, "default_q":0.0}

        rlang_dyna_agent_plan = RLangDynaQLearningAgent(
            name="RLang-Dyna-Q-plan",
            actions=actions, get_knowledge=get_knowledge,
            state_hash_func=hash_state,
            use_policy=False, use_plan=True, use_effects=False, policy_epsilon=1.0, **params)
        rlang_dyna_agent_policy = RLangDynaQLearningAgent(
            name="RLang-Dyna-Q-policy",
            actions=actions, get_knowledge=get_knowledge,
            state_hash_func=hash_state,
            use_policy=True, use_plan=False, use_effects=False, policy_epsilon=1.0, **params)
        rlang_dyna_agent_effect = RLangDynaQLearningAgent(
            name="RLang-Dyna-Q-effect",
            actions=actions, get_knowledge=get_knowledge,
            state_hash_func=hash_state,
            use_policy=False, use_plan=False, use_effects=True, policy_epsilon=1.0,
            gamma=0.99, epsilon=0.01, alpha=0.1, default_q=0.0)
        rlang_dyna_agent_combined = RLangDynaQLearningAgent(
            name="RLang-Dyna-Q-combined",
            actions=actions, get_knowledge=get_knowledge,
            state_hash_func=hash_state,
            use_policy=True, use_plan=True, use_effects=True, policy_epsilon=1.0,
            gamma=0.99, epsilon=0.01, alpha=0.1, default_q=0.0)
        
        return ([rlang_dyna_agent_combined, rlang_dyna_agent_plan], params, {"instances":10, "episodes":50, "steps":2000})

def get_rlang_agents_for_env_fast(gym_mdp, env_name, actions, agent_types):
    get_knowledge = lambda state, filename=f"rlang_advice/{env_name.lower()}.rlang": get_knowledge_from_file(state, filename, gym_mdp)
    params = {"gamma":0.99, "epsilon":0.01, "alpha":0.1, "anneal":False, "default_q":0.001}

    rlang_dyna_agent_effect = RLangDynaQLearningAgent(
        name="RLang-Dyna-Q-effect",
        actions=actions, get_knowledge=get_knowledge,
        state_hash_func=hash_state,
        use_policy=False, use_plan=False, use_effects=True, policy_epsilon=1.0, **params)
    rlang_dyna_agent_policy = RLangDynaQLearningAgent(
        name="RLang-Dyna-Q-policy",
        actions=actions, get_knowledge=get_knowledge,
        state_hash_func=hash_state,
        use_policy=True, use_plan=False, use_effects=False, policy_epsilon=1.0, **params)
    rlang_dyna_agent_plan = RLangDynaQLearningAgent(
        name="RLang-Dyna-Q-plan",
        actions=actions, get_knowledge=get_knowledge,
        state_hash_func=hash_state,
        use_policy=False, use_plan=True, use_effects=False, policy_epsilon=1.0, **params)
    rlang_dyna_agent_combined = RLangDynaQLearningAgent(
        name="RLang-Dyna-Q-combined",
        actions=actions, get_knowledge=get_knowledge,
        state_hash_func=hash_state,
        use_policy=False, use_plan=True, use_effects=True, policy_epsilon=1.0, **params)
    agent_dict = {"effect": rlang_dyna_agent_effect, "policy": rlang_dyna_agent_policy,\
                    "plan":rlang_dyna_agent_plan, "combined": rlang_dyna_agent_combined}
    return ([agent_dict[agent_type] for agent_type in agent_types],
            {"gamma":0.99, "epsilon":0.01, "alpha":1.0, "default_q":0.001},
            {"instances":10, "episodes":50, "steps":500})
                    

def main():
    # rlang_advices = nl2rlang("MiniGrid-LavaCrossingS11N5-v0", seed=1, advices=["go to the goal", "don't walk into lava"])
    # rlang_advices = nl2rlang("MiniGrid-MultiRoom-N4-S5-v0", seed=1, advices=["first go t o the blue door, then the green door, then the grey door, then the purple door"])
    # rlang_advices = nl2rlang("MiniGrid-LockedRoom-v0", seed=1, advices=["get the red key from behind the grey door, then go to the red door", "the key is behind the grey door"])
    # rlang_advices = nl2rlang("MiniGrid-MemoryS11-v0", seed=1, advices=["pick up the green ball if there are two green ball"] ,update_advice_dict=True) # doesn't work yet
    # rlang_advices = nl2rlang("MiniGrid-BlockedUnlockPickup-v0", seed=1, advices=["pick up the key if door is locked", "a key can open a locked door"] ,update_advice_dict=True)
    # rlang_advices = nl2rlang("MiniGrid-DistShift1-v0", seed=1, advices=["never step on lava", "go to the goal"], update_advice_dict=True)
    # rlang_advices = nl2rlang("MiniGrid-Dynamic-Obstacles-5x5-v0", seed=1, advices=["avoid the obstacles", "go to the goal"], update_advice_dict=True) # GPT thinks lava, wall, and ball as obstacles
    # rlang_advices = nl2rlang("MiniGrid-RedBlueDoors-6x6-v0", seed=1, advices=["never go to blue door first"], update_advice_dict=True)
    # rlang_advices = nl2rlang("MiniGrid-Unlock-v0", seed=1, advices=["pickup the key", "a key can open a door"], update_advice_dict=True)
    # rlang_advices = nl2rlang("BabyAI-GoToRedBallGrey-v0", seed=1, advices=["pick up the red ball"], update_advice_dict=True) # Want to use "don't pick any object that is not red"
    # rlang_advices = nl2rlang("BabyAI-GoToRedBall-v0", seed=1, advices=["pick up the red ball"], update_advice_dict=True) # Want to use "don't pick any object that is not red"
    # rlang_advices = nl2rlang("BabyAI-GoToObj-v0", seed=1, advices=["pick up the ball or key or box", ""], update_advice_dict=True) # Want to use "don't pick any object that is not red"
    # rlang_advices = nl2rlang("BabyAI-MiniBossLevel-v0", seed=1, advices=["a key can open a door"], update_advice_dict=True) # Want to use "don't pick any object that is not red"
    # rlang_advices = nl2rlang("nl2rlang", seed=1, advices=["you can open blue door with a blue key"], update_advice_dict=False, customized_env=TestEnv())
    # rlang_advices = nl2rlang("MidMazeLava", seed=1, advices=["pick up the blue ball and drop it to your right. Then pick up the green key and unlock the green door. Then drop the key to your right and go pick up the green ball.",
    #                                                          "Some general advice: If you have the green key and the green door is closed, open it if you are at it, otherwise go to it if it is closed. The same applies for the grey key and door.",
    #                                                          "Don't walk into lava"], update_advice_dict=False, customized_env=TestEnvMediumLava())
    # [print(rlang_advice) for rlang_advice in rlang_advices]

    # run_experiment_for_env("MiniGrid-Empty-8x8-v0", "Empty", action_idxs=[0, 1, 2], agent_start_pos=None)
    # run_experiment_for_env("MiniGrid-DoorKey-5x5-v0", "DoorKey", action_idxs=[0, 1, 2, 3, 5], render=True)
    # run_experiment_for_env("MiniGrid-LavaCrossingS9N1-v0", "LavaCrossing", action_idxs=[0, 1, 2], render=False, max_steps=1000)
    # run_experiment_for_env("MiniGrid-GoToObject-6x6-N2-v0", "GoToObject", action_idxs=[0, 1, 2, 6])
    # run_experiment_for_env("MiniGrid-MultiRoom-N4-S5-v0", "MultiRoom", action_idxs=[0, 1, 2, 5], render=False, max_steps=500)
    # run_experiment_for_env("MiniGrid-LockedRoom-v0", "LockedRoom", action_idxs=[0, 1, 2, 3, 5], max_steps=500, render=True)
    # run_experiment_for_env("MiniGrid-MemoryS11-v0", "memory", action_idxs=[0, 1, 2, 3, 5], max_steps=500, render=False, agent_types=["effect", "plan"])
    # run_experiment_for_env("MiniGrid-BlockedUnlockPickup-v0", "BlockedUnlockPickup", action_idxs=[0, 1, 2, 3], max_steps=500, render=False, agent_types=["effect"]) # doesn't work yet, have more than one color
    # run_experiment_for_env("MiniGrid-DistShift1-v0", "DistShift", action_idxs=[0, 1, 2], max_steps=500, render=False, agent_types=["effect", "plan", "combined"])
    # run_experiment_for_env("MiniGrid-Dynamic-Obstacles-5x5-v0", "DynamicObstacles", action_idxs=[0, 1, 2], max_steps=3000, render=False, agent_types=["effect", "plan", "combined"]) 
    # run_experiment_for_env("MiniGrid-RedBlueDoors-6x6-v0", "RedBlueDoors", action_idxs=[0, 1, 2, 5], max_steps=500, render=False, agent_types=["plan"])
    # run_experiment_for_env("MiniGrid-Unlock-v0", "Unlock", action_idxs=[0, 1, 2, 3, 5], max_steps=500, render=False, agent_types=["plan", "effect"])
    # run_experiment_for_env("BabyAI-GoToRedBallGrey-v0", "GoToRedBallGrey", action_idxs=[0, 1, 2, 3], max_steps=500, render=False, agent_types=["plan"])
    # run_experiment_for_env("BabyAI-GoToRedBall-v0", "GoToRedBallGrey", action_idxs=[0, 1, 2, 3], max_steps=500, render=False, agent_types=["plan"])
    # run_experiment_for_env("BabyAI-MiniBossLevel-v0", "MiniBossLevel", action_idxs=[0, 1, 2, 3, 4], max_steps=500, render=False, agent_types=["effect"])
    # run_experiment_for_env("HardMaze", "HardMaze", action_idxs=[0, 1, 2, 3, 4, 5], max_steps=2000, render=True, agent_types=None, customized_env=TestEnv())
    # run_experiment_for_env("EasyMaze", "EasyMaze", action_idxs=[0, 1, 2, 3, 4, 5], max_steps=2000, render=False, agent_types=None, customized_env=TestEnvEasy())
    # run_experiment_for_env("MidMaze", "MidMaze", action_idxs=[0, 1, 2, 3, 4, 5], max_steps=1000, render=False, verbose=False, agent_types=None, customized_env=TestEnvMedium())
    # run_experiment_for_env("MidMazeLava", "MidMazeLava", action_idxs=[0, 1, 2, 3, 4, 5], max_steps=2000, render=False, verbose=False, agent_types=None, customized_env=TestEnvMediumLava())
    # perform_param_sweep("MidMaze", action_ids=[0, 1, 2, 3, 4, 5], env_nickname="MidMaze", customized_env=TestEnvMedium())
    # perform_param_sweep("MidMazeLava", action_ids=[0, 1, 2, 3, 4, 5], env_nickname="MidMazeLava", customized_env=TestEnvMediumLava())
    # perform_param_sweep("HardMaze", action_ids=[0, 1, 2, 3, 4, 5], env_nickname="HardMaze", customized_env=TestEnv())
    # perform_param_sweep("HardMazeLight", action_ids=[0, 1, 2, 3, 4, 5], env_nickname="HardMazeLight", customized_env=TestEnvLessBall())
    # perform_param_sweep("MiniGrid-MultiRoom-N4-S5-v0", action_ids=[0, 1, 2, 5], seed=1)
    # perform_param_sweep("MiniGrid-LavaCrossingS9N1-v0", action_ids=[0, 1, 2])
    # perform_param_sweep("MiniGrid-LockedRoom-v0", action_ids=[0, 1, 2, 3, 5])
    # generate_all_trial_data()

    make_plots("results/gym-MidMazeLava", ["Random", "Dyna-Q", "RLang-Dyna-Q-effect", "RLang-Dyna-Q-policy", "RLang-Dyna-Q-plan", "RLang-Dyna-Q-combined"])

def test():
    gym_mdp = GymMDP(env_name="MiniGrid-LockedRoom-v0", render=False, wrapper=FullyObsObjWrapper, seed=1)#, agent_start_pos=(2,3))
    # generate_trial_data("MiniGrid-LavaCrossingS9N1-v0", list(range(10)), "trial_data/LavaCrossing")

    obs = gym_mdp.reset()
    obs = obs[0]
    print(obs)

    obs = gym_mdp.reset(seed=1)
    print(obs[0])
    # obs = gym_mdp.reset(seed=1)
    # print(obs['image'])
    # obs = gym_mdp.reset(seed=1)
    # print(obs['image'])

    print(get_primitives_for(state_tuple=obs, env=gym_mdp))

    image_obs = gym_mdp.env.get_frame(highlight=False)
    plt.imshow(image_obs)
    plt.show()

    # obs = gym_mdp.reset(seed=2)
    # print(get_primitives_for(state=obs['image'], env=gym_mdp))
    # image_obs = gym_mdp.env.get_frame(highlight=False)
    # plt.imshow(image_obs)
    # plt.show()


if __name__ == "__main__":
    LOGGER = logging.getLogger(__name__)
    LOGGER.setLevel(logging.ERROR)
    main()
    # test()
