import copy
import numpy as np
import rlang
from rlang.grounding import ActionReference, MDPObject, MDPObjectGrounding, Feature, ConstantGrounding, Domain, Predicate, MDPClassGrounding, Proposition, Plan
from grounding import find_obj_by_id, get_stable_knowledge
import matplotlib.pyplot as plt
from functools import lru_cache
from utils import shortest_path
from simple_rl.tasks.gym.GymStateClass import GymState

COLOR_TO_IDX = {"red": 0, "green": 1, "blue": 2, "purple": 3, "yellow": 4, "grey": 5}
IDX_TO_COLOR = dict(zip(COLOR_TO_IDX.values(), COLOR_TO_IDX.keys()))

OBJECT_TO_IDX = {
    "unseen": 0,
    "empty": 1,
    "wall": 2,
    "floor": 3,
    "door": 4,
    "key": 5,
    "ball": 6,
    "box": 7,
    "goal": 8,
    "lava": 9,
    "agent": 10,
}

IDX_TO_OBJECT = dict(zip(OBJECT_TO_IDX.values(), OBJECT_TO_IDX.keys()))


class SmartStateFeaturizer:
    """
    This featurizer makes the following assumptions:
    1. There is not more than one object-color pair in the world, e.g. there is no more than one yellow door.
    2. After initialization, objects cannot appear nor disappear. If they are not present in the image grid, they are in the agent's inventory.
    3. There is *no* continuity in object locations (except walls, lava, doors, and goal tiles).

    Since we assume the same set of objects are present in the world at all times, we can generate a set of RLang objects at initialization and change their attributes as needed.
    If it's in the initial state, it's in any given state.

    This is a semi-tough problem, as there is a stable set of RLang objects that we need to adjust based on the state.
    We could probably store the state in this featurizer to know if we need to update the RLang objects. This would prevent us from having to do it multiple times.
    Do we really want to compute every object attribute at every state? Or do we want to compute them as needed?
    """
    def __init__(self, knowledge):
        self.knowledge = knowledge
        self.skill_names_reversed = None

    # @lru_cache(maxsize=None)
    def _find_obj_by_typecolor(self, state, targ_type_idx, targ_color_idx):
        """
        Given an object type and color, find the object in the world or the agent's inventory and construct a dictionary of its attributes.
        """

        if not isinstance(state, (GymState, tuple)):
            state = (state[0].view(np.ndarray), tuple(state[1].view(np.ndarray)), tuple(state[2].view(np.ndarray)))

        if state[1][0] != 1:
            type_idx, color_idx, obj_state = state[1]
            if type_idx == targ_type_idx and color_idx == targ_color_idx:
                return {
                    'color': color_idx,
                    'x': -1,
                    'y': -1
                }

        image = state[0]
        width, height, channels = image.shape
        assert channels == 3

        for i in range(width):
            for j in range(height):
                type_idx, color_idx, obj_state = image[i, j]
                if type_idx == targ_type_idx and color_idx == targ_color_idx:
                    if type_idx == 4:
                        return {
                            'color': color_idx,
                            'x': i,
                            'y': j,
                            'is_open': obj_state == 0,
                            'is_locked': obj_state == 2
                        }
                    else:
                        return {
                            'color': color_idx,
                            'x': i,
                            'y': j
                        }
        
        raise Exception(f"Object of type {targ_type_idx} and color {targ_color_idx} not found in state {state}")

    def generate_rlang_objects(self, state):
        image = state[0]
        rlang_classes = self.knowledge.classes()
        rlang_objects = {}

        width, height, channels = image.shape
        assert channels == 3

        for i in range(width):
            for j in range(height):
                type_idx, color_idx, obj_state = image[i, j]
                type = IDX_TO_OBJECT[type_idx]
                color = IDX_TO_COLOR[color_idx]

                if type == 'door':
                    obj_name = f"{color}_{type}"
                    rlang_cls = rlang_classes['Door']
                    rlang_objects[obj_name] = MDPObjectGrounding(obj=rlang_cls(name=obj_name,
                                                                            x=Feature(lambda state, action=None, type_color=(type_idx, color_idx): self._find_obj_by_typecolor(state, *type_color)['x'], name="x"),
                                                                            y=Feature(lambda state, action=None, type_color=(type_idx, color_idx): self._find_obj_by_typecolor(state, *type_color)['y'], name="y"),
                                                                            color=ConstantGrounding(value=color_idx, codomain=Domain.REAL_VALUE),
                                                                            is_open=Feature(lambda state, action=None, type_color=(type_idx, color_idx): self._find_obj_by_typecolor(state, *type_color)['is_open'], name="is_open"),
                                                                            is_locked=Feature(lambda state, action=None, type_color=(type_idx, color_idx): self._find_obj_by_typecolor(state, *type_color)['is_locked'], name="is_locked")),
                                                                domain=Domain.STATE)
                if type in ('key', 'ball', 'box', 'goal'):
                    obj_name = f"{color}_{type}" if type != 'goal' else 'goal'
                    rlang_cls = rlang_classes[type.capitalize() if type != 'goal' else 'GoalTile']
                    rlang_objects[obj_name] = MDPObjectGrounding(obj=rlang_cls(name=obj_name,
                                                                            x=Feature(lambda state, action=None, type_color=(type_idx, color_idx): self._find_obj_by_typecolor(state, *type_color)['x'], name="x"),
                                                                            y=Feature(lambda state, action=None, type_color=(type_idx, color_idx): self._find_obj_by_typecolor(state, *type_color)['y'], name="y"),
                                                                            color=ConstantGrounding(value=color_idx, codomain=Domain.REAL_VALUE)),
                                                                domain=Domain.STATE)
            
        rlang_objects['agent'] = MDPObjectGrounding(obj=rlang_classes['Agent'](name="agent",
                                                                        x=Feature(lambda state, action=None: state[2][0], name="x"),
                                                                        y=Feature(lambda state, action=None: state[2][1], name="y"),
                                                                        dir=Feature(lambda state, action=None: state[2][2], name="dir")),
                                                    domain=Domain.STATE)
        
        if state[1][0] != 1: # The agent is carrying something, either a key or a ball
            type_idx, color_idx, _ = state[1]
            type = IDX_TO_OBJECT[type_idx]
            color = IDX_TO_COLOR[color_idx]

            obj_name = f"{color}_{type}"
            rlang_cls = rlang_classes[type.capitalize()]
            rlang_objects[obj_name] = MDPObjectGrounding(obj=rlang_cls(name=obj_name,
                                                                        x=Feature(lambda state, action=None, type_color=(type_idx, color_idx): self._find_obj_by_typecolor(state, *type_color)['x'], name="x"),
                                                                        y=Feature(lambda state, action=None, type_color=(type_idx, color_idx): self._find_obj_by_typecolor(state, *type_color)['y'], name="y"),
                                                                        color=ConstantGrounding(value=color_idx, codomain=Domain.REAL_VALUE)),
                                                        domain=Domain.STATE)       

        rlang_objects['carrying'] = Predicate(self.carrying, name="carrying")
        rlang_objects['carrying_something'] = Predicate(self.carrying_something, name="carrying_something")
        rlang_objects['reachable'] = Predicate(self.reachable, name="reachable")
        rlang_objects['go_to'] = Plan(function=self.go_to, name="go_to")
        rlang_objects['at'] = Predicate(self.at, name="at")
            
        return rlang_objects
    
    def generate_skill_dict(self, this_state):
        skill_dict = {}
        skill_names = {}
        skill_names_reversed = {}
        iterator = 7
        # Generate an action for each object in the world
        # We assume the agent is not holding any inventory items

        image = this_state[0]
        rlang_classes = self.knowledge.classes()
        rlang_objects = {}

        width, height, channels = image.shape
        assert channels == 3

        for i in range(width):
            for j in range(height):
                type_idx, color_idx, obj_state = image[i, j]
                type_idx = int(type_idx)
                color_idx = int(color_idx)
                obj_type = IDX_TO_OBJECT[type_idx]
                color = IDX_TO_COLOR[color_idx]

                if obj_type in ('key', 'ball', 'box', 'goal', 'door'):
                    obj_name = f"{color}_{obj_type}"
                    skill_dict[iterator] = lambda state, type_color=(type_idx, color_idx): self.go_to(state=state, obj=type('FastClass', (object,), self._find_obj_by_typecolor(state, *type_color)), return_action=True)
                    skill_names[iterator] = f"go_to({obj_name})"
                    skill_names_reversed[f"go_to({obj_name})"] = iterator
                    iterator += 1

        self.skill_names_reversed = skill_names_reversed
        
        return skill_dict, skill_names


    def carrying(self, obj, state, **kwargs):
        # print(type(state))
        if not isinstance(state, (GymState, tuple)):
            state = (state[0].view(np.ndarray), tuple(state[1].view(np.ndarray)), tuple(state[2].view(np.ndarray)))

        if state is not None:
            if state[1][0] == 1:
                return False
            else:
                if isinstance(obj, type):
                    return state[1][0] == OBJECT_TO_IDX[obj.__name__.lower()]
                else:
                    type_idx, color_idx, _ = state[1]
                    # This may fail depending on whether the object was instantiated in an RLang file or in a python file!
                    if isinstance(obj, MDPObjectGrounding):
                        obj = obj.obj
                    # print("type:", type_idx == OBJECT_TO_IDX[type(obj).__name__.lower()])
                    # print("color:", color_idx == (
                    #     obj.color(state=state) if isinstance(obj.color, Feature) else obj.color))
                    return (type_idx == OBJECT_TO_IDX[type(obj).__name__.lower()]) and (color_idx == (obj.color(state=state) if isinstance(obj.color, Feature) else obj.color))
        return False
    
    def carrying_something(self, state, **kwargs):
        if not isinstance(state, (GymState, tuple)):
            state = (state[0].view(np.ndarray), tuple(state[1].view(np.ndarray)), tuple(state[2].view(np.ndarray)))
        
        return state[1][0] != 1
    
    def reachable(self, obj, state, **kwargs):
        if self.at(obj, state, **kwargs):
            return True
        
        return self.go_to(obj, state, **kwargs) is not None
    
    def at(self, obj, state, **kwargs):      
        if not isinstance(state, (GymState, tuple)):
            state = (state[0].view(np.ndarray), tuple(state[1].view(np.ndarray)), tuple(state[2].view(np.ndarray)))
        
        agent_x, agent_y, agent_dir = state[2]

        # print(agent_x,agent_y,obj_x,obj_y,agent_dir)
        # agent_loc = np.array((agent_x, agent_y))
        DIR_TO_VEC = [
            # Pointing right (positive X)
            [1,0],
            # Down (positive Y)
            [0,1],
            # Pointing left (negative X)
            [-1,0],
            # Up (negative Y)
            [0,-1],
        ]

        if isinstance(obj, type):
            new_x = agent_x + DIR_TO_VEC[agent_dir][0]
            new_y = agent_y + DIR_TO_VEC[agent_dir][1]

            if new_x < 0 or new_x >= len(state[0]) or new_y < 0 or new_y >= len(state[0][0]):
                return False

            return state[0][new_x][new_y][0] == OBJECT_TO_IDX[obj.__name__.lower()]
        else:
            # print(type(obj.x))
            obj_x = obj.x(state=state) if isinstance(obj.x, Feature) else obj.x
            obj_y = obj.y(state=state) if isinstance(obj.y, Feature) else obj.y

            return (agent_x+DIR_TO_VEC[agent_dir][0] == obj_x) and (agent_y+DIR_TO_VEC[agent_dir][1] == obj_y)

    def go_to(self, obj, state, return_action=True, **kwargs):
        if not return_action:
            # We only want to return the index of the skill based on the object type and color
            if self.skill_names_reversed is None:
                self.generate_skill_dict(state)
            return self.skill_names_reversed[f"go_to({obj.name})"]
            # return self.skill_names_reversed[f"IDX_TO_COLOR[color_idx]"]

        # This go_to function only brings the agent to face the object, not to be on top of it in the case of a goal
        if not isinstance(state, (GymState, tuple)):
            state = (state[0].view(np.ndarray), tuple(state[1].view(np.ndarray)), tuple(state[2].view(np.ndarray)))
        
        agent_x, agent_y, agent_dir = tuple(state[2])

        direction = shortest_path(state[0], (agent_x, agent_y), (obj.x, obj.y))
        if direction is None:   # There is no path to the object
            return None
        
        dir = {(1,0): 0, (0,1): 1, (-1,0): 2, (0,-1): 3}[direction[0]]
        last_step = len(direction) == 2

        if agent_dir != dir:
            diff = (dir - agent_dir) % 4
            # choose the next action based on the difference
            if diff == 1 or diff == 3:
                if diff == 1:
                    return 1
                else:
                    return 0  # rotate counterclockwise
            elif diff == 2:
                return 1
        else: # at this point the agent is facing the right direction.
            # if this is the last step in the path, we should go forward if we can. If we can't then we return None
            # if it is not the last step, just go forward. last_step is a boolean
            # You can only walk through open doors, empty space, and goals. ids 1, 8, or 4 with is_open == 0
            if not last_step:
                return 2
            else:
                # check if the space in front of the agent is walkable
                if state[0][agent_x + direction[0][0]][agent_y + direction[0][1]][0] in (1, 8, 4) and state[0][agent_x + direction[0][0]][agent_y + direction[0][1]][2] == 0:
                    return 2
                else:
                    return None
    
    def get_state_from_complete_predictions(self, state, predictions):
        new_state = copy.deepcopy(state)
        # print(new_state)
        # print(predictions)
        image, inventory, agent_info = new_state

        inventory = list(inventory)
        agent_info = list(agent_info)

        for k, v in predictions.items():
            if v == {}:
                continue
            v = int(list(v.keys())[0])

            if OBJECT_TO_IDX[k.grounding.obj.name] == 10: # if agent
                if k.attribute_chain == ['x']:
                    agent_info[0] = v
                elif k.attribute_chain == ['y']:
                    agent_info[1] = v
                elif k.attribute_chain == ['dir']:
                    agent_info[2] = v
            elif OBJECT_TO_IDX[k.grounding.obj.name] == 4: # if door
                print(k.grounding.obj)
                if k.attribute_chain == ['is_open']:
                    raise NotImplementedError
                # TODO: Modify door state in the image
            
            # TODO: Extend this for objects other than agent.

            # old_object = self.knowledge[k.grounding.obj.name](state=state)
            # if v == {}:
            #     continue
            # v = int(list(v.keys())[0])

            # if k.attribute_chain == ['dir']:
            #     state 

            # old_x = old_object.x
            # old_y = old_object.y
            # old_color = state[old_x][old_y][1]
            # old_state = state[old_x][old_y][2]

            # replacement = [1, 0, 0]
            # doors = self.get_objects_by_class(4)
            # if doors:
            #     for cls, x, y, c, s in doors:
            #         if x == old_x and y == old_y:
            #             replacement = [4, c, 0]
            #             break

            # new_state[old_x][old_y] = replacement

            # if k.attribute_chain == ['x']:
            #     old_x = v
            # elif k.attribute_chain == ['y']:
            #     old_y = v
            # elif k.attribute_chain == ['color']:
            #     old_color = v
            # elif k.attribute_chain == ['is_open']:
            #     old_state = 0 if v else 1
            # elif k.attribute_chain == ['is_locked']:
            #     old_state = 2 if v else 1
            # else:
            #     return None

            # obj_class_id = self.class_names_to_add[type(k.grounding.obj).__name__]
            # new_state[old_x][old_y] = [obj_class_id, old_color, old_state]
        
        # raise NotImplementedError

        return (image, tuple(inventory), tuple(agent_info))


