
import pdb
import copy
import random
import hashlib
import json
from rlang import DictState
from simple_rl.mdp.StateClass import State

def convert_action(action_dict):
    agent_do = [item for item, action in action_dict.items() if action is not None]
    # Make sure only one agent interact with the same object
    if len(action_dict.keys()) > 1:
        if None not in list(action_dict.values()) and sum(['walk' in x for x in action_dict.values()]) < 2:
            # continue
            objects_interaction = [x.split('(')[1].split(')')[0] for x in action_dict.values()]
            if len(set(objects_interaction)) == 1:
                agent_do = [random.choice([0,1])]

    script_list = ['']

    for agent_id in agent_do:
        script = action_dict[agent_id]
        if script is None:
            continue
        current_script = ['<char{}> {}'.format(agent_id, script)]

        script_list = [x + '|' + y if len(x) > 0 else y for x, y in zip(script_list, current_script)]

    # script_list = [x.replace('[walk]', '[walktowards]') for x in script_list]
    return script_list


def args_per_action(action):

    action_dict = {'turnleft': 0,
    'walkforward': 0,
    'turnright': 0,
    'walktowards': 1,
    'open': 1,
    'close': 1,
    'putback':1,
    'putin': 1,
    'put': 1,
    'grab': 1,
    'no_action': 0,
    'walk': 1}
    return action_dict[action]


def can_perform_action(action, o1_id, agent_id, graph, 
                       object_restrictions=None, obj_prop=None, teleport=True):
    """
    Check whether the current action can be done
    Returns None if Action cannot be performed and a formatted action as a string if yes
    """

    if action == 'no_action':
        return None

    obj2_str = ''
    obj1_str = ''
    id2node = {node['id']: node for node in graph['nodes']}
    o1 = id2node[o1_id]['class_name']
    num_args = 0 if o1 is None else 1
    if num_args != args_per_action(action):
        return None
    
    grabbed_objects = [edge['to_id'] for edge in graph['edges'] if edge['from_id'] == agent_id and edge['relation_type'] in ['HOLDS_RH', 'HOLD_LH']]
    close_edge = len([edge['to_id'] for edge in graph['edges'] if edge['from_id'] == agent_id and edge['to_id'] == o1_id and edge['relation_type'] == 'CLOSE']) > 0
    
    if action == 'grab':
        if len(grabbed_objects) > 0:
            return None

    if action.startswith('walk'):
        if o1_id in grabbed_objects:
            return None
    
    if o1_id == agent_id:
        return None

    if (action in ['grab', 'open', 'close']) and not close_edge:
        return None

    if action == 'open':
        if object_restrictions is not None:
            if id2node[o1_id]['class_name'] not in object_restrictions['objects_inside']:
                return None
        if 'OPEN' in id2node[o1_id]['states'] or 'CLOSED' not in id2node[o1_id]['states']:
            return None

    if action == 'close':
        if object_restrictions is not None:
            if id2node[o1_id]['class_name'] not in object_restrictions['objects_inside']:
                return None
        if 'CLOSED' in id2node[o1_id]['states'] or 'OPEN' not in id2node[o1_id]['states']:
            return None

    if 'put' in action:
        if len(grabbed_objects) == 0:
            # print("No grabbed obj")
            return None
        else:
            o2_id = grabbed_objects[0]
            if o2_id == o1_id:
                return None
            o2 = id2node[o2_id]['class_name']
            obj2_str = f'<{o2}> ({o2_id})'

    if o1 is not None:
        obj1_str = f'<{o1}> ({o1_id})'
    
    if o1_id in id2node.keys():
        if id2node[o1_id]['class_name'] == 'character':
            return None

    if action.startswith('put'):
        if len(grabbed_objects) == 0:
            return None
        if object_restrictions is not None:
            if id2node[o1_id]['class_name'] in object_restrictions['objects_inside']:
                action = 'putin'
            if id2node[o1_id]['class_name'] in object_restrictions['objects_surface']:
                action = 'putback'
        else:
            if "CONTAINERS" in id2node[o1_id]['properties']:
                action = 'putin'
            else:
                action = 'putback'

    if action.startswith('walk') and teleport:
        action = 'walkto'

    action_str = f'[{action}] {obj2_str} {obj1_str}'.strip()
    # print(action_str)
    return action_str


def generate_all_available_actions(state, restriction_dict, high_level_actions=['walk', 'open', 'close', 'put', 'grab']):
    state_graph = state.data[0]
    # state_graph is a dictionary with keys "edges" and "nodes". The value for "nodes" is a list of dictionaries, each corresponding to a single node.
    actionable_objects_id_to_node = {node_dict['id']: node_dict for node_dict in state_graph['nodes'] if node_dict['id']}
    close_objects = [edge['to_id'] for edge in state_graph['edges'] if edge['from_id'] == 1 and edge['relation_type'] == 'CLOSE']

    # Our goal is to select an action that we can apply to a node whose ID is in action_space_ids.
    all_possible_actions = []
    for action_name in high_level_actions:
        object_ids = []
        if action_name == "grab":
            object_ids = [id for id, object in actionable_objects_id_to_node.items() if object['class_name'] in restriction_dict["GRABBABLE"] and id in close_objects]
        elif action_name in ('walk', 'run'):
            object_ids = list(actionable_objects_id_to_node.keys())
        elif action_name == "open" or action_name == "close":
            object_ids = [id for id, object in actionable_objects_id_to_node.items() if object['class_name'] in restriction_dict["CAN_OPEN"] and id in close_objects]
        elif action_name.startswith('put'):
            for put_action in ["putin", "putback"]:
                obj_type = "CONTAINERS" if put_action == "putin" else "SURFACES"
                object_ids = [id for id, object in actionable_objects_id_to_node.items() if object['class_name'] in restriction_dict[obj_type] and id in close_objects]
                for object_id in object_ids:
                    action_str = can_perform_action(action_name, object_id, 1, state_graph, obj_prop=obj_type, teleport=False)
                    if action_str is not None:
                        # print(action_str)
                        all_possible_actions.append(action_str)
        for object_id in object_ids:
            action_str = can_perform_action(action_name, object_id, 1, state_graph, teleport=False)
            if action_str is not None:
                all_possible_actions.append(action_str)
    # print(all_possible_actions)
    return all_possible_actions

def state_hash_fn(state, is_unwrapped_state_graph=False):
    # if not is_unwrapped_state_graph:
    #     state_graph = state
    # else:
    #     state_graph = state
    
    if type(state) == DictState:
        state = state.dict_state[0]
        # print("DictState:", state)
    elif type(state) == State:
        state = state.data[0]
        # print("statclass:", state)

    sorted_dict = json.dumps(state, sort_keys=True)
    sample_hash = hashlib.sha256()
    sample_hash.update(sorted_dict.encode())
    # sample_hash.update(str(state).encode("utf-8"))
    # print(sample_hash.hexdigest()[:16])
    return sample_hash.hexdigest()[:16]
