import sys
sys.path.append('../')
import unittest
from DMB_1side_text import Q_model
import torch
from torch.testing import assert_close
from math import e
class TestQMethod(unittest.TestCase):
    def setUp(self):
        '''
        Testing under the following conditions:
            vocab_size = 5
            seqlen = 3
        '''
        vocab_size = 5
        m_permute = torch.tensor([[1,0,2,3,4], 
                                  [0,1,2,3,4]])

        m_permute_inverse = torch.tensor([[1,0,2,3,4],
                                          [0,1,2,3,4]])
        self.Q_model = Q_model(vocab_size=5,
                               seqlen=2,
                               m_permute=m_permute,
                               m_permute_inverse=m_permute_inverse,
                               initialization="random")



    def test_cal_p_T(self):
        '''
        Test the calculation of p_T = p_0 @ exp(h_sigma_t * Q)
        '''
        self.Q_model.Lambda.data = torch.tensor([[-1, -1, -1, -1],
                                                 [-1, -1, -1, -1]])*1.0

        h_sigma_t = torch.tensor([0.5, 0.5])

        # Case 1
        p_0 = torch.tensor([[  1,   0, 0, 0, 0],
                            [0.5, 0.5, 0, 0, 0],])
        p_T = self.Q_model.cal_p_T(p_0, h_sigma_t)
        true_p_T = torch.tensor([
                [[e**(-1.5), 0, e**(-1)-e**(-1.5), e**(-0.5)-e**(-1), 1-e**(-0.5)],
                 [0.5*e**(-2), e**(-1.5)-0.5*e**(-2), e**(-1)-e**(-1.5),e**(-0.5)-e**(-1), 1-e**(-0.5)],],
                [[e**(-1.5), 0, e**(-1)-e**(-1.5), e**(-0.5)-e**(-1), 1-e**(-0.5)],
                 [0.5*e**(-2), e**(-1.5)-0.5*e**(-2), e**(-1)-e**(-1.5),e**(-0.5)-e**(-1), 1-e**(-0.5)],],
                ]
            ).to(torch.float64)
        assert_close(p_T, true_p_T, msg=f"Expected {true_p_T} but got {p_T}")


if __name__=="__main__":
    unittest.main()

