import copy
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F

from . import diffeq_layers
from .squeeze import squeeze, unsqueeze

__all__ = ["ODEnet", "CODEnet", "CIFARCODEnet", "AutoencoderDiffEqNet", "ODEfunc", "AutoencoderODEfunc"]


def divergence_bf(dx, y, **unused_kwargs):
    sum_diag = 0.
    for i in range(y.shape[1]):
        sum_diag += torch.autograd.grad(dx[:, i].sum(), y, create_graph=True)[0].contiguous()[:, i].contiguous()
    return sum_diag.contiguous()


# def divergence_bf(f, y, **unused_kwargs):
#     jac = _get_minibatch_jacobian(f, y)
#     diagonal = jac.view(jac.shape[0], -1)[:, ::jac.shape[1]]
#     return torch.sum(diagonal, 1)


def _get_minibatch_jacobian(y, x):
    """Computes the Jacobian of y wrt x assuming minibatch-mode.

    Args:
      y: (N, ...) with a total of D_y elements in ...
      x: (N, ...) with a total of D_x elements in ...
    Returns:
      The minibatch Jacobian matrix of shape (N, D_y, D_x)
    """
    assert y.shape[0] == x.shape[0]
    y = y.view(y.shape[0], -1)

    # Compute Jacobian row by row.
    jac = []
    for j in range(y.shape[1]):
        dy_j_dx = torch.autograd.grad(y[:, j], x, torch.ones_like(y[:, j]), retain_graph=True,
                                      create_graph=True)[0].view(x.shape[0], -1)
        jac.append(torch.unsqueeze(dy_j_dx, 1))
    jac = torch.cat(jac, 1)
    return jac


def divergence_approx(f, y, e=None):
    e_dzdx = torch.autograd.grad(f, y, e, create_graph=True)[0]
    e_dzdx_e = e_dzdx * e
    approx_tr_dzdx = e_dzdx_e.view(y.shape[0], -1).sum(dim=1)
    return approx_tr_dzdx


def sample_rademacher_like(y):
    return torch.randint(low=0, high=2, size=y.shape).to(y) * 2 - 1


def sample_gaussian_like(y):
    return torch.randn_like(y)


class Swish(nn.Module):

    def __init__(self):
        super(Swish, self).__init__()
        self.beta = nn.Parameter(torch.tensor(1.0))

    def forward(self, x):
        return x * torch.sigmoid(self.beta * x)


class Lambda(nn.Module):

    def __init__(self, f):
        super(Lambda, self).__init__()
        self.f = f

    def forward(self, x):
        return self.f(x)


NONLINEARITIES = {
    "tanh": nn.Tanh(),
    "relu": nn.ReLU(),
    "softplus": nn.Softplus(),
    "elu": nn.ELU(),
    "swish": Swish(),
    "square": Lambda(lambda x: x**2),
    "identity": Lambda(lambda x: x),
}


