import json
import argparse
import utils
from key import OPENAI_API_KEY
from openai import OpenAI
import tqdm
import typing
from tqdm.contrib.concurrent import process_map
import functools
from multiprocessing import Pool
import typing
import os
from meteor_reasoner.utils import parser as parser
from nl import convert_to_nl

client = None


def default_argument_parser():
    parser = argparse.ArgumentParser(description="generate DatalogMTL data")
    parser.add_argument("--rational-number", action='store_true')
    parser.add_argument("--multiple-body-atoms", action='store_true')
    parser.add_argument("--recursive", action='store_true')
    parser.add_argument("--variable", action='store_true')
    parser.add_argument("--mixed-operators", action='store_true')
    parser.add_argument("--multiple-rules", action='store_true')
    parser.add_argument("--mixed-operators-2", action='store_true')
    parser.add_argument("--mixed-operators-3", action='store_true')
    parser.add_argument("--mixed-operators-4", action='store_true')
    parser.add_argument("--fixed-new-nodes", type=int)
    parser.add_argument("--fixed-rules", type=int)
    parser.add_argument("--addl-rules", type=int)
    parser.add_argument("--addl-data", type=int)
    parser.add_argument("--single", action='store_true')
    parser.add_argument("--nl", action='store_true', help="translate logic into natural language", default=False)
    parser.add_argument("--size", type=int, required=True)
    parser.add_argument(
        "--model", choices=["gpt-4", "gpt-4o", "llama3-8b-instruct"], required=True)
    parser.add_argument("--eval", action='store_true')
    parser.add_argument("--url", default="https://api.openai.com/v1")
    parser.add_argument(
        "--prompt-type", choices=["zero-shot", "few-shot", "zero-shot-cot"], required=True)
    return parser


def to_sign(features: set, additional_args: dict) -> str:
    if "mix_operators_2" in features:
        assert (len(features) == 1)
        return "mix_operators_2"
    if "mix_operators_3" in features:
        assert (len(features) == 1)
        return "mix_operators_3"
    if "mix_operators_4" in features:
        assert (len(features) == 1)
        return "mix_operators_4"

    feature_list = list(features)

    if len(feature_list) == 0:
        feature_list.append("basic")
    for k, v in additional_args.items():
        if v is None:
            continue
        feature_list.append("%s=%s" % (k, str(v)))
    feature_list.sort()
    ret = "-".join(feature_list)
    return ret


def query_gpt(instance: dict, prompt_type: typing.Literal["zero-shot", "few-shot", "zero-shot-cot"], model: str, sample=None, nl:bool = False) -> typing.Tuple[bool, typing.List[dict]]:
    info_prompt = "You are given a dataset and a temporal rule, and your task is to judge whether the given fact is entailed by the dataset and the rule. \
The rules are expressed as DatalogMTL, a knowledge representation language that extends Datalog with operators from metric temporal logic (MTL). The semantics of four MTL operators are given as follows:\
If Diamondminus[a,b]A is true at the time t, it requires that A needs to be true at some time between t-b and t-a.\
If Boxminus[a,b]A is true at the time t, it requires that A needs to be true continuously between t-b and t-a.\
If Diamondplus[a,b]A is true at the time t, it requires that A needs to be true at some point between t+a and t+b.\
If Boxplus[a,b]A is true at the time t, it requires that A needs to be true continuously between t+a and t+b."
    if nl:
        info_prompt = ""
    if prompt_type == "zero-shot":
        system_prompt = info_prompt + \
            "You should not give any explanation and you should only output \"true\" or \"false\""
        messages = [
            {
                "role": "system",
                "content": system_prompt
            },
            {
                "role": "user",
                "content": "Now we have some temporal data and some rules, data:\n %s \n\n rule: %s\n\n\
Is %s true or not?" % ("\n".join(instance['data']), "\n".join(instance['rule']), instance['query'])
            }
        ]
        response = utils.ask_gpt(client, messages, model=model)
        messages.append({"role": "assistant", "content": response})
        # print(messages, response)
    elif prompt_type == "few-shot":
        assert (sample is not None)
        assert (sample[0]['valid'])
        assert (not sample[1]['valid'])
        system_prompt = info_prompt + \
            "You should not give any explanation and you should only output \"true\" or \"false\""
        messages = [
            {
                "role": "system",
                "content": system_prompt
            },
            {
                "role": "user",
                "content": "To help you better understand the task, I will provide two examples.\n\
                    Example 1: data:\n %s \n\n rule: %s\n\n in this case you should output \"true\" for %s\n\n\
                    Example 2: data:\n %s \n\n rule: %s\n\n in this case you should output \"false\" for %s\n\n\
                    Now we have some temporal data and some rules, data:\n %s \n\n rule: %s\n\n\
                    Is %s true or not?" % ("\n".join(sample[0]['data']), "\n".join(sample[0]['rule']), sample[0]['query'],
                                           "\n".join(sample[1]['data']), "\n".join(sample[1]['rule']), sample[1]['query'],
                                           "\n".join(instance['data']), "\n".join(instance['rule']), instance['query'])
            }
        ]
        # print(messages)
        response = utils.ask_gpt(client, messages, model=model)
        messages.append({"role": "assistant", "content": response})

    elif prompt_type == "zero-shot-cot":
        system_prompt = info_prompt
        messages = [
            {
                "role": "system",
                "content": system_prompt
            },
            {
                "role": "user",
                "content": "Now we have some temporal data and some rules, data:\n %s \n\n rule: %s\n\n\
                    Is %s true or not? Do not answer directly, think step by step." % ("\n".join(instance['data']), "\n".join(instance['rule']), instance['query'])
            }]
        response = utils.ask_gpt(client, messages, model=model)
        messages.append({"role": "assistant", "content": response})
        if response == None:
            return None, messages
        messages.append(
            {"role": "user", "content": "Based on your previous response, without any explanation state your answers to the question. You should only output \"true\" or \"false\""})
        response = utils.ask_gpt(client, messages, model=model)
        messages.append({"role": "assistant", "content": response})
    else:
        raise "strategy %s unknown" % prompt_type
    response = response.lower()
    if response == "true":
        return True, messages
    elif response == "false":
        return False, messages
    else:
        return None, messages

