# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

import argparse
import unittest

import tests.utils as test_utils
import torch
from fairseq.sequence_scorer import SequenceScorer


class TestSequenceScorer(unittest.TestCase):
    def test_sequence_scorer(self):
        # construct dummy dictionary
        d = test_utils.dummy_dictionary(vocab_size=2)
        self.assertEqual(d.pad(), 1)
        self.assertEqual(d.eos(), 2)
        self.assertEqual(d.unk(), 3)
        eos = d.eos()
        w1 = 4
        w2 = 5

        # construct dataloader
        data = [
            {
                "source": torch.LongTensor([w1, w2, eos]),
                "target": torch.LongTensor([w1, w2, w1, eos]),
            },
            {
                "source": torch.LongTensor([w2, eos]),
                "target": torch.LongTensor([w2, w1, eos]),
            },
            {
                "source": torch.LongTensor([w2, eos]),
                "target": torch.LongTensor([w2, eos]),
            },
        ]
        data_itr = test_utils.dummy_dataloader(data)

        # specify expected output probabilities
        args = argparse.Namespace()
        unk = 0.0
        args.beam_probs = [
            # step 0:
            torch.FloatTensor(
                [
                    # eos      w1   w2
                    [0.0, unk, 0.6, 0.4],  # sentence 1
                    [0.0, unk, 0.4, 0.6],  # sentence 2
                    [0.0, unk, 0.7, 0.3],  # sentence 3
                ]
            ),
            # step 1:
            torch.FloatTensor(
                [
                    # eos      w1   w2
                    [0.0, unk, 0.2, 0.7],  # sentence 1
                    [0.0, unk, 0.8, 0.2],  # sentence 2
                    [0.7, unk, 0.1, 0.2],  # sentence 3
                ]
            ),
            # step 2:
            torch.FloatTensor(
                [
                    # eos       w1    w2
                    [0.10, unk, 0.50, 0.4],  # sentence 1
                    [0.15, unk, 0.15, 0.7],  # sentence 2
                    [0.00, unk, 0.00, 0.0],  # sentence 3
                ]
            ),
            # step 3:
            torch.FloatTensor(
                [
                    # eos      w1    w2
                    [0.9, unk, 0.05, 0.05],  # sentence 1
                    [0.0, unk, 0.00, 0.0],  # sentence 2
                    [0.0, unk, 0.00, 0.0],  # sentence 3
                ]
            ),
        ]
        expected_scores = [
            [0.6, 0.7, 0.5, 0.9],  # sentence 1
            [0.6, 0.8, 0.15],  # sentence 2
            [0.3, 0.7],  # sentence 3
        ]

        task = test_utils.TestTranslationTask.setup_task(args, d, d)
        model = task.build_model(args)
        scorer = SequenceScorer(task.target_dictionary)
        for sample in data_itr:
            hypos = task.inference_step(scorer, [model], sample)
            for id, hypos_id in zip(sample["id"].tolist(), hypos):
                self.assertHypoTokens(hypos_id[0], data[id]["target"])
                self.assertHypoScore(hypos_id[0], expected_scores[id])

    def assertHypoTokens(self, hypo, tokens):
        self.assertTensorEqual(hypo["tokens"], torch.LongTensor(tokens))

    def assertHypoScore(self, hypo, pos_probs, normalized=True, lenpen=1.0):
        pos_scores = torch.FloatTensor(pos_probs).log()
        self.assertAlmostEqual(hypo["positional_scores"], pos_scores)
        self.assertEqual(pos_scores.numel(), hypo["tokens"].numel())
        score = pos_scores.sum()
        if normalized:
            score /= pos_scores.numel() ** lenpen
        self.assertLess(abs(score - hypo["score"]), 1e-6)

    def assertAlmostEqual(self, t1, t2):
        self.assertEqual(t1.size(), t2.size(), "size mismatch")
        self.assertLess((t1 - t2).abs().max(), 1e-4)

    def assertTensorEqual(self, t1, t2):
        self.assertEqual(t1.size(), t2.size(), "size mismatch")
        self.assertEqual(t1.ne(t2).long().sum(), 0)


if __name__ == "__main__":
    unittest.main()
