import json
import argparse
import utils
import tqdm
import typing
from tqdm.contrib.concurrent import process_map
import functools
from multiprocessing import Pool
import typing
from meteor_reasoner.utils.loader import load_dataset, load_program
from meteor_reasoner.utils import parser as parser


client = None


def default_argument_parser():
    parser = argparse.ArgumentParser(description="verify DatalogMTL data")
    parser.add_argument("--rational-number", action='store_true')
    parser.add_argument("--multiple-body-atoms", action='store_true')
    parser.add_argument("--recursive", action='store_true')
    parser.add_argument("--variable", action='store_true')
    parser.add_argument("--mixed-operators", action='store_true')
    parser.add_argument("--multiple-rules", action='store_true')
    parser.add_argument("--mixed-operators-2", action='store_true')
    parser.add_argument("--mixed-operators-3", action='store_true')
    parser.add_argument("--mixed-operators-4", action='store_true')
    parser.add_argument("--fixed-new-nodes", type=int)
    parser.add_argument("--fixed-rules", type=int)
    parser.add_argument("--single", action='store_true')
    return parser


def to_sign(features: set, additional_args: dict) -> str:
    if "mix_operators_2" in features:
        assert (len(features) == 1)
        return "mix_operators_2"
    if "mix_operators_3" in features:
        assert (len(features) == 1)
        return "mix_operators_3"
    if "mix_operators_4" in features:
        assert (len(features) == 1)
        return "mix_operators_4"

    feature_list = list(features)
    
    if len(feature_list) == 0:
        feature_list.append("basic")
    for k, v in additional_args.items():
        if v is None:
            continue
        feature_list.append("%s=%s" % (k, str(v)))
    feature_list.sort()
    ret = "-".join(feature_list)
    return ret

def infer(instance:dict) -> bool:
    dataset = load_dataset(instance['data'])
    program = load_program(instance['rule'])

    dataset, _ = utils.infer(dataset, program)
    
    predicate, entity, interval = parser.parse_str_fact(instance['query'])
    return utils.entail(dataset, (predicate, entity, interval))

def main():
    global client
    args = default_argument_parser().parse_args()
    features = set()
    FEATURES_LIST = ["rational_number", "multiple_body_atoms",
                     "recursive", "variable", "mixed_operators",
                     "multiple_rules", "multiple_rules_2", "multiple_rules_3",
                     "mixed_operators_2", "mixed_operators_3", "mixed_operators_4"]
    ADDITIONAL_ARGS_KEYS = ["fixed_new_nodes", "fixed_rules"]
    additional_args = {}
    for arg in vars(args):
        if getattr(args, arg) and arg in FEATURES_LIST:
            features.add(arg)
        if arg in ADDITIONAL_ARGS_KEYS:
            additional_args[arg] = getattr(args, arg)
    data_path = "data/%s.json" % to_sign(features, additional_args)
    instances = json.load(open(data_path))
    results_messages = []

    if args.single:
        results_messages = []
        for instance_to_eval in tqdm.tqdm(instances):
            results_messages.append(infer(instance_to_eval))
    else:
        results_messages = process_map(
            infer, instances, max_workers=8)

    # for instance in tqdm.tqdm(instances_to_eval):
    #     result = query_gpt(instance, prompt_type=args.prompt_type, model=model)
    #     results.append(result)
        # print(result)

    for result, instance in zip(results_messages, instances):
        # print(instance)
        assert(result == instance['valid'])


if __name__ == "__main__":
    main()
