import logging
import os
from federatedscope.register import register_metric
from federatedscope.nlp.metric.rouge.utils import test_rouge

logger = logging.getLogger(__name__)


def load_cnndm_metrics(ctx, **kwargs):
    tmp_dir = os.path.join(ctx.cfg.outdir, 'temp')
    rouges = test_rouge(tmp_dir, ctx.pred_path, ctx.tgt_path)
    results = {
        k: v
        for k, v in rouges.items()
        if k in {'rouge_1_f_score', 'rouge_2_f_score', 'rouge_l_f_score'}
    }
    return results


def call_cnndm_metric(types):
    if 'cnndm' in types:
        the_larger_the_better = True
        return 'cnndm', load_cnndm_metrics, the_larger_the_better


register_metric('cnndm', call_cnndm_metric)
