# import virtualhome
import agents   # Trust me keep this
# from dynamic_grounding import get_knowledge_from_file, get_primitives_for
# from environment import SimpleRLVirtualHomeEnv
import openai
import json
import base64

localkeys = json.load(open('keys.json', 'r'))

openai.api_key = localkeys['key']

def encode_image(image_path):
  with open(image_path, "rb") as image_file:
    return base64.b64encode(image_file.read()).decode('utf-8')

# text-davinci-003 is deprecated, gpt-3.5-turbo-instruct is used instead.
def execute_completion(prompt, **kwargs):
    response = openai.Completion.create(
        engine="gpt-4",
        prompt=prompt,
        temperature=0.0,
        stop= ["\n\n"],
        **kwargs
    )
    return response.choices[0].text.strip()

def execute_completion_new(prompt=None, question=None, messages=None, images=None):
    client = openai.OpenAI(api_key=localkeys['key'])

    if prompt:
        completion = client.chat.completions.create(
            model="gpt-4o",
            messages=[
                {"role": "system", "content": prompt},
                {"role": "user", "content": question},
                # {"role": "assistant", "content": "The Los Angeles Dodgers won the World Series in 2020."},
                # {"role": "user", "content": "Where was it played?"}
            ]
        )
    else:
        completion = client.chat.completions.create(
            model="gpt-4o-mini",
            messages=messages
        )

    return completion.choices[0].message.content

def stage_0(advice, primitives, new_openai=True):
    messages = json.load(open('./prompts/stage-0.txt', 'r'))

    new_message = dict({"role": "user", "content": "Entities: " + str(primitives) + "\nAdvice: " + advice})
    messages.append(new_message)

    response = execute_completion_new(messages=messages)
    return response


def stage_1(advice, new_openai=False):
    """This function asks the LLM which RLang grounding to attempt to translate the advice to."""
    prompt = open('./prompts/stage-1.txt', 'r').read()
    augmented_prompt = prompt + "Advice = " + f'"{advice.strip()}"' + "\n" + "Grounding: "
    if new_openai:
        return execute_completion_new(prompt, "Advice = " + f'"{advice.strip()}"' + "\n" + "Grounding: ")
    else:
        return execute_completion(augmented_prompt, max_tokens=20)


def stage_2(advice, primitives, grounding_type, new_openai=False):
    if grounding_type == "Effect":
        prompt = open('./prompts/effect-advice.txt', 'r').read()
    elif grounding_type == "Policy":
        prompt = open('./prompts/policy-advice.txt', 'r').read()
    elif grounding_type == "Plan":
        prompt = open('./prompts/plan-advice.txt', 'r').read()
    
    augmented_prompt = prompt + "\nAdvice = " + f'"{advice.strip()}"' + "\n" + "Primitives = " + str(primitives) + "\n"
    if new_openai:
        return execute_completion_new(prompt, "\nAdvice = " + f'"{advice.strip()}"' + "\n" + "Primitives = " + str(primitives) + "\n")
    else:
        return execute_completion(augmented_prompt, max_tokens=200)

def stage_3(advice, primitives):
    prompt = open('./prompts/stage-3.txt', 'r').read()

    augmented_prompt = prompt + "Primitives = " + str(primitives) + "\n" + "RLang: " + "\n" + advice + "\n"
    response = execute_completion(augmented_prompt, max_tokens=30)
    if response == "no":
        return None
    else:
        return response.split('\n')

def nl2rlang(advices, **kwargs):
    recording_options={'recording': False, 
                    # 'output_folder': "/users/zyang157/git/nl2rlang/virtualhome-baseline/recordings", 
                    'output_folder': localkeys['recording_folder'],
                    'file_name_prefix': "test",
                    # 'cameras': 'PERSON_FROM_BACK',
                    'modality': 'normal'}
    executable_args={
        # 'file_name': '/users/zyang157/data/zyang157/virtualhome/exec_v2.3.0/linux_exec.v2.3.0.x86_64',
        'file_name': localkeys['executable'],
        'no_graphics': True,
        'logging': False
    }

    mdp = SimpleRLVirtualHomeEnv(use_editor=False, 
                                 executable_args=executable_args, 
                                 recording_options=recording_options, 
                                 observation_types=['full_trimmed_large'],
                                 handmade_reward_fn=hard_reward_fn_2)
    primitives = get_primitives_for(mdp.reset())
    print(primitives)
    
    rlang_advices = []
    for advice in advices:
        grounding_selection = stage_1(advice)
        rlang_advice = stage_2(advice, primitives, grounding_selection)
        rlang_advices.append(rlang_advice)
        # if kwargs['update_advice_dict'] == True:
        #     with open(advice_dict_fpath, "r") as f: advice_dict = json.load(f)
        #     nl2rlang_inst = {"advice": advice, "grounding": grounding_selection, "rlang": rlang_advice}
        #     if env_name in advice_dict.keys():
        #         advice_dict[env_name].append(nl2rlang_inst)
        #     else:
        #         advice_dict[env_name] = [nl2rlang_inst]
        #     with open(advice_dict_fpath, "w") as f: json.dump(advice_dict, f, sort_keys=True, indent=4 * ' ')
    return rlang_advices

