import torch
import evaluate
from tqdm.auto import tqdm
import numpy as np

from collections import defaultdict
import argparse

parser = argparse.ArgumentParser()
parser.add_argument('--file', type=str, help=".pth  with predicts")
args = parser.parse_args()
file = args.file

wer = evaluate.load('wer')
cer = evaluate.load('cer')
bleu = evaluate.load('bleu')
meteor = evaluate.load('meteor')
sacrebleu = evaluate.load("sacrebleu")
rouge1 = evaluate.load("rouge")
chrf = evaluate.load("chrf")
outputs = defaultdict(list)


metrics_preds = torch.load(file)
for example in tqdm(metrics_preds):
    generated_texts, latex, pron = example['latex_pred'], example['latex_true'], example['pron']
    outputs['cer'].append(cer.compute(predictions=[generated_texts], references=[latex]))
    outputs['wer'].append(wer.compute(predictions=[generated_texts], references=[latex]))
    
    if generated_texts and latex:
        outputs['rouge1'].append(rouge1.compute(predictions=[generated_texts], references=[latex])['rouge1'])
        outputs['chrf'].append(chrf.compute(predictions=[generated_texts], references=[latex])['score'] / 100)
        outputs['chrfpp'].append(chrf.compute(predictions=[generated_texts], references=[latex], word_order=2)['score'] / 100)
        outputs['bleu'].append(bleu.compute(predictions=[generated_texts], references=[latex])['bleu'])
        outputs['sbleu'].append(sacrebleu.compute(predictions=[generated_texts], references=[latex], tokenize="char")['score'] / 100)
        outputs['meteor'].append(meteor.compute(predictions=[generated_texts], references=[latex])['meteor'])

    else:
        print(f"{pron=}, {latex=}, {generated_texts=}")
        outputs['rouge1'].append(0)
        outputs['chrf'].append(0)
        outputs['chrfpp'].append(0)
        outputs['bleu'].append(0)
        outputs['sbleu'].append(0)
        outputs['meteor'].append(0)



print(f"cer (l - 0) = {np.mean(outputs['cer']):.4f}, rouge1 (h - 1) = {np.mean(outputs['rouge1']):.4f}, sbleu (h - 1) = {np.mean(outputs['sbleu']):.4f}, chrf (h - 100) = {np.mean(outputs['chrf']):.4f}")
print(f"wer (l - 0) = {np.mean(outputs['wer']):.4f}, meteor (h - 1) = {np.mean(outputs['meteor']):.4f}, bleu (h - 1) = {np.mean(outputs['bleu']):.4f}, chrf++ (h - 100) = {np.mean(outputs['chrfpp']):.4f}")
print(f"cer_lower (l - 0) = {np.mean(outputs['cer_lower']):.4f}, rouge1_lower (h - 1) = {np.mean(outputs['rouge1_lower']):.4f}, sbleu_lower (h - 1) = {np.mean(outputs['sbleu_lower']):.4f}, chrf_lower (h - 100) = {np.mean(outputs['chrf_lower']):.4f}")
print(f"wer_lower (l - 0) = {np.mean(outputs['wer_lower']):.4f}, meteor_lower (h - 1) = {np.mean(outputs['meteor_lower']):.4f}, bleu_lower (h - 1) = {np.mean(outputs['bleu_lower']):.4f}, chrf++_lower (h - 100) = {np.mean(outputs['chrfpp_lower']):.4f}")