import torch
import torch.nn as nn

class RegularizedODEfunc(nn.Module):
    def __init__(self, odefunc):
        super(RegularizedODEfunc, self).__init__()
        self.odefunc = odefunc
        self.regularization_fns = (quadratic_cost, jacobian_frobenius_regularization_fn)
        self.nfe = self.odefunc.nfe
        self.input_dim = self.odefunc.input_dim
    
    def forward(self, t, state):
        self.nfe += 1
        with torch.enable_grad():
            x = state[0]
            x.requires_grad_(True)
            t.requires_grad_(True)
            dstate = self.odefunc(t, x)
            if len(state) > 1:
                dx = dstate[0]
                reg_states = tuple(reg_fn(x, t, dx, self.odefunc) for reg_fn in self.regularization_fns)
                return dstate + reg_states
            else:
                # for inference. No need to compute the regularizations
                return dstate

    @property
    def _num_evals(self):
        return self.odefunc._num_evals

# kinetic energy
def quadratic_cost(x, t, dx, unused_context):
    del x, t, unused_context
    dx = dx.view(dx.shape[0], -1)
    return 0.5*dx.pow(2).mean(dim=-1)

# Jacobian frobenius norm
def jacobian_frobenius_regularization_fn(x, t, dx, context):
    sh = x.shape
    del t, dx, x
    sqjac = context.sqjacnorm

    return context.sqjacnorm

def divergence_bf(dx, y, **unused_kwargs):
    sum_diag = 0.
    for i in range(y.shape[1]):
        sum_diag += torch.autograd.grad(dx[:, i].sum(), y, create_graph=True)[0].contiguous()[:, i].contiguous()
    return sum_diag.contiguous()

def divergence_approx(f, y, e=None):

    samples = []
    sqnorms = []
    for e_ in e:
        e_dzdx = torch.autograd.grad(f, y, e_, create_graph=True)[0]
        n = e_dzdx.view(y.size(0),-1).pow(2).mean(dim=1, keepdim=True)
        sqnorms.append(n)
        e_dzdx_e = e_dzdx * e_
        samples.append(e_dzdx_e.view(y.shape[0], -1).sum(dim=1, keepdim=True))

    S = torch.cat(samples, dim=1)
    approx_tr_dzdx = S.mean(dim=1)

    N = torch.cat(sqnorms, dim=1).mean(dim=1)

    return approx_tr_dzdx, N