import numpy as np
import rlang
from rlang.grounding import ActionReference, MDPObject, MDPObjectGrounding, Feature, ConstantGrounding, Domain, ParameterizedAction, Predicate, Plan
# from utils import shortest_path
from environment.utils import can_perform_action
import os
import json
from environment.utils import generate_all_available_actions

curr_dir = os.path.dirname(os.path.realpath(__file__))

restriction_dict_path = f'{curr_dir}/restrict_dict.json'
with open(restriction_dict_path, 'r') as f:
    restriction_dict = json.load(f)


def _walk_to(obj, state, **kwargs):
    # print(state.data)
    # print(type(state.dict_state))
    action = can_perform_action('walk', obj.id, 1, state.dict_state.data[0], teleport=False)
    return action

def _open(obj, state, **kwargs):
    action = can_perform_action('open', obj.id, 1, state.dict_state.data[0], teleport=False)
    return action

def _close(obj, state, **kwargs):
    action = can_perform_action('close', obj.id, 1, state.dict_state.data[0], teleport=False)
    return action

def _putin(obj, obj2=None, state=None, **kwargs):
    if obj2 is None:
        action = can_perform_action('putin', obj.id, 1, state.dict_state.data[0], teleport=False)
    else:
        action = can_perform_action('putin', obj2.id, 1, state.dict_state.data[0], teleport=False)
    return action

def _puton(obj, obj2=None, state=None, **kwargs):
    if obj2 is None:
        action = can_perform_action('putback', obj.id, 1, state.dict_state.data[0], teleport=False)
    else:
        action = can_perform_action('putback', obj2.id, 1, state.dict_state.data[0], teleport=False)
    return action

def _drop(obj, state=None, **kwargs):
    # print(list(restriction_dict.keys()))
    possible_actions = generate_all_available_actions(state=state.dict_state, restriction_dict=restriction_dict, high_level_actions=["put"])
    if len(possible_actions) == 0:
        with open('hard_rw_2_failstate.json', 'w') as f:
            json.dump(state.dict_state.data[0], f)
        raise ValueError("Can't put the object down anywhere")
    return possible_actions[0]

def _grab(obj, state, **kwargs):
    action = can_perform_action('grab', obj.id, 1, state.dict_state.data[0], teleport=False)
    return action


def _can_drop(obj, state=None, **kwargs):
    possible_actions = generate_all_available_actions(state=state.dict_state, restriction_dict=restriction_dict, high_level_actions=["put"])
    return len(possible_actions) > 0


def _inside(obj1, obj2, state, **kwargs):
    for edge in state.dict_state.data[0]['edges']:
        if edge['from_id'] == obj1.id and edge['to_id'] == obj2.id:
            if edge['relation_type'] == "INSIDE":
                return True
    
    return False

def _on(obj1, obj2, state, **kwargs):
    for edge in state.dict_state.data[0]['edges']:
        if edge['from_id'] == obj1.id and edge['to_id'] == obj2.id:
            if edge['relation_type'] == "ON":
                return True
    
    return False

def _at(obj, state, **kwargs):
    for edge in state.dict_state.data[0]['edges']:
        if edge['from_id'] == 1 and edge['to_id'] == obj.id and edge['relation_type'] == "CLOSE":
            return True
    
    return False

def _is_closed(obj, state, **kwargs):
    for node in state.dict_state.data[0]['nodes']:
        if node['id'] == obj.id:
            return "CLOSED" in node['states']
    return False

def _is_open(obj, state, **kwargs):
    for node in state.dict_state.data[0]['nodes']:
        if node['id'] == obj.id:
            return "OPEN" in node['states']
    return False

def _holding(obj, state, **kwargs):
    for edge in state.dict_state.data[0]['edges']:
        if edge['to_id'] == obj.id:
            return edge['relation_type'] in ['HOLDS_RH', 'HOLD_LH']
    return False

def _near(obj, state, **kwargs):
    for edge in state.dict_state.data[0]['edges']:
        if edge['to_id'] == obj.id and edge['from_id'] == 1:
            return edge['relation_type'] == "CLOSE"
    return False

def _inside_something(obj, state, **kwargs):
    def is_room(id):
        for node in state.dict_state.data[0]['nodes']:
            if node['id'] == id:
                return node['category'] == "Rooms"

    for edge in state.dict_state.data[0]['edges']:
        if edge['from_id'] == obj.id and not is_room(edge['from_id']):
            if edge['relation_type'] == "INSIDE":
                return True
    
    return False


def _is_drop(i, **kwargs):
    a_string = kwargs['action'][0].item()
    return "put" in a_string


inside = Predicate(_inside, name='inside')
on = Predicate(_on, name='on')
at = Predicate(_at, name='at')
is_closed = Predicate(_is_closed, name='at')
is_open = Predicate(_is_open, name='is_open')
holding = Predicate(_holding, name='holding')
near = Predicate(_near, name='near')
can_drop = Predicate(_can_drop, name='can_drop')
is_drop = Predicate(_is_drop, name='is_drop')
inside_something = Predicate(_inside_something, name='inside_something')

walk_to = ParameterizedAction(_walk_to, name='walk_to')
open_ = ParameterizedAction(_open, name='open')
close = ParameterizedAction(_close, name='close')
putin = ParameterizedAction(_putin, name='putin')
puton = ParameterizedAction(_puton, name='puton')
grab = ParameterizedAction(_grab, name='grab')
drop = ParameterizedAction(_drop, name='drop')


def obj_class_constructor(name, properties):
    class_name = name.capitalize()
    attr_list = properties
    def __init__(self, *args, **kwargs):
        for attr, value in zip(attr_list, args):
            setattr(self, attr, value)
        for attr in attr_list:
            if attr not in kwargs:
                continue
            setattr(self, attr, kwargs.get(attr))
    return type(class_name, (MDPObject,), {'attr_list': attr_list, '__init__': __init__})

def get_stable_knowledge():
    knowledge = rlang.knowledge.RLangKnowledge()

    classes = ["bookshelf", "fridge", "pie", "microwave", "kitchen", "salmon", "character", "toothpaste", "sofa", "cereal", "cabinet", "remotecontrol", "bathroom", "bedroom", "kitchentable", "livingroom"]
    props = ["name", "id"]

    for c in set(classes):
        cls = obj_class_constructor(c, props)
        knowledge.update({c.capitalize(): cls})

    knowledge.update({
        "walk_to": walk_to, 
        "open": open_, 
        "close": close, 
        "putin": putin, 
        "puton": puton,
        "grab": grab, 
        "drop": drop,
        "can_drop": can_drop,
        "is_drop": is_drop,
        "inside": inside, 
        "inside_something": inside_something,
        "on": on, 
        "at": at, 
        "is_closed": is_closed, 
        "is_open": is_open,
        "holding": holding,
        "near": near
    })

    return knowledge
