import json
import argparse
import utils
import typing
import tqdm
import functools
import typing
import os
from meteor_reasoner.utils import parser as parser
from nl import convert_to_nl
from key import keys
import io
import time
import asyncio
import random


def default_argument_parser():
    parser = argparse.ArgumentParser(description="generate DatalogMTL data")
    parser.add_argument("--rational-number", action='store_true')
    parser.add_argument("--multiple-body-atoms", action='store_true')
    parser.add_argument("--recursive", action='store_true')
    parser.add_argument("--variable", action='store_true')
    parser.add_argument("--mixed-operators", action='store_true')
    parser.add_argument("--multiple-rules", action='store_true')
    parser.add_argument("--mixed-operators-2", action='store_true')
    parser.add_argument("--mixed-operators-3", action='store_true')
    parser.add_argument("--mixed-operators-4", action='store_true')
    parser.add_argument("--fixed-new-nodes", type=int)
    parser.add_argument("--fixed-rules", type=int)
    parser.add_argument("--addl-rules", type=int)
    parser.add_argument("--addl-data", type=int)
    parser.add_argument(
        "--mode", choices=["single", "multi", "batch", "eval"], default="batch", required=True)
    parser.add_argument("--nl", action='store_true',
                        help="translate logic into natural language", default=False)
    parser.add_argument("--size", type=int, required=True)
    parser.add_argument(
        "--model", choices=keys.keys(), required=True)
    parser.add_argument("--url", default="https://api.openai.com/v1")
    parser.add_argument(
        "--prompt-type", choices=["zero-shot", "few-shot", "zero-shot-cot"], required=True)
    return parser


def to_sign(features: set, additional_args: dict) -> str:
    if "mix_operators_2" in features:
        assert (len(features) == 1)
        return "mix_operators_2"
    if "mix_operators_3" in features:
        assert (len(features) == 1)
        return "mix_operators_3"
    if "mix_operators_4" in features:
        assert (len(features) == 1)
        return "mix_operators_4"

    feature_list = list(features)

    if len(feature_list) == 0:
        feature_list.append("basic")
    for k, v in additional_args.items():
        if v is None:
            continue
        feature_list.append("%s=%s" % (k, str(v)))
    feature_list.sort()
    ret = "-".join(feature_list)
    return ret


def generate_messages(instance: dict, prompt_type: typing.Literal["zero-shot", "few-shot", "zero-shot-cot"], sample=None, nl: bool = False) -> typing.List:
    info_prompt = "You are given a dataset and a temporal rule, and your task is to judge whether the given fact is entailed by the dataset and the rule. \
The rules are expressed as DatalogMTL, a knowledge representation language that extends Datalog with operators from metric temporal logic (MTL). The semantics of four MTL operators are given as follows:\
If Diamondminus[a,b]A is true at the time t, it requires that A needs to be true at some time between t-b and t-a.\
If Boxminus[a,b]A is true at the time t, it requires that A needs to be true continuously between t-b and t-a.\
If Diamondplus[a,b]A is true at the time t, it requires that A needs to be true at some point between t+a and t+b.\
If Boxplus[a,b]A is true at the time t, it requires that A needs to be true continuously between t+a and t+b."
    if nl:
        info_prompt = ""
    if prompt_type == "zero-shot":
        system_prompt = info_prompt + \
            "You should not give any explanation and you should only output \"true\" or \"false\""
        messages = [
            {
                "role": "system",
                "content": system_prompt
            },
            {
                "role": "user",
                "content": "Now we have some temporal data and some rules, data:\n %s \n\n rule: %s\n\n\
Is %s true or not?" % ("\n".join(instance['data']), "\n".join(instance['rule']), instance['query'])
            }
        ]
        return messages
    elif prompt_type == "few-shot":
        assert (sample is not None)
        assert (sample[0]['valid'])
        assert (not sample[1]['valid'])
        system_prompt = info_prompt + \
            "You should not give any explanation and you should only output \"true\" or \"false\""
        messages = [
            {
                "role": "system",
                "content": system_prompt
            },
            {
                "role": "user",
                "content": "To help you better understand the task, I will provide two examples.\n\
                    Example 1: data:\n %s \n\n rule: %s\n\n in this case you should output \"true\" for %s\n\n\
                    Example 2: data:\n %s \n\n rule: %s\n\n in this case you should output \"false\" for %s\n\n\
                    Now we have some temporal data and some rules, data:\n %s \n\n rule: %s\n\n\
                    Is %s true or not?" % ("\n".join(sample[0]['data']), "\n".join(sample[0]['rule']), sample[0]['query'],
                                           "\n".join(sample[1]['data']), "\n".join(
                                               sample[1]['rule']), sample[1]['query'],
                                           "\n".join(instance['data']), "\n".join(instance['rule']), instance['query'])
            }
        ]
        return messages
    elif prompt_type == "zero-shot-cot":
        system_prompt = info_prompt
        messages = [
            {
                "role": "system",
                "content": system_prompt
            },
            {
                "role": "user",
                "content": "Now we have some temporal data and some rules, data:\n %s \n\n rule: %s\n\n\
                    Is %s true or not? Do not answer directly, think step by step." % ("\n".join(instance['data']), "\n".join(instance['rule']), instance['query'])
            }]
        return messages
    else:
        raise "strategy %s unknown" % prompt_type


