# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

import itertools
import functools
import json
import logging
import os
import re
import shutil
import operator
from collections import OrderedDict
from argparse import Namespace

import numpy as np
from fairseq import metrics, options, utils, distributed_utils, file_utils
from fairseq.data import (
    AppendTokenDataset,
    ConcatDataset,
    LanguagePairDataset,
    PrependTokenDataset,
    StripTokenDataset,
    TruncateDataset,
    data_utils,
    encoders,
    indexed_dataset,
    DynamicDatasetServer,
    DynamicLanguagePairDataset,
    ParallelReader,
)
from fairseq.tasks import LegacyFairseqTask, register_task
from fairseq.models import modular_transformer, ARCH_MODEL_REGISTRY


EVAL_BLEU_ORDER = 4


logger = logging.getLogger(__name__)


def load_langpair_dataset(
    data_path,
    split,
    src,
    src_dict,
    tgt,
    tgt_dict,
    combine,
    dataset_impl,
    upsample_primary,
    left_pad_source,
    left_pad_target,
    max_source_positions,
    max_target_positions,
    prepend_bos=False,
    load_alignments=False,
    truncate_source=False,
    append_source_id=False,
    num_buckets=0,
    shuffle=True,
    pad_to_multiple=1,
    skip_empty_lines=False,
):
    def split_exists(split, src, tgt, lang, data_path):
        filename = os.path.join(data_path, "{}.{}-{}.{}".format(split, src, tgt, lang))
        return indexed_dataset.dataset_exists(filename, impl=dataset_impl)

    src_datasets = []
    tgt_datasets = []

    for k in itertools.count():
        split_k = split + (str(k) if k > 0 else "")

        # infer langcode
        if split_exists(split_k, src, tgt, src, data_path):
            prefix = os.path.join(data_path, "{}.{}-{}.".format(split_k, src, tgt))
        elif split_exists(split_k, tgt, src, src, data_path):
            prefix = os.path.join(data_path, "{}.{}-{}.".format(split_k, tgt, src))
        else:
            if k > 0:
                break
            else:
                raise FileNotFoundError(
                    "Dataset not found: {} ({})".format(split, data_path)
                )

        src_dataset = data_utils.load_indexed_dataset(
            prefix + src, src_dict, dataset_impl
        )
        if truncate_source:
            src_dataset = AppendTokenDataset(
                TruncateDataset(
                    StripTokenDataset(src_dataset, src_dict.eos()),
                    max_source_positions - 1,
                ),
                src_dict.eos(),
            )
        src_datasets.append(src_dataset)

        tgt_dataset = data_utils.load_indexed_dataset(
            prefix + tgt, tgt_dict, dataset_impl
        )
        if tgt_dataset is not None:
            tgt_datasets.append(tgt_dataset)

        if skip_empty_lines:
            LanguagePairDataset.remove_empty_lines(src_dataset, tgt_dataset)

        logger.info(
            "{} {} {}-{} {} examples".format(
                data_path, split_k, src, tgt, len(src_datasets[-1])
            )
        )

        if not combine:
            break

    assert len(src_datasets) == len(tgt_datasets) or len(tgt_datasets) == 0

    if len(src_datasets) == 1:
        src_dataset = src_datasets[0]
        tgt_dataset = tgt_datasets[0] if len(tgt_datasets) > 0 else None
    else:
        sample_ratios = [1] * len(src_datasets)
        sample_ratios[0] = upsample_primary
        src_dataset = ConcatDataset(src_datasets, sample_ratios)
        if len(tgt_datasets) > 0:
            tgt_dataset = ConcatDataset(tgt_datasets, sample_ratios)
        else:
            tgt_dataset = None

    if prepend_bos:
        assert hasattr(src_dict, "bos_index") and hasattr(tgt_dict, "bos_index")
        src_dataset = PrependTokenDataset(src_dataset, src_dict.bos())
        if tgt_dataset is not None:
            tgt_dataset = PrependTokenDataset(tgt_dataset, tgt_dict.bos())

    eos = None
    if append_source_id:
        src_dataset = AppendTokenDataset(
            src_dataset, src_dict.index("[{}]".format(src))
        )
        if tgt_dataset is not None:
            tgt_dataset = AppendTokenDataset(
                tgt_dataset, tgt_dict.index("[{}]".format(tgt))
            )
        eos = tgt_dict.index("[{}]".format(tgt))

    align_dataset = None
    if load_alignments:
        align_path = os.path.join(data_path, "{}.align.{}-{}".format(split, src, tgt))
        if indexed_dataset.dataset_exists(align_path, impl=dataset_impl):
            align_dataset = data_utils.load_indexed_dataset(
                align_path, None, dataset_impl
            )

    tgt_dataset_sizes = tgt_dataset.sizes if tgt_dataset is not None else None
    return LanguagePairDataset(
        src_dataset,
        src_dataset.sizes,
        src_dict,
        tgt_dataset,
        tgt_dataset_sizes,
        tgt_dict,
        left_pad_source=left_pad_source,
        left_pad_target=left_pad_target,
        align_dataset=align_dataset,
        eos=eos,
        num_buckets=num_buckets,
        shuffle=shuffle,
        pad_to_multiple=pad_to_multiple,
        src_lang=src,
        tgt_lang=tgt,
    )


