import re
import json
from transformers import AutoTokenizer

# metric
def parse_triples(text):
    pattern = re.compile("\((.*);(.*);(.*)\)")
    triples = []
    for span in text.split("|"):
        triples.extend(re.findall(pattern, span))
    if len(triples) == 0:
        return None
    for i, triple in enumerate(triples):
        triples[i] = (triple[0].strip(), triple[1].strip(), triple[2].strip())
    return list(set(triples))

def evaluate(tokenizer):
    data = json.load(open("output/docred/flan-t5-small/predictions.json"))
    decoded_preds = []
    decoded_labels = []
    for item in data:
        decoded_preds.append(item["pred"])
        decoded_labels.append(item["label"])

    def clean_str(x_str):
            for to_remove_token in [tokenizer.eos_token, tokenizer.pad_token]:
                x_str = x_str.replace(to_remove_token, '')
            return x_str.strip()
    
    tp = 0
    n_gold = 0
    n_pred = 0
    for pred_text, gold_text in zip(decoded_preds, decoded_labels):
        pred_text = clean_str(pred_text)
        gold_text = clean_str(gold_text)
        # gold triple
        gold_triples = []
        if gold_text != "NA":
            gold_text = gold_text.lower()
            triples = parse_triples(gold_text)
            if triples is not None:
                gold_triples.extend(triples)
        # pred triple
        pred_triples = []
        if pred_text != "NA":
            pred_text = pred_text.lower()
            triples = parse_triples(pred_text)
            if triples is not None:
                pred_triples.extend(triples)
        ok = False
        for triple in gold_triples:
            if triple in pred_triples:
                tp += 1
                ok = True
        # if ok == False and gold_text != "NA" and pred_text != "NA":
        #     import pdb; pdb.set_trace()
        n_gold += len(gold_triples)
        n_pred += len(pred_triples)
        if len(gold_triples) > len(pred_triples):
            import pdb; pdb.set_trace()
    import pdb; pdb.set_trace()
    precision = tp / n_pred
    recall = tp / n_gold
    f1 = 2 * precision * recall / (precision + recall)
    result = {
        "precision": precision,
        "recall": recall,
        "f1": f1
    }
    print(result)


if __name__ == "__main__":
    tokenizer = AutoTokenizer.from_pretrained("google/flan-t5-xxl")
    evaluate(tokenizer)
