import torch
from torch import nn

from utils import flatten

T = torch.Tensor

__all__ = ["full_hessian", "jacobian"]


# Hessian and Jacobian code from: https://gist.github.com/apaszke/226abdf867c4e9d6698bd198f3b45fb7
def jacobian(y: T, x: T, create_graph: bool = False) -> T:
    jac = []
    flat_y = y.reshape(-1)
    grad_y = torch.zeros_like(flat_y)
    for i in range(len(flat_y)):
        grad_y[i] = 1.0
        grad_x, = torch.autograd.grad(flat_y, x, grad_y, retain_graph=True, create_graph=create_graph)
        jac.append(grad_x.reshape(x.shape))
        grad_y[i] = 0.0

    return torch.stack(jac).reshape(y.shape + x.shape)


# inspired by the implementation here: https://gist.github.com/apaszke/226abdf867c4e9d6698bd198f3b45fb7
# the implementation at the above link cannot handle the full hessian of the entire network.
# the full hessian will only work on the smallest networks, and most networks will require the
# Kronecker Factored Hessians
def full_hessian(loss: T, module: nn.Module, retain_graph: bool = True) -> T:
    flat_jac = flatten(torch.autograd.grad(loss, module.parameters(), retain_graph=True, create_graph=True))  # type: ignore

    # for each scalar parameter in the jacobian, get the gradients of all the parameters $\partial J / partial w_i w_j$
    hessian = []
    for i, v in enumerate(flat_jac):
        h = flatten(torch.autograd.grad(v, module.parameters(), retain_graph=retain_graph))  # type: ignore
        hessian.append(h)

    return torch.stack(hessian)
