'''Contain all related to networks'''
import pdb
import torch.nn as nn
import torch
import utils_Jacnorm as utils
import torchdiffeq as tdeq  # Only need to odeint part, not the entire adjoint
from torch_geometric.nn import ChebConv
from torch_geometric.nn import Sequential as Graph_Sequential
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


''' Define classifier net '''

def def_classifier_net(args):
    # Simpler classifier net

    class LinearClassifier(torch.nn.Module):
        def __init__(self, args):
            super(LinearClassifier, self).__init__()
            # Define hyperparameters for the classifier
            input_dim = args.Xdim
            classifier_hidden_dim_str = args.classifier_hidden_dim_str # Defines # layers
            classifier_activation = args.classifier_activation # Type of classifier activation
            activation_dict = {'elu': nn.ELU(), 'tanh': nn.Tanh(),
                               'softplus': nn.Softplus(beta=20),
                               'relu': nn.ReLU(), 'prelu': nn.PReLU()}
            # Construct the classifier
            hidden_dims = tuple(map(int, classifier_hidden_dim_str.split("-")))
            dims = (input_dim,) + tuple(hidden_dims) + (1,)
            layers_in_block = []
            for (in_dim, out_dim) in zip(dims[:-1], dims[1:]):
                layers_in_block.append(nn.Linear(in_dim, out_dim))
                if out_dim != 1:
                    layers_in_block.append(activation_dict[classifier_activation])
            self.classifier = nn.Sequential(*layers_in_block)
        def forward(self, x):
            x = self.classifier(x)
            return x

    classifier = LinearClassifier(args).to(device)
    return classifier

'''1. Specific to problems'''


class GNN(nn.Module):
    def __init__(self, args, logpx=True):
        super(GNN, self).__init__()
        self.net = build_net(args.Xdim, args.hidden_dim_str,
                            args.activation, layer='Chebnet')
        self.logpx = logpx
    # self.net can be ANYTHING, just that we need the grad operation
    # Which can of course be a separate function

    def forward(self, t, x):
        return net_forward(self, t, x, layer='Chebnet', edge_index=self.edge_index)


class FCnet(nn.Module):
    # The function that give the gradient or force
    def __init__(self, args, logpx=True):
        super(FCnet, self).__init__()
        self.net = build_net(args.Xdim, args.hidden_dim_str,
                             args.activation)
        self.logpx = logpx
    # self.net can be ANYTHING, just that we need the grad operation
    # Which can of course be a separate function

    def forward(self, t, x):
        return net_forward(self, t, x, layer='FC')


class ODEnet(nn.Module):
    def __init__(self, args, logpx=True, layer_type = 'ODE'):
        super(ODEnet, self).__init__()
        self.net = build_net(
            args.Xdim, args.hidden_dim_str, activation=args.activation, layer=layer_type)
        self.logpx = logpx
        self.layer_type = layer_type

    def forward(self, t, x):
        return net_forward(self, t, x, layer=self.layer_type)



class Convnet(nn.Module):
    def __init__(self, args, logpx=True):
        super(Convnet, self).__init__()
        self.net = build_net(
            args.Xdim, args.hidden_dim_str, activation=args.activation, layer='Convnet')
        self.logpx = logpx

    def forward(self, t, x):
        return net_forward(self, t, x, layer='Convnet')



class ODEConvnet(nn.Module):
    def __init__(self, args, logpx=True):
        super(ODEConvnet, self).__init__()
        self.net = build_ODE_convnet(
            args.Xdim, args.hidden_dim_str, 
            activation=args.activation, use_BN = args.use_BN)
        self.logpx = logpx

    def forward(self, t, x):
        return net_forward(self, t, x, layer='ODE')

class ODEConvnet32dim(nn.Module):
    def __init__(self, args, logpx=True):
        super(ODEConvnet32dim, self).__init__()
        self.net = build_ODE_convnet32dim(
            args.Xdim, args.hidden_dim_str, 
            activation=args.activation, use_BN = args.use_BN)
        self.logpx = logpx

    def forward(self, t, x):
        return net_forward(self, t, x, layer='ODE')



