import datasets

metric = datasets.load_metric('bleu')
# print(metric)

import sys
path = sys.argv[1]
ref = f"{path}/target.txt"
hyp = f"{path}/hypotheses.txt"

with open(ref, "r") as f:
    ref =f.read().splitlines()
ref_split = [[r.split()] for r in ref]

with open(hyp, "r") as f:
    hyp=f.read().splitlines()
hyp_split = [h.split() for h in hyp]

name=path.split("/")[1]

result={}
def round_num(val, digit = 4):
    return round(val, digit)
import sacrebleu
sacrebleu_result = sacrebleu.corpus_bleu(hyp, [ref])
result["SACREBLEU"]=round_num(sacrebleu_result.score)


# METEOR
meteor = datasets.load_metric('meteor')
meteor.add_batch(predictions=hyp, references = ref)
meteor_score = meteor.compute()["meteor"]
result["METEOR"] = round_num(meteor_score*100)

# WER
wer = datasets.load_metric("wer")
wer.add_batch(predictions=hyp, references = ref)
result["WER"] = round_num(wer.compute()) 
# print("WER",round(wer.compute(), 4))

# ROUGE
rouge = datasets.load_metric("rouge")
rouge.add_batch(predictions=hyp, references = ref)
rouge_result = rouge.compute()
for rouge_metric in ["rouge1", "rouge2", "rougeL"]:
    # print(rouge_metric,rouge_result[rouge_metric].mid.fmeasure)
    result[rouge_metric.upper()]=round_num(rouge_result[rouge_metric].mid.fmeasure*100)

# NIST
from nltk.translate.nist_score import corpus_nist
NIST = corpus_nist(ref_split, hyp_split)
result["NIST"]=round_num(NIST)
# print("NIST", NIST)


print("RESULT -", name)
for key in result.keys():
    print(key, "-", result[key])