'''Contain all related to networks'''
import pdb
import torch.nn as nn
import torch
import utils
import torchdiffeq as tdeq  # Only need to odeint part, not the entire adjoint
from torch_geometric.nn import ChebConv
from torch_geometric.nn import Sequential as Graph_Sequential
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

'''1. Specific to problems'''


class GNN(nn.Module):
    def __init__(self, args, logpx=True):
        super(GNN, self).__init__()
        self.net = build_net(args.Xdim, args.hidden_dim_str,
                             args.param_ls, args.activation, layer='Chebnet')
        self.logpx = logpx
    # self.net can be ANYTHING, just that we need the grad operation
    # Which can of course be a separate function

    def forward(self, t, x):
        return net_forward(self, t, x, layer='Chebnet', edge_index=self.edge_index)


class FCnet(nn.Module):
    # The function that give the gradient or force
    def __init__(self, args, logpx=True):
        super(FCnet, self).__init__()
        self.net = build_net(args.Xdim, args.hidden_dim_str,
                             args.param_ls, args.activation)
        self.logpx = logpx
    # self.net can be ANYTHING, just that we need the grad operation
    # Which can of course be a separate function

    def forward(self, t, x):
        return net_forward(self, t, x, layer='FC')


class ODEnet(nn.Module):
    def __init__(self, args, logpx=True):
        super(ODEnet, self).__init__()
        self.net = build_net(
            args.Xdim, args.hidden_dim_str, activation=args.activation, layer='ODE')
        self.logpx = logpx

    def forward(self, t, x):
        return net_forward(self, t, x, layer='ODE')


class Convnet(nn.Module):
    def __init__(self, args, logpx=True):
        super(Convnet, self).__init__()
        self.net = build_net(
            1, args.hidden_dim_str, activation=args.activation, layer='Convnet')
        self.logpx = logpx

    def forward(self, t, x):
        return net_forward(self, t, x, layer='Convnet')


'''2. Generic to all nets'''


class CNF(nn.Module):
    '''
        odefunc can be any function, as long as its forward mapping takes t,x and outputs 'out, -divf'
    '''

    def __init__(self, odefunc):
        super(CNF, self).__init__()
        self.odefunc = odefunc
        self.final_viz = False

    def forward(self, x, args, reverse=False, test=False):
        self.odefunc.logpx = True
        integration_times = torch.linspace(
            0.0, args.T, args.num_int_pts+1).to(device)
        if test:
            self.odefunc.logpx = False  # Need not track dlogpx
        if reverse:
            integration_times = torch.flip(integration_times, [0])
            self.odefunc.logpx = False  # Need not track dlogpx
        if self.final_viz:
            self.odefunc.logpx = True
        # True only for using FCnet on graph node feature
        self.odefunc.stackX = False
        # Default is True for using brute force divergence for 2D examples
        self.odefunc.div_bf = args.div_bf
        # Number of random projection to get trace estimator
        self.odefunc.num_e = args.num_e
        if args.netname == 'FCnet' and args.cond_gen:
            self.odefunc.stackX = True
        dlogpx = torch.zeros(x.shape[0]).to(device)
        if args.use_NeuralODE is False:
            # Based on my ODEInt
            # No backward adjoint step
            predz, dlogpx = odeint(
                self.odefunc, (x, dlogpx), integration_times, mtd=args.int_mtd)
        else:
            # Extract only the essential part from it
            predz, dlogpx = odeint_adjoint(
                self.odefunc, (x, dlogpx), integration_times, mtd=args.int_mtd)
        return predz, dlogpx


def odeint_adjoint(func, y0, t, mtd='RK4'):
    adjoint_params = tuple(list(func.parameters()))
    shapes, func, y0 = utils._check_inputs(func, y0)
    mtd = 'rk4' if mtd == 'RK4' else 'euler'
    solution = OdeintAdjointMethod.apply(func, y0, t, mtd, *adjoint_params)
    solution = utils._flat_to_shape(solution, (len(t),), shapes)
    return solution


