import numpy as np
import rlang
from rlang.grounding import ActionReference, MDPObject, MDPObjectGrounding, Feature, ConstantGrounding, Domain, ParameterizedAction, Predicate, Plan
from utils import shortest_path

class Agent(MDPObject):
    attr_list = ['name', 'x', 'y', 'dir']

    def __init__(self, name, x, y, dir):
        self.name = name
        self.x = x
        self.y = y
        self.dir = dir

class Wall(MDPObject):
    attr_list = ['name', 'x', 'y']

    def __init__(self, name, x, y):
        self.name = name
        self.x = x
        self.y = y

class Door(MDPObject):
    attr_list = ['name', 'x', 'y', 'color', 'is_open', 'is_locked']

    def __init__(self, name, x, y, color, is_open, is_locked):
        self.name = name
        self.x = x
        self.y = y
        self.color = color
        self.is_open = is_open
        self.is_locked = is_locked

class Key(MDPObject):
    attr_list = ['name', 'x', 'y', 'color']

    def __init__(self, name, x, y, color):
        self.name = name
        self.x = x
        self.y = y
        self.color = color

class Ball(MDPObject):
    attr_list = ['name', 'x', 'y', 'color']

    def __init__(self, name, x, y, color):
        self.name = name
        self.x = x
        self.y = y
        self.color = color

class Box(MDPObject):
    attr_list = ['name', 'x', 'y', 'color']

    def __init__(self, name, x, y, color):
        self.name = name
        self.x = x
        self.y = y
        self.color = color

class GoalTile(MDPObject):
    attr_list = ['name', 'x', 'y', 'color']

    def __init__(self, name, x, y, color):
        self.name = name
        self.x = x
        self.y = y
        self.color = color

class Lava(MDPObject):
    attr_list = ['name', 'x', 'y']

    def __init__(self, name, x, y):
        self.name = name
        self.x = x
        self.y = y


left = ActionReference(0, name='left')
right = ActionReference(1, name='right')
forward = ActionReference(2, name='forward')
pickup = ActionReference(3, name='pickup')
drop = ActionReference(4, name='drop')
toggle = ActionReference(5, name='toggle')
done = ActionReference(6, name='done')

pointing_right = ConstantGrounding(codomain=Domain.REAL_VALUE, value=0, name='pointing_right')
pointing_down = ConstantGrounding(codomain=Domain.REAL_VALUE, value=1, name='pointing_down')
pointing_left = ConstantGrounding(codomain=Domain.REAL_VALUE, value=2, name='pointing_left')
pointing_up = ConstantGrounding(codomain=Domain.REAL_VALUE, value=3, name='pointing_up')


def find_obj_by_id(state, id=10, that_isnt=None, has_color=None):
    print(state)
    for i in range(state.shape[0]):
        for j in range(state.shape[1]):
            if state[i][j][0] == id:
                # print(id, (i, j))
                if that_isnt is None:
                    if has_color is None:
                        return i, j, int(state[i][j][1]), int(state[i][j][2])
                    elif int(state[i][j][1]) == has_color:
                        return i, j, int(state[i][j][1]), int(state[i][j][2])
                    else:
                        # print(f"Rejected because of color: {int(state[i][j][1])}, {has_color}")
                        continue
                elif (i, j) in that_isnt:
                    continue
                else:
                    if has_color is None:
                        return i, j, int(state[i][j][1]), int(state[i][j][2])
                    elif int(state[i][j][1]) == has_color:
                        return i, j, int(state[i][j][1]), int(state[i][j][2])
                    else:
                        continue
    return -1,-1,-1,-1

# agent_x = Feature(lambda state: find_obj_by_id(state, 10)[0], name='agent_x')
# agent_y = Feature(lambda state: find_obj_by_id(state, 10)[1], name='agent_y')
# agent_dir = Feature(lambda state: find_obj_by_id(state, 10)[3], name='agent_dir')

# goal_x = Feature(lambda state: find_obj_by_id(state, 8)[0], name='goal_x')
# goal_y = Feature(lambda state: find_obj_by_id(state, 8)[1], name='goal_y')


# goal = MDPObjectGrounding(obj=GoalTile("goal", goal_x, goal_y, 1))
# agent = MDPObjectGrounding(obj=Agent("agent", agent_x, agent_y, agent_dir))

def turn_towards(agent_dir, goal_dir):
    if agent_dir == goal_dir:
        return None
    else:
        return 0

def go_to_obj(obj, state):
    # print(obj)
    agent_x, agent_y, _, agent_dir = find_obj_by_id(state, 10)

    direction = shortest_path(state, (agent_x, agent_y), (obj.x, obj.y))
    if direction is None:   # There is no path to the object
        return None
    
    dir = {(1,0): 0, (0,1): 1, (-1,0): 2, (0,-1): 3}[direction[0]]
    last_step = len(direction) == 2

    if agent_dir != dir:
        diff = (dir - agent_dir) % 4
        # choose the next action based on the difference
        if diff == 1 or diff == 3:
            if diff == 1:
                return 1
            else:
                return 0  # rotate counterclockwise
        elif diff == 2:
            return 1
    else: # at this point the agent is facing the right direction.
        # if this is the last step in the path, we should go forward if we can. If we can't then we return None
        # if it is not the last step, just go forward. last_step is a boolean
        # You can only walk through open doors, empty space, and goals. ids 1, 8, or 4 with is_open == 0
        if not last_step:
            return 2
        else:
            # check if the space in front of the agent is walkable
            if state[agent_x + direction[0][0]][agent_y + direction[0][1]][0] in (1, 8, 4) and state[agent_x + direction[0][0]][agent_y + direction[0][1]][2] == 0:
                return 2
            else:
                return None

def at_(obj, state):
    # This will only work for things that are 1 space away from the agent
    agent_x, agent_y, _, agent_dir = find_obj_by_id(state, 10)
    obj_x = obj.x
    obj_y = obj.y
    # agent_loc = np.array((agent_x, agent_y))
    DIR_TO_VEC = [
        # Pointing right (positive X)
        [1,0],
        # Down (positive Y)
        [0,1],
        # Pointing left (negative X)
        [-1,0],
        # Up (negative Y)
        [0,-1],
    ]

    if (agent_x+DIR_TO_VEC[agent_dir][0] == obj_x) and (agent_y+DIR_TO_VEC[agent_dir][1] == obj_y):
        return True
    else:
        return False


at = Predicate(at_, name='at')
step_towards = ParameterizedAction(go_to_obj, name='step_towards')
go_to = Plan(function=go_to_obj, name='go_to')

def get_stable_knowledge():
    knowledge = rlang.knowledge.RLangKnowledge()
    knowledge.update({"Agent": Agent, "Wall": Wall, "GoalTile": GoalTile, "Lava": Lava, "Key": Key, "Door": Door, "Box": Box, "Ball": Ball,
                    "left": left, "right": right, "forward": forward, "pickup": pickup, "drop": drop, "toggle": toggle, "done": done,
                    "pointing_right": pointing_right, "pointing_down": pointing_down, "pointing_left": pointing_left, "pointing_up": pointing_up})
    return knowledge