class ODEnet(nn.Module):
    """
    Helper class to make neural nets for use in continuous normalizing flows
    """

    def __init__(
        self, hidden_dims, path_dims, input_shape, strides, conv, layer_type="concat", nonlinearity="softplus", num_squeeze=0
    ):
        super(ODEnet, self).__init__()
        self.num_squeeze = num_squeeze
        if conv:
            assert len(strides) == len(hidden_dims) + 1
            base_layer = {
                "ignore": diffeq_layers.IgnoreConv2d,
                "hyper": diffeq_layers.HyperConv2d,
                "squash": diffeq_layers.SquashConv2d,
                "concat": diffeq_layers.ConcatConv2d,
                "concat_v2": diffeq_layers.ConcatConv2d_v2,
                "concatsquash": diffeq_layers.ConcatSquashConv2d,
                "blend": diffeq_layers.BlendConv2d,
                "concatcoord": diffeq_layers.ConcatCoordConv2d,
            }[layer_type]
        else:
            strides = [None] * (len(hidden_dims) + 1)
            base_layer = {
                "ignore": diffeq_layers.IgnoreLinear,
                "hyper": diffeq_layers.HyperLinear,
                "squash": diffeq_layers.SquashLinear,
                "concat": diffeq_layers.ConcatLinear,
                "concat_v2": diffeq_layers.ConcatLinear_v2,
                "concatsquash": diffeq_layers.ConcatSquashLinear,
                "blend": diffeq_layers.BlendLinear,
                "concatcoord": diffeq_layers.ConcatLinear,
            }[layer_type]

        # build layers and add them
        layers = []
        activation_fns = []
        hidden_shape = input_shape

        for dim_out, stride in zip(hidden_dims + (input_shape[0],), strides):
            if stride is None:
                layer_kwargs = {}
            elif stride == 1:
                layer_kwargs = {"ksize": 3, "stride": 1, "padding": 1, "transpose": False}
            elif stride == 2:
                layer_kwargs = {"ksize": 4, "stride": 2, "padding": 1, "transpose": False}
            elif stride == -2:
                layer_kwargs = {"ksize": 4, "stride": 2, "padding": 1, "transpose": True}
            else:
                raise ValueError('Unsupported stride: {}'.format(stride))

            layer = base_layer(hidden_shape[0], dim_out, **layer_kwargs)
            layers.append(layer)
            activation_fns.append(NONLINEARITIES[nonlinearity])

            hidden_shape = list(copy.copy(hidden_shape))
            hidden_shape[0] = dim_out
            if stride == 2:
                hidden_shape[1], hidden_shape[2] = hidden_shape[1] // 2, hidden_shape[2] // 2
            elif stride == -2:
                hidden_shape[1], hidden_shape[2] = hidden_shape[1] * 2, hidden_shape[2] * 2

        self.layers = nn.ModuleList(layers)
        self.activation_fns = nn.ModuleList(activation_fns[:-1])

    def forward(self, t, y):
        dx = y
        # squeeze
        for _ in range(self.num_squeeze):
            dx = squeeze(dx, 2)
        for l, layer in enumerate(self.layers):
            dx = layer(t, dx)
            # if not last layer, use nonlinearity
            if l < len(self.layers) - 1:
                dx = self.activation_fns[l](dx)
        # unsqueeze
        for _ in range(self.num_squeeze):
            dx = unsqueeze(dx, 2)
        return dx

