import json
import argparse
import sys
from generate import call_to_open_api
from collections import defaultdict
from functools import reduce
import logging
import os

def safe_mkdir(path):
    if not os.path.exists(path):
        os.mkdir(path)

if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument('-e', '--episode', type=str, required=True)
    parser.add_argument('-v', '--prompt_version', type=str, required=True)
    # parser.add_argument('--prediction-file', type=str, default='P01_11__0_result.json')
    # parser.add_argument('--ground-truth-file', type=str, default='P01_11__0_gt.json')
    # parser.add_argument('--output-file', type=str, default='P01_11__0_eval.json')
    parser.add_argument('--model', type=str, default='gpt-3.5-turbo')

    args = parser.parse_args()

    if 'gpt-4' in args.model:
        print("Evaluation is a O(N x M1 x M2) number of API calls, "
                "where N is the number of action annotations, "
                "M1 is the max number of state-predicates predicted per action, "
                "and M2 is the max number of predicates in ground truth per action. "
                "GPT-4 is an expensive model - HIGHLY SUGGEST to terminate this script here "
                "and use --model gpt-3.5-turbo.", file=sys.stderr)
        input("You have been warned. Press ENTER to continue")
    # video_id = "_".join(args.prediction_file.split('_', 2)[:-1])

    safe_mkdir(os.path.join("eval_data/", args.prompt_version))

    prediction_file = os.path.join("raw_data", args.prompt_version, f"{args.prompt_version}_{args.episode}_output.json")
    ground_truth_file = os.path.join("eval_data/annotation", f"{args.episode}_ann.json")

    with open(prediction_file, 'r') as fin:
        pd_json = json.load(fin)

    with open(ground_truth_file, 'r') as fin:
        gt_json = json.load(fin)

    
    predictions = {}
    for _, preds in pd_json.items():
        predictions.update({pred['id']: pred for pred in preds})
    ground_truths = {}
    for _, gts in gt_json.items():
        ground_truths.update({gt['id']: gt for gt in gts})
    # for pred in predictions:
        # print(f"Prediction: {predictions[pred]}")
        # print(f"Ground truth: {ground_truths[pred]}")
    prompts_json = {}
    prompts_json['header'] = ("You are a helpful AI assistant designed to evaluate state and action predicates "
                            "produced by an open-vocabulary language model that has visual context embedded\n"
                            "Each time you will see all predictions separated by '; ', and one ground-truth predicate\n"
                            "You need to find a predicate that matches the ground-truth predicate, or say there is none.\n"
                            "Actions and objects involved should be the same.\n"
                            "e.g. open('container_1') and take_off('lid_1', 'container_1') is a match (same action same objects)\n"
                            "e.g. on_top_of('container_1', 'kitchentop_1') and on('bottle_1', 'kitchentop_1') do not match (different object)\n"
                            "e.g. pick_up('mug_1') and pour('coffee', 'mug_1') do not match (different actions)\n"
                            "Output template if there is no match:\n"
                            "MATCH:<ground-truth>:none-of-the-above\n"
                            "REASON:no prediction has the same meaning as <ground-truth>\n"
                            "ANSWER:[[no]]\n"
                            "Output template (if there is a match):\n"
                            "MATCH:<ground-truth>:<matched-prediction>\n"
                            "REASON:<reason-for-match>\n"
                            "ANSWER:[[yes]]\n")
    prompts_json['example'] = {'conversations': {}} # no need for examples
    prompts_json['prompts'] = []
    prompts = prompts_json['prompts']
    num_retrieved = defaultdict(lambda: 0) # how many predicates in prediction per timestep
    num_relevant = defaultdict(lambda: 0) # how many predicates in groundtruth per timestep
    for id, pred_data in predictions.items():
        pred = pred_data['prediction_data']
        t = pred['timestep']
        p_act = pred['action']
        pred_data['ground_truth_data'] = ground_truths[id]['prediction_data']
        gt_act = pred_data['ground_truth_data']['action']
        prompts.append({"id": f"{id}-{t}-act", "conversations": [{'from': 'human', 'value': f"PREDICTION: {p_act}\nGROUND TRUTH: {gt_act}"}]})
        p_st = pred['states']
        gt_st = pred_data['ground_truth_data']['states']
        num_retrieved[t] = len(p_st)
        num_relevant[t] = len(gt_st)
        for i, gt in enumerate(gt_st):
            prompts.append({"id": f"{id}-{t}-{i}", "conversations": [{'from': 'human', 'value': f"PREDICTIONS: {'; '.join(p_st + ['none-of-the-above'])}\nGROUND TRUTH: {gt}"}]})
    
    temp_output_fp = os.path.join("eval_data/", args.prompt_version, f"{args.prompt_version}_{args.episode}_eval_temp.json")
    outputs = call_to_open_api(prompts_json, temp_output_fp, args.model, logging.WARN)
    num_correct_actions = 0

    num_retrieved_relevant = defaultdict(lambda: set()) # how many [[yes]] for each timestep
    i = 0
    for out in outputs:
        pred_id = out['id']
        id, t, i = pred_id.split('-')
        if 'eval' not in predictions[id]:
            predictions[id]['eval'] = {}
        _, *gpts = out['conversations']
        gpt = gpts[0]
        if 'act' in out['id']:
            predictions[id]['eval'].update({'action': gpt['value']})
            n = len(gpts)
            num_yes = reduce(lambda x, y: x + 1 if '[[yes]]' in y['value'] else x, gpts, 0)
            if num_yes == n:
                num_correct_actions += 1
            continue
        t = int(t)
        i = int(i)
        predictions[id]['eval'].update({f"state-{i}": gpt['value']})
        n = len(gpts)
        num_yes = reduce(lambda x, y: x + 1 if '[[yes]]' in y['value'] else x, gpts, 0)
        if num_yes == n:
            num_retrieved_relevant[t].add(i)
    def reduce_fn(x, y):
        return x + len(y)
    precision = reduce(reduce_fn, num_retrieved_relevant.values(), 0) / sum(num_retrieved.values())
    recall = reduce(reduce_fn, num_retrieved_relevant.values(), 0) / sum(num_relevant.values())
    print(f"accuracy in num actions = {100 * num_correct_actions / len(predictions):.2f}%")
    print(f"precision in predicates = {100 * precision:.2f}%")
    print(f"recall in predicates = {100 * recall:.2f}%")
    if precision == 0 or recall == 0:
        f1 = 0
    else:
        f1 = 2 * precision * recall / (precision + recall)
    print(f"F1 in predicates = {100 * f1:.2f}%")

    output_fp = os.path.join("eval_data/", args.prompt_version, f"{args.prompt_version}_{args.episode}_eval.json")

    with open(output_fp, 'w') as fout:
        json.dump(predictions, fout, indent=4)

