import os
import json
import re


def parse_triple(text):
    pattern = re.compile("\((.*); (.*); (.*)\)")
    triple = re.findall(pattern, text)
    if len(triple) == 0:
        return None
    return triple[0]

def comput_f1(input_file):
    tp = 0
    n_gold = 0
    n_pred = 0
    with open(input_file) as f:
        for line in f.readlines():
            instance = json.loads(line.strip())
            # gold triple
            gold_triples = []
            gold_text = instance["instance"]["references"][0]["output"]["text"].strip()
            if gold_text != "NA":
                gold_triple = parse_triple(gold_text)
                assert gold_triple is not None
                gold_triples.append(gold_triple)
            # pred triple
            pred_triples = []
            pred_text = instance["request"]["result"]["choices"][0]["text"].strip()
            if pred_text != "NA":
                pred_triple = parse_triple(pred_text)
                if pred_triple is not None:
                    pred_triples.append(pred_triple)
            for triple in gold_triples:
                if triple in pred_triples:
                    tp += 1
            n_gold += len(gold_triples)
            n_pred += len(pred_triples)
    precision = tp / n_pred
    recall = tp / n_gold
    f1 = 2 * precision * recall / (precision + recall)
    return {
        "precision": precision,
        "recall": recall,
        "f1": f1
    }


if __name__ == "__main__":
    result = comput_f1("../../results/tacred/results.jsonl")
    print(result)