'''2. Generic to all nets'''


def net_forward(self, t, x_and_dlogpx, layer='FC', edge_index=None):
    x, _ , _ = x_and_dlogpx
    if self.stackX:
        N, V, C = x.shape
        x = x.view(x.shape[0], -1)  # Transform to a matrix (N, VC), for graph data
    def get_output(layer):
        out_grad = None
        if layer == 'FC':
            out = self.net(x)
            if self.stackX:
                out_grad = out
                out = out.reshape(N, V, C)
        if 'ODE' in layer:
            out = self.net(t, x)[1]
        if layer == 'Chebnet':
            out = self.net(x, edge_index)
        if layer == 'Convnet':
            out = self.net(x)
        return out_grad, out
    if self.logpx:
        with torch.set_grad_enabled(True):
            x.requires_grad_(True)
            t.requires_grad_(True)
            out_grad, out = get_output(layer)
            if self.div_bf:
                divf = utils.divergence_bf(out, x).flatten()
                Jac_norm_out = torch.zeros_like(divf)
            else:
                # Sample from Gaussian distribution to get Hutchinson trace estimator
                # Used in FFJORD
                if self.stackX: 
                    # This is because we instead want to compute trace over a matrix of size VC by VC
                    if self.fix_e_ls:
                        if self.e_ls is None:
                            self.e_ls = get_e_ls(out_grad, self.num_e)
                    else:
                        self.e_ls = get_e_ls(out_grad, self.num_e)
                    divf, Jac_norm_out = utils.divergence_approx(out_grad, x, self.e_ls)
                else:
                    if self.fix_e_ls:
                        if self.e_ls is None:
                            self.e_ls = get_e_ls(out, self.num_e)
                    else:
                        self.e_ls = get_e_ls(out, self.num_e)
                    if self.e_ls is None:
                        self.e_ls = get_e_ls(out, self.num_e)
                    divf, Jac_norm_out= utils.divergence_approx(out, x, self.e_ls)
    else:
        # Do not compute divf, just act as placeholder
        divf = torch.zeros(x.shape[0]).to(device)
        Jac_norm_out = torch.zeros_like(divf)
        _, out = get_output(layer)
    self.counter += 1 # Count how many times the function is called in one odeint along [t_k, t_{k+1}]
    return out, -divf, Jac_norm_out


def get_e_ls(out, num_e):
    e_ls = []
    for i in range(num_e):
        # torch.manual_seed(1103+i)
        e_ls.append(torch.randn_like(out).to(device))
    return e_ls


class CNF(nn.Module):
    '''
        odefunc can be any function, as long as its forward mapping takes t,x and outputs 'out, -divf'
        where out is the output of the function and divf is the divergence of the function
        and the shape of out is the same as the shape of x.
    '''

    def __init__(self, odefunc):
        super(CNF, self).__init__()
        self.odefunc = odefunc

    def forward(self, x, args, reverse=False, test=False, mult_gpu=False):
        self.odefunc.logpx = True
        # NOTE, if we actually build time into self.odefunc, we would have
        # time be accumulated, rather than starting at 0.0
        integration_times = torch.linspace(
            args.Tk_1, args.Tk, args.num_int_pts+1).to(device)
        if test:
            self.odefunc.logpx = False  # Need not track dlogpx
        if reverse:
            integration_times = torch.flip(integration_times, [0])
        # True only for using FCnet on graph node feature
        self.odefunc.stackX = False
        # Default is True for using brute force divergence for 2D examples
        self.odefunc.div_bf = args.div_bf
        # Number of random projection to get trace estimator
        self.odefunc.num_e = args.num_e
        dlogpx = torch.zeros(x.shape[0]).to(device)
        dJacnorm = torch.zeros(x.shape[0]).to(device)
        ######################
        self.odefunc.e_ls = None 
        # Same eps. for the entire [0,T] (when breaking to sub-intervals)
        self.odefunc.fix_e_ls = args.fix_e_ls
        self.odefunc.counter = 0
        ########################
        if args.use_NeuralODE is False:
            # Naive backprop, No backward adjoint step
            predz, dlogpx, dJacnorm = odeint(
                self.odefunc, (x, dlogpx, dJacnorm), integration_times, mtd=args.int_mtd,
                rtol = args.rtol, atol = args.atol)
        else:
            # Extract only the essential part from it
            predz, dlogpx, dJacnorm = odeint_adjoint(
                self.odefunc, (x, dlogpx, dJacnorm), integration_times, mtd=args.int_mtd,
                rtol = args.rtol, atol = args.atol)
        if mult_gpu:
            # This is because predz stores intermediate trajectories along dim 0, 
            # and in multi-gpu context,
            # this would not be correct
            return predz[-1], dlogpx[-1], dJacnorm[-1]
        else:
            return predz, dlogpx, dJacnorm
    
