# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

import unittest
from tests.speech import TestFairseqSpeech
from fairseq import utils

S3_BASE_URL = "https://dl.fbaipublicfiles.com/fairseq/"


class TestS2STransformer(TestFairseqSpeech):
    def setUp(self):
        self._set_up(
            "s2s",
            "speech_tests/s2s",
            [
                "dev_shuf200.tsv",
                "src_feat.zip",
                "config_specaug_lb.yaml",
                "vocoder",
                "vocoder_config.json",
            ],
        )

    def test_s2s_transformer_checkpoint(self):
        self.base_test(
            ckpt_name="s2u_transformer_reduced_fisher.pt",
            reference_score=38.3,
            dataset="dev_shuf200",
            arg_overrides={
                "config_yaml": "config_specaug_lb.yaml",
                "multitask_config_yaml": None,
                "target_is_code": True,
                "target_code_size": 100,
                "eval_inference": False,
            },
            score_type="bleu",
            strict=False,
        )

    def postprocess_tokens(self, task, target, hypo_tokens):
        tgt_tokens = utils.strip_pad(target, task.tgt_dict.pad()).int().cpu()
        tgt_str = task.tgt_dict.string(tgt_tokens)
        hypo_str = task.tgt_dict.string(hypo_tokens)
        return tgt_str, hypo_str


if __name__ == "__main__":
    unittest.main()
