import virtualhome
from simple_rl.run_experiments import run_agents_on_mdp
import agents
from environment import SimpleRLVirtualHomeEnv
from environment import utils
import os
import json
from dynamic_grounding import VHStateFeaturizer, get_knowledge_from_file
from grounding import get_stable_knowledge

curr_dir = os.path.dirname(os.path.realpath(__file__))

restriction_dict_path = f'{curr_dir}/object_action_info.json'
with open(restriction_dict_path, 'r') as f:
    restriction_dict = json.load(f)

# Switch between envs by changing obs_type
# Simple and medium reward functions work for both full_trimmed and full_trimmed_large
def simple_reward_fn_1(state, action, next_state):
    """Gives a reward if pie 319 is put in the fridge 305"""
    next_state_graph = next_state.data[0]
    things_in_fridge_305 = [edge['from_id'] for edge in next_state_graph['edges'] if edge['to_id'] == 305 and edge['relation_type'] == 'INSIDE']
    # print(things_in_cabinet_234)
    if 319 in things_in_fridge_305:
        return 5, True  # This also returns whether the environment ends!
    else:
        return 0, False # This also returns whether the environment ends!

def simple_reward_fn_2(state, action, next_state):
    """Gives a reward if pie 319 is put in the fridge 305 and fridge is closed"""
    next_state_graph = next_state.data[0]
    things_in_fridge_305 = [edge['from_id'] for edge in next_state_graph['edges'] if edge['to_id'] == 305 and edge['relation_type'] == 'INSIDE']
    fridge_is_closed = 'CLOSED' in [node["states"] for node in next_state_graph['nodes'] if node['id'] == 305][0]
    if 319 in things_in_fridge_305 and fridge_is_closed:
        return 5, True  # This also returns whether the environment ends!
    else:
        return 0, False # This also returns whether the environment ends!

def medium_reward_fn_1(state, action, next_state):
    """Gives a reward if pie 319 is put in the fridge 305 and salmon 327 is put in the microwave 313"""
    next_state_graph = next_state.data[0]
    things_in_fridge_305 = [edge['from_id'] for edge in next_state_graph['edges'] if edge['to_id'] == 305 and edge['relation_type'] == 'INSIDE']
    things_in_microwave_313 = [edge['from_id'] for edge in next_state_graph['edges'] if edge['to_id'] == 313 and edge['relation_type'] == 'INSIDE']
    if 319 in things_in_fridge_305 and 327 in things_in_microwave_313:
        return 5, True  # This also returns whether the environment ends!
    else:
        return 0, False # This also returns whether the environment ends!

def medium_reward_fn_2(state, action, next_state):  # AKA pie_salmon_closed.rlang
    """Gives a reward if pie 319 is put in the fridge 305 and salmon 327 is put in the microwave 313, and both fridge and microwave are closed"""
    next_state_graph = next_state.data[0]
    things_in_fridge_305 = [edge['from_id'] for edge in next_state_graph['edges'] if edge['to_id'] == 305 and edge['relation_type'] == 'INSIDE']
    things_in_microwave_313 = [edge['from_id'] for edge in next_state_graph['edges'] if edge['to_id'] == 313 and edge['relation_type'] == 'INSIDE']
    fridge_is_closed = 'CLOSED' in [node["states"] for node in next_state_graph['nodes'] if node['id'] == 305][0]
    microwave_is_closed = 'CLOSED' in [node["states"] for node in next_state_graph['nodes'] if node['id'] == 313][0]
    if 319 in things_in_fridge_305 and 327 in things_in_microwave_313 and fridge_is_closed and microwave_is_closed:
        return 5, True  # This also returns whether the environment ends!
    else:
        return 0, False # This also returns whether the environment ends!

# Below are harder reward functions for full_trimmed_large ONLY
def hard_reward_fn_1(state, action, next_state):
    """Gives a reward if remotecontrol 452 is put on sofa 368 and penalize if cereal 334 is on sofa 368"""
    next_state_graph = next_state.data[0]
    things_on_sofa_368 = [edge['from_id'] for edge in next_state_graph['edges'] if edge['to_id'] == 368 and edge['relation_type'] == 'ON']
    if 452 in things_on_sofa_368:
        return 5, True  # This also returns whether the environment ends!
    elif 334 in things_on_sofa_368:
        return -5, True # Get killed when obj is picked up
    else:
        return 0, False # This also returns whether the environment ends!

def hard_reward_fn_2(state, action, next_state):
    """Gives a reward if remotecontrol 452 is put on sofa 368 and cereal 334 is put in the cabinet 415, penalize if toothpaste 62 is picked up"""
    next_state_graph = next_state.data[0]
    things_on_sofa_368 = [edge['from_id'] for edge in next_state_graph['edges'] if edge['to_id'] == 368 and edge['relation_type'] == 'ON']
    things_in_cabinet_415 = [edge['from_id'] for edge in next_state_graph['edges'] if edge['to_id'] == 415 and edge['relation_type'] == 'INSIDE']
    things_in_hands = [edge['to_id'] for edge in next_state_graph['edges'] if edge['from_id'] == 1 and 'HOLDS' in edge['relation_type']]
    if 452 in things_on_sofa_368 and 334 in things_in_cabinet_415:
        return 5, True  # This also returns whether the environment ends!
    elif 62 in things_in_hands:
        return -1, False # Get penalzed but not killed
    else:
        return 0, False # This also returns whether the environment ends!