mtd_dict = {'RK4': 'rk4', 'Euler': 'euler', 'DOPRI5': 'dopri5'}


def odeint(func, x_now, t_ls, mtd='RK4', rtol=1e-4, atol=1e-4):
    '''
        # New implementation based on torchdiffeq ODE
        # Old implementation based on no tolerance is below
    '''
    shapes, func, x_now = utils._check_inputs(func, x_now)
    mtd = mtd_dict[mtd]
    # atol and rtol only used for adaptive solver.
    # but this on image is very very slow, esp if rtol and atol to 1e-5
    # It may help with performance.
    solution = tdeq.odeint(func, x_now, t_ls, rtol=rtol,
                           atol=atol, method=mtd)
    solution = utils._flat_to_shape(solution, (len(t_ls),), shapes)
    return solution

def odeint_adjoint(func, x_now, t_ls, mtd='RK4', rtol=1e-4, atol=1e-4):
    '''
        # New implementation based on torchdiffeq ODE
        # Old implementation based on no tolerance is below
    '''
    shapes, func, x_now = utils._check_inputs(func, x_now)
    mtd = mtd_dict[mtd]
    # atol and rtol only used for adaptive solver.
    # but this on image is very very slow, esp if rtol and atol to 1e-5
    # It may help with performance.
    solution = tdeq.odeint_adjoint(func, x_now, t_ls, rtol=rtol,
                                   atol=atol, method=mtd)
    solution = utils._flat_to_shape(solution, (len(t_ls),), shapes)
    return solution

# def odeint_adjoint(func, y0, t, mtd='RK4', rtol = 1e-4, atol = 1e-4):
#     adjoint_params = tuple(list(func.parameters()))
#     shapes, func, y0 = utils._check_inputs(func, y0)
#     mtd = mtd_dict[mtd]
#     solution = OdeintAdjointMethod.apply(func, y0, t, mtd, rtol, atol, *adjoint_params)
#     solution = utils._flat_to_shape(solution, (len(t),), shapes)
#     return solution


# class OdeintAdjointMethod(torch.autograd.Function):
#     '''
#         Simplified from NeuralODE paper.
#     '''
#     @staticmethod
#     def forward(ctx, func, y0, t, mtd, rtol, atol, *adjoint_params):
#         ctx.func = func
#         ctx.mtd = mtd
#         ctx.rtol, ctx.atol = rtol, atol
#         with torch.no_grad():
#             ans = tdeq.odeint(
#                 func, y0, t, rtol=ctx.rtol, atol=ctx.atol, method=ctx.mtd)
#             y = ans
#             ctx.save_for_backward(t, y, *adjoint_params)
#         return ans

#     @staticmethod
#     def backward(ctx, *grad_y):
#         with torch.no_grad():
#             func = ctx.func
#             t, y, *adjoint_params = ctx.saved_tensors
#             grad_y = grad_y[0]
#             adjoint_params = tuple(adjoint_params)
#             ##################################
#             #      Set up initial state      #
#             ##################################

