from nl import convert_to_nl
import json
import utils
from key import OPENAI_API_KEY
from openai import OpenAI
import tqdm

client = OpenAI(api_key=OPENAI_API_KEY)

sym_data = json.load(open("data/fixed_rules=2-multiple_rules.json"))
sym_samples = sym_data[0:53]
nl_samples = list(map(convert_to_nl, sym_samples))

def query(icl_nl_samples, icl_sym_samples, nl_query):
    messages = [{
        "role": "system",
        "content": "Convert the following logic rules in natural language to symbolic language. In the symbolic language,\
    the rules are expressed as DatalogMTL, a knowledge representation language that extends Datalog with operators from metric temporal logic (MTL).\
    The semantics of four MTL operators are given as follows:\
    If Diamondminus[a,b]A is true at the time t, it requires that A needs to be true at some time between t-b and t-a.\
    If Boxminus[a,b]A is true at the time t, it requires that A needs to be true continuously between t-b and t-a.\
    If Diamondplus[a,b]A is true at the time t, it requires that A needs to be true at some point between t+a and t+b.\
    If Boxplus[a,b]A is true at the time t, it requires that A needs to be true continuously between t+a and t+b.\
    I will give you few examples to help you better understand the expected output"
    }]
    for nl_sample, sym_sample in zip(icl_nl_samples, icl_sym_samples):
        messages.append({
            "role": "user",
            "content": "For the verbalized representation %s, you should output %s"%(json.dumps(nl_sample), json.dumps(sym_sample))
        })
    messages.append({
            "role": "user",
            "content": "For the verbalized representation %s, what is the its symbolic representation in DatalogMTL? You should only output json without addition symbols or texts"%(json.dumps(nl_query))
        })

    response = utils.ask_gpt(client, messages=messages, model='gpt-4o')
    messages.append({"role": "assistant", "content": response})
    # print(messages)
    start = response.find('{')
    end = response.rfind('}')
    json_str = response[start:end+1]
    # print(json_str)
    return json.loads(json_str)

icl_nl_samples = nl_samples[0:3]
icl_sym_samples = sym_samples[0:3]
nl_samples = nl_samples[3:]
sym_samples = sym_samples[3:]

assert(len(nl_samples) == len(sym_samples))

tot = len(nl_samples)
correct = 0

for i in tqdm.tqdm(list(range(tot))):
    pred_sym_sample = query(icl_nl_samples, icl_sym_samples, nl_samples[i])
    if json.dumps(pred_sym_sample) == json.dumps(sym_samples[i]):
        correct+=1

print(correct/tot)