import torch
import unittest
import numpy



################################################################################
# Add some extra assertion functions to the unittest.TestCase class
class TestCase(unittest.TestCase):
    def setUp(self):
        super().setUp()
        self.addTypeEqualityFunc(numpy.ndarray, self.assert_numpy_equal)
        self.addTypeEqualityFunc(torch.Tensor, self.assert_torch_equal)

    def assert_numpy_equal(self, lhs, rhs, msg = None):
        if not numpy.array_equal(lhs, rhs):
            if msg is None:
                msg = "\n{}\n\t!=\n{}".format(lhs, rhs)
            raise self.failureException(msg)
            
    def assert_torch_equal(self, lhs, rhs, msg = None):
        if not torch.all(torch.eq(lhs, rhs)):
            if msg is None:
                msg = "\n{}\n\t!=\n{}".format(lhs, rhs)
            raise self.failureException(msg)

    def assert_torch_allclose(self, lhs, rhs, msg = None, **kwds):
        if not torch.allclose(lhs, rhs, **kwds):
            if msg is None:
                msg = "\n{}\n\t!=\n{}".format(lhs, rhs)
            raise self.failureException(msg)

################################################################################
# Poolings with padded sequences

def create_mask_0_1(original_lengths: torch.Tensor, max_length: int) -> torch.Tensor:
    indices = torch.arange(start = 0, end = max_length, step = 1)
    return indices.view(1, -1) < original_lengths.view(-1,1)
