# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

import torch
import unittest
from fairseq.modules.sparse_multihead_attention import SparseMultiheadAttention


class TestSparseMultiheadAttention(unittest.TestCase):
    def test_sparse_multihead_attention(self):
        attn_weights = torch.randn(1, 8, 8)
        bidirectional_sparse_mask = torch.tensor([
                [0, 0, 0, 0, 0, float('-inf'), float('-inf'), 0],
                [0, 0, 0, 0, 0, float('-inf'), float('-inf'), 0],
                [0, 0, 0, 0, 0, float('-inf'), float('-inf'), 0],
                [0, 0, 0, 0, 0, float('-inf'), float('-inf'), 0],
                [float('-inf'), float('-inf'), float('-inf'), 0, 0, 0, 0, 0],
                [float('-inf'), float('-inf'), float('-inf'), 0, 0, 0, 0, 0],
                [float('-inf'), float('-inf'), float('-inf'), 0, 0, 0, 0, 0],
                [float('-inf'), float('-inf'), float('-inf'), 0, 0, 0, 0, 0]
            ])

        bidirectional_attention = SparseMultiheadAttention(16, 1, stride=4, expressivity=1, is_bidirectional=True)
        bidirectional_attention_sparse_mask = bidirectional_attention.buffered_sparse_mask(attn_weights, 8, 8)
        torch.all(torch.eq(bidirectional_attention_sparse_mask, bidirectional_sparse_mask))

        sparse_mask = torch.tensor([
                [0, float('-inf'), float('-inf'), float('-inf'), float('-inf'), float('-inf'),
                 float('-inf'), float('-inf')],
                [0, 0, float('-inf'), float('-inf'), float('-inf'), float('-inf'), float('-inf'), float('-inf')],
                [0, 0, 0, float('-inf'), float('-inf'), float('-inf'), float('-inf'), float('-inf')],
                [0, 0, 0, 0, float('-inf'), float('-inf'), float('-inf'), float('-inf')],
                [0, 0, 0, 0, 0, float('-inf'), float('-inf'), float('-inf')],
                [float('-inf'), float('-inf'), float('-inf'), 0, 0, 0, float('-inf'), float('-inf')],
                [float('-inf'), float('-inf'), float('-inf'), 0, 0, 0, 0, float('-inf')],
                [float('-inf'), float('-inf'), float('-inf'), 0, 0, 0, 0, 0],
            ])

        attention = SparseMultiheadAttention(16, 1, stride=4, expressivity=1, is_bidirectional=False)
        attention_sparse_mask = attention.buffered_sparse_mask(attn_weights, 8, 8)

        torch.all(torch.eq(attention_sparse_mask, sparse_mask))


if __name__ == '__main__':
    unittest.main()