@register_task("translation")
class TranslationTask(LegacyFairseqTask):
    """
    Translate from one (source) language to another (target) language.

    Args:
        src_dict (~fairseq.data.Dictionary): dictionary for the source language
        tgt_dict (~fairseq.data.Dictionary): dictionary for the target language

    .. note::

        The translation task is compatible with :mod:`fairseq-train`,
        :mod:`fairseq-generate` and :mod:`fairseq-interactive`.

    The translation task provides the following additional command-line
    arguments:

    .. argparse::
        :ref: fairseq.tasks.translation_parser
        :prog:
    """

    @staticmethod
    def add_args(parser):
        """Add task-specific arguments to the parser."""
        # fmt: off
        parser.add_argument('data', help='colon separated path to data directories list, \
                            will be iterated upon during epochs in round-robin manner; \
                            however, valid and test data are always in the first directory to \
                            avoid the need for repeating them in all directories')
        parser.add_argument('-s', '--source-lang', default=None, metavar='SRC',
                            help='source language')
        parser.add_argument('-t', '--target-lang', default=None, metavar='TARGET',
                            help='target language')
        parser.add_argument('--load-alignments', action='store_true',
                            help='load the binarized alignments')
        parser.add_argument('--left-pad-source', default='True', type=str, metavar='BOOL',
                            help='pad the source on the left')
        parser.add_argument('--left-pad-target', default='False', type=str, metavar='BOOL',
                            help='pad the target on the left')
        parser.add_argument('--max-source-positions', default=1024, type=int, metavar='N',
                            help='max number of tokens in the source sequence')
        parser.add_argument('--max-target-positions', default=1024, type=int, metavar='N',
                            help='max number of tokens in the target sequence')
        parser.add_argument('--upsample-primary', default=1, type=int,
                            help='amount to upsample primary dataset')
        parser.add_argument('--truncate-source', action='store_true', default=False,
                            help='truncate source to max-source-positions')
        parser.add_argument('--num-batch-buckets', default=0, type=int, metavar='N',
                            help='if >0, then bucket source and target lengths into N '
                                 'buckets and pad accordingly; this is useful on TPUs '
                                 'to minimize the number of compilations')

        # options for reporting BLEU during validation
        parser.add_argument('--eval-bleu', action='store_true',
                            help='evaluation with BLEU scores')
        parser.add_argument('--eval-bleu-detok', type=str, default="space",
                            help='detokenize before computing BLEU (e.g., "moses"); '
                                 'required if using --eval-bleu; use "space" to '
                                 'disable detokenization; see fairseq.data.encoders '
                                 'for other options')
        parser.add_argument('--eval-bleu-detok-args', type=str, metavar='JSON',
                            help='args for building the tokenizer, if needed')
        parser.add_argument('--eval-tokenized-bleu', action='store_true', default=False,
                            help='compute tokenized BLEU instead of sacrebleu')
        parser.add_argument('--eval-bleu-remove-bpe', nargs='?', const='@@ ', default=None,
                            help='remove BPE before computing BLEU')
        parser.add_argument('--eval-bleu-args', type=str, metavar='JSON',
                            help='generation args for BLEU scoring, '
                                 'e.g., \'{"beam": 4, "lenpen": 0.6}\'')
        parser.add_argument('--eval-bleu-print-samples', type=int, nargs='?', const=0,
                            help='print sample generations during validation')
        parser.add_argument('--eval-bleu-save-samples', action='store_true',
                            help='save the decoding outputs to files')

        parser.add_argument('--target-args', default='{"tagger": "identity"}',
                            help='encoder options for the target sides '
                                 '(e.g., to use another BPE model or tokenizer); '
                                 'disables tagger by default')

        # DynamicDataset options
        parser.add_argument('--corpus-config',
                            help='path to YAML config file with the list of train and valid corpora')
        parser.add_argument('--skip-empty-lines', action='store_true',
                            help='ignore empty lines in the training and validation data (requires --dataset-impl raw)')
        parser.add_argument('--lines-per-epoch', type=int,
                            help='Manually define how many line pairs constitute a training epoch '
                                 '(default: total line count in training corpora)')
        parser.add_argument('--scale-epoch-length', type=float,
                            help='scale the number of lines in a training epoch by this amount '
                                 '(e.g., 0.1 for 10 times shorter epochs)')
        parser.add_argument('--lang-temperature', type=float, default=1,
                            help='sample line pairs by language depending on this temperature parameter, '
                                 'values higher than 1 mean closer to a uniform distribution over all languages')
        parser.add_argument('--dynamic-dataset-block-size', type=int, default=256,
                            help='buffer size in random parallel corpus readers, smaller values mean increased '
                                 'randomness, but decreased speed. A value of zero means sequential reading')
        parser.add_argument('--batch-by-lang-pair', action='store_true',
                            help='batches will only contain sentence pairs from one language pair')
        parser.add_argument('--batch-by-lang-pair-and-corpus', action='store_true',
                            help='batches will only contain sentence pairs from one (language pair, corpus) tuple')
        parser.add_argument('--batch-by-target-lang', action='store_true',
                            help='batches will only contain sentence pairs from one target language')

        parser.add_argument('--adapter-uids', nargs='+',
                            help='names of the adapters when using AdapterTransformerModel, only the first one'
                            ' is used, but several adapter names can be specified when training from an existing checkpoint,'
                            ' in order to keep existing but unused adapters in the future checkpoints.')
        parser.add_argument('--decoder-adapter-uids', nargs='+',
                            help='specify this parameter to define different adapter ids for the decoder than for the encoder.')
        parser.add_argument('--corpus-adapter-uids', nargs='+',
                            help='names of adapters for training decoder corpora-specific adapters')
        parser.add_argument('--enc-corpus-adapter-uids', nargs='+',
                            help='names of adapters for training encoder corpora-specific adapters')
        # fmt: on

    def __init__(self, args, src_dicts, tgt_dicts, train_corpora={}, valid_corpora={}):
        super().__init__(args)
        if not getattr(args, 'path', None):
            # automatically infer --path value in interactive.py
            paths = utils.split_paths(args.data)
            args.path = os.path.join(paths[0], 'checkpoint_best.pt')
        self.src_dicts = src_dicts
        self.tgt_dicts = tgt_dicts
        # for inference or monolingual translation
        self.src_dict = next(iter(src_dicts.values()))
        self.tgt_dict = next(iter(tgt_dicts.values()))
        self.dicts = {**src_dicts, **tgt_dicts}
        self.langs = list(self.dicts)

        target_args = json.loads(getattr(args, "target_args", None) or "{}")
        args_ = Namespace(**{**vars(args), **target_args})
        mono_args_ = {**vars(args), **target_args}
        mono_args_['tagger'] = 'denoising_tagger'
        mono_args_ = Namespace(**mono_args_)
        self.tgt_tokenizer = encoders.build_tokenizer(args_) if 'tokenizer' in target_args else self.tokenizer
        self.tgt_bpe = encoders.build_bpe(args_) if 'bpe' in target_args else self.bpe
        self.tgt_tagger = encoders.build_tagger(args_) if 'tagger' in target_args else self.tagger
        self.mono_tagger = encoders.build_tagger(mono_args_)

        # back-compatibility with --eval-bleu-detok and --eval-bleu-remove-bpe, but prefer using --tokenizer, --bpe, and --target-args
        eval_bleu_detok = getattr(args, "eval_bleu_detok", None)
        if eval_bleu_detok and eval_bleu_detok != "space":
            eval_bleu_detok_args = json.loads(getattr(args, "eval_bleu_detok_args", None) or "{}")
            tokenizer = encoders.build_tokenizer(
                Namespace(
                    tokenizer=getattr(args, "eval_bleu_detok", None), **eval_bleu_detok_args
                )
            )
            self.detokenize = tokenizer.decode
        else:
            self.detokenize = self.tgt_tokenizer.decode

        remove_bpe = getattr(args, "eval_bleu_remove_bpe", None)
        if remove_bpe:
            self.remove_bpe = functools.partial(data_utils.post_process, symbol=args.eval_bleu_remove_bpe)
        else:
            self.remove_bpe = self.tgt_bpe.decode

        self.valid_corpora = valid_corpora  # dict of (corpus id: ParallelCorpus)
        self.train_corpora = train_corpora  # dict of (corpus id: ParallelCorpus)

        if args.dynamic_dataset:
            corpora = self.train_corpora.values()
            self.src_langs = list(set(corpus.src_lang for corpus in corpora))
            self.tgt_langs = list(set(corpus.tgt_lang for corpus in corpora))
            self.corpus_tags = list(set(corpus.corpus_tag for corpus in corpora))
            self.lang_pairs = list(set((corpus.src_lang, corpus.tgt_lang) for corpus in corpora))
        else:
            self.src_langs = [args.source_lang]
            self.tgt_langs = [args.target_lang]
            self.corpus_tags = None
            self.lang_pairs = [(args.source_lang, args.target_lang)]

    @classmethod
    def get_corpora(cls, args):
        paths = utils.split_paths(args.data)
        data_path = utils.split_paths(args.data)[0]
        default_lang_pairs = [(args.source_lang, args.target_lang)]

        logger.info(f"[{default_lang_pairs[0]}] lang")
        logger.info(f"{args.valid_subset} lang")
        if args.dynamic_dataset:
            if args.corpus_config is None:
                def get_corpus(split):
                    src, tgt = args.source_lang, args.target_lang
                    prefixes = [split, f'{split}.{src}-{tgt}', f'{split}.{tgt}-{src}']
                    src_path = file_utils.find_file(*[f'{prefix}.{src}' for prefix in prefixes], dirs=[data_path])
                    tgt_path = file_utils.find_file(*[f'{prefix}.{tgt}' for prefix in prefixes], dirs=[data_path])
                    return utils.ParallelCorpus(src_path, tgt_path)

                train_corpora = [get_corpus(args.train_subset)]
                valid_corpora = [get_corpus(split) for split in args.valid_subset.split(',')]
            else:
                train_corpora, valid_corpora = utils.parse_parallel_corpus_config(
                    file_utils.find_file(args.corpus_config, dirs=[data_path]),
                    default_lang_pairs=default_lang_pairs
                )

            train_corpora = OrderedDict((corpus.corpus_id, corpus) for corpus in train_corpora)
            valid_corpora = OrderedDict((corpus.corpus_id, corpus) for corpus in valid_corpora)
        else:
            train_corpora = {}
            valid_corpora = args.valid_subset.split(',')

        return train_corpora, valid_corpora

    @classmethod
    def setup_task(cls, args, **kwargs):
        """Setup the task (e.g., load dictionaries).

        Args:
            args (argparse.Namespace): parsed command-line arguments
        """

        assert not args.corpus_config or args.dynamic_dataset, "--corpus-config requires --dynamic-dataset"

        args.left_pad_source = utils.eval_bool(args.left_pad_source)
        args.left_pad_target = utils.eval_bool(args.left_pad_target)

        paths = utils.split_paths(args.data)
        assert len(paths) > 0
        assert len(paths) == 1 or not args.dynamic_dataset, "--dynamic-dataset does not support multiple data directories"

        args.source_lang = args.source_lang or 'src'
        args.target_lang = args.target_lang or 'tgt'

        train_corpora, valid_corpora = cls.get_corpora(args)


        def find_dict(lang, prefix):
            # looks for dictionaries in data_dir in this specific order
            return file_utils.find_file(
                f"dict.{prefix}.{lang}.txt",
                f"dict.{prefix}.txt",
                f"dict.{lang}.txt",
                "dict.txt",
                dirs=paths[:1]
            )

        # load dictionaries
        if hasattr(args, 'arch') and issubclass(ARCH_MODEL_REGISTRY[args.arch], modular_transformer.ModularTransformerModel):
            # ModularTransformerModel can have one source or target dictionary per language
            assert args.dynamic_dataset, "ModularTransformerModel requires --dynamic-dataset"

            src_langs = set(corpus.src_lang for corpus in train_corpora.values())
            tgt_langs = set(corpus.tgt_lang for corpus in train_corpora.values())

            src_dicts = {}
            for lang in src_langs:
                src_dict = cls.load_dictionary(find_dict(lang, 'src'))
                src_dicts[lang] = src_dict
                logger.info(f"[{lang}] source dictionary: {len(src_dict)} types")

            tgt_dicts = {}
            for lang in tgt_langs:
                tgt_dict = cls.load_dictionary(find_dict(lang, 'tgt'))
                tgt_dicts[lang] = tgt_dict
                logger.info(f"[{lang}] target dictionary: {len(tgt_dict)} types")
        else:
            src_dict = cls.load_dictionary(find_dict(args.source_lang, 'src'))
            src_dicts = {args.source_lang: src_dict}
            logger.info(f"[{args.source_lang}] dictionary: {len(src_dict)} types")

            tgt_dict = cls.load_dictionary(find_dict(args.target_lang, 'tgt'))
            tgt_dicts = {args.target_lang: tgt_dict}
            logger.info(f"[{args.target_lang}] dictionary: {len(tgt_dict)} types")

            assert src_dict.pad() == tgt_dict.pad()
            assert src_dict.eos() == tgt_dict.eos()
            assert src_dict.unk() == tgt_dict.unk()

        return cls(args, src_dicts, tgt_dicts, train_corpora, valid_corpora)

    def load_dataset(self, split, epoch=1, combine=False, **kwargs):
        """Load a given dataset split.

        Args:
            split (str): name of the split (e.g., train, valid, test)
        """
        if self.args.dynamic_dataset:
            if split == self.args.train_subset:
                dataset = DynamicLanguagePairDataset(
                    self.args.dynamic_dataset_port,
                    self.args.device_id,
                    pad_idx=self.src_dict.pad_index,
                    eos_idx=self.src_dict.eos_index,
                    left_pad_source=self.args.left_pad_source,
                    left_pad_target=self.args.left_pad_target,
                    input_feeding=True
                )
            else:
                # set in eval mode to disable BPE dropout
                self.eval()
                corpus = self.valid_corpora[split]
                src_dict = self.src_dicts.get(corpus.src_lang, self.src_dict)
                tgt_dict = self.tgt_dicts.get(corpus.tgt_lang, self.tgt_dict)
                src_dataset = data_utils.load_indexed_dataset(
                    corpus.src_path, src_dict, dataset_impl='raw',
                    encode_fn=functools.partial(self.line_encoder, meta=corpus.meta),
                    keep_raw_lines=True
                )
                tgt_dataset = data_utils.load_indexed_dataset(
                    corpus.tgt_path, tgt_dict, dataset_impl='raw',
                    encode_fn=functools.partial(self.tgt_line_encoder, meta=corpus.meta),
                    keep_raw_lines=True
                )
                if self.args.skip_empty_lines:
                    LanguagePairDataset.remove_empty_lines(src_dataset, tgt_dataset)

                logger.info(
                    "{} {} {} examples".format(
                        os.path.dirname(corpus.src_path), corpus.corpus_id, len(src_dataset)
                    )
                )

                if self.args.dynamic_dataset_verbose:
                    size = src_dataset.size
                    for i in range(0, size, size // 5):
                        logger.info(f'Sample | {corpus} | source {i + 1} | {src_dataset.lines[i]}')
                        logger.info(f'Sample | {corpus} | target {i + 1} | {tgt_dataset.lines[i]}')

                dataset = LanguagePairDataset(
                    src_dataset,
                    src_dataset.sizes,
                    src_dict,
                    tgt_dataset,
                    tgt_dataset.sizes,
                    tgt_dict,
                    left_pad_source=self.args.left_pad_source,
                    left_pad_target=self.args.left_pad_target,
                    num_buckets=self.args.num_batch_buckets,
                    shuffle=False,
                    pad_to_multiple=self.args.required_seq_len_multiple,
                    src_lang=corpus.src_lang,
                    tgt_lang=corpus.tgt_lang,
                    corpus_tag=corpus.meta['corpus_tag'],
                )

            self.datasets[split] = dataset
            return

        paths = utils.split_paths(self.args.data)
        assert len(paths) > 0
        if split != getattr(self.args, "train_subset", None):
            # if not training data set, use the first shard for valid and test
            paths = paths[:1]
        data_path = paths[(epoch - 1) % len(paths)]

        # infer langcode
        src, tgt = self.args.source_lang, self.args.target_lang

        self.datasets[split] = load_langpair_dataset(
            data_path,
            split,
            src,
            self.src_dict,
            tgt,
            self.tgt_dict,
            combine=combine,
            dataset_impl=self.args.dataset_impl,
            upsample_primary=self.args.upsample_primary,
            left_pad_source=self.args.left_pad_source,
            left_pad_target=self.args.left_pad_target,
            max_source_positions=self.args.max_source_positions,
            max_target_positions=self.args.max_target_positions,
            load_alignments=self.args.load_alignments,
            truncate_source=self.args.truncate_source,
            num_buckets=self.args.num_batch_buckets,
            shuffle=(split != "test"),
            pad_to_multiple=self.args.required_seq_len_multiple,
            skip_empty_lines=self.args.skip_empty_lines,
        )

    def build_dataset_for_inference(self, src_tokens, src_lengths, constraints=None, prefix_tokens=None):
        return LanguagePairDataset(
            src_tokens,
            src_lengths,
            self.source_dictionary,
            tgt=prefix_tokens,
            tgt_dict=self.target_dictionary,
            constraints=constraints,
        )

    def build_model(self, args):
        model = super().build_model(args)
        if getattr(args, "eval_bleu", False):
            gen_args = json.loads(getattr(args, "eval_bleu_args", "{}") or "{}")

            # Build as many sequence generators as there are target dictionaries.
            # "build_generator" uses self.tgt_dict, so temporarily change its value.
            target_dict = self.tgt_dict
            self.sequence_generators = {}
            for lang, dict_ in self.tgt_dicts.items():
                self.tgt_dict = dict_
                sequence_generator = self.build_generator(
                    [model], Namespace(**gen_args)
                )
                self.sequence_generators[lang] = sequence_generator
            self.tgt_dict = target_dict

        return model

    def valid_step(self, sample, model, criterion):
        if self.args.eval_bleu:
            if self.args.dynamic_dataset and len(self.tgt_dicts) > 1:   # FIXME: ugly
                corpus = self.valid_corpora[self._valid_subset]
                sequence_generator = self.sequence_generators[corpus.tgt_lang]
                # change self.tgt_dict for self._inference_with_bleu
                self.tgt_dict = self.tgt_dicts[corpus.tgt_lang]
            else:
                sequence_generator = self.sequence_generators[self.args.target_lang]

        loss, sample_size, logging_output = super().valid_step(sample, model, criterion)
        if self.args.eval_bleu:
            bleu = self._inference_with_bleu(sequence_generator, sample, model)
            logging_output["_bleu_sys_len"] = bleu.sys_len
            logging_output["_bleu_ref_len"] = bleu.ref_len
            # we split counts into separate entries so that they can be
            # summed efficiently across workers using fast-stat-sync
            assert len(bleu.counts) == EVAL_BLEU_ORDER
            for i in range(EVAL_BLEU_ORDER):
                logging_output["_bleu_counts_" + str(i)] = bleu.counts[i]
                logging_output["_bleu_totals_" + str(i)] = bleu.totals[i]
        return loss, sample_size, logging_output

    def reduce_metrics(self, logging_outputs, criterion):
        super().reduce_metrics(logging_outputs, criterion)
        if self.args.eval_bleu:

            def sum_logs(key):
                return sum(log.get(key, 0) for log in logging_outputs)

            counts, totals = [], []
            for i in range(EVAL_BLEU_ORDER):
                counts.append(sum_logs("_bleu_counts_" + str(i)))
                totals.append(sum_logs("_bleu_totals_" + str(i)))

            if max(totals) > 0:
                # log counts as numpy arrays -- log_scalar will sum them correctly
                metrics.log_scalar("_bleu_counts", np.array(counts))
                metrics.log_scalar("_bleu_totals", np.array(totals))
                metrics.log_scalar("_bleu_sys_len", sum_logs("_bleu_sys_len"))
                metrics.log_scalar("_bleu_ref_len", sum_logs("_bleu_ref_len"))

                def compute_bleu(meters):
                    import inspect
                    import sacrebleu

                    fn_sig = inspect.getfullargspec(sacrebleu.compute_bleu)[0]
                    if "smooth_method" in fn_sig:
                        smooth = {"smooth_method": "exp"}
                    else:
                        smooth = {"smooth": "exp"}
                    bleu = sacrebleu.compute_bleu(
                        correct=meters["_bleu_counts"].sum,
                        total=meters["_bleu_totals"].sum,
                        sys_len=meters["_bleu_sys_len"].sum,
                        ref_len=meters["_bleu_ref_len"].sum,
                        **smooth
                    )
                    return round(bleu.score, 2)

                metrics.log_derived("bleu", compute_bleu)

    @property
    def eval_samples_path(self):
        return os.path.join(self.args.save_dir, f"{self._valid_subset}.{self._valid_step}.out")

    def finalize_validation(self):
        """Called just after the end of validation"""
        if self.args.eval_bleu:
            if self.args.eval_bleu_save_samples and distributed_utils.is_master(self.args):
                # "self._inference_with_bleu" saves the decoding outputs in one file per GPU.
                # This function aggregates all these files, sorts their lines by sample id and stores the output in a single file.
                # Ultimately, it'll be more flexible to store the decoding outputs in "logging_outputs" and use
                # "distributed_utils.all_gather_list" to aggregate them. However, this requires more code changes and poses
                # issues in terms of GPU memory.

                lines = {}

                for rank in range(self.args.distributed_world_size):
                    path = f"{self.eval_samples_path}.{rank}.tmp"
                    with open(path) as f:
                        for line in f:
                            id_, line = line.split('\t', maxsplit=1)
                            # if tmp file contains several candidates for one id, keep latest. This can happen when validation is interrupted
                            # and temporary files are not properly deleted.
                            lines[int(id_)] = line
                    os.remove(path)

                lines = list(lines.items())
                lines.sort(key=operator.itemgetter(0))
                lines = [line[1] for line in lines]  # remove indices
                with open(self.eval_samples_path, 'w') as f:
                    f.writelines(lines)

                if self.args.eval_bleu_print_samples:
                    # this gives us the possibility to log samples in the right order (first N samples), contrary
                    # to what was done before in "self._inference_with_bleu"
                    n = self.args.eval_bleu_print_samples or None
                    for line in lines[:n]:
                        line = line.strip('\r\n')
                        logger.info(f'example hypothesis: {line}')

    def max_positions(self):
        """Return the max sentence length allowed by the task."""
        return (self.args.max_source_positions, self.args.max_target_positions)

    @property
    def source_dictionary(self):
        """Return the source :class:`~fairseq.data.Dictionary`."""
        return self.src_dict

    @property
    def target_dictionary(self):
        """Return the target :class:`~fairseq.data.Dictionary`."""
        return self.tgt_dict

    def _inference_with_bleu(self, generator, sample, model):
        import sacrebleu

        def decode(toks, escape_unk=False):
            raw = self.tgt_dict.string(
                toks.int().cpu(),
                # The default unknown string in fairseq is `<unk>`, but
                # this is tokenized by sacrebleu as `< unk >`, inflating
                # BLEU scores. Instead, we use a somewhat more verbose
                # alternative that is unlikely to appear in the real
                # reference, but doesn't get split into multiple tokens.
                unk_string="▁UNKREF" if escape_unk else "▁UNK",
            )
            detok = self.line_decoder(raw)
            return raw, detok

        gen_out = self.inference_step(generator, [model], sample, prefix_tokens=None)
        hyps, refs = [], []
        raw_hyps = []
        for i in range(len(gen_out)):
            raw, detok = decode(gen_out[i][0]["tokens"])
            raw_hyps.append(raw)
            hyps.append(detok)

            tgt_dataset = self.dataset(self._valid_subset).tgt
            if getattr(tgt_dataset, 'keep_raw_lines', False):
                # Use the untokenized references when possible, to compute true BLEU scores.
                # FIXME: might conflict with some tokenization and evaluation settings: for instance,
                # if one wants to compute tokenized BLEU on references that were initially not tokenized.
                ref = tgt_dataset.raw_lines[sample["id"][i]]
            else:
                # Detokenize the tokenized references and compute BLEU against this
                # This might not always reflect true BLEU scores, as tokenization is not always reversible,
                # especially in the presence of unknown symbols.
                _, ref = decode(
                    utils.strip_pad(sample["target"][i], self.tgt_dict.pad()),
                    escape_unk=True,  # don't count <unk> as matches to the hypo
                )

            refs.append(ref)

        if self.args.eval_bleu_save_samples:
            # save the decoding outputs to temporary files that will be aggregated by "self.finalize_validation"
            path = f"{self.eval_samples_path}.{self.args.distributed_rank}.tmp"
            save_dir = os.path.dirname(self.eval_samples_path)
            if not os.path.isdir(save_dir):
                os.makedirs(save_dir, exist_ok=True)
            with open(path, 'a') as f:
                for id_, hyp in zip(sample['id'].tolist(), raw_hyps):
                    print(id_, hyp, sep='\t', file=f)

        if self.args.eval_tokenized_bleu:
            return sacrebleu.corpus_bleu(hyps, [refs], tokenize="none")
        else:
            return sacrebleu.corpus_bleu(hyps, [refs])

    def reader(self):
        """Used by DynamicDatasetServer to read line pairs from parallel corpora"""
        return ParallelReader(self.args, list(self.train_corpora.values()))

    def line_encoder(self, line, meta={}):
        if self.args.skip_empty_lines and not line.strip():
            # Modify the preprocessing functions to avoid adding tokens to empty lines (e.g., tags)
            # It is important that they remain empty for the line pair filtering step
            return ''
        # TODO: retrieve corpus from corpus id and use corpus.src_meta
        meta = dict(meta)
        meta['src_lang'] = meta.get('src_lang') or self.args.source_lang
        meta['lang'] = meta['src_lang']
        meta['corpus_tag'] = meta.get('corpus_tag')
        meta['tgt_lang'] = meta.get('tgt_lang') or self.args.target_lang
        if meta['src_lang'] == meta['tgt_lang']:
            meta['dictionary'] = self.src_dict
        return super().line_encoder(line, meta=meta)

    def tgt_line_encoder(self, line, meta={}):
        if self.args.skip_empty_lines and not line.strip():
            return ''
        meta = dict(meta)
        meta['src_lang'] = meta.get('src_lang') or self.args.source_lang
        meta['tgt_lang'] = meta.get('tgt_lang') or self.args.target_lang
        meta['lang'] = meta['tgt_lang']
        meta['corpus_tag'] = meta.get('corpus_tag')
        line = self.tgt_tokenizer.encode(line, meta=meta)
        line = self.tgt_bpe.encode(line, meta=meta)
        line = self.tgt_tagger.encode(line, meta=meta)
        '''if meta['src_lang'] == meta['tgt_lang']:
            meta['dictionary'] = self.tgt_dict
            line = self.mono_tagger.encode(line, meta=meta)
        else:
            line = self.tgt_tagger.encode(line, meta=meta)'''
        return line

    def src_line_decoder(self, line):
        return super().line_decoder(line, meta=meta)

    def line_decoder(self, line):
        line = self.tgt_tagger.decode(line)
        line = self.remove_bpe(line)
        line = self.detokenize(line)
        return line

    def worker(self, line, meta, worker_id=0):
        """Used by DynamicDatasetServer to encode a pair of lines (tokenization + BPE + tagging + binarization)"""
        #TODO: add monolingual objective option
        src_line, tgt_line = line
        index = meta['line_id']

        src_dict = self.src_dicts.get(meta['src_lang'], self.src_dict)
        src_line = self.line_encoder(src_line, meta)
        src = data_utils.binarize(src_line.split(), src_dict.indices, src_dict.unk_index, src_dict.eos_index)

        tgt_dict = self.tgt_dicts.get(meta['tgt_lang'], self.tgt_dict)
        tgt_line = self.tgt_line_encoder(tgt_line, meta)
        tgt = data_utils.binarize(tgt_line.split(), tgt_dict.indices, tgt_dict.unk_index, tgt_dict.eos_index)

        if self.args.dynamic_dataset_verbose and (index + 1) % 100000 == 0:
            logger.info(f'Sample | {meta["corpus_id"]} | source {index + 1} | {src_line}')
            logger.info(f'Sample | {meta["corpus_id"]} | target {index + 1} | {tgt_line}')

        # skip lines pairs that are too long or empty
        OK = (0 < len(src) <= self.args.max_source_positions and 0 < len(tgt) <= self.args.max_target_positions)
        keys = meta.keys()
        sample = {
            'id': index,
            'source': src,
            'target': tgt,
            'meta': meta,
        }

        return sample, OK

    def batcher(self, buffer):
        if not self.args.batch_by_lang_pair and not self.args.batch_by_target_lang and not self.args.batch_by_lang_pair_and_corpus:
            groups = {'all': buffer}
        else:
            # Only put samples that share the same 'key' in the same batches ('key' can be the language pair or the target language).
            # This may result in some batches that are much smaller than the max batch size, and slow down training.
            # To mitigate this issue, lang_temperature can be increased (to reduce the imbalance between tasks),
            # or dynamic_dataset_buffer can be increased.
            groups = {}
            for sample in buffer:
                src = sample['meta']['src_lang']
                tgt = sample['meta']['tgt_lang']
                corpus = sample['meta']['corpus_tag']
                if self.args.batch_by_lang_pair:
                    key = f"{src}-{tgt}"
                elif self.args.batch_by_lang_pair_and_corpus:
                    key = f"{corpus}.{src}-{tgt}"
                else:
                    key = tgt
                sample['meta']['key'] = key
                groups.setdefault(key, []).append(sample)

        batches = []
        for group in groups.values():
            """Used by DynamicDatasetServer to sort examples in this buffer by length and create efficient batches"""
            src_sizes = np.array([len(sample['source']) for sample in group])
            tgt_sizes = np.array([len(sample['target']) for sample in group])
            # shuffle, then sort the examples by length
            indices = np.random.permutation(len(group))
            indices = indices[np.argsort(tgt_sizes[indices], kind='mergesort')]  # merge sort is stable
            indices = indices[np.argsort(src_sizes[indices], kind='mergesort')]  # so this equivalent to sorting by (src size, tgt size)

            def num_tokens(index):
                return max(src_sizes[index], tgt_sizes[index])

            # then build batches
            for batch in data_utils.batch_by_size(indices, num_tokens, self.args.max_tokens,
                                                  self.args.batch_size, self.args.required_batch_size_multiple):
                batches.append([group[i] for i in batch])

        return batches

    @classmethod
    def get_files(cls, data_dir):
        """Get list of files in given directory that are necessary to run the model (e.g., dicts, BPE model, etc.)"""
        patterns = [r'dict.*\.txt', r'bpecodes.*', r'bpe-vocab.*', r'spm.*']
        paths = []
        if os.path.isdir(data_dir):
            for filename in os.listdir(data_dir):
                if any(re.match(regex + '$', filename) for regex in patterns):
                    paths.append(os.path.join(data_dir, filename))
        return paths

    @classmethod
    def copy_files(cls, args):
        """Copy files from data dir to model dir"""
        for path in cls.get_files(utils.split_paths(args.data)[0]):
            dst = os.path.join(args.save_dir, os.path.basename(path))
            if not os.path.exists(dst):
                try:
                    shutil.copy(path, dst)
                except PermissionError:
                    pass

    def train(self):
        """Put in training mode (e.g., activate BPE dropout or noise generation)"""
        super().train()
        self.tgt_tokenizer.train = self.tgt_bpe.train = self.tgt_tagger.train = True

    def eval(self):
        """Put in inference mode (e.g., disable BPE dropout or noise generation)"""
        super().eval()
        self.tgt_tokenizer.train = self.tgt_bpe.train = self.tgt_tagger.train = False
