import torch
from torch.autograd import Function, Variable
from matrix_sqrt_utils import sqrt_newton_schulz, lyap_newton_schulz


class MatrixSquareRoot(Function):
    """Square root of a positive definite matrix.
    NOTE: matrix square root is not differentiable for matrices with
          zero eigenvalues.
          for excellent discussion of this, see page 4 of http://vis-www.cs.umass.edu/bcnn/docs/improved_bcnn.pdf
    """
    @staticmethod
    def forward(ctx, input):
        # Newton Schulz iterations for forward step
        sA, error = sqrt_newton_schulz(input, numIters=10, dtype=input.dtype)
        ctx.save_for_backward(sA)
        return sA

    @staticmethod
    def backward(ctx, grad_output):
        # Gradients by iterative Lyapunov solver
        z, = ctx.saved_tensors
        dlda = lyap_newton_schulz(z, dldz=grad_output, numIters=10, dtype=grad_output.dtype)
        return dlda

# unclear why this is done
sqrtm = MatrixSquareRoot.apply


def tests():
    from matrix_sqrt_utils import create_symm_matrix
    from torch.autograd import gradcheck

    # perform reccomended step to ensure gradients found by finite-difference method matches those returned
    # from backward()
    a = create_symm_matrix(batchSize=2, dim=3, numPts=5, tau=1.0, dtype=torch.double)
    a = a.clone().detach().requires_grad_(True)# torch.tensor(a, requires_grad=True)
    assert gradcheck(MatrixSquareRoot.apply, a)

    # sanity test 1
    dldz = torch.rand(2, 3, 3).requires_grad_(True) # this is what would be returned from autograd
    a = a_sqrt = torch.eye(3).expand(2, 3, 3).requires_grad_(True) # this is from the forward prop
    z = sqrtm(a)
    a.retain_grad()
    z.backward(dldz) # argument to backward is what auto-grad would feed back to it: dldz.
    dlda = a.grad
    assert torch.allclose(0.5*dldz, dlda), f'when input a is idenitity, dlda_sqrt == 0.5 * dldz'



if __name__ == '__main__':
    tests()