import sys
sys.path.append('../')
import unittest
from DMB_1side_text import Q_model
import torch
from torch.testing import assert_close
from math import e
class TestQMethod(unittest.TestCase):
    def setUp(self):
        '''
        Testing under the following conditions:
            n = 5
        '''
        n = 5
        seqlen = 2
        m_permute = torch.tensor([[1,0,2,3,4],
                                  [0,1,2,3,4]])
        m_permute_inverse = torch.tensor([[1,0,2,3,4],
                                          [0,1,2,3,4]])
        self.Q_model = Q_model(vocab_size=5,
                               seqlen=2,
                               m_permute=m_permute,
                               m_permute_inverse=m_permute_inverse,
                               initialization="random"
                               )



    def test_get_lambda(self):
        """
        Lambda is a diagonal matrix with negative entries
        """
        # [seqlen, vocab]
        Lambda = self.Q_model.get_lambda()
        assert Lambda.shape == torch.Size([2,5]), f"Lambda is expected to have shape [5] but got {Lambda.shape}"
        assert torch.all(Lambda) <= 0, f"All entries of Lambda should be non positive, but got {Lambda}"
        assert_close(Lambda[:,-1], torch.zeros(Lambda.shape[0]).to(torch.float64), msg=f"The last entry of Lambda is expected to be {torch.zeros(Lambda.shape[0])} but got {Lambda[:, -1]}")



    def test_get_row(self):
        # Case 1
        # a1=4, a2=3, a3=2, a4=1
        self.Q_model.Lambda.data = torch.tensor([[-4, -3, -2, -1],
                                                 [-4, -3, -2, -1]])*1.0
        true_Q = torch.tensor([
                [[ -6, 0,  3,  2, 1],
                 [  4, -10,  3,  2, 1],
                 [  0,  0, -3,  2, 1],
                 [  0,  0,  0, -1, 1],
                 [  0,  0,  0,  0, 0]],
                [[-10, 4, 3, 2, 1],
                 [  0,-6, 3, 2, 1],
                 [  0, 0,-3, 2, 1],
                 [  0, 0, 0,-1, 1],
                 [  0, 0, 0, 0, 0]],
                ]).to(torch.float64)
        ans = torch.tensor([
            [[  0,  0,  0,  0, 0],
             [  0,  0, -3,  2, 1],
             ],
            ]).to(torch.float64)

        idx = torch.tensor([[4, 2]])
        Q_row = self.Q_model.get_row(idx)
        assert_close(Q_row, ans, msg=f"Row {idx}, expected {ans} but got {Q_row}")


        # Case 2
        # a1=0.5, a2=3.7, a3=10.3, a4=1.1
        self.Q_model.Lambda.data = torch.tensor([[-0.5, -3.7, -10.3, -1.1],
                                                 [ 0.5, -3.7, -10.3, -1.1]]) 
        true_Q = torch.tensor([
            [[-15.1,     0,   3.7, 10.3, 1.1],
             [  0.5, -15.6,   3.7, 10.3, 1.1],
             [    0,     0, -11.4, 10.3, 1.1],
             [    0,     0,     0, -1.1, 1.1],
             [    0,     0,     0,    0,   0]],
            [[-15.6, 0.5, 3.7, 10.3, 1.1],
             [    0, -15.1, 3.7, 10.3, 1.1],
             [    0,     0, -11.4, 10.3, 1.1],
             [    0,     0,     0, -1.1, 1.1],
             [    0,     0,     0,    0,   0]],
             ]).to(torch.float64)

        idx = torch.tensor([[4, 2]])
        ans = torch.tensor([
            [[    0,     0,     0,    0,   0],
             [    0,     0, -11.4, 10.3, 1.1],]
            ]).to(torch.float64)
        Q_row = self.Q_model.get_row(idx)
        assert_close(Q_row, ans, msg=f"Row {idx}, expected {ans} but got {Q_row}")

        # Case 3
        # a1=4, a2=3, a3=2, a4=1
        self.Q_model.Lambda.data = torch.tensor([[-4, -3, -2, -1],
                                                 [-4, -3, -2, -1]])*1.0
        self.Q_model.m_permute = torch.tensor([[3, 0, 2, 1, 4],
                                               [1, 0, 2, 3, 4]])
        self.Q_model.m_permute_inverse = torch.tensor([[1, 3, 2, 0, 4],
                                                       [1, 0, 2, 3, 4]])
        true_Q = torch.tensor([
                [[ -1,   0,  0,  0, 1],
                 [  2, -10,  3,  4, 1],
                 [  2,   0, -3,  0, 1],
                 [  2,   0,  3, -6, 1],
                 [  0,   0,  0,  0, 0]],
                [[ -6, 0,  3,  2, 1],
                 [  4, -10,  3,  2, 1],
                 [  0,  0, -3,  2, 1],
                 [  0,  0,  0, -1, 1],
                 [  0,  0,  0,  0, 0]],
                ]).to(torch.float64)

        idx = torch.tensor([[0, 1],
                            [4, 2]])
        ans = torch.tensor([
            [[ -1,   0,  0,  0, 1],
             [  4, -10,  3,  2, 1],],
            [[  0,   0,  0,  0, 0],
             [  0,  0, -3,  2, 1],],
            ]).to(torch.float64)
        Q_row = self.Q_model.get_row(idx)
        assert_close(Q_row, ans, msg=f"Row {idx}, expected {ans} but got {Q_row}")



    def test_get_col(self):
        # Case 1
        # a1=4, a2=3, a3=2, a4=1
        self.Q_model.Lambda.data = torch.tensor([[-4, -3, -2, -1],
                                                 [-4, -3, -2, -1]])*1.0
        true_Q = torch.tensor([
            [[ -6,   0,  3,  2, 1],
             [  4, -10,  3,  2, 1],
             [  0,   0, -3,  2, 1],
             [  0,   0,  0, -1, 1],
             [  0,   0,  0,  0, 0]],
            [[ -10,  4,  3,  2, 1],
             [  0,  -6,  3,  2, 1],
             [  0,   0, -3,  2, 1],
             [  0,   0,  0, -1, 1],
             [  0,   0,  0,  0, 0]],
            ]).to(torch.float64)

        idx = torch.tensor([[3, 2]])
        ans = torch.tensor([[
            [2, 2, 2, -1, 0],
            [3, 3,-3,  0, 0],
            ]]).to(torch.float64)
        Q_col = self.Q_model.get_col(idx)
        assert_close(Q_col, ans, msg=f"Col {idx}, expected {ans} but got {Q_col}")

        # Case 1*
        idx = torch.tensor([[0, 1]])
        ans = torch.tensor([
            [[-6, 4, 0, 0, 0],
            [4, -6, 0, 0, 0]]
            ]).to(torch.float64)
        Q_col = self.Q_model.get_col(idx)
        assert_close(Q_col, ans, msg=f"Col {idx}, expected {ans} but got {Q_col}")


        # Case 2
        # a1=0.5, a2=3.7, a3=10.3, a4=1.1
        self.Q_model.Lambda.data = torch.tensor([[-0.5, -3.7, -10.3, -1.1],
                                                 [-0.5, -3.7, -10.3, -1.1]])
        true_Q = torch.tensor([
            [[-15.1,     0,   3.7, 10.3, 1.1],
             [  0.5, -15.6,   3.7, 10.3, 1.1],
             [    0,     0, -11.4, 10.3, 1.1],
             [    0,     0,     0, -1.1, 1.1],
             [    0,     0,     0,    0,   0]],
            [[-15.6,  0.5,   3.7, 10.3, 1.1],
             [    0,-15.1,   3.7, 10.3, 1.1],
             [    0,    0, -11.4, 10.3, 1.1],
             [    0,    0,     0, -1.1, 1.1],
             [    0,    0,     0,    0,   0]],]).to(torch.float64)

        idx = torch.tensor([[3, 0]])
        ans = torch.tensor([[
            [10.3, 10.3, 10.3, -1.1, 0],
            [-15.6,   0,    0,    0, 0],
            ]]).to(torch.float64)
        Q_col = self.Q_model.get_col(idx)
        assert_close(Q_col, ans, msg=f"Col {idx}, expected {ans} but got {Q_col}")


        # Case 3
        # a1=4, a2=3, a3=2, a4=1
        self.Q_model.Lambda.data = torch.tensor([[-4, -3, -2, -1],
                                                 [-4, -3, -2, -1]]).to(torch.float64)
        self.Q_model.m_permute = torch.tensor([[3, 0, 2, 1, 4],
                                               [0, 1, 2, 3, 4]])
        self.Q_model.m_permute_inverse = torch.tensor([[1, 3, 2, 0, 4],
                                                      [0, 1, 2, 3, 4]])
        true_Q = torch.tensor([
                [[ -1,   0,  0,  0, 1],
                 [  2, -10,  3,  4, 1],
                 [  2,   0, -3,  0, 1],
                 [  2,   0,  3, -6, 1],
                 [  0,   0,  0,  0, 0]],
                [[-10, 4,  3,  2, 1],
                 [  0,-6,  3,  2, 1],
                 [  0, 0, -3,  2, 1],
                 [  0, 0,  0, -1, 1],
                 [  0, 0,  0,  0, 0]]
                ]).to(torch.float64)

        idx = torch.tensor([[0, 1],
                            [3, 2]])
        ans = torch.tensor([
            [[-1, 2, 2, 2, 0],
             [4, -6, 0, 0, 0]],
            [[0, 4, 0, -6, 0],
             [3, 3, -3, 0, 0]]
            ]).to(torch.float64)
        Q_col = self.Q_model.get_col(idx)
        assert_close(Q_col, ans, msg=f"Col {idx}, expected {ans} but got {Q_col}")

        idx = torch.tensor([[4, 2]])
        ans = torch.tensor([
            [[1, 1, 1, 1, 0],
             [3, 3, -3, 0, 0]]
            ]).to(torch.float64)
        Q_col = self.Q_model.get_col(idx)
        assert_close(Q_col, ans, msg=f"Col {idx}, expected {ans} but got {Q_col}")





    def test_get_exp_row(self):
        '''
        Test matrix exponential
        '''
        # Case 1
        # a1=4, a2=3, a3=2, a4=1
        self.Q_model.Lambda.data = torch.tensor([[-4, -3, -2, -1],
                                                 [-4, -3, -2, -1]])*1.0
        true_exp_Q = torch.tensor([
            [[e**(-6),          0,        (-1 + e**3)/e**6, (-1 + e**2)/e**3, 1 - e**(-1)],
            [(-1 + e**4)/e**10, e**(10),  (-1 + e**3)/e**6, (-1 + e**2)/e**3, 1 - e**(-1)],
            [ 0,                0,        e**(-3),          (-1 + e**2)/e**3, 1 - e**(-1)],
            [ 0,                0,        0,                e**(-1),          1 - e**(-1)], 
            [ 0,                0,        0,                0,                1]],

            [[e**(-10),    (-1 + e**4)/e**(10),  (-1 + e**3)/e**6, (-1 + e**2)/e**3, 1 - e**(-1)],
            [ 0,           e**(-6),              (-1 + e**3)/e**6, (-1 + e**2)/e**3, 1 - e**(-1)],
            [ 0,           0,                    e**(-3),          (-1 + e**2)/e**3, 1 - e**(-1)],
            [ 0,           0,                    0,                e**(-1),          1 - e**(-1)], 
            [ 0,           0,                    0,                0,                1]],
            ]).to(torch.float64)
        h_sigma_t = torch.tensor([[1, 0.5]])
        idx = torch.tensor([[4, 2], [0, 1]])
        ans = torch.tensor([
            [[ 0, 0, 0, 0, 1],
            [  0, 0, e**(-3), (-1 + e**2)/e**3, 1 - e**(-1)]],

            [[e**(-3),     0 ,       (-1 + e**1.5)/e**3, (-1 + e**1)/e**1.5, 1 - e**(-0.5)],
            [ 0,      e**(-3),           (-1 + e**1.5)/e**3, (-1 + e**1)/e**1.5, 1 - e**(-0.5)]]
            ]).to(torch.float64)
        exp_Q_row = self.Q_model.get_exp_row(idx, h_sigma_t)
        assert_close(exp_Q_row, 
                     ans, 
                     msg=f"Case 1, Exp Row {idx}, expected {ans} but got {exp_Q_row}")

        # Case 2
        # a1=0.5, a2=3.7, a3=10.3, a4=1.1
        self.Q_model.Lambda.data = torch.tensor([[-0.5, -3.7, -10.3, -1.1],
                                                 [-4, -3, -2, -1]])*1.0
        true_exp_Q = torch.tensor([
            [[e**(-15.1),  0, e**(-11.4)-e**(-15.1), e**(-1.1)-e**(-11.4), 1-e**(-1.1)],
            [e**(-15.1)-e**(-15.6), e**(-15.6), e**(-11.4)-e**(-15.1), e**(-1.1)-e**(-11.4), 1-e**(-1.1)],
            [0, 0, e**(-11.4), e**(-1.1)-e**(-11.4), 1-e**(-1.1)],
            [0, 0,          0,            e**(-1.1),            1-e**(-1.1)],
            [0, 0, 0, 0, 1]],

            [[e**(-10), (-1 + e**4)/e**(10),  (-1 + e**3)/e**6,   (-1 + e**2)/e**3, 1 - e**(-1)],
            [ 0,        e**(-6),              (-1 + e**3)/e**6,   (-1 + e**2)/e**3, 1 - e**(-1)],
            [ 0,        0,                    e**(-3),            (-1 + e**2)/e**3, 1 - e**(-1)],
            [ 0,        0,                    0,                  e**(-1),          1 - e**(-1)], 
            [ 0,        0,                    0,                  0,                1]],
            ]).to(torch.float64)

        idx = torch.tensor([[1, 3]])
        h_sigma_t = torch.tensor([1])
        ans = torch.tensor([
            [[e**(-15.1)-e**(-15.6), e**(-15.6), e**(-11.4)-e**(-15.1), e**(-1.1)-e**(-11.4), 1-e**(-1.1)],
            [ 0,           0,                     0,                  e**(-1),          1 - e**(-1)]]]).to(torch.float64)
        exp_Q_row = self.Q_model.get_exp_row(idx, h_sigma_t)
        assert_close(exp_Q_row, 
                     ans, 
                     msg=f"Exp Row {idx}, expected {ans} but got {exp_Q_row}")



        # Case 3
        # a1=1, a2=1, a3=1, a4=1
        self.Q_model.Lambda.data = torch.tensor([[-1, -1, -1, -1], 
                                                 [-1, -1, -1, -1]])*1.0
        true_exp_Q = torch.tensor([
            [[e**(-3),         0,       e**(-2)-e**(-3), e**(-1)-e**(-2), 1-e**(-1)],
            [e**(-3)-e**(-4), e**(-4), e**(-2)-e**(-3), e**(-1)-e**(-2), 1-e**(-1)],
            [0,                 0,       e**(-2),           e**(-1)-e**(-2), 1-e**(-1)],
            [0,                 0,       0,                 e**(-1),         1-e**(-1)],
            [0,                 0,       0,                 0,                 1]],
            [[ e**(-4), e**(-3)-e**(-4), e**(-2)-e**(-3), e**(-1)-e**(-2), 1-e**(-1)],
            [  0,       e**(-3),         e**(-2)-e**(-3), e**(-1)-e**(-2), 1-e**(-1)],
            [0,                 0,       e**(-2),           e**(-1)-e**(-2), 1-e**(-1)],
            [0,                 0,       0,                 e**(-1),         1-e**(-1)],
            [0,                 0,       0,                 0,                 1]]]).to(torch.float64)

        h_sigma_t = torch.tensor([0.5])
        idx = torch.tensor([[2,3]])
        ans = torch.tensor([
            [[0,                 0,       e**(-1),           e**(-0.5)-e**(-1), 1-e**(-0.5)],
             [0,                 0,       0,                 e**(-0.5),         1-e**(-0.5)]]
            ]).to(torch.float64)

        exp_Q_row = self.Q_model.get_exp_row(idx, h_sigma_t)
        assert_close(exp_Q_row, 
                     ans, 
                     msg=f"Exp Row {idx} h_sigma_t {h_sigma_t}, expected {ans} but got {exp_Q_row}")


    
if __name__=="__main__":
    unittest.main()

