import torch


def divergence_bf(f, z):
    sum_diag = 0.
        
    for i in range(z.shape[1]):
        """
        Naive implementation:
        to compute each diagonal term Jacobian {\partial f_i}{\partial z_i}
        create_graph=True for second order derivative! d(Tr_df_dz)/dz
        """
        sum_diag += torch.autograd.grad(f[:, i].sum(), z, create_graph=True)[0].contiguous()[:, i].contiguous()

    return sum_diag.contiguous()

def divergence_approx(f, z, e=None):
    e_dzdx = torch.autograd.grad(f, z, e, create_graph=True)[0]
    e_dzdx_e = e_dzdx * e
    approx_tr_dzdx = e_dzdx_e.view(z.shape[0], -1).sum(dim=1)
    return approx_tr_dzdx

def trace_df_dz(f, z, method="naive", e=None):
    """Calculates the trace of the Jacobian df/dz.
    from: https://github.com/rtqichen/ffjord/blob/master/lib/layers/odefunc.py#L13
    f: dz/dt [batch_size, z_dim]
    z: [batch_size, z_dim]
    Return Tr(df/dz): [batch_size]
    """
    assert f.shape == z.shape
    
    if method == "naive":
        return divergence_bf(f, z).view(f.shape[0], 1)

    elif method == "hutchinson":
        # modified from ffjord
        if e is not None and e.shape == z.shape:
            return divergence_approx(f, z, e=e).view(f.shape[0], 1)
        else:
            return divergence_approx(f, z, e=torch.randn_like(z)).view(f.shape[0], 1)

        # raise ValueError("Hutchinson method is not implemented yet")
    elif method == "hollownet":
        raise ValueError("HollowNet method is not implemented yet")
    else:
        raise ValueError("method only support for ['naive', 'hutchinson', 'hollownet']")

def jacobian_df_dz(f, z):
    """ a function wrapper to compute Jacobian 
        (pytorch implementationrequires to give python func)
        using torch.autograd.grad
        Args:
            f: dz/dt [batch_size, z_dim]
            z: [batch_size, z_dim]
        Return jacobian: [batch_size, z_dim, z_dim]
    """

    assert f.shape == z.shape
    batch_size = f.shape[0]
    z_dim = f.shape[-1]
    jacobian = torch.zeros(batch_size, z_dim, z_dim).to(f.device)
    for ii in range(z_dim):        
        # ii-th row w.r.t. all z elements
        jacobian[:,ii,:] = torch.autograd.grad(f[:, ii].sum(), z, create_graph = True)[0].contiguous()

    return jacobian.contiguous()

def grad_trace_df_dz(f, z, method="naive", e = None):
    """ Calculate the gradient of trace of the Jacobian grad(Tr(df/dz))=grad(div(f))
        Args:
            f: dz/dt [batch_size, z_dim]
            z: [batch_size, z_dim]
        Return grad_div_f: [batch_size, z_dim]
    """

    assert f.shape == z.shape
    div_f = trace_df_dz(f, z, method=method, e = e)
    # [batch_size, z_dim]
    grad_div_f = torch.autograd.grad(div_f.sum(), z, create_graph=True, allow_unused=True)[0]
    if grad_div_f is None:
        # when div_f is not related to z e.g. linear case, return 0
        return torch.zeros_like(z)
    else:
        assert grad_div_f.shape == z.shape
        return grad_div_f.contiguous() 