class OdeintAdjointMethod(torch.autograd.Function):
    '''
        Simplified from NeuralODE paper.
    '''
    @staticmethod
    def forward(ctx, func, y0, t, mtd, *adjoint_params):
        ctx.func = func
        ctx.mtd = mtd
        with torch.no_grad():
            ans = tdeq.odeint(
                func, y0, t, rtol=1e-3, atol=1e-6, method=ctx.mtd)
            y = ans
            ctx.save_for_backward(t, y, *adjoint_params)
        return ans

    @staticmethod
    def backward(ctx, *grad_y):
        with torch.no_grad():
            func = ctx.func
            t, y, *adjoint_params = ctx.saved_tensors
            grad_y = grad_y[0]
            adjoint_params = tuple(adjoint_params)
            ##################################
            #      Set up initial state      #
            ##################################

            # [-1] because y and grad_y are both of shape (len(t), *y0.shape)
            aug_state = [torch.zeros((), dtype=y.dtype, device=y.device),
                         y[-1], grad_y[-1]]  # vjp_t, y, vjp_y
            aug_state.extend([torch.zeros_like(param)
                             for param in adjoint_params])  # vjp_params

            ##################################
            #    Set up backward ODE func    #
            ##################################

            def augmented_dynamics(t, y_aug):
                # Dynamics of the original system augmented with
                # the adjoint wrt y, and an integrator wrt t and args.
                y = y_aug[1]
                adj_y = y_aug[2]
                with torch.enable_grad():
                    t_ = t.detach()
                    t = t_.requires_grad_(True)
                    y = y.detach().requires_grad_(True)

                    func_eval = func(t_, y)

                    vjp_t, vjp_y, *vjp_params = torch.autograd.grad(
                        func_eval, (t, y) + adjoint_params, -adj_y,
                        allow_unused=True, retain_graph=True
                    )

                # autograd.grad returns None if no gradient, set to zero.
                vjp_t = torch.zeros_like(t) if vjp_t is None else vjp_t
                vjp_y = torch.zeros_like(y) if vjp_y is None else vjp_y
                vjp_params = [torch.zeros_like(param) if vjp_param is None else vjp_param
                              for param, vjp_param in zip(adjoint_params, vjp_params)]

                return (vjp_t, func_eval, vjp_y, *vjp_params)

            ##################################
            #       Solve adjoint ODE        #
            ##################################

            for i in range(len(t) - 1, 0, -1):
                # Run the augmented system backwards in time.
                aug_state = tdeq.odeint(
                    augmented_dynamics, tuple(aug_state),
                    t[i - 1:i + 1].flip(0), rtol=1e-3, atol=1e-6, method=ctx.mtd)
                # extract just the t[i - 1] value
                aug_state = [a[1] for a in aug_state]
                # update to use our forward-pass estimate of the state
                aug_state[1] = y[i - 1]
                # update any gradients wrt state at this time point
                aug_state[2] += grad_y[i - 1]
            adj_y = aug_state[2]
            adj_params = aug_state[3:]
        return (None, adj_y, None, None, *adj_params)

# New implementation. Entirely based on the more accurate odeint from torchdiffeq


def odeint(func, x_now, t_ls, mtd='RK4'):
    '''
        # New implementation based on torchdiffeq ODE
        # Old implementation based on no tolerance is below
    '''
    shapes, func, x_now = utils._check_inputs(func, x_now)
    mtd = 'rk4' if mtd == 'RK4' else 'euler'
    solution = tdeq.odeint(func, x_now, t_ls, rtol=1e-3,
                           atol=1e-6, method=mtd)
    solution = utils._flat_to_shape(solution, (len(t_ls),), shapes)
    return solution

