import os
import json
import copy
import argparse

def comput_f1(input_file):
    tp = 0
    n_gold = 0
    n_pred = 0
    input_file = os.path.join(input_file, 'test.json')
    with open(input_file) as f:
        for instance in json.load(f)["request_states"]:
            if instance["request"]["result"].get('success', False):
                # gold triple
                gold_text = instance["instance"]["references"][0]["output"]["text"].strip()
                gold_text = ''.join(gold_text.split())
                gold_label = []
                for label in gold_text.split(';'):
                    if label:
                        gold_label.append(label)
                # pred triple
                pred_text = instance["request"]["result"]["completions"][0]["text"].strip()
                pred_text = ''.join(pred_text.split())
                try:
                    pred_text = pred_text.split('Answer:')[1]
                except:
                    pred_text = ''
                pred_label = []
                for label in pred_text.split(';'):
                    label = label.strip()
                    if label and 'NA' not in label and ":" in label:
                        pred_label.append(':'.join([l.strip() for l in label.split(':')]))
                label_stack = copy.deepcopy(gold_label)
                # print('gold', gold_label)
                # print('pred', pred_label)
                for label in pred_label:
                    if label in label_stack:
                        tp += 1
                        label_stack.remove(label)
                n_gold += len(gold_label)
                n_pred += len(pred_label)
    precision = tp / (n_pred + 1e-10)
    recall = tp / (n_gold + 1e-10)
    f1 = 2 * precision * recall / (precision + recall + 1e-10)
    return {
        "precision": precision,
        "recall": recall,
        "f1": f1
    }


if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Evaluate Event Detection")
    # IO
    parser.add_argument("--input_dir", type=str)
    parser.add_argument("--output_dir", type=str, default='result.json')

    args = parser.parse_args()

    result = comput_f1(args.input_dir)
    with open(args.output_dir, 'w') as f:
        json.dump(result, f, indent=4)
    print(result)

