import  openai
import json

openai.api_key = json.load(open('keys.json', 'r'))['key']

def execute_completion(prompt, **kwargs):
    response = openai.Completion.create(
        engine="text-davinci-003",
        prompt=prompt,
        **kwargs
    )
    return response.choices[0].text.strip()


def stage_1(advice):
    """This function asks the LLM which RLang grounding to attempt to translate the advice to."""
    prompt = open('./prompts/stage-1.txt', 'r').read()
    augmented_prompt = prompt + "Advice = " + advice + "\n" + "Grounding: "
    return execute_completion(augmented_prompt, max_tokens=20)


def stage_2(advice, primitives, grounding_type):
    if grounding_type == "Effect":
        prompt = open('./prompts/effect-advice.txt', 'r').read()
    elif grounding_type == "Policy":
        prompt = open('./prompts/policy-advice.txt', 'r').read()
    elif grounding_type == "Plan":
        prompt = open('./prompts/plan-advice.txt', 'r').read()
    
    augmented_prompt = prompt + "Advice = " + advice + "\n" + "Primitives = " + str(primitives) + "\n"
    return execute_completion(augmented_prompt, max_tokens=200)


def test_prompt():
    prompt = open('./prompts/general-v1.txt', 'r').read()
    prompt = open('./prompts/effect-advice.txt', 'r').read()

    advice = "Don't step into Lava.\n"
    primitives = ['Agent', 'Wall', 'GoalTile', 'Lava', 'Key', 'Door', 'Box', 'Ball', 'left', 'right', 'forward', 'pickup', 'drop', 'toggle', 'done', 'pointing_right', 'pointing_down', 'pointing_left', 'pointing_up', 'go_to', 'step_towards', 'agent', 'goal', 'is_on_a', 'at', 'at_any', 'in_inventory']

    augmented_prompt = prompt + "Advice = " + advice + "Primitives = " + str(primitives)

    print(augmented_prompt)

    return execute_completion(augmented_prompt, max_tokens=100)

def test_dialogue():
    # prompt_to_rlang = open('./prompts/general-v1.txt', 'r').read()
    prompt_to_rlang = open('./prompts/effect-advice.txt', 'r').read()
    prompt_to_nl = open('./prompts/dialogue_general-v1.txt', 'r').read()

    advice = "Don't step into Lava.\n"
    primitives = ['Agent', 'Wall', 'GoalTile', 'Lava', 'Key', 'Door', 'Box', 'Ball', 'left', 'right', 'forward', 'pickup', 'drop', 'toggle', 'done', 'pointing_right', 'pointing_down', 'pointing_left', 'pointing_up', 'go_to', 'step_towards', 'agent', 'goal', 'is_on_a', 'at', 'at_any', 'in_inventory']

    augmented_prompt_to_rlang = prompt_to_rlang + "Advice = " + advice + "Primitives = " + str(primitives)
    rlang = execute_completion(augmented_prompt_to_rlang, max_tokens=100)
    augmented_prompt_to_nl = f"{prompt_to_nl}\nPrimitives = {str(primitives)}\n{rlang}\nAdvice ="
    nl = execute_completion(augmented_prompt_to_nl, max_tokens=100)

    print(f"Advice: {advice}")
    print(f"RLang generated: {rlang}")
    print(f"NL translated back: {nl}")

if __name__ == '__main__':
    # test_prompt()
    # stage_1("Don't step into Lava.\n")
    test_dialogue()