def main():
    global client
    args = default_argument_parser().parse_args()
    model = args.model
    client = OpenAI(api_key=OPENAI_API_KEY, base_url=args.url)
    features = set()
    FEATURES_LIST = ["rational_number", "multiple_body_atoms",
                     "recursive", "variable", "mixed_operators",
                     "multiple_rules", "multiple_rules_2", "multiple_rules_3",
                     "mixed_operators_2", "mixed_operators_3", "mixed_operators_4"]
    ADDITIONAL_ARGS_KEYS = ["fixed_new_nodes",
                            "fixed_rules", "addl_data", "addl_rules"]
    additional_args = {}
    for arg in vars(args):
        if getattr(args, arg) and arg in FEATURES_LIST:
            features.add(arg)
        if arg in ADDITIONAL_ARGS_KEYS:
            if getattr(args, arg) is not None:
                additional_args[arg] = getattr(args, arg)
    data_path = "data/%s.json" % to_sign(features, additional_args)
    instances = json.load(open(data_path))
    results_messages = []
    instances_to_eval = []
    few_shot_samples = []
    for instance in instances[:args.size]:
        instances_to_eval.append(instance)
        assert (instance['valid'])
    for instance in instances[-args.size:]:
        instances_to_eval.append(instance)
        assert (not instance['valid'])
    few_shot_samples.append(instances[args.size])
    few_shot_samples.append(instances[-args.size-1])

    if args.nl:
        instances_to_eval = list(map(convert_to_nl, instances_to_eval))
        few_shot_samples = list(map(convert_to_nl, few_shot_samples))

    assert (len(instances[:args.size]) == args.size)
    assert (len(instances[-args.size:]) == args.size)
    if args.nl:
        result_file = "evaluator/results/nl_%s_%d_%s_%s.json" % (
            to_sign(features, additional_args), args.size, model, args.prompt_type)    
    else:
        result_file = "evaluator/results/%s_%d_%s_%s.json" % (
            to_sign(features, additional_args), args.size, model, args.prompt_type)
    if not args.eval:
        if os.path.exists(result_file):
            print("WARNING: result file is present")
        process_function = functools.partial(
            query_gpt, prompt_type=args.prompt_type, model=model, sample=few_shot_samples, nl=args.nl)
        # with Pool(8) as p:
        #     results = p.map(process_function, instances_to_eval)
        # results_messages = process_map(
        #     process_function, instances_to_eval, max_workers=8)
        if args.single:
            results_messages = []
            for instance_to_eval in tqdm.tqdm(instances_to_eval):
                results_messages.append(process_function(instance_to_eval))
        else:
            results_messages = process_map(
                process_function, instances_to_eval, max_workers=8)
        json.dump(results_messages, open(result_file, "w"))

    else:
        results_messages = json.load(open(result_file))

    # for instance in tqdm.tqdm(instances_to_eval):
    #     result = query_gpt(instance, prompt_type=args.prompt_type, model=model)
    #     results.append(result)
        # print(result)

    tot_cnt = 0
    tot_val = 0
    correct_cnt = 0
    correct_val = 0

    tp = fp = tn = fn = 0  # true positive, false positive, true negative, false negative

    for result_message, instance in zip(results_messages, instances_to_eval):
        result, message = result_message
        expected = instance['valid']
        if expected and result:
            tp += 1
        elif (not expected) and result:
            fp += 1
        elif expected and (not result):
            fn += 1
        elif (not expected) and (not result):
            tn += 1
        if instance['valid']:
            tot_val += 1
        tot_cnt += 1
    precision = tp / (tp + fp) if (tp + fp) > 0 else 0
    recall = tp / (tp + fn) if (tp + fn) > 0 else 0
    f1 = 0
    if precision + recall != 0:
        f1 = 2 * (precision * recall) / (precision + recall)


    print("Acc", round((tp+tn)/(tn+fp+tp+fn), 3))
    print("F1", round(f1, 3))


if __name__ == "__main__":
    main()