# # Old implementation
# def odeint(odefunc, x_now, t_ls, mtd='RK4'):
#     '''
#         Hand crafted odeint, which in facts takes (nearly) the same input as torchdiffeq and give the same output
#         mtd: Any numerical ODE solver. Currently implemented RK4 and Euler
#         x_now = (xn, dlogpxn)
#     '''
#
#     h_ls = torch.diff(t_ls)  # Steps
#     X_ls = []
#     dlogpx_ls = []
#     for i, t in enumerate(t_ls):
#         X_ls = X_ls + [x_now[0]]
#         dlogpx_ls = dlogpx_ls + [x_now[1]]
#         if i >= len(t_ls) - 1:
#             break
#         h = h_ls[i]
#         if mtd == 'Euler':
#             xdiff, negdivfdiff = odefunc(t, x_now)
#             x_now = (x_now[0]+h*xdiff, x_now[1]+h*negdivfdiff)
#         if mtd == 'RK4':
#             x_now_tmp = x_now  # It is tuple, so no .clone()
#             xdiff1, negdivfdiff1 = odefunc(t, x_now_tmp)
#             x_now_tmp = (x_now_tmp[0]+h*xdiff1/2,
#                          x_now_tmp[1]+h*negdivfdiff1/2)
#             xdiff2, negdivfdiff2 = odefunc(t+h/2, x_now_tmp)
#             x_now_tmp = (x_now_tmp[0]+h*xdiff2/2,
#                          x_now_tmp[1]+h*negdivfdiff2/2)
#             xdiff3, negdivfdiff3 = odefunc(t+h/2, x_now_tmp)
#             x_now_tmp = (x_now_tmp[0]+h*xdiff3, x_now_tmp[1]+h*negdivfdiff3)
#             xdiff4, negdivfdiff4 = odefunc(t+h, x_now_tmp)
#             x_now = (x_now[0]+h*(xdiff1+xdiff2+xdiff3+xdiff4)/6,
#                      x_now[1]+h*(negdivfdiff1+negdivfdiff2+negdivfdiff3+negdivfdiff4)/6)
#     return torch.stack(X_ls), torch.stack(dlogpx_ls)
#


def build_net(Xdim, hidden_dim_str, param_ls=[], activation='elu', layer='FC', dropout=False):
    '''
        param_ls: stores customized weight and matrices for the model, since I want them to be output from another network.
            Note, it is only currently implemented for FC networks
    '''
    hidden_dims = tuple(map(int, hidden_dim_str.split("-")))
    dims = (Xdim,) + tuple(hidden_dims) + (Xdim,)
    activation_dict = {'elu': nn.ELU(), 'tanh': nn.Tanh()}
    layers_in_block = []
    if layer == 'Convnet':
        layers_in_block = construct_conv_net(
            hidden_dims, activation_dict[activation])
    else:
        for i, (in_dim, out_dim) in enumerate(zip(dims[:-1], dims[1:])):
            if layer == 'FC':
                if len(param_ls) == 0:
                    layers_in_block.append(nn.Linear(in_dim, out_dim))
                else:
                    # Assign "continuous gradient"
                    weight_i, bias_i = param_ls[2*i], param_ls[2*i+1]
                    layers_in_block.append(
                        Linear_customizedWeightBias(weight_i, bias_i))
                if i < len(hidden_dims):
                    layers_in_block.append(activation_dict[activation])
                if dropout:
                    # Dropout with 100*prob % probability after activation
                    prob = 0.1
                    layers_in_block.append(nn.Dropout(prob))
            if layer == 'ODE':
                if i < len(hidden_dims):
                    act = activation_dict[activation]
                else:
                    act = None
                layers_in_block.append(
                    ConcatSquashLinear(in_dim, out_dim, act))
            if layer == 'Chebnet':
                # Compare with OneChebTwoFC as used in IGNN
                if i == 0:
                    layers_in_block.append(
                        (ChebConv(in_dim, out_dim, K=2), 'x, edge_index ->x'))
                else:
                    layers_in_block.append(torch.nn.Linear(in_dim, out_dim))
                if i < len(hidden_dims):
                    layers_in_block.append(activation_dict[activation])
    if layer == 'Chebnet':
        force_f = Graph_Sequential(
            'x, edge_index', layers_in_block)
        return force_f
    else:
        if layer in ['FC', 'Convnet']:
            seq_class = nn.Sequential
        else:
            seq_class = utils.mySequential
        return seq_class(*layers_in_block)


