import numpy as np

from rouge_score import rouge_scorer
import nltk
nltk.download('punkt')
nltk.download('punkt_tab')



def calc_rouge(preds, refs):
  # Get ROUGE F1 scores
  scorer = rouge_scorer.RougeScorer(['rouge1', 'rouge2', 'rougeLsum'], \
                                    use_stemmer=True, split_summaries=True)
  scores = [scorer.score(p, refs[i]) for i, p in enumerate(preds)]
  rouge1 = [s['rouge1'].fmeasure for s in scores]
  rouge2 = [s['rouge2'].fmeasure for s in scores]
  rougeL = [s['rougeLsum'].fmeasure for s in scores]
  return rouge1, rouge2, rougeL


def compute_metrics(preds, goldens):
    # ROUGE
    rouge_preds = ["\n".join(nltk.sent_tokenize(pred.strip())) for pred in preds]
    rouge_refs = ["\n".join(nltk.sent_tokenize(golden.strip())) for golden in goldens]
    rouge1, rouge2, rougeL = calc_rouge(rouge_preds, rouge_refs)

    return rouge1, rouge2, rougeL
