import torch
import numpy as np
import unittest
from fairseq.modules import (
    ESPNETMultiHeadedAttention,
    RelPositionMultiHeadedAttention,
    RotaryPositionMultiHeadedAttention,
)

torch.use_deterministic_algorithms(True)


class TestESPNETMultiHeadedAttention(unittest.TestCase):
    def setUp(self) -> None:
        self.T = 3
        self.B = 1
        self.C = 2
        torch.manual_seed(0)
        self.sample = torch.randn(self.T, self.B, self.C)  # TBC
        self.sample_scores = torch.randn(self.B, 1, self.T, self.T)
        self.MHA = ESPNETMultiHeadedAttention(self.C, 1, dropout=0)

    def test_forward(self):
        expected_scores = torch.tensor(
            [[[0.1713, -0.3776]], [[0.2263, -0.4486]], [[0.2243, -0.4538]]]
        )
        scores, _ = self.MHA(self.sample, self.sample, self.sample)
        self.assertTrue(
            np.allclose(
                expected_scores.cpu().detach().numpy(),
                scores.cpu().detach().numpy(),
                atol=1e-4,
            )
        )

    def test_forward_qkv(self):
        expected_query = torch.tensor(
            [[[[-1.0235, 0.0409], [0.4008, 1.3077], [0.5396, 2.0698]]]]
        )
        expected_key = torch.tensor(
            [[[[0.5053, -0.4965], [-0.3730, -0.9473], [-0.7019, -0.1935]]]]
        )
        expected_val = torch.tensor(
            [[[[-0.9940, 0.5403], [0.5924, -0.7619], [0.7504, -1.0892]]]]
        )
        sample_t = self.sample.transpose(0, 1)
        query, key, val = self.MHA.forward_qkv(sample_t, sample_t, sample_t)
        self.assertTrue(
            np.allclose(
                expected_query.cpu().detach().numpy(),
                query.cpu().detach().numpy(),
                atol=1e-4,
            )
        )
        self.assertTrue(
            np.allclose(
                expected_key.cpu().detach().numpy(),
                key.cpu().detach().numpy(),
                atol=1e-4,
            )
        )
        self.assertTrue(
            np.allclose(
                expected_val.cpu().detach().numpy(),
                val.cpu().detach().numpy(),
                atol=1e-4,
            )
        )

    def test_forward_attention(self):
        expected_scores = torch.tensor(
            [[[0.1627, -0.6249], [-0.2547, -0.6487], [-0.0711, -0.8545]]]
        )
        scores = self.MHA.forward_attention(
            self.sample.transpose(0, 1).view(self.B, 1, self.T, self.C),
            self.sample_scores,
            mask=None,
        )
        self.assertTrue(
            np.allclose(
                expected_scores.cpu().detach().numpy(),
                scores.cpu().detach().numpy(),
                atol=1e-4,
            )
        )


class TestRelPositionMultiHeadedAttention(unittest.TestCase):
    def setUp(self) -> None:
        self.T = 3
        self.B = 1
        self.C = 2
        torch.manual_seed(0)
        self.sample = torch.randn(self.T, self.B, self.C)  # TBC
        self.sample_x = torch.randn(self.B, 1, self.T, self.T * 2 - 1)
        self.sample_pos = torch.randn(self.B, self.T * 2 - 1, self.C)
        self.MHA = RelPositionMultiHeadedAttention(self.C, 1, dropout=0)

    def test_rel_shift(self):
        expected_x = torch.tensor(
            [
                [
                    [
                        [-0.7193, -0.4033, -0.5966],
                        [-0.8567, 1.1006, -1.0712],
                        [-0.5663, 0.3731, -0.8920],
                    ]
                ]
            ]
        )
        x = self.MHA.rel_shift(self.sample_x)
        self.assertTrue(
            np.allclose(
                expected_x.cpu().detach().numpy(),
                x.cpu().detach().numpy(),
                atol=1e-4,
            )
        )

    def test_forward(self):
        expected_scores = torch.tensor(
            [
                [[-0.9609, -0.5020]],
                [[-0.9308, -0.4890]],
                [[-0.9473, -0.4948]],
                [[-0.9609, -0.5020]],
                [[-0.9308, -0.4890]],
                [[-0.9473, -0.4948]],
                [[-0.9609, -0.5020]],
                [[-0.9308, -0.4890]],
                [[-0.9473, -0.4948]],
                [[-0.9609, -0.5020]],
                [[-0.9308, -0.4890]],
                [[-0.9473, -0.4948]],
                [[-0.9609, -0.5020]],
                [[-0.9308, -0.4890]],
                [[-0.9473, -0.4948]],
            ]
        )
        scores, _ = self.MHA(self.sample, self.sample, self.sample, self.sample_pos)
        self.assertTrue(
            np.allclose(
                expected_scores.cpu().detach().numpy(),
                scores.cpu().detach().numpy(),
                atol=1e-4,
            )
        )


class TestRotaryPositionMultiHeadedAttention(unittest.TestCase):
    def setUp(self) -> None:
        self.T = 3
        self.B = 1
        self.C = 2
        torch.manual_seed(0)
        self.sample = torch.randn(self.T, self.B, self.C)  # TBC
        self.MHA = RotaryPositionMultiHeadedAttention(
            self.C, 1, dropout=0, precision=None
        )

    def test_forward(self):
        expected_scores = torch.tensor(
            [[[-0.3220, -0.4726]], [[-1.2813, -0.0979]], [[-0.3138, -0.4758]]]
        )
        scores, _ = self.MHA(self.sample, self.sample, self.sample)
        self.assertTrue(
            np.allclose(
                expected_scores.cpu().detach().numpy(),
                scores.cpu().detach().numpy(),
                atol=1e-4,
            )
        )


if __name__ == "__main__":
    unittest.main()
