#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

import unittest

import numpy as np
import torch
from examples.speech_recognition.data.collaters import Seq2SeqCollater


class TestSeq2SeqCollator(unittest.TestCase):
    def test_collate(self):

        eos_idx = 1
        pad_idx = 0
        collater = Seq2SeqCollater(
            feature_index=0, label_index=1, pad_index=pad_idx, eos_index=eos_idx
        )

        # 2 frames in the first sample and 3 frames in the second one
        frames1 = np.array([[7, 8], [9, 10]])
        frames2 = np.array([[1, 2], [3, 4], [5, 6]])
        target1 = np.array([4, 2, 3, eos_idx])
        target2 = np.array([3, 2, eos_idx])
        sample1 = {"id": 0, "data": [frames1, target1]}
        sample2 = {"id": 1, "data": [frames2, target2]}
        batch = collater.collate([sample1, sample2])

        # collate sort inputs by frame's length before creating the batch
        self.assertTensorEqual(batch["id"], torch.tensor([1, 0]))
        self.assertEqual(batch["ntokens"], 7)
        self.assertTensorEqual(
            batch["net_input"]["src_tokens"],
            torch.tensor(
                [[[1, 2], [3, 4], [5, 6]], [[7, 8], [9, 10], [pad_idx, pad_idx]]]
            ),
        )
        self.assertTensorEqual(
            batch["net_input"]["prev_output_tokens"],
            torch.tensor([[eos_idx, 3, 2, pad_idx], [eos_idx, 4, 2, 3]]),
        )
        self.assertTensorEqual(batch["net_input"]["src_lengths"], torch.tensor([3, 2]))
        self.assertTensorEqual(
            batch["target"],
            torch.tensor([[3, 2, eos_idx, pad_idx], [4, 2, 3, eos_idx]]),
        )
        self.assertEqual(batch["nsentences"], 2)

    def assertTensorEqual(self, t1, t2):
        self.assertEqual(t1.size(), t2.size(), "size mismatch")
        self.assertEqual(t1.ne(t2).long().sum(), 0)


if __name__ == "__main__":
    unittest.main()