async def single_query(client, instance: dict, prompt_type: typing.Literal["zero-shot", "few-shot", "zero-shot-cot"],
                       model: str, sample, nl: bool, max_tokens: int, semaphore: asyncio.Semaphore, task_id: int):
    async with semaphore:
        messages = generate_messages(instance, prompt_type, sample, nl)
        completion = await client.chat.completions.create(
            model=model,
            messages=messages,
            max_tokens=max_tokens,
            temperature=0
        )
        if completion.choices[0].finish_reason != "stop":
            print("[LLM WARNING] the LLM request doesn't stop gracefully, reason: ",
                  completion.choices[0].finish_reason)
        response = completion.choices[0].message.content
        messages.append({"role": "assistant", "content": response})
        if prompt_type == "zero-shot-cot":
            if response == None:
                return messages, task_id
            messages.append(
                {"role": "user", "content": "Based on your previous response, without any explanation state your answers to the question. You should only output \"true\" or \"false\""})
            completion = await client.chat.completions.create(
                model=model,
                messages=messages,
                max_tokens=max_tokens,
                temperature=0
            )
            if completion.choices[0].finish_reason != "stop":
                print("[LLM WARNING] the LLM request doesn't stop gracefully, reason: ",
                      completion.choices[0].finish_reason)
            response = completion.choices[0].message.content
            messages.append({"role": "assistant", "content": response})
        return messages, task_id


async def multi_query(keys: typing.List[dict], prompt_type: str, few_shot_samples, nl: bool, instances_to_eval) -> typing.List[typing.List[dict]]:
    semaphore = asyncio.Semaphore(16)
    clients = list(map(utils.build_async_client, keys))
    tasks = []
    for i, instance in enumerate(instances_to_eval):
        selected_client_id = i % len(keys)
        tasks.append(single_query(clients[selected_client_id],
                                  instance, prompt_type, keys[selected_client_id]['NAME'], few_shot_samples, nl,
                                  keys[selected_client_id]['MAX_TOKEN'], semaphore, i))

    results = [None] * len(tasks)
    for coro in tqdm.tqdm(asyncio.as_completed(tasks), total=len(tasks), desc="Fetching URLs"):
        ret = await coro
        results[ret[1]] = ret[0]

    return results


def batch_query(key: dict, prompt_type: str, few_shot_samples, nl: bool, instances_to_eval):
    model_name = key["NAME"]

    process_function = functools.partial(
        generate_messages, prompt_type=prompt_type, sample=few_shot_samples, nl=nl)
    messages = list(map(process_function, instances_to_eval))
    requests = map(lambda idx_m: {"custom_id": str(idx_m[0]), "method": "POST", "url": "/v1/chat/completions", "body": {"model": model_name, "messages": idx_m[1], "max_tokens": key['MAX_TOKEN']}},
                   enumerate(messages))
    request_list = "\n".join(map(json.dumps, requests))
    # print(request_list)

    buffer = io.BytesIO(request_list.encode())
    client = utils.build_client(key)
    batch_input_file = client.files.create(
        file=buffer,
        purpose="batch"
    )
    batch_input_file_id = batch_input_file.id
    batch_request_result = client.batches.create(
        input_file_id=batch_input_file_id,
        endpoint="/v1/chat/completions",
        completion_window="24h",
        metadata={
            "description": "tBen eval"
        }
    )
    print(batch_request_result)
    while True:
        batch_status = client.batches.retrieve(batch_request_result.id)
        print(batch_status)
        if batch_status.status in ["completed", "failed", "cancelled"]:
            break
        time.sleep(5)

    response_objs = list(map(json.loads, client.files.content(
        batch_status.output_file_id).text.split("\n")[:-1]))
    assert len(response_objs) == len(messages)
    results_messages = [None] * len(messages)
    for msg, response_obj in zip(messages, response_objs):
        response_msg = response_obj['response']['body']['choices'][0]['message']
        results_messages[int(response_obj['custom_id'])] = msg + [response_msg]
    return results_messages