#             # [-1] because y and grad_y are both of shape (len(t), *y0.shape)
#             aug_state = [torch.zeros((), dtype=y.dtype, device=device),
#                          y[-1], grad_y[-1]]  # vjp_t, y, vjp_y
#             aug_state.extend([torch.zeros_like(param)
#                              for param in adjoint_params])  # vjp_params

#             ##################################
#             #    Set up backward ODE func    #
#             ##################################

#             def augmented_dynamics(t, y_aug):
#                 # Dynamics of the original system augmented with
#                 # the adjoint wrt y, and an integrator wrt t and args.
#                 y = y_aug[1]
#                 adj_y = y_aug[2]
#                 with torch.enable_grad():
#                     t_ = t.detach()
#                     t = t_.requires_grad_(True)
#                     y = y.detach().requires_grad_(True)

#                     func_eval = func(t_, y)

#                     vjp_t, vjp_y, *vjp_params = torch.autograd.grad(
#                         func_eval, (t, y) + adjoint_params, -adj_y,
#                         allow_unused=True, retain_graph=True
#                     )

#                 # autograd.grad returns None if no gradient, set to zero.
#                 vjp_t = torch.zeros_like(t) if vjp_t is None else vjp_t
#                 vjp_y = torch.zeros_like(y) if vjp_y is None else vjp_y
#                 vjp_params = [torch.zeros_like(param) if vjp_param is None else vjp_param
#                               for param, vjp_param in zip(adjoint_params, vjp_params)]

#                 return (vjp_t, func_eval, vjp_y, *vjp_params)

#             ##################################
#             #       Solve adjoint ODE        #
#             ##################################

#             for i in range(len(t) - 1, 0, -1):
#                 # Run the augmented system backwards in time.
#                 aug_state = tdeq.odeint(
#                     augmented_dynamics, tuple(aug_state),
#                     t[i - 1:i + 1].flip(0), rtol=ctx.rtol, atol=ctx.atol, method=ctx.mtd)
#                 # extract just the t[i - 1] value
#                 aug_state = [a[1] for a in aug_state]
#                 # update to use our forward-pass estimate of the state
#                 aug_state[1] = y[i - 1]
#                 # update any gradients wrt state at this time point
#                 aug_state[2] = aug_state[2] + grad_y[i - 1]
#             adj_y = aug_state[2]
#             adj_params = aug_state[3:]
#         return (None, adj_y, None, None, *adj_params)

def build_ODE_convnet32dim(Xdim, hidden_dim_str, activation='elu', use_BN = False):
    hidden_dims = tuple(map(int, hidden_dim_str.split("-")))
    dims = (Xdim,) + tuple(hidden_dims) + (Xdim,)
    activation_dict = {'elu': nn.ELU(), 'tanh': nn.Tanh(),
                       'softplus': nn.Softplus(beta=20)}
    layers_in_block = []
    BNs = []
    for i, (in_dim, out_dim) in enumerate(zip(dims[:-1], dims[1:])):
        if i < int(len(dims)/2):
            ksize = 3
            stride = 2 if i < int(len(dims)/2) - 1 else 1
            layers_in_block.append(nn.Conv2d(in_dim+1, out_dim, kernel_size=ksize, stride=stride, padding=1))
        else:
            ksize = 3
            stride = 2 if i > int(len(dims)/2) else 1
            out_p = 1 if i > int(len(dims)/2) else 0
            layers_in_block.append(nn.ConvTranspose2d(in_dim+1, out_dim, kernel_size=ksize, 
                                                      stride=stride, padding=1,
                                                      output_padding=out_p))
        if i < len(hidden_dims) and use_BN:
            BNs.append(nn.BatchNorm2d(out_dim))

    class customized_seq(nn.Module):
        def __init__(self, layers, BNs = []):
            super(customized_seq, self).__init__()
            self.layers = nn.ModuleList(layers)
            self.act = activation_dict[activation]
            self.bns = None if len(BNs) == 0 else nn.ModuleList(BNs)            
        def forward(self,t, x):
            s = 0
            for layer in self.layers:
                tt = torch.ones_like(x[:, :1, :, :]) * t
                ttx = torch.cat([tt, x], 1)
                x = layer(ttx)
                # Apply self.act on x except for last layer
                if s < len(self.layers) - 1:
                    # print(f'Apply activation {self.act} on {layer}')
                    x = self.act(x)
                    if self.bns is not None:
                        # print(f'Apply BN {self.bns[s]} on {layer}')
                        x = self.bns[s](x)
                s += 1
            return t, x
    return customized_seq(layers_in_block, BNs)


