# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

import io
import tempfile
import unittest

import torch

from fairseq.data import Dictionary


class TestDictionary(unittest.TestCase):

    def test_finalize(self):
        txt = [
            'A B C D',
            'B C D',
            'C D',
            'D',
        ]
        ref_ids1 = list(map(torch.IntTensor, [
            [4, 5, 6, 7, 2],
            [5, 6, 7, 2],
            [6, 7, 2],
            [7, 2],
        ]))
        ref_ids2 = list(map(torch.IntTensor, [
            [7, 6, 5, 4, 2],
            [6, 5, 4, 2],
            [5, 4, 2],
            [4, 2],
        ]))

        # build dictionary
        d = Dictionary()
        for line in txt:
            d.encode_line(line, add_if_not_exist=True)

        def get_ids(dictionary):
            ids = []
            for line in txt:
                ids.append(dictionary.encode_line(line, add_if_not_exist=False))
            return ids

        def assertMatch(ids, ref_ids):
            for toks, ref_toks in zip(ids, ref_ids):
                self.assertEqual(toks.size(), ref_toks.size())
                self.assertEqual(0, (toks != ref_toks).sum().item())

        ids = get_ids(d)
        assertMatch(ids, ref_ids1)

        # check finalized dictionary
        d.finalize()
        finalized_ids = get_ids(d)
        assertMatch(finalized_ids, ref_ids2)

        # write to disk and reload
        with tempfile.NamedTemporaryFile(mode='w') as tmp_dict:
            d.save(tmp_dict.name)
            d = Dictionary.load(tmp_dict.name)
            reload_ids = get_ids(d)
            assertMatch(reload_ids, ref_ids2)
            assertMatch(finalized_ids, reload_ids)

    def test_overwrite(self):
        # for example, Camembert overwrites <unk>, <s> and </s>
        dict_file = io.StringIO(
            "<unk> 999 #fairseq:overwrite\n"
            "<s> 999 #fairseq:overwrite\n"
            "</s> 999 #fairseq:overwrite\n"
            ", 999\n"
            "▁de 999\n"
        )
        d = Dictionary()
        d.add_from_file(dict_file)
        self.assertEqual(d.index('<pad>'), 1)
        self.assertEqual(d.index('foo'), 3)
        self.assertEqual(d.index('<unk>'), 4)
        self.assertEqual(d.index('<s>'), 5)
        self.assertEqual(d.index('</s>'), 6)
        self.assertEqual(d.index(','), 7)
        self.assertEqual(d.index('▁de'), 8)

    def test_no_overwrite(self):
        # for example, Camembert overwrites <unk>, <s> and </s>
        dict_file = io.StringIO(
            "<unk> 999\n"
            "<s> 999\n"
            "</s> 999\n"
            ", 999\n"
            "▁de 999\n"
        )
        d = Dictionary()
        with self.assertRaisesRegex(RuntimeError, 'Duplicate'):
            d.add_from_file(dict_file)

    def test_space(self):
        # for example, character models treat space as a symbol
        dict_file = io.StringIO(
            "  999\n"
            "a 999\n"
            "b 999\n"
        )
        d = Dictionary()
        d.add_from_file(dict_file)
        self.assertEqual(d.index(' '), 4)
        self.assertEqual(d.index('a'), 5)
        self.assertEqual(d.index('b'), 6)


if __name__ == '__main__':
    unittest.main()
