import math
import torch



def zero_grads(params):
    for p in params:
        if p.grad is not None:
            p.grad.zero_()

def make_symmetric(H):
    diag = torch.diagonal(H)
    triu = torch.triu(H, 1)
    return triu + triu.t() + torch.diag(diag)

def compute_hessian(params, score):
    grad = torch.autograd.grad(score, params, create_graph=True)
    grad = torch.cat([g.view(-1) for g in grad])
    H = []
    for dx in grad:
        d2x = torch.autograd.grad(dx, params, retain_graph=True)
        d2x = torch.cat([h.view(-1) for h in d2x])
        H.append(d2x)
    H = torch.stack(H)
    return H

def get_parse_laplacians(model, parses, img):
    nll = model.losses_fn(parses, img) # compute negative log-likelihoods
    A_list = []
    for k in range(len(parses)):
        params = parses[k].stroke_params
        zero_grads(params)
        A = compute_hessian(params, nll[k]) # A is the mvn precision matrix
        A = make_symmetric(A) # make sure it is symmetric
        A_list.append(A)
    ll = -nll # convert to log-likelihoods
    return A_list, ll

def logdet_fn(A):
    eigv, _ = torch.symeig(A)
    eigv = eigv[eigv > 0]
    logd = torch.sum(torch.log(eigv))
    return logd, len(eigv)

def mvn_lognorm(A):
    # compute log(Z) where Z is the normalizing constant
    # of a MVN with precision "A"
    logdet, d = logdet_fn(A)
    lognorm = 0.5*d*math.log(2*math.pi) - 0.5*logdet
    return lognorm

def log_marginal_fn(model, parses, img):
    A_list, log_joints = get_parse_laplacians(model, parses, img)
    log_norms = torch.stack([mvn_lognorm(A) for A in A_list]) # log of normalizing constant for each parse
    log_joints = log_joints.detach().cpu()
    log_marginal = torch.logsumexp(log_joints + log_norms, 0)
    return log_marginal