import unittest

import torch
from torch import Tensor


class TestCase(unittest.TestCase):
    def assertEqualTensor(self, a: Tensor | float, b: Tensor | float, msg: str | None = None):
        if not torch.all(a == b):  # type: ignore
            self.fail(msg or f"Tensors are not equal:\n{a}\n !=\n{b}")

    def assertShape(self, arr: Tensor, shape: tuple[int, ...]):
        if arr.shape != shape:
            self.fail(f"Tensor shape is {arr.shape}, but expected {shape}.")

    def assertAlmostEqual(self, a: Tensor | float, b: Tensor | float, delta: float = 1e-3, msg: str | None = None):
        if not torch.all(torch.abs(a - b) < delta):  # type: ignore
            max_delta = torch.max(torch.abs(a - b))  # type: ignore
            self.fail(f"Tensors are not equal. Allowed delta: {delta}, maximum delta: {max_delta}.\n{a}\n != (within delta)\n{b}")

    def assertDtypesEqual(self, a: Tensor, b: Tensor):
        if not a.dtype == b.dtype:
            self.fail(f"Tensor dtypes are not equal. {a.dtype} != {b.dtype}")