def medium_reward_fn_2(state, action, next_state):
    """Gives a reward if pie 319 is put in the fridge 305 and salmon 327 is put in the microwave 313, and both fridge and microwave are closed"""
    next_state_graph = next_state.data[0]
    things_in_fridge_305 = [edge['from_id'] for edge in next_state_graph['edges'] if edge['to_id'] == 305 and edge['relation_type'] == 'INSIDE']
    things_in_microwave_313 = [edge['from_id'] for edge in next_state_graph['edges'] if edge['to_id'] == 313 and edge['relation_type'] == 'INSIDE']
    fridge_is_closed = 'CLOSED' in [node["states"] for node in next_state_graph['nodes'] if node['id'] == 305][0]
    microwave_is_closed = 'CLOSED' in [node["states"] for node in next_state_graph['nodes'] if node['id'] == 313][0]
    if 319 in things_in_fridge_305 and 327 in things_in_microwave_313 and fridge_is_closed and microwave_is_closed:
        return 5, True  # This also returns whether the environment ends!
    else:
        return 0, False # This also returns whether the environment ends!
    


def hard_reward_fn_2(state, action, next_state):
    """Gives a reward if remotecontrol 452 is put on sofa 368 and cereal 334 is put in the cabinet 415, penalize if toothpaste 62 is picked up"""
    next_state_graph = next_state.data[0]
    things_on_sofa_368 = [edge['from_id'] for edge in next_state_graph['edges'] if edge['to_id'] == 368 and edge['relation_type'] == 'ON']
    things_in_cabinet_415 = [edge['from_id'] for edge in next_state_graph['edges'] if edge['to_id'] == 415 and edge['relation_type'] == 'INSIDE']
    things_in_hands = [edge['to_id'] for edge in next_state_graph['edges'] if edge['from_id'] == 1 and 'HOLDS' in edge['relation_type']]
    if 452 in things_on_sofa_368 and 334 in things_in_cabinet_415:
        return 5, True  # This also returns whether the environment ends!
    elif 62 in things_in_hands:
        return -1, False # Get penalzed but not killed
    else:
        return 0, False # This also returns whether the environment ends!


def nl2rlang_exp1():
    advices = [
        "Go to fridge and open it, and then go find the pie and pick it up, walk back to the fridge and put the pie in the fridge. You have to close the fridge too",
        "If the salmon is in the microwave, and you are at the microwave and it's open, close it. Otherwise if you are holding salmon, do the following: open the microwave if you are near it but it's closed, put the salmon into the microwave if it's open and you're near it, else walk to the microwave.",
        "If the pie is in the fridge, and the salmon is in the microwave, then closing the fridge if the microwave is closed or closing the microwave if the friedge is closed will give you reward and end the episode."
    ]
    rlang_advices = nl2rlang(advices)
    [print(rlang_advice) for rlang_advice in rlang_advices]

def nl2rlang_exp2():
    advices = [
        "If you're holding the toothpaste and can drop it, drop it.",
        "Go grab the remote control and put it on the sofa.",
        "If you're holding the toothpaste and not trying to drop it, you will be penalized. Also, nothing will happen if you try to walk to the remote control, cereal, toothpaste, or salmon, if you try to walk to them and they are contained inside anything."
    ]
    rlang_advices = nl2rlang(advices)
    [print(rlang_advice) for rlang_advice in rlang_advices]


def nl2rlang_vlm(advices):
    primitives = ['Toothpaste', 'Cereal', 'Bathroom', 'Sofa', 'Salmon', 'Pie', 'Kitchentable', 'Remotecontrol', 'Fridge', 'Microwave', 'Kitchen', 'Bookshelf', 'Livingroom', 'Cabinet']
    rlang_advices = []
    for advice in advices:
        print("Advice: " + advice)
        entites = stage_0(advice, primitives)
        print(entites)
        # grounding_selection = stage_1(advice, new_openai=True)
        # print("Grounding type: "+ grounding_selection)
        # rlang_advice = stage_2(advice, primitives, grounding_selection, new_openai=True)
        # print(rlang_advice)
        # response = stage_3(rlang_advice, primitives)
        # print(response)
        # rlang_advices.append(rlang_advice)

    return rlang_advices


def nl2rlang_vlm_exp():
    with open('vlm_advice.txt', 'r') as file:
        advices = file.read().splitlines()
    
    rlang_advices = nl2rlang_vlm(advices)
    # [print(rlang_advice) for rlang_advice in rlang_advices]

if __name__ == "__main__":
    # nl2rlang_exp1()
    nl2rlang_vlm_exp()