class CODEnet(nn.Module):
    """
    Helper class to make neural nets for use in continuous normalizing flows
    """

    def __init__(
        self, hidden_dims, path_dims, input_shape, strides, conv, layer_type="concat", nonlinearity="softplus", num_squeeze=0
    ):
        super(CODEnet, self).__init__()
        self.num_squeeze = num_squeeze
        if conv:
            assert len(strides) == len(hidden_dims) + 1
            base_layer = {
                "ignore": diffeq_layers.IgnoreConv2d,
                "hyper": diffeq_layers.HyperConv2d,
                "squash": diffeq_layers.SquashConv2d,
                "concat": diffeq_layers.ConcatConv2d,
                "concat_v2": diffeq_layers.ConcatConv2d_v2,
                "concatsquash": diffeq_layers.ConcatSquashConv2d,
                "blend": diffeq_layers.BlendConv2d,
                "concatcoord": diffeq_layers.ConcatCoordConv2d,
            }[layer_type]
        else:
            strides = [None] * (len(hidden_dims) + 1)
            base_layer = {
                "ignore": diffeq_layers.IgnoreLinear,
                "hyper": diffeq_layers.HyperLinear,
                "squash": diffeq_layers.SquashLinear,
                "concat": diffeq_layers.ConcatLinear,
                "concat_v2": diffeq_layers.ConcatLinear_v2,
                "concatsquash": diffeq_layers.ConcatSquashLinear,
                "blend": diffeq_layers.BlendLinear,
                "concatcoord": diffeq_layers.ConcatLinear,
            }[layer_type]

        # build layers on first dimension and add them
        layers1 = []
        activation_fns1 = []
        hidden_shape1 = input_shape

        for dim_out, stride in zip(hidden_dims + (input_shape[0],), strides):
            if stride is None:
                layer_kwargs = {}
            elif stride == 1:
                layer_kwargs = {"ksize": 3, "stride": 1, "padding": 1, "transpose": False}
            elif stride == 2:
                layer_kwargs = {"ksize": 4, "stride": 2, "padding": 1, "transpose": False}
            elif stride == -2:
                layer_kwargs = {"ksize": 4, "stride": 2, "padding": 1, "transpose": True}
            else:
                raise ValueError('Unsupported stride: {}'.format(stride))

            layer1 = base_layer(hidden_shape1[0], dim_out, **layer_kwargs)
            layers1.append(layer1)
            activation_fns1.append(NONLINEARITIES[nonlinearity])

            hidden_shape1 = list(copy.copy(hidden_shape1))
            hidden_shape1[0] = dim_out
            if stride == 2:
                hidden_shape1[1], hidden_shape1[2] = hidden_shape1[1] // 2, hidden_shape1[2] // 2
            elif stride == -2:
                hidden_shape1[1], hidden_shape1[2] = hidden_shape1[1] * 2, hidden_shape1[2] * 2

        self.layers1 = nn.ModuleList(layers1)
        self.activation_fns1 = nn.ModuleList(activation_fns1[:-1])

        # build layers on second dimension and add them
        layers2 = []
        activation_fns2 = []
        hidden_shape2 = input_shape

        for dim_out, stride in zip(hidden_dims + (input_shape[0],), strides):
            if stride is None:
                layer_kwargs = {}
            elif stride == 1:
                layer_kwargs = {"ksize": 3, "stride": 1, "padding": 1, "transpose": False}
            elif stride == 2:
                layer_kwargs = {"ksize": 4, "stride": 2, "padding": 1, "transpose": False}
            elif stride == -2:
                layer_kwargs = {"ksize": 4, "stride": 2, "padding": 1, "transpose": True}
            else:
                raise ValueError('Unsupported stride: {}'.format(stride))

            layer2 = base_layer(hidden_shape2[0], dim_out, **layer_kwargs)
            layers2.append(layer2)
            activation_fns2.append(NONLINEARITIES[nonlinearity])

            hidden_shape2 = list(copy.copy(hidden_shape2))
            hidden_shape2[0] = dim_out
            if stride == 2:
                hidden_shape2[1], hidden_shape2[2] = hidden_shape2[1] // 2, hidden_shape2[2] // 2
            elif stride == -2:
                hidden_shape2[1], hidden_shape2[2] = hidden_shape2[1] * 2, hidden_shape2[2] * 2

        self.layers2 = nn.ModuleList(layers2)
        self.activation_fns2 = nn.ModuleList(activation_fns2[:-1])

        self.path = nn.Sequential( # define the network for the integration path
        nn.Conv2d(input_shape[0]+1,path_dims, 3, padding=1, bias=False),
        nn.SiLU(),
        nn.Conv2d(path_dims,path_dims, 3, padding=1, bias=False),
        nn.SiLU(),
        nn.Conv2d(path_dims,3,1),
        nn.Flatten(),
        nn.Linear(input_shape[1]*input_shape[1]*3,2),
        #nn.Tanh()
        )

    def forward(self, t, y):
        dx1 = y
        dx2 = y
        t_input = t.expand(y.size(0),1) # resize
        t_channel = ((t_input.view(y.size(0),1,1)).expand(y.size(0),1,y.size(2)*y.size(3))).view(y.size(0),1,y.size(2),y.size(3)) # resize
        path_input = torch.cat((t_channel, y),dim=1) # concatenate the time and the image
        # squeeze
        for _ in range(self.num_squeeze):
            dx1 = squeeze(dx1, 2)
            dx2 = squeeze(dx2, 2)
        for l1, layer1 in enumerate(self.layers1):
            dx1 = layer1(t, dx1)
            # if not last layer, use nonlinearity
            if l1 < len(self.layers1) - 1:
                dx1 = self.activation_fns1[l1](dx1)
        for l2, layer2 in enumerate(self.layers2):
            dx2 = layer2(t, dx2)
            # if not last layer, use nonlinearity
            if l2 < len(self.layers2) - 1:
                dx2 = self.activation_fns2[l2](dx2)
        g_h_i = self.path(path_input)
        dg_dt = g_h_i[:,0].view(g_h_i.size(0),1,1) # resize 
        dg_dt = dg_dt.expand(dg_dt.size(0),1,y.size(2)*y.size(3)) # resize 
        dg_dt = dg_dt.view(dg_dt.size(0),1,y.size(2),y.size(3)) # resize 
        dh_dt = g_h_i[:,1].view(g_h_i.size(0),1,1) # resize 
        dh_dt = dh_dt.expand(dh_dt.size(0),1,y.size(2)*y.size(3)) # resize 
        dh_dt = dh_dt.view(dh_dt.size(0),1,y.size(2),y.size(3)) # resize 
        dx = torch.mul(dx1,dg_dt) + torch.mul(dx2,dh_dt)
        # unsqueeze
        for _ in range(self.num_squeeze):
            dx = unsqueeze(dx, 2)
        return dx
    
    