def hard_reward_fn_3(state, action, next_state):
    """Gives a reward if salmon 327 is in microwave 313 and microwave 313 is turned on, penalize if cereal 334 is in microwave 313"""
    next_state_graph = next_state.data[0]
    things_in_microwave_313 = [edge['from_id'] for edge in next_state_graph['edges'] if edge['to_id'] == 313 and edge['relation_type'] == 'INSIDE']
    microwave_is_on = 'ON' in [n['states'] for n in next_state_graph["nodes"] if n["class_name"] == 'microwave'][0]
    if 327 in things_in_microwave_313 and microwave_is_on:
        return 5, True  # This also returns whether the environment ends!
    elif 334 in things_in_microwave_313:
        return -5, True # Get killed
    else:
        return 0, False # This also returns whether the environment ends!


def main():
    localkeys = json.load(open('keys.json', 'r'))
    
    recording_options={'recording': False, 
                    'output_folder': localkeys['recording_folder'],
                    'file_name_prefix': "test",
                    # 'cameras': 'PERSON_FROM_BACK',
                    'modality': 'normal'}
    executable_args={
        'file_name': localkeys['executable'],
        'no_graphics': True,
        'logging': False
    }

    learning_params={
        'alpha': 0.1,
        'gamma': 0.99,
        'epsilon': 0.1,
        'anneal': True
    }

    restrict_dict_path = f'restrict_dict.json'
    with open(restrict_dict_path, 'r') as f:
        restrict_dict = json.load(f)

    mdp = SimpleRLVirtualHomeEnv(use_editor=False, 
                                 executable_args=executable_args, 
                                 recording_options=recording_options, 
                                 observation_types=['full_trimmed_large'],
                                 handmade_reward_fn=hard_reward_fn_2)

    random_agent = agents.RandomVHAgent(actions=mdp.get_action_space(), 
                                        actions_available=mdp.actions_available,
                                        get_action_space=mdp.get_action_space, 
                                        can_perform_action=utils.can_perform_action, 
                                        get_graph=mdp.get_graph, 
                                        name="Random")
    
    q_agent = agents.VHQLearningAgentSimple(state_hash_func=utils.state_hash_fn,
                                            available_action_function=lambda *args, **kwargs: utils.generate_all_available_actions(*args, restriction_dict=restrict_dict, **kwargs),)

    dyna_agent = agents.VHDynaQLearningAgent(state_hash_func=utils.state_hash_fn,
                                             available_action_function=lambda *args, **kwargs: utils.generate_all_available_actions(*args, restriction_dict=restrict_dict, **kwargs),
                                             **learning_params)

    rlang_policy_agent = agents.VHRLangDynaQLearningAgent(state_hash_func=utils.state_hash_fn,
                                                          name="RLang-Dyna-Q-policy",
                                                   available_action_function=lambda *args, **kwargs: utils.generate_all_available_actions(*args, restriction_dict=restrict_dict, **kwargs),
                                                   get_knowledge=lambda state, filename=f"rlang_advice/hard_rw_2.rlang": get_knowledge_from_file(state, filename, mdp),
                                                   use_policy=True, use_plan=False, use_effects=False,
                                                   **learning_params)
    
    rlang_plan_agent = agents.VHRLangDynaQLearningAgent(state_hash_func=utils.state_hash_fn,
                                                        name="RLang-Dyna-Q-plan",
                                                   available_action_function=lambda *args, **kwargs: utils.generate_all_available_actions(*args, restriction_dict=restrict_dict, **kwargs),
                                                   get_knowledge=lambda state, filename=f"rlang_advice/hard_rw_2.rlang": get_knowledge_from_file(state, filename, mdp),
                                                   use_policy=False, use_plan=True, use_effects=False,
                                                   **learning_params)
    
    rlang_effect_agent = agents.VHRLangDynaQLearningAgent(state_hash_func=utils.state_hash_fn,
                                                          name="RLang-Dyna-Q-effect",
                                                   available_action_function=lambda *args, **kwargs: utils.generate_all_available_actions(*args, restriction_dict=restrict_dict, **kwargs),
                                                   get_knowledge=lambda state, filename=f"rlang_advice/hard_rw_2.rlang": get_knowledge_from_file(state, filename, mdp),
                                                   use_policy=False, use_plan=False, use_effects=True,
                                                   **learning_params)
    
    rlang_combined_agent = agents.VHRLangDynaQLearningAgent(state_hash_func=utils.state_hash_fn,
                                                          name="RLang-Dyna-Q-combined",
                                                   available_action_function=lambda *args, **kwargs: utils.generate_all_available_actions(*args, restriction_dict=restrict_dict, **kwargs),
                                                   get_knowledge=lambda state, filename=f"rlang_advice/hard_rw_2.rlang": get_knowledge_from_file(state, filename, mdp),
                                                   use_policy=True, use_plan=True, use_effects=True,
                                                   **learning_params)

    exp_params = {"instances":10, "episodes":70, "steps":100}

    run_agents_on_mdp([rlang_effect_agent, rlang_combined_agent],
                        mdp, **exp_params, open_plot=False, verbose=True,
                        cumulative_plot=True, track_disc_reward=False, seed_per_inst=False, seeds=list(range(1, 51)))
    


if __name__ == "__main__":
    main()