def build_ODE_convnet(Xdim, hidden_dim_str, activation='elu', use_BN = False):
    hidden_dims = tuple(map(int, hidden_dim_str.split("-")))
    dims = (Xdim,) + tuple(hidden_dims) + (Xdim,)
    activation_dict = {'elu': nn.ELU(), 'tanh': nn.Tanh(),
                       'softplus': nn.Softplus(beta=20)}
    layers_in_block = []
    BNs = []
    for i, (in_dim, out_dim) in enumerate(zip(dims[:-1], dims[1:])):
        if i < int(len(dims)/2):
            ksize = 3 if i % 2 == 0 else 4
            stride = 1 if i % 2 == 0 else 2
            layers_in_block.append(nn.Conv2d(in_dim+1, out_dim, kernel_size=ksize, stride=stride, padding=1))
        else:
            ksize = 3 if i % 2 == 1 else 4
            stride = 1 if i % 2 == 1 else 2
            layers_in_block.append(nn.ConvTranspose2d(in_dim+1, out_dim, kernel_size=ksize, stride=stride, padding=1))
        if i < len(hidden_dims) and use_BN:
            BNs.append(nn.BatchNorm2d(out_dim))

    class customized_seq(nn.Module):
        def __init__(self, layers, BNs = []):
            super(customized_seq, self).__init__()
            self.layers = nn.ModuleList(layers)
            self.act = activation_dict[activation]
            self.bns = None if len(BNs) == 0 else nn.ModuleList(BNs)            
        def forward(self,t, x):
            s = 0
            for layer in self.layers:
                tt = torch.ones_like(x[:, :1, :, :]) * t
                ttx = torch.cat([tt, x], 1)
                x = layer(ttx)
                # Apply self.act on x except for last layer
                if s < len(self.layers) - 1:
                    # print(f'Apply activation {self.act} on {layer}')
                    x = self.act(x)
                    if self.bns is not None:
                        # print(f'Apply BN {self.bns[s]} on {layer}')
                        x = self.bns[s](x)
                s += 1
            return t, x
    return customized_seq(layers_in_block, BNs)


def build_net(Xdim, hidden_dim_str, activation='elu', layer='FC'):
    hidden_dims = tuple(map(int, hidden_dim_str.split("-")))
    dims = (Xdim,) + tuple(hidden_dims) + (Xdim,)
    activation_dict = {'elu': nn.ELU(), 'tanh': nn.Tanh(),
                       'softplus': nn.Softplus(beta=20)}
    layers_in_block = []
    if layer == 'Convnet':
        layers_in_block = complex_net(
            Xdim, activation_dict[activation])
        # hidden_dim = 256  # ResFlow used 512 for CIFAR10 and others
        # layers_in_block = ResFlow_block(Xdim, hidden_dim)
    else:
        for i, (in_dim, out_dim) in enumerate(zip(dims[:-1], dims[1:])):
            if layer == 'FC':
                layers_in_block.append(nn.Linear(in_dim, out_dim))
                if i < len(hidden_dims):
                    layers_in_block.append(activation_dict[activation])
            if layer == 'ODE':
                if i < len(hidden_dims):
                    act = activation_dict[activation]
                else:
                    act = None
                layers_in_block.append(
                    ConcatSquashLinear(in_dim, out_dim, act))
            if layer == 'ODE_concat':
                if i < len(hidden_dims):
                    layers_in_block.append(ConcatLinear(in_dim, out_dim, activation_dict[activation]))
                else:
                    layers_in_block.append(ConcatLinear(in_dim, out_dim, None))
            if layer == 'Chebnet':
                # Compare with OneChebTwoFC as used in IGNN
                if i == 0:
                    layers_in_block.append(
                        (ChebConv(in_dim, out_dim, K=3), 'x, edge_index ->x'))
                else:
                    layers_in_block.append(torch.nn.Linear(in_dim, out_dim))
                if i < len(hidden_dims):
                    layers_in_block.append(activation_dict[activation])
    if layer == 'Chebnet':
        force_f = Graph_Sequential(
            'x, edge_index', layers_in_block)
        return force_f
    else:
        if layer in ['FC', 'Convnet']:
            seq_class = nn.Sequential
        else:
            seq_class = utils.mySequential
        return seq_class(*layers_in_block)