def net_forward(self, t, x, layer='FC', edge_index=None):
    x, _ = x
    with torch.set_grad_enabled(True):
        if self.stackX:
            N, V, C = x.shape
            x = x.view(x.shape[0], -1)  # Transform to a matrix (N, VC)
            x.requires_grad_(True)
        else:
            x.requires_grad_(True)
        t.requires_grad_(True)
        if layer == 'FC':
            out = self.net(x)
            if self.stackX:
                out_grad = out
                out = out.reshape(N, V, C)
        if layer == 'ODE':
            out = self.net(t, x)[1]
        if layer == 'Chebnet':
            out = self.net(x, edge_index)
        if layer == 'Convnet':
            out = self.net(x)
        if self.logpx:
            if len(x.shape) == 2 and x.view(x.shape[0], -1).shape[1] == 2 and self.div_bf:
                divf = utils.divergence_bf(out, x).flatten()
            else:
                # Sample from Gaussian distribution to get Hutchinson trace estimator
                # Used in FFJORD
                if self.stackX:
                    # This is because we instead want to compute trace over a matrix of size VC by VC
                    e_ls = [torch.randn_like(out_grad).to(device)
                            for i in range(self.num_e)]
                    divf = utils.divergence_approx(out_grad, x, e_ls)
                else:
                    e_ls = [torch.randn_like(out).to(device)
                            for i in range(self.num_e)]
                    divf = utils.divergence_approx(out, x, e_ls)
        else:
            # Do not compute divf, just act as placeholder
            divf = torch.zeros(x.shape[0]).to(device)
    return out, -divf


def construct_conv_net(hidden_dim, activation_func):
    # Assume just 2 hidden dimensions for simplicity
    in_channels = 1
    ksize = 3
    hidden_dim = hidden_dim[0]
    layers = [nn.Conv2d(in_channels=in_channels,
                        out_channels=hidden_dim, kernel_size=ksize, stride=1)]
    layers.append(activation_func)
    layers.append(nn.Conv2d(in_channels=hidden_dim,
                            out_channels=hidden_dim, kernel_size=ksize, stride=1))
    layers.append(activation_func)
    layers.append(nn.Conv2d(in_channels=hidden_dim,
                            out_channels=in_channels, kernel_size=ksize, stride=1))
    layers.append(activation_func)
    # layers.append(nn.ConvTranspose2d(in_channels=out2,
    #                                  out_channels=out1, kernel_size=ksize+1, stride=2))
    # layers.append(activation_func)
    # layers.append(nn.ConvTranspose2d(in_channels=out1,
    #                                  out_channels=in_channels, kernel_size=ksize, stride=1))
    layers.append(nn.Upsample(size=(28, 28), mode='bilinear'))
    return layers


def get_n_params(model):
    pp = 0
    for p in list(model.parameters()):
        nn = 1
        for s in list(p.size()):
            nn = nn*s
        pp += nn
    return pp


class SmallGenNet(nn.Module):
    '''
        Yield the conditional mean of the base distribution using one-hot encoded response Y
    '''

    def __init__(self, Y_dim, C):
        super().__init__()
        self.fc = nn.Linear(Y_dim, C)

    def forward(self, Y):
        return self.fc(Y)


class param_net_class(nn.Module):
    '''
        This can be ANYTHING, as long as it outputs some vector (out_dim is dynamically updated :))
    '''

    def __init__(self, out_dim):
        super(param_net_class, self).__init__()
        self.layer = nn.Linear(1, out_dim)

    def forward(self, t):
        return self.layer(t)

### Customized nets ###


class ConcatSquashLinear(nn.Module):
    def __init__(self, dim_in, dim_out, act):
        super(ConcatSquashLinear, self).__init__()
        self._layer = nn.Linear(dim_in, dim_out)
        self._hyper_bias = nn.Linear(1, dim_out, bias=False)
        self._hyper_gate = nn.Linear(1, dim_out)
        self.act = act

    def forward(self, t, x):
        if self.act is not None:
            return t, self.act(self._layer(x) * torch.sigmoid(self._hyper_gate(t.view(1, 1)))
                               + self._hyper_bias(t.view(1, 1)))
        else:
            return t, self._layer(x) * torch.sigmoid(self._hyper_gate(t.view(1, 1))) \
                + self._hyper_bias(t.view(1, 1))


class Linear_customizedWeightBias(nn.Module):
    '''
        Exactly the same as nn.Linear, but we need to specify customized weights and bias
        This ie because we want them to be trainable.
        Note, if using OTHER networks, we need to do similar things by specifying weight and bias explicitly
    '''

    def __init__(self, weight, bias):
        super(Linear_customizedWeightBias, self).__init__()
        self.weight = weight
        self.bias = bias

    def forward(self, x):
        return nn.functional.linear(input=x, weight=self.weight, bias=self.bias)

############
############
############
############
############
############
############
############
############
############
############
############
############
############
############
############
############
############
############
############
############