class CIFARCODEnet(nn.Module):
    """
    Helper class to make neural nets for use in continuous normalizing flows
    """

    def __init__(
        self, hidden_dims, path_dims, input_shape, strides, conv, layer_type="concat", nonlinearity="softplus", num_squeeze=0
    ):
        super(CIFARCODEnet, self).__init__()
        self.num_squeeze = num_squeeze
        if conv:
            assert len(strides) == len(hidden_dims) + 1
            base_layer = {
                "ignore": diffeq_layers.IgnoreConv2d,
                "hyper": diffeq_layers.HyperConv2d,
                "squash": diffeq_layers.SquashConv2d,
                "concat": diffeq_layers.ConcatConv2d,
                "concat_v2": diffeq_layers.ConcatConv2d_v2,
                "concatsquash": diffeq_layers.ConcatSquashConv2d,
                "blend": diffeq_layers.BlendConv2d,
                "concatcoord": diffeq_layers.ConcatCoordConv2d,
            }[layer_type]
        else:
            strides = [None] * (len(hidden_dims) + 1)
            base_layer = {
                "ignore": diffeq_layers.IgnoreLinear,
                "hyper": diffeq_layers.HyperLinear,
                "squash": diffeq_layers.SquashLinear,
                "concat": diffeq_layers.ConcatLinear,
                "concat_v2": diffeq_layers.ConcatLinear_v2,
                "concatsquash": diffeq_layers.ConcatSquashLinear,
                "blend": diffeq_layers.BlendLinear,
                "concatcoord": diffeq_layers.ConcatLinear,
            }[layer_type]

        # build layers on first dimension and add them
        layers1 = []
        activation_fns1 = []
        hidden_shape1 = input_shape

        for dim_out, stride in zip(hidden_dims + (input_shape[0],), strides):
            if stride is None:
                layer_kwargs = {}
            elif stride == 1:
                layer_kwargs = {"ksize": 3, "stride": 1, "padding": 1, "transpose": False}
            elif stride == 2:
                layer_kwargs = {"ksize": 4, "stride": 2, "padding": 1, "transpose": False}
            elif stride == -2:
                layer_kwargs = {"ksize": 4, "stride": 2, "padding": 1, "transpose": True}
            else:
                raise ValueError('Unsupported stride: {}'.format(stride))

            layer1 = base_layer(hidden_shape1[0], dim_out, **layer_kwargs)
            layers1.append(layer1)
            activation_fns1.append(NONLINEARITIES[nonlinearity])

            hidden_shape1 = list(copy.copy(hidden_shape1))
            hidden_shape1[0] = dim_out
            if stride == 2:
                hidden_shape1[1], hidden_shape1[2] = hidden_shape1[1] // 2, hidden_shape1[2] // 2
            elif stride == -2:
                hidden_shape1[1], hidden_shape1[2] = hidden_shape1[1] * 2, hidden_shape1[2] * 2

        self.layers1 = nn.ModuleList(layers1)
        self.activation_fns1 = nn.ModuleList(activation_fns1[:-1])

        # build layers on second dimension and add them
        layers2 = []
        activation_fns2 = []
        hidden_shape2 = input_shape

        for dim_out, stride in zip(hidden_dims + (input_shape[0],), strides):
            if stride is None:
                layer_kwargs = {}
            elif stride == 1:
                layer_kwargs = {"ksize": 3, "stride": 1, "padding": 1, "transpose": False}
            elif stride == 2:
                layer_kwargs = {"ksize": 4, "stride": 2, "padding": 1, "transpose": False}
            elif stride == -2:
                layer_kwargs = {"ksize": 4, "stride": 2, "padding": 1, "transpose": True}
            else:
                raise ValueError('Unsupported stride: {}'.format(stride))

            layer2 = base_layer(hidden_shape2[0], dim_out, **layer_kwargs)
            layers2.append(layer2)
            activation_fns2.append(NONLINEARITIES[nonlinearity])

            hidden_shape2 = list(copy.copy(hidden_shape2))
            hidden_shape2[0] = dim_out
            if stride == 2:
                hidden_shape2[1], hidden_shape2[2] = hidden_shape2[1] // 2, hidden_shape2[2] // 2
            elif stride == -2:
                hidden_shape2[1], hidden_shape2[2] = hidden_shape2[1] * 2, hidden_shape2[2] * 2

        self.layers2 = nn.ModuleList(layers2)
        self.activation_fns2 = nn.ModuleList(activation_fns2[:-1])

        # build layers on third dimension and add them
        layers3 = []
        activation_fns3 = []
        hidden_shape3 = input_shape

        for dim_out, stride in zip(hidden_dims + (input_shape[0],), strides):
            if stride is None:
                layer_kwargs = {}
            elif stride == 1:
                layer_kwargs = {"ksize": 3, "stride": 1, "padding": 1, "transpose": False}
            elif stride == 2:
                layer_kwargs = {"ksize": 4, "stride": 2, "padding": 1, "transpose": False}
            elif stride == -2:
                layer_kwargs = {"ksize": 4, "stride": 2, "padding": 1, "transpose": True}
            else:
                raise ValueError('Unsupported stride: {}'.format(stride))

            layer3 = base_layer(hidden_shape3[0], dim_out, **layer_kwargs)
            layers3.append(layer3)
            activation_fns3.append(NONLINEARITIES[nonlinearity])

            hidden_shape3 = list(copy.copy(hidden_shape3))
            hidden_shape3[0] = dim_out
            if stride == 2:
                hidden_shape3[1], hidden_shape3[2] = hidden_shape3[1] // 2, hidden_shape3[2] // 2
            elif stride == -2:
                hidden_shape3[1], hidden_shape3[2] = hidden_shape3[1] * 2, hidden_shape3[2] * 2

        self.layers3 = nn.ModuleList(layers3)
        self.activation_fns3 = nn.ModuleList(activation_fns3[:-1])
        
        
        self.path = nn.Sequential( # define the network for the integration path
        nn.Conv2d(input_shape[0]+1,path_dims, 3, padding=1, bias=False),
        nn.Tanh(),
        nn.Conv2d(path_dims,path_dims, 3, padding=1, bias=False),
        nn.Tanh(),
        nn.Conv2d(path_dims,3,1),
        nn.Flatten(),
        nn.Linear(input_shape[1]*input_shape[1]*3,3),
        #nn.Tanh()
        )

    def forward(self, t, y):
        dx1 = y
        dx2 = y
        dx3 = y
        t_input = t.expand(y.size(0),1) # resize
        t_channel = ((t_input.view(y.size(0),1,1)).expand(y.size(0),1,y.size(2)*y.size(3))).view(y.size(0),1,y.size(2),y.size(3)) # resize
        path_input = torch.cat((t_channel, y),dim=1) # concatenate the time and the image
        # squeeze
        for _ in range(self.num_squeeze):
            dx1 = squeeze(dx1, 2)
            dx2 = squeeze(dx2, 2)
            dx3 = squeeze(dx3, 2)
        for l1, layer1 in enumerate(self.layers1):
            dx1 = layer1(t, dx1)
            # if not last layer, use nonlinearity
            if l1 < len(self.layers1) - 1:
                dx1 = self.activation_fns1[l1](dx1)
        for l2, layer2 in enumerate(self.layers2):
            dx2 = layer2(t, dx2)
            # if not last layer, use nonlinearity
            if l2 < len(self.layers2) - 1:
                dx2 = self.activation_fns2[l2](dx2)
        for l3, layer3 in enumerate(self.layers3):
            dx3 = layer3(t, dx3)
            # if not last layer, use nonlinearity
            if l3 < len(self.layers3) - 1:
                dx3 = self.activation_fns3[l3](dx3)
        g_h_i = self.path(path_input)
        dg_dt = g_h_i[:,0].view(g_h_i.size(0),1,1) # resize 
        dg_dt = dg_dt.expand(dg_dt.size(0),1,y.size(2)*y.size(3)) # resize 
        dg_dt = dg_dt.view(dg_dt.size(0),1,y.size(2),y.size(3)) # resize 
        dh_dt = g_h_i[:,1].view(g_h_i.size(0),1,1) # resize 
        dh_dt = dh_dt.expand(dh_dt.size(0),1,y.size(2)*y.size(3)) # resize 
        dh_dt = dh_dt.view(dh_dt.size(0),1,y.size(2),y.size(3)) # resize 
        di_dt = g_h_i[:,2].view(g_h_i.size(0),1,1) # resize 
        di_dt = di_dt.expand(di_dt.size(0),1,y.size(2)*y.size(3)) # resize 
        di_dt = di_dt.view(di_dt.size(0),1,y.size(2),y.size(3)) # resize 
        dx = torch.mul(dx1,dg_dt) + torch.mul(dx2,dh_dt) + torch.mul(dx3,di_dt)
        # unsqueeze
        for _ in range(self.num_squeeze):
            dx = unsqueeze(dx, 2)
        return dx

