# Copied from https://github.com/huggingface/datasets/blob/d3c7b9481d427ce41256edaf6773c47570f06f3b/metrics/rouge/rouge.py
# Added multiprocessing

import multiprocessing
import nltk
from rouge_score import rouge_scorer
from multiprocessing import Pool


def compute_rouge(predictions, references, rouge_types=None, use_stemmer=False):
    if rouge_types is None:
        rouge_types = ["rouge1", "rouge2", "rougeL", "rougeLsum"]

    scorer = rouge_scorer.RougeScorer(rouge_types=rouge_types, use_stemmer=use_stemmer)
    with Pool() as p:
        scores = p.starmap(scorer.score, zip(references, predictions))

    result = {}
    for key in scores[0]:
        result[key] = list(score[key] for score in scores)

    return result


# Copied from https://github.com/huggingface/transformers/blob/3977b58437b8ce1ea1da6e31747d888efec2419b/examples/pytorch/summarization/run_summarization.py#L520
def postprocess_text(text):
    # rougeLSum expects newline after each sentence
    return "\n".join(nltk.sent_tokenize(text))