import sys
import logging

import numpy as np
from simple_rl.agents import RandomAgent, QLearningAgent
# from simple_rl.tasks import GymMDP
from simple_rl.run_experiments import run_agents_on_mdp
import minigrid
from minigrid.wrappers import FullyObsWrapper
import rlang
from PIL import Image
from utils import hash_state
from dynamic_grounding import get_knowledge_from_file, get_primitives_for
from llm_rlang import stage_1, stage_2
from param_sweep import perform_param_sweep
from dynamic_grounding import get_stable_knowledge, SmartStateFeaturizer, StateFeaturizer

from TestEnv import TestEnv, GymMDP

from agents import QLearningAgentSimple, DynaQLearningAgent, RLangDynaQLearningAgent, FullyObsObjWrapper, SkillsAndObjWrapper
import matplotlib.pyplot as plt

plt.rcParams["axes.grid"] = False

env = GymMDP(env_name="MiniGrid-LavaCrossingS9N1-v0", wrapper=SkillsAndObjWrapper, render=True, render_mode='rgb_array')
# env = SkillsAndObjWrapper(env)
# env.render
# obs, _ = env.reset()
# print(obs)

print(type(env))
env.reset()
state = env.get_init_state()
image_obs = env.env.get_frame(highlight=False)  # Need to do this for the skills-wrapped env
print(state)
# plt.imshow(image_obs)
# plt.show()

# knowledge = get_stable_knowledge()
# statefeaturizer = SmartStateFeaturizer(knowledge=knowledge)
# knowledge.update(statefeaturizer.generate_rlang_objects(state))
# env.env.skill_dict, skill_names = statefeaturizer.generate_skill_dict(state)
# # rlang_objects = statefeaturizer.generate_objects()
# # print(rlang_objects)
# print(list(knowledge.keys()))
# # env.env.skill_dict = {8: "s"}
# print(env.env.skill_dict)

# knowledge.update(get_knowledge_from_file(state, "rlang_advice/lavacrossing.rlang", env=env)[0])

# print(list(knowledge.keys()))

# print(state[0].shape)

# s0 = np.ones_like(state[0]) * 9
# print(s0)

# print(knowledge['main_effect'].transition_function(state=(s0, state[1], state[2]), action=2))

# OHHHHHH get_knowledge_from_file does not actually update the knowledge.transition_function, etc. So the above does not work.

# state = (state[0],(6, 1, 0), (3, 5, 1))

# print(statefeaturizer._find_obj_by_typecolor(state=state, targ_type_idx=6, targ_color_idx=1))

# print(state)

# carrying = knowledge['carrying']
# print(carrying(obj=knowledge['green_ball'], state=state))

# at = knowledge['at']
# print(at(obj=knowledge['green_ball'], state=state))

# s = rlang.VectorState(state, dtype=object)
# print(s)

# print(str(state[0].data.tobytes()))
# print(str(rlang_state_to_tuple(s).data.tobytes()))
# print(str(state[0].data.tobytes()) == str(rlang_state_to_tuple(s).data.tobytes()))