class AutoencoderDiffEqNet(nn.Module):
    """
    Helper class to make neural nets for use in continuous normalizing flows
    """

    def __init__(self, hidden_dims, input_shape, strides, conv, layer_type="concat", nonlinearity="softplus"):
        super(AutoencoderDiffEqNet, self).__init__()
        assert layer_type in ("ignore", "hyper", "concat", "concatcoord", "blend")
        assert nonlinearity in ("tanh", "relu", "softplus", "elu")

        self.nonlinearity = {"tanh": F.tanh, "relu": F.relu, "softplus": F.softplus, "elu": F.elu}[nonlinearity]
        if conv:
            assert len(strides) == len(hidden_dims) + 1
            base_layer = {
                "ignore": diffeq_layers.IgnoreConv2d,
                "hyper": diffeq_layers.HyperConv2d,
                "squash": diffeq_layers.SquashConv2d,
                "concat": diffeq_layers.ConcatConv2d,
                "blend": diffeq_layers.BlendConv2d,
                "concatcoord": diffeq_layers.ConcatCoordConv2d,
            }[layer_type]
        else:
            strides = [None] * (len(hidden_dims) + 1)
            base_layer = {
                "ignore": diffeq_layers.IgnoreLinear,
                "hyper": diffeq_layers.HyperLinear,
                "squash": diffeq_layers.SquashLinear,
                "concat": diffeq_layers.ConcatLinear,
                "blend": diffeq_layers.BlendLinear,
                "concatcoord": diffeq_layers.ConcatLinear,
            }[layer_type]

        # build layers and add them
        encoder_layers = []
        decoder_layers = []
        hidden_shape = input_shape
        for i, (dim_out, stride) in enumerate(zip(hidden_dims + (input_shape[0],), strides)):
            if i <= len(hidden_dims) // 2:
                layers = encoder_layers
            else:
                layers = decoder_layers

            if stride is None:
                layer_kwargs = {}
            elif stride == 1:
                layer_kwargs = {"ksize": 3, "stride": 1, "padding": 1, "transpose": False}
            elif stride == 2:
                layer_kwargs = {"ksize": 4, "stride": 2, "padding": 1, "transpose": False}
            elif stride == -2:
                layer_kwargs = {"ksize": 4, "stride": 2, "padding": 1, "transpose": True}
            else:
                raise ValueError('Unsupported stride: {}'.format(stride))

            layers.append(base_layer(hidden_shape[0], dim_out, **layer_kwargs))

            hidden_shape = list(copy.copy(hidden_shape))
            hidden_shape[0] = dim_out
            if stride == 2:
                hidden_shape[1], hidden_shape[2] = hidden_shape[1] // 2, hidden_shape[2] // 2
            elif stride == -2:
                hidden_shape[1], hidden_shape[2] = hidden_shape[1] * 2, hidden_shape[2] * 2

        self.encoder_layers = nn.ModuleList(encoder_layers)
        self.decoder_layers = nn.ModuleList(decoder_layers)

    def forward(self, t, y):
        h = y
        for layer in self.encoder_layers:
            h = self.nonlinearity(layer(t, h))

        dx = h
        for i, layer in enumerate(self.decoder_layers):
            dx = layer(t, dx)
            # if not last layer, use nonlinearity
            if i < len(self.decoder_layers) - 1:
                dx = self.nonlinearity(dx)
        return h, dx


