from openai import OpenAI
import json
import os
import ast
import random
from argparse import ArgumentParser
from tenacity import retry, stop_after_attempt, wait_random_exponential

from utils import *


@retry(wait=wait_random_exponential(min=1, max=60), stop=stop_after_attempt(6))
def call_openai(client, gpt_model, messages, temperature):
    return (
        client.chat.completions.create(
            model=gpt_model, messages=messages, temperature=temperature
        )
        .choices[0]
        .message.content
    )


def main(args):
    with open(f"experiment_configs/{args.config}", "r") as f:
        config = json.loads(f.read())

    cache = {}
    run_name = config["run_name"]

    def get_ex(file_path):
        if not file_path in cache:
            prompt, init, goal, plan, data = build_in_context_ex(
                run_name, file_path, config["yaml"], original=config["original_prompt"]
            )
            tup = (prompt, init, goal, plan, data)
            cache[file_path] = tup

            return tup

        return cache[file_path]

    client = OpenAI(api_key=os.getenv("OPENAI_API_KEY"))
    filenames = list(os.listdir(config["instances"]))
    sort_nicely(filenames)
    output_json = []
    for i, instance_filename in enumerate(filenames[1:], 1):
        cache_key = f"{config['instances']}/{instance_filename}"
        actual_prompt, actual_init, actual_goal, actual_plan, actual_data = get_ex(
            cache_key
        )

        example_key = f"{config['instances']}/{filenames[i-1]}"
        if config["random_example"]:
            example_key = f"{config['instances']}/{random.choice(filenames)}"
        example_prompt, example_init, example_goal, example_plan, example_data = get_ex(
            example_key
        )

        if config["original_prompt"]:
            context, instance1, plan1 = split_prompt_old(example_prompt)
            _, instance2, plan2 = split_prompt_old(actual_prompt)
        else:
            context, instance1, plan1 = split_prompt_new(example_prompt)
            _, instance2, plan2 = split_prompt_new(actual_prompt)

        messages = [
            {"role": "system", "content": "You are a planning agent"},
            {"role": "user", "content": context + "\n\n" + instance1},
            {"role": "assistant", "content": plan1},
            {"role": "user", "content": instance2},
        ]
        malformed = True
        while malformed:
            completion = call_openai(
                client=client,
                gpt_model=config["model"],
                messages=messages,
                temperature=config["temperature"],
            )

            val = evaluate_old if config["original_prompt"] else evaluate_new
            feedback = val(
                run_name,
                actual_plan,
                completion,
                read_config("yamls/blocksworld.yaml"),
                cache_key,
                config["yaml"]
            )
            feedback["in_context_example_pddl"] = example_key
            feedback["instance_pddl"] = cache_key
            import pdb; pdb.set_trace()
            if not ("malformed" in feedback and feedback["malformed"]):
                malformed = False
                output_json.append(feedback)
                with open(config["output_file"], "w") as f:
                    f.write(json.dumps(output_json))


if __name__ == "__main__":
    parser = ArgumentParser()
    parser.add_argument("--config", type=str)

    main(parser.parse_args())
