import unittest

import torch

SKIP_TEST = None
try:
    from apex.contrib.transducer import TransducerJoint
    from apex.contrib.transducer import _transducer_ref as transducer_ref
except ImportError as e:
    SKIP_TEST = e


@unittest.skipIf(SKIP_TEST, f"{SKIP_TEST}")
class TransducerJointTest(unittest.TestCase):
    def setUp(self, seed=1234):
        torch.manual_seed(seed)

    def gen_input(self, for_vector_kernel):
        self.B = 4
        T_min = 51
        T_max = 101
        U_min = 12
        U_max = 25
        if for_vector_kernel:
            H = 512
        else:
            H = 509
        dtype = torch.float16
        device = "cuda"

        self.f_tst = torch.randn((self.B, T_max, H), dtype=dtype, requires_grad=True, device=device)
        self.g_tst = torch.randn((self.B, U_max, H), dtype=dtype, requires_grad=True, device=device)
        self.h_grad = torch.randn(self.B, T_max, U_max, H, dtype=dtype, device=device)
        self.f_len = torch.randint(T_min, T_max+1, (self.B,), dtype=torch.int, device=device)
        self.g_len = torch.randint(U_min, U_max+1, (self.B,), dtype=torch.int, device=device)
        self.f_len[torch.randint(0, self.B, (1,)).item()] = T_max
        self.g_len[torch.randint(0, self.B, (1,)).item()] = U_max
        self.dropout_prob = 0.5

        # Make sure gradients from out-of-bound locations are zero. This should be guaranteed by
        # the loss function
        for b in range(self.B):
            self.h_grad[b, self.f_len[b]:, :, :] = 0
            self.h_grad[b, :, self.g_len[b]:, :] = 0
        self.h_grad_packed = self._pack(self.h_grad, self.f_len, self.g_len)


    def _pack(self, x, f_len, g_len):
        B = x.size(0)
        list_x = []
        for b in range(B):
            list_x_row = [x[b, t, :g_len[b]] for t in range(f_len[b])]
            x_row = torch.cat(list_x_row)
            list_x.append(x_row)
        x_packed = torch.cat(list_x).data.clone()
        x_packed.requires_grad = True
        batch_offset = torch.cumsum(f_len * g_len, dim=0)
        return x_packed

    def _unpack(self, x, f_len, g_len):
        batch_offset = torch.cumsum(f_len * g_len, dim=0)
        x_unpacked = torch.zeros_like(self.h_grad, dtype=torch.uint8)
        B = self.h_grad.size(0)
        H = self.h_grad.size(-1)
        for b in range(B):
            my_batch_offset = 0 if b == 0 else batch_offset[b-1]
            my_f_len = f_len[b]
            my_g_len = g_len[b]
            for t in range(my_f_len):
                x_unpacked[b, t, :my_g_len] = x[my_batch_offset + t*my_g_len :
                                                my_batch_offset + t*my_g_len + my_g_len]
        return x_unpacked

    def run_transducer_joint(self, for_vector_kernel, pack_output, relu, dropout):
        self.gen_input(for_vector_kernel=for_vector_kernel)
        # Generate reference
        f_ref = self.f_tst.data.clone()
        g_ref = self.g_tst.data.clone()
        f_ref.requires_grad = True
        g_ref.requires_grad = True

        my_joint = TransducerJoint(pack_output=pack_output, relu=relu, dropout=dropout,
                                    dropout_prob=self.dropout_prob, probe_mask=True)
        if not pack_output:
            h_tst = my_joint(   f=self.f_tst,
                                g=self.g_tst,
                                f_len=self.f_len,
                                g_len=self.g_len)
            h_tst.backward(self.h_grad)
            if dropout:
                mask = my_joint.mask_probe[0]
        else:
            batch_offset = torch.cumsum(self.f_len * self.g_len, dim=0)
            h_tst = my_joint(   f=self.f_tst,
                                g=self.g_tst,
                                f_len=self.f_len,
                                g_len=self.g_len,
                                batch_offset=batch_offset,
                                packed_batch=batch_offset[-1])
            h_tst.backward(self.h_grad_packed)
            if dropout:
                mask_packed = my_joint.mask_probe[0]
                mask = self._unpack(mask_packed, self.f_len, self.g_len)

        # reference
        h_ref, f_grad_ref, g_grad_ref \
            = transducer_ref.transducer_joint_reference(f=f_ref,
                                                        g=g_ref,
                                                        h_grad=self.h_grad,
                                                        f_len=self.f_len,
                                                        g_len=self.g_len,
                                                        pack_output=pack_output,
                                                        relu=relu,
                                                        dropout=dropout,
                                                        dropout_prob=self.dropout_prob,
                                                        mask=mask if dropout else None)

        f_grad_tst = self.f_tst.grad
        g_grad_tst = self.g_tst.grad

        torch.testing.assert_close(h_ref, h_tst, atol=1e-5, rtol=1e-5)
        torch.testing.assert_close(f_grad_ref, f_grad_tst, atol=5e-5, rtol=1e-3)
        torch.testing.assert_close(g_grad_ref, g_grad_tst, atol=1e-3, rtol=1e-3)

    def test_transducer_joint(self):
        self.run_transducer_joint(for_vector_kernel=True, pack_output=True, relu=False, dropout=False)

    def test_transducer_joint_vec(self):
        self.run_transducer_joint(for_vector_kernel=True, pack_output=False, relu=False, dropout=False)

    def test_transducer_joint_pack(self):
        self.run_transducer_joint(for_vector_kernel=False, pack_output=True, relu=False, dropout=False)

    def test_transducer_joint_vec_pack(self):
        self.run_transducer_joint(for_vector_kernel=True, pack_output=True, relu=False, dropout=False)

    def test_transducer_joint_relu(self):
        self.run_transducer_joint(for_vector_kernel=True, pack_output=True, relu=True, dropout=False)

    def test_transducer_joint_vec_relu(self):
        self.run_transducer_joint(for_vector_kernel=True, pack_output=False, relu=True, dropout=False)

    def test_transducer_joint_pack_relu(self):
        self.run_transducer_joint(for_vector_kernel=False, pack_output=True, relu=True, dropout=False)

    def test_transducer_joint_vec_pack_relu(self):
        self.run_transducer_joint(for_vector_kernel=True, pack_output=True, relu=True, dropout=False)

    @unittest.expectedFailure
    def test_transducer_joint_relu_dropout(self):
        self.run_transducer_joint(for_vector_kernel=True, pack_output=True, relu=True, dropout=True)

    @unittest.expectedFailure
    def test_transducer_joint_vec_relu_dropout(self):
        self.run_transducer_joint(for_vector_kernel=True, pack_output=False, relu=True, dropout=True)

    @unittest.expectedFailure
    def test_transducer_joint_pack_relu_dropout(self):
        self.run_transducer_joint(for_vector_kernel=False, pack_output=True, relu=True, dropout=True)

    @unittest.expectedFailure
    def test_transducer_joint_vec_pack_relu_dropout(self):
        self.run_transducer_joint(for_vector_kernel=True, pack_output=True, relu=True, dropout=True)


if __name__ == '__main__':
    unittest.main()