class ODEfunc(nn.Module):

    def __init__(self, diffeq, divergence_fn="approximate", residual=False, rademacher=False):
        super(ODEfunc, self).__init__()
        assert divergence_fn in ("brute_force", "approximate")

        # self.diffeq = diffeq_layers.wrappers.diffeq_wrapper(diffeq)
        self.diffeq = diffeq
        self.residual = residual
        self.rademacher = rademacher

        if divergence_fn == "brute_force":
            self.divergence_fn = divergence_bf
        elif divergence_fn == "approximate":
            self.divergence_fn = divergence_approx

        self.register_buffer("_num_evals", torch.tensor(0.))

    def before_odeint(self, e=None):
        self._e = e
        self._num_evals.fill_(0)

    def num_evals(self):
        return self._num_evals.item()

    def forward(self, t, states):
        assert len(states) >= 2
        y = states[0]

        # increment num evals
        self._num_evals += 1

        # convert to tensor
        t = torch.tensor(t).type_as(y)
        batchsize = y.shape[0]

        # Sample and fix the noise.
        if self._e is None:
            if self.rademacher:
                self._e = sample_rademacher_like(y)
            else:
                self._e = sample_gaussian_like(y)

        with torch.set_grad_enabled(True):
            y.requires_grad_(True)
            t.requires_grad_(True)
            for s_ in states[2:]:
                s_.requires_grad_(True)
            dy = self.diffeq(t, y, *states[2:])
            # Hack for 2D data to use brute force divergence computation.
            if not self.training and dy.view(dy.shape[0], -1).shape[1] == 2:
                divergence = divergence_bf(dy, y).view(batchsize, 1)
            else:
                divergence = self.divergence_fn(dy, y, e=self._e).view(batchsize, 1)
        if self.residual:
            dy = dy - y
            divergence -= torch.ones_like(divergence) * torch.tensor(np.prod(y.shape[1:]), dtype=torch.float32
                                                                     ).to(divergence)
        return tuple([dy, -divergence] + [torch.zeros_like(s_).requires_grad_(True) for s_ in states[2:]])