def get_n_params(model):
    pp = 0
    for p in list(model.parameters()):
        nn = 1
        for s in list(p.size()):
            nn = nn*s
        pp += nn
    return pp


class SmallGenNet(nn.Module):
    '''
        Yield the conditional mean of the base distribution using one-hot encoded response Y
    '''

    def __init__(self, Y_dim, C):
        super().__init__()
        self.fc = nn.Linear(Y_dim, C)

    def forward(self, Y):
        return self.fc(Y)


class param_net_class(nn.Module):
    '''
        This can be ANYTHING, as long as it outputs some vector (out_dim is dynamically updated :))
    '''

    def __init__(self, out_dim):
        super(param_net_class, self).__init__()
        self.layer = nn.Linear(1, out_dim)

    def forward(self, t):
        return self.layer(t)

### Customized nets ###


class ConcatSquashLinear(nn.Module):
    def __init__(self, dim_in, dim_out, act):
        super(ConcatSquashLinear, self).__init__()
        self._layer = nn.Linear(dim_in, dim_out)
        self._hyper_bias = nn.Linear(1, dim_out, bias=False)
        self._hyper_gate = nn.Linear(1, dim_out)
        self.act = act

    def forward(self, t, x):
        # Note, return t below is like a placeholder I guess. Keep in
        if self.act is not None:
            return t, self.act(self._layer(x) * torch.sigmoid(self._hyper_gate(t.view(1, 1)))
                               + self._hyper_bias(t.view(1, 1)))
        else:
            return t, self._layer(x) * torch.sigmoid(self._hyper_gate(t.view(1, 1))) \
                + self._hyper_bias(t.view(1, 1))

class ConcatLinear(nn.Module):
    def __init__(self, dim_in, dim_out, act):
        super(ConcatLinear, self).__init__()
        self._layer = nn.Linear(dim_in + 1, dim_out)
        self.act = act

    def forward(self, t, x):
        tt = torch.ones_like(x[:,:1]) * t
        ttx = torch.cat([tt, x], 1)
        if self.act is not None:
            return t, self.act(self._layer(ttx))
        else:
            return t, self._layer(ttx)

