import sys
sys.path.append('../')
import unittest
from DMB_1side_text import Q_model
import torch
from torch.testing import assert_close
class TestQMethod(unittest.TestCase):
    def setUp(self):
        '''
        Testing under the following conditions:
            vocab=5
            seqlen=2
        '''
        m_permute = torch.tensor([[1,0,2,3,4],
                                  [0,1,2,3,4]])
        m_permute_inverse = torch.tensor([[1,0,2,3,4],
                                          [0,1,2,3,4]])
        self.Q_model = Q_model(vocab_size=5,
                               seqlen=2,
                               m_permute=m_permute,
                               m_permute_inverse=m_permute_inverse,
                               initialization="random")



    def test_stat_count(self):
        # Case 1
        self.Q_model.seqlen = 5
        data = torch.tensor([[2, 3, 4, 4, 2],
                             [3, 4, 4, 2, 1]])
        true_count = torch.tensor([[0, 0, 1, 1, 0],
                                   [0, 0, 0, 1, 1],
                                   [0, 0, 0, 0, 2],
                                   [0, 0, 1, 0, 1],
                                   [0, 1, 1, 0, 0]])
        
        count = self.Q_model.stat_count(data)
        assert_close(count, true_count, msg=f"Expected {true_count} but got {count}")

        # Case 2
        self.Q_model.seqlen = 5
        data = torch.tensor([[3, 4, 4, 2, 1],
                             [3, 4, 4, 2, 1],
                             [0, 0, 0, 0, 2],])

        true_count = torch.tensor([[1, 0, 0, 2, 0],
                                   [1, 0, 0, 0, 2],
                                   [1, 0, 0, 0, 2],
                                   [1, 0, 2, 0, 0],
                                   [0, 2, 1, 0, 0]])
        count = self.Q_model.stat_count(data)
        assert_close(count, true_count, msg=f"Expected {true_count} but got {count}")



    def test_stat_m_permute(self):
        self.Q_model.seqlen = 2
        phi_data = torch.tensor([[2, 4, 4, 4, 2, 1, 4, 2, 0, 0, 4, 4],
                                 [3, 4, 4, 2, 1, 0, 2, 0, 0, 4, 4, 0]]).T
        mu_data = torch.tensor([[3, 4, 4, 2, 1, 0, 2, 0, 0, 4, 3, 1, 2, 3, 4, 4, 2, 1, 0, 2, 0, 0, 4, 3, 0],
                                [3, 4, 4, 2, 1, 1, 2, 0, 0, 4, 3, 0, 2, 3, 4, 4, 2, 1, 0, 2, 0, 0, 4, 3, 0]]).T

        phi_dist = torch.tensor([[2, 1, 3, 0, 6],
                                 [4, 1, 2, 1, 4]])
        mu_dist = torch.tensor([[7, 3, 5, 4, 6],
                                [7, 3, 5, 4, 6]])
        ratio = torch.tensor([[2/7, 1/3, 3/5, 0, 6/6],
                              [4/7, 1/3, 2/5, 1/4, 4/6]])

        true_m_permute_inverse = torch.tensor([[3, 0, 1, 2, 4],
                                               [3, 1, 2, 0, 4]])
        true_m_permute = torch.tensor([[1, 2, 3, 0, 4],
                                       [3, 1, 2, 0, 4]])

        m_permute, m_permute_inverse = self.Q_model.stat_m_permute(mu_data, phi_data)
        m_permute = m_permute.to('cpu')
        m_permute_inverse = m_permute_inverse.to('cpu')
        assert_close(m_permute, true_m_permute, msg=f"Permutation Matrix Expected {true_m_permute} but got {m_permute}")
        assert_close(m_permute_inverse, true_m_permute_inverse, msg=f"Inverse Permutation Matrix Expected {true_m_permute_inverse} but got {m_permute_inverse}")



if __name__=="__main__":
    unittest.main()