async def main():
    args = default_argument_parser().parse_args()
    selected_keys = keys[args.model]

    features = set()
    FEATURES_LIST = ["rational_number", "multiple_body_atoms",
                     "recursive", "variable", "mixed_operators",
                     "multiple_rules", "multiple_rules_2", "multiple_rules_3",
                     "mixed_operators_2", "mixed_operators_3", "mixed_operators_4"]
    ADDITIONAL_ARGS_KEYS = ["fixed_new_nodes",
                            "fixed_rules", "addl_data", "addl_rules"]
    additional_args = {}
    for arg in vars(args):
        if getattr(args, arg) and arg in FEATURES_LIST:
            features.add(arg)
        if arg in ADDITIONAL_ARGS_KEYS:
            if getattr(args, arg) is not None:
                additional_args[arg] = getattr(args, arg)
    data_path = "data/%s.json" % to_sign(features, additional_args)
    instances = json.load(open(data_path))
    results_messages = []
    instances_to_eval = []
    few_shot_samples = []
    for instance in instances[:args.size]:
        instances_to_eval.append(instance)
        assert (instance['valid'])
    for instance in instances[-args.size:]:
        instances_to_eval.append(instance)
        assert (not instance['valid'])
    few_shot_samples.append(instances[args.size])
    few_shot_samples.append(instances[-args.size-1])

    if args.nl:
        instances_to_eval = list(map(convert_to_nl, instances_to_eval))
        few_shot_samples = list(map(convert_to_nl, few_shot_samples))

    assert (len(instances[:args.size]) == args.size)
    assert (len(instances[-args.size:]) == args.size)
    if args.nl:
        result_file = "evaluator/results/nl_%s_%d_%s_%s.json" % (
            to_sign(features, additional_args), args.size, args.model, args.prompt_type)
    else:
        result_file = "evaluator/results/%s_%d_%s_%s.json" % (
            to_sign(features, additional_args), args.size, args.model, args.prompt_type)
    if args.mode == "eval":
        print("Evaluating %s"%result_file)
        results_messages = json.load(open(result_file))
    if args.mode == "single":
        results_messages = []
        client = utils.build_async_client(selected_keys[0])
        for instance_to_eval in tqdm.tqdm(instances_to_eval):
            ret_obj = await single_query(client, instance_to_eval, args.prompt_type,
                                         model=selected_keys[0]['NAME'], sample=few_shot_samples, nl=args.nl, semaphore=asyncio.Semaphore(), task_id=0)
            results_messages.append(ret_obj[0])
        json.dump(results_messages, open(result_file, "w"))
    elif args.mode == "multi":
        results_messages = await multi_query(selected_keys, prompt_type=args.prompt_type,
                                             few_shot_samples=few_shot_samples, nl=args.nl, instances_to_eval=instances_to_eval)
        json.dump(results_messages, open(result_file, "w"))
    elif args.mode == "batch":
        if len(selected_keys) != 1:
            raise Exception("Only one key per model is supported at this time")
        if args.prompt_type == "zero-shot-cot":
            raise Exception("CoT is not supported in batch mode")
        results_messages = batch_query(selected_keys[0], prompt_type=args.prompt_type,
                                       few_shot_samples=few_shot_samples, nl=args.nl, instances_to_eval=instances_to_eval)

        json.dump(results_messages, open(result_file, "w"))

    # for instance in tqdm.tqdm(instances_to_eval):
    #     result = query_gpt(instance, prompt_type=args.prompt_type, model=model)
    #     results.append(result)
        # print(result)

    tot_cnt = 0

    tp = fp = tn = fn = 0  # true positive, false positive, true negative, false negative

    for result_message, instance in zip(results_messages, instances_to_eval):
        message = result_message
        tot_cnt += 1
        if message[-1]['content'] == None:
            continue
        response = message[-1]['content'].lower()
        if "</think>" in response:
            response = response.split("</think>")[1].strip()

        if response == "true":
            result = True
        elif response == "false":
            result = False
        else:
            print("Invalid response, %s, found"% response)
            continue

        expected = instance['valid']
        if expected == True and result == True:
            tp += 1
        elif expected == False and result == True:
            fp += 1
        elif expected == True and result == False:
            fn += 1
        elif expected == False and result == False:
            tn += 1

    precision = tp / (tp + fp) if (tp + fp) > 0 else 0
    recall = tp / (tp + fn) if (tp + fn) > 0 else 0
    f1 = 0
    if precision + recall != 0:
        f1 = 2 * (precision * recall) / (precision + recall)

    print("Acc", round((tp+tn)/(tn+fp+tp+fn), 3))
    print("F1", round(f1, 3))


if __name__ == "__main__":
    asyncio.run(main())
