import unittest
import torch


class MyTestCase(unittest.TestCase):
    def test_something(self):
        N = 5
        p = torch.nn.Parameter(torch.ones(1))
        A = torch.randn([N, N])

        def chol1(A):
            A = p * A.T @ A
            L = torch.linalg.cholesky(A)
            return torch.sum(L)
        f1 = chol1(A)
        f1.backward()
        print(f1)
        print(p.grad)
        grad = p.grad.clone()
        p.grad = torch.zeros_like(p.grad)
        def chol2(A):
            A = p * A.T @ A
            L = torch.linalg.cholesky(torch.tril(A))
            return torch.sum(L)
        f2 = chol2(A)
        f2.backward()
        print(f2)
        print(p.grad)
        self.assertAlmostEqual(p.grad.item(), grad.item())
        # TODO: Dear torch developers ...


if __name__ == '__main__':
    unittest.main()
