# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

import unittest
import torch
from tests.speech import TestFairseqSpeech
from fairseq.data.data_utils import post_process
from fairseq import utils
from omegaconf import open_dict

S3_BASE_URL = "https://dl.fbaipublicfiles.com/fairseq"


@unittest.skipIf(not torch.cuda.is_available(), "test requires a GPU")
class TestWav2Vec2(TestFairseqSpeech):
    def setUp(self):
        self._set_up(
            "librispeech_w2v2",
            "conformer/wav2vec2/librispeech",
            [
                "test_librispeech-other.ltr",
                "test_librispeech-other.tsv",
                "test_librispeech-other_small.ltr_100",
                "test_librispeech-other_small.tsv",
                "test-other.zip",
                "dict.ltr.txt",
                "dict.ltr_100.txt",
            ],
        )
        self.unzip_files(
            "test-other.zip",
        )

    def test_transformer_w2v2(self):
        self.base_test(
            ckpt_name="transformer_oss_small_100h.pt",
            reference_score=38,
            score_delta=1,
            dataset="test_librispeech-other",
            max_tokens=1000000,
            max_positions=(700000, 1000),
            arg_overrides={
                "task": "audio_finetuning",
                "labels": "ltr",
                "nbest": 1,
                "tpu": False,
            },
            strict=False,
        )

    def test_conformer_w2v2(self):
        self.base_test(
            ckpt_name="conformer_LS_PT_LS_FT_rope.pt",
            reference_score=4.5,
            score_delta=1,
            dataset="test_librispeech-other_small",
            max_tokens=1000000,
            max_positions=(700000, 1000),
            arg_overrides={
                "task": "audio_finetuning",
                "labels": "ltr_100",
                "nbest": 1,
                "tpu": False,
            },
            strict=True,
        )

    def build_generator(self, task, models, cfg):
        try:
            from examples.speech_recognition.w2l_decoder import W2lViterbiDecoder
        except Exception:
            raise Exception("Cannot run this test without flashlight dependency")
        with open_dict(cfg):
            cfg.nbest = 1
        return W2lViterbiDecoder(cfg, task.target_dictionary)

    def postprocess_tokens(self, task, target, hypo_tokens):
        tgt_tokens = utils.strip_pad(target, task.target_dictionary.pad()).int().cpu()
        tgt_str = task.target_dictionary.string(tgt_tokens)
        tgt_str = post_process(tgt_str, "letter")

        hypo_pieces = task.target_dictionary.string(hypo_tokens)
        hypo_str = post_process(hypo_pieces, "letter")
        return tgt_str, hypo_str


if __name__ == "__main__":
    unittest.main()