class StateFeaturizer:
    def __init__(self, state_tuple, knowledge):
        self.state = state_tuple[0]
        self.inv = state_tuple[1]
        self.knowledge = knowledge
        self.object_directory = {}
        self.class_names_to_add = {'Lava': 9, 'Key': 5, 'Door': 4, 'Box': 7, 'Agent': 10, 'GoalTile': 8, 'Wall': 2, 'Ball': 6}
        # self.generate_objects()
        
        # May need to change the constructor here to exclude state, unsure.

    def update_objects(self, state):
        def find_closest_object(x_, y_, k__, c_, s_, missing_objs):
            # This function finds the closest object of the same class and color
            # calculate manhattan distance between x_, y_ and all objects in missing_objs that are the same class and color
            # return the object with the smallest distance
            # if no objects are found, return None
            lowest_distance = 1000
            indx = -1
            for i_ in missing_objs:
                obj = self.object_directory[i_]
                if obj[0] == k__ and obj[3] == c_:
                    distance = abs(x_ - obj[1]) + abs(y_ - obj[2])
                    if distance < lowest_distance:
                        lowest_distance = distance
                        indx = i_
            return indx


        self.state = state

        dummy_state = copy.deepcopy(state)
        
        # Update the object directory. Do I iterate through the state or iterate through the objects?
        # Something new may pop up in the state. First of all, let's see all the objects that have not moved.
        missing_obj_inds = []
        for k_, v_ in self.object_directory.items():
            k, x, y, c, s = v_
            kcs = np.array([k, c, s])
            if np.all(dummy_state[x][y] == kcs):
                # print("Object has not changed", dummy_state[x][y], k, c, s)
                # Do something here, like cross off the dummy_state entry
                dummy_state[x][y] = [-5, -5, -5]
            # Check for objects that have changed state
            elif np.all(dummy_state[x][y][:2] == kcs[:2]):
                # print("Object has changed state", dummy_state[x][y], k, c, s)
                self.object_directory[k_] = (k, x, y, c, int(dummy_state[x][y][2]))
                dummy_state[x][y] = [-5, -5, -5]    # I think this is 251 for some reason. probably unsigned.
                # Do Something here like update the object directory
            else:
                # print("Object has moved or is missing") # I don't think objects can spotaneously change color, but they can move or disappear
                missing_obj_inds.append(k_)
        
        # missing_objects = [self.object_directory[k] for k in missing_obj_inds]
        # Now we need to find the missing objects
        # We iterate through the dummy state and try to map to nearest missing object
        for i in range(dummy_state.shape[0]):
            for j in range(dummy_state.shape[1]):
                if dummy_state[i][j][0] in self.class_names_to_add.values():
                    # print(dummy_state[i][j])
                    indx = find_closest_object(i, j, *dummy_state[i][j], missing_obj_inds)
                    if indx != -1:
                        # Update the object directory
                        self.object_directory[indx] = (int(dummy_state[i][j][0]), i, j, int(dummy_state[i][j][1]), int(dummy_state[i][j][2]))
                        dummy_state[i][j] = [-5, -5, -5]
                        # Don't forget to remove indx from missing_obj_inds
                        missing_obj_inds.remove(indx)
                        # print("An object was found")
                    else:
                        # In the case that an object has appeared, we should spawn a new object.
                        print("An object has appeared!!")
                        # print(dummy_state[i][j])
                        pass
                        # IDK what to do in this case
        # if len(missing_obj_inds) > 0:
        #     print("Some objects are missing:", len(missing_obj_inds))
        #     for i in missing_obj_inds:
        #         print(self.object_directory[i])
        #     pass

        return missing_obj_inds
            

    def get_object_from_id(self, object_id, state=None):
        if state is not None:
            self.update_objects(state)
        return self.object_directory[object_id]
    

    def get_objects_by_class(self, class_id):
        objs = []
        for obj in self.object_directory.values():
            if obj[0] == class_id:
                objs.append(obj)
        return objs
    

    def is_on_a(self, obj, cls, state, **kwargs):
        if state is not None:
            self.update_objects(state)
        if obj.name == 'agent':
            class_id = 10
        agent = self.get_objects_by_class(class_id)[0]

        if isinstance(cls, MDPClassGrounding):
            cls = cls()
        
        cls_id = self.class_names_to_add[cls.__name__]
        objs = self.get_objects_by_class(cls_id)

        for obj in objs:
            if obj[1] == agent[1] and obj[2] == agent[2]:
                return True
        return False
    
    def in_inventory(self, obj, state, **kwargs):
        if state is not None:
            missing_items = self.update_objects(state)
        missing_items = [self.object_directory[o] for o in missing_items]

        for o in missing_items:
            if o[0] == self.class_names_to_add[type(obj).__name__] and o[3] == obj.color:
                return True
        return False


    def at(self, obj, state, **kwargs):
        # agent_x, agent_y, _, agent_dir = find_obj_by_id(state, 10)
        if state is not None:
            self.update_objects(state)
        agent = self.get_objects_by_class(10)[0]
        agent_x = agent[1]
        agent_y = agent[2]
        agent_dir = agent[4]

        obj_x = obj.x
        obj_y = obj.y

        # print(agent_x,agent_y,obj_x,obj_y,agent_dir)
        # agent_loc = np.array((agent_x, agent_y))
        DIR_TO_VEC = [
            # Pointing right (positive X)
            [1,0],
            # Down (positive Y)
            [0,1],
            # Pointing left (negative X)
            [-1,0],
            # Up (negative Y)
            [0,-1],
        ]

        if (agent_x+DIR_TO_VEC[agent_dir][0] == obj_x) and (agent_y+DIR_TO_VEC[agent_dir][1] == obj_y):
            return True
        else:
            return False

    def at_any(self, cls, state, **kwargs):
        # print(cls)
        if state is not None:
            self.update_objects(state)

        agent = self.get_objects_by_class(10)[0]

        agent_x = agent[1]
        agent_y = agent[2]
        agent_dir = agent[4]

        DIR_TO_VEC = [
            # Pointing right (positive X)
            [1,0],
            # Down (positive Y)
            [0,1],
            # Pointing left (negative X)
            [-1,0],
            # Up (negative Y)
            [0,-1],
        ]
        if cls.__name__ not in self.class_names_to_add.keys():
            return False
        lava_objs = self.get_objects_by_class(self.class_names_to_add[cls.__name__])

        for obj in lava_objs:
            # print(agent_x, agent_y, agent_dir)
            # print(obj)
            if (agent_x+DIR_TO_VEC[agent_dir][0] == obj[1]) and (agent_y+DIR_TO_VEC[agent_dir][1] == obj[2]):
                return True
        return False
    

    def generate_objects(self):
        # This function is just for generating an initial list of objects and storing them
        
        # take subset of classes dictionary that are in class_names_to_add
        classes_to_add = {k: self.knowledge.classes()[k] for k in self.class_names_to_add.keys()}
        class_instances = {}
        rlang_objects = {}
        object_directory = {}

        i = 0
        for k, v in classes_to_add.items():
            class_instances[k] = []
            for obj_naming_iterator in range(500): # Assumbing a maximum of 500 object instances per class
                x, y, c, s = find_obj_by_id(state=self.state, id=self.class_names_to_add[k], that_isnt=class_instances[k])
                if x != -1 and y != -1:
                    # print(k, x, y, c, s)
                    class_instances[k].append((x, y))
                    object_directory[i] = (self.class_names_to_add[k], x, y, c, s)
                    # This doesn't work :(
                    # Todo: Give IDs to objects based on their count. Check one by one and increment the count
                    if k == 'Door':
                        obj_name = f"{IDX_TO_COLOR[c]}_{k.lower()}"
                        iter = 2
                        while obj_name in rlang_objects:
                            obj_name = f"{IDX_TO_COLOR[c]}_{k.lower()}_{iter}"
                            iter += 1

                        rlang_objects[obj_name] = MDPObjectGrounding(obj=v(name=obj_name,
                                                                                x=Feature(lambda state, action=None, id=i: self.get_object_from_id(id, state)[1], name="x"),
                                                                                y=Feature(lambda state, action=None, id=i: self.get_object_from_id(id, state)[2], name="y"),
                                                                                color=ConstantGrounding(value=c, codomain=Domain.REAL_VALUE),
                                                                                is_open=Feature(lambda state, action=None, id=i: self.get_object_from_id(id, state)[4], name="is_open") == 0,
                                                                                is_locked=Feature(lambda state, action=None, id=i: self.get_object_from_id(id, state)[4], name="is_locked") == 2),
                                                                        domain=Domain.STATE)
                    elif k == 'Agent':
                        rlang_objects["agent"] = MDPObjectGrounding(obj=v(name="agent",
                                                                        x=Feature(lambda state, action=None, id=i: self.get_object_from_id(id, state)[1], name="x"),
                                                                        y=Feature(lambda state, action=None, id=i: self.get_object_from_id(id, state)[2], name="y"),
                                                                        dir=Feature(lambda state, action=None, id=i: self.get_object_from_id(id, state)[4], name="dir")),
                                                                        domain=Domain.STATE)
                    elif k == 'Lava':
                        rlang_objects[f"lava_{obj_naming_iterator}"] = MDPObjectGrounding(obj=v(name=f"lava_{obj_naming_iterator+1}",
                                                                    x=Feature(lambda state, action=None, id=i: self.get_object_from_id(id, state)[1], name="x"),
                                                                    y=Feature(lambda state, action=None, id=i: self.get_object_from_id(id, state)[2], name="y")),
                                                                        domain=Domain.STATE)
                    elif k in ('Key', 'Ball', 'Box', 'GoalTile'):
                        rlang_objects[f"{IDX_TO_COLOR[c]}_{k.lower()}" if k is not 'GoalTile' else 'goal'] = MDPObjectGrounding(obj=v(name=f"{IDX_TO_COLOR[c]}_{k.lower()}" if k is not 'GoalTile' else 'goal',
                                                                        x=Feature(lambda state, action=None, id=i: self.get_object_from_id(id, state)[1], name="x"),
                                                                        y=Feature(lambda state, action=None, id=i: self.get_object_from_id(id, state)[2], name="y"),
                                                                        color=ConstantGrounding(value=c, codomain=Domain.REAL_VALUE)),
                                                                        domain=Domain.STATE)
                    i += 1
        
        self.object_directory = object_directory
            
        # print(class_instances)
        # print(object_directory)
        # print(rlang_objects)
        rlang_objects['is_on_a'] = Predicate(self.is_on_a, name="is_on_a")
        rlang_objects['at'] = Predicate(self.at, name="at")
        rlang_objects['at_any'] = Predicate(self.at_any, name="at_any")
        rlang_objects['in_inventory'] = Predicate(self.in_inventory, name="in_inventory")

        return rlang_objects
    
    def get_state_from_complete_predictions(self, state, predictions):
        new_state = copy.deepcopy(state)
        # print(new_state)

        for k, v in predictions.items():
            old_object = self.knowledge[k.grounding.obj.name](state=state)
            if v == {}:
                continue
            v = int(list(v.keys())[0])

            if k.attribute_chain == ['dir']:
                state 

            old_x = old_object.x
            old_y = old_object.y
            old_color = state[old_x][old_y][1]
            old_state = state[old_x][old_y][2]

            replacement = [1, 0, 0]
            doors = self.get_objects_by_class(4)
            if doors:
                for cls, x, y, c, s in doors:
                    if x == old_x and y == old_y:
                        replacement = [4, c, 0]
                        break

            new_state[old_x][old_y] = replacement

            if k.attribute_chain == ['x']:
                old_x = v
            elif k.attribute_chain == ['y']:
                old_y = v
            elif k.attribute_chain == ['color']:
                old_color = v
            elif k.attribute_chain == ['is_open']:
                old_state = 0 if v else 1
            elif k.attribute_chain == ['is_locked']:
                old_state = 2 if v else 1
            else:
                return None

            obj_class_id = self.class_names_to_add[type(k.grounding.obj).__name__]
            new_state[old_x][old_y] = [obj_class_id, old_color, old_state]

        return new_state


def get_primitives_for(state_tuple, env=None):
    knowledge = get_stable_knowledge()
    statefeaturizer = SmartStateFeaturizer(knowledge=knowledge)
    knowledge.update(statefeaturizer.generate_rlang_objects(state_tuple))
    # print(list(knowledge.keys()))
    # print("Mission:", env.env.mission)
    return list(knowledge.keys())


def get_knowledge_from_file(state_tuple, filename="rlang_advice/minigrid.rlang", env=None, show_image=False):
    knowledge = get_stable_knowledge()
    statefeaturizer = SmartStateFeaturizer(knowledge=knowledge)
    knowledge.update(statefeaturizer.generate_rlang_objects(state_tuple))
    if show_image:
        print(list(knowledge.keys()))
        print("Mission:", env.env.mission)
        img = env.env.render()
        plt.imshow(img)
        plt.show()
    knowledge = rlang.parse_file(filename, knowledge)
    return knowledge, statefeaturizer