class AutoencoderODEfunc(nn.Module):

    def __init__(self, autoencoder_diffeq, divergence_fn="approximate", residual=False, rademacher=False):
        assert divergence_fn in ("approximate"), "Only approximate divergence supported at the moment. (TODO)"
        assert isinstance(autoencoder_diffeq, AutoencoderDiffEqNet)
        super(AutoencoderODEfunc, self).__init__()
        self.residual = residual
        self.autoencoder_diffeq = autoencoder_diffeq
        self.rademacher = rademacher

        self.register_buffer("_num_evals", torch.tensor(0.))

    def before_odeint(self, e=None):
        self._e = e
        self._num_evals.fill_(0)

    def forward(self, t, y_and_logpy):
        y, _ = y_and_logpy  # remove logpy

        # increment num evals
        self._num_evals += 1

        # convert to tensor
        t = torch.tensor(t).type_as(y)
        batchsize = y.shape[0]

        with torch.set_grad_enabled(True):
            y.requires_grad_(True)
            t.requires_grad_(True)
            h, dy = self.autoencoder_diffeq(t, y)

            # Sample and fix the noise.
            if self._e is None:
                if self.rademacher:
                    self._e = sample_rademacher_like(h)
                else:
                    self._e = sample_gaussian_like(h)

            e_vjp_dhdy = torch.autograd.grad(h, y, self._e, create_graph=True)[0]
            e_vjp_dfdy = torch.autograd.grad(dy, h, e_vjp_dhdy, create_graph=True)[0]
            divergence = torch.sum((e_vjp_dfdy * self._e).view(batchsize, -1), 1, keepdim=True)

        if self.residual:
            dy = dy - y
            divergence -= torch.ones_like(divergence) * torch.tensor(np.prod(y.shape[1:]), dtype=torch.float32
                                                                     ).to(divergence)

        return dy, -divergence