def complex_net(Xdim, activation_func):
    in_channels = Xdim
    if Xdim == 3:
        dim0, dim1, dim2 = 32, 64, 128  # No drastic inc. in computation time
    else:
        dim0, dim1, dim2 = 16, 32, 64
        dim0, dim1, dim2 = 64, 64, 128
    ksize = 3
    # layers = [nn.Conv2d(in_channels=in_channels,
    #                     out_channels=dim0, kernel_size=ksize, stride=1, padding=1)]  # dim0-32-32
    # layers.append(activation_func)
    # layers.append(nn.BatchNorm2d(dim0))
    # layers.append(nn.Conv2d(in_channels=dim0,
    #                     out_channels=dim0, kernel_size=ksize, stride=2, padding=1))  # dim0-16-16
    layers = [nn.Conv2d(in_channels=Xdim,
                        out_channels=dim0, kernel_size=ksize, stride=2, padding=1)]  # dim0-16-16
    layers.append(activation_func)
    layers.append(nn.BatchNorm2d(dim0))
    # layers.append(nn.Conv2d(in_channels=dim0,
    #                         out_channels=dim1, kernel_size=ksize, stride=1, padding=1)) # dim1-16-16
    # layers.append(activation_func)
    # layers.append(nn.BatchNorm2d(dim1))
    # layers.append(nn.Conv2d(in_channels=dim1,
    #                         out_channels=dim1, kernel_size=ksize, stride=2, padding=1))  # dim1-8-8
    layers.append(nn.Conv2d(in_channels=dim0,
                            out_channels=dim1, kernel_size=ksize, stride=2, padding=1))  # dim1-8-8
    layers.append(activation_func)
    layers.append(nn.BatchNorm2d(dim1))
    # layers.append(nn.Conv2d(in_channels=dim1,
    #                         out_channels=dim2, kernel_size=ksize, stride=1, padding=1)) # dim2-8-8
    # layers.append(activation_func)
    # layers.append(nn.BatchNorm2d(dim2))
    # layers.append(nn.Conv2d(in_channels=dim2,
    #                         out_channels=dim2, kernel_size=ksize, stride=2, padding=1))  # dim2-4-4
    layers.append(nn.Conv2d(in_channels=dim1,
                            out_channels=dim2, kernel_size=ksize, stride=2, padding=1))  # dim2-4-4
    layers.append(activation_func)
    layers.append(nn.BatchNorm2d(dim2))
    # layers.append(nn.ConvTranspose2d(in_channels=dim2,
    #                         out_channels=dim2, kernel_size=ksize, stride=1, padding=1))  # dim2-4-4
    # layers.append(activation_func)
    # layers.append(nn.BatchNorm2d(dim2))
    layers.append(nn.ConvTranspose2d(in_channels=dim2,
                                     out_channels=dim1, kernel_size=ksize, stride=2, padding=1, output_padding=1)) # dim1-8-8
    layers.append(activation_func)
    layers.append(nn.BatchNorm2d(dim1))
    # layers.append(nn.ConvTranspose2d(in_channels=dim1,
    #                                  out_channels=dim1, kernel_size=ksize, stride=1, padding=1)) # dim1-8-8
    # layers.append(activation_func)
    # layers.append(nn.BatchNorm2d(dim1))
    layers.append(nn.ConvTranspose2d(in_channels=dim1,
                                     out_channels=dim0, kernel_size=ksize, stride=2, padding=1, output_padding=1)) # dim0-16-16
    layers.append(activation_func)
    layers.append(nn.BatchNorm2d(dim0))
    # layers.append(nn.ConvTranspose2d(in_channels=dim0,
    #                                  out_channels=dim0, kernel_size=ksize, stride=1, padding=1)) # dim0-16-16
    # layers.append(activation_func)
    # layers.append(nn.BatchNorm2d(dim0))
    layers.append(nn.ConvTranspose2d(in_channels=dim0,
                                     out_channels=Xdim, kernel_size=ksize, stride=2, padding=1, output_padding=1))  # 3-32-32
    return layers

def ResFlow_block(Xdim, hidden_dim):
    ''' Mimick FFJORD for CIFAR10 '''
    import ResFlow_layer as Rlayer
    layers = [Rlayer.InducedNormConv2d(in_channels=Xdim,
                                       out_channels=hidden_dim, kernel_size=3, stride=1, padding=1)]
    layers.append(Rlayer.Swish())
    layers.append(Rlayer.InducedNormConv2d(in_channels=hidden_dim,
                                           out_channels=hidden_dim, kernel_size=1, stride=1, padding=0))
    layers.append(Rlayer.Swish())
    layers.append(Rlayer.InducedNormConv2d(in_channels=hidden_dim,
                                           out_channels=Xdim, kernel_size=3, stride=1, padding=1))
    return layers



############
############
############
############
############
############
############
############
############
############
############
############
############
############
############
############
############
############
############
############
############
