from utils import trace_df_dz, jacobian_df_dz, grad_trace_df_dz
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchdiffeq import odeint_adjoint as odeint



NONLINEARITIES = {
    "tanh": nn.Tanh(),
    "relu": nn.ReLU(),
    "softplus": nn.Softplus(),
    "elu": nn.ELU(),
    #"swish": Swish(),
    #"square": Lambda(lambda x: x**2),
    #"identity": Lambda(lambda x: x),
}



class ConcateNetwork(nn.Module):
    """ Concate network for f(z, t) concatenate t with input z
    """

    def __init__(self, in_out_dim, hidden_dim, deeper = False, activation = "tanh"):
        super().__init__()

        self.deeper = deeper
        self.in_out_dim = in_out_dim
        self.hidden_dim = hidden_dim

        _linear_concat_net = []
        fc1 = nn.Linear(self.in_out_dim + 1, self.hidden_dim)
        _linear_concat_net.append(fc1)
        _linear_concat_net.append(NONLINEARITIES[activation])
        if not self.deeper:
            fc2 = nn.Linear(self.hidden_dim, self.hidden_dim)
            fc3 = nn.Linear(self.hidden_dim, self.in_out_dim)
            _linear_concat_net.append(fc2)
            _linear_concat_net.append(NONLINEARITIES[activation])
            _linear_concat_net.append(fc3)
        else:
            fc2 = nn.Linear(self.hidden_dim, self.hidden_dim*3)
            fc3 = nn.Linear(self.hidden_dim*3, self.hidden_dim*3)
            fc4 = nn.Linear(self.hidden_dim*3, self.hidden_dim)
            fc5 = nn.Linear(self.hidden_dim, self.in_out_dim)
            _linear_concat_net.append(fc2)
            _linear_concat_net.append(NONLINEARITIES[activation])
            _linear_concat_net.append(fc3)
            _linear_concat_net.append(NONLINEARITIES[activation])
            _linear_concat_net.append(fc4)
            _linear_concat_net.append(NONLINEARITIES[activation])
            _linear_concat_net.append(fc5)

        self.linear_concat_net = nn.Sequential(*_linear_concat_net)

    def forward(self, t, z):
        """ Args: t: [,], z: [batch_size, z_dim]
        return dz/dt [batch_size, z_dim]
        """
        # repeat
        tt = torch.ones_like(z[:, :1]) * t
        # concate
        ttz = torch.cat([tt, z], dim=-1)
        
        return self.linear_concat_net(ttz)
        
        

class HyperNetwork(nn.Module):
    """Hyper-network allowing f(z(t), t) to change with time.

    Adapted from the NumPy implementation at:
    https://gist.github.com/rtqichen/91924063aa4cc95e7ef30b3a5491cc52
    """
    def __init__(self, in_out_dim, hidden_dim, width, deeper = False, activation = "tanh"):
        """
            Hyper-network returns parameters [W, B, U]:
                W: [width, in_out_dim, 1]
                B: [width, 1, 1]
                U: [width, 1, in_out_dim]
        """
        super().__init__()

        self.in_out_dim = in_out_dim
        self.hidden_dim = hidden_dim
        self.width = width
        blocksize = width * in_out_dim
        self.blocksize = blocksize
        self.deeper = deeper
        
        _hypernet = []
        
        # t as single input
        fc1 = nn.Linear(1, hidden_dim)
        _hypernet.append(fc1)
        _hypernet.append(NONLINEARITIES[activation])  
        if not self.deeper:
            fc2 = nn.Linear(hidden_dim, hidden_dim)
            fc3 = nn.Linear(hidden_dim, 3 * blocksize + width)
            _hypernet.append(fc2)
            _hypernet.append(NONLINEARITIES[activation])
            _hypernet.append(fc3)
        else:
            fc2 = nn.Linear(hidden_dim, hidden_dim*10)
            fc3 = nn.Linear(hidden_dim*10, hidden_dim*10)
            fc4 = nn.Linear(hidden_dim*10, hidden_dim)
            fc5 = nn.Linear(hidden_dim, 3 * blocksize + width)
            _hypernet.append(fc2)
            _hypernet.append(NONLINEARITIES[activation])
            _hypernet.append(fc3)
            _hypernet.append(NONLINEARITIES[activation])
            _hypernet.append(fc4)
            _hypernet.append(NONLINEARITIES[activation])
            _hypernet.append(fc5)

        self.hypernet = nn.Sequential(*_hypernet)

        

    def forward(self, t):
        # predict params
        params = self.hypernet(t.reshape(1, 1))

        # restructure
        params = params.reshape(-1)
        # first block for W [width, in_out_dim, 1]
        W = params[:self.blocksize].reshape(self.width, self.in_out_dim, 1)
        # second block for U [width, 1, in_out_dim]
        U = params[self.blocksize:2 * self.blocksize].reshape(self.width, 1, self.in_out_dim)
        # third block for G [width, 1, in_out_dim]
        G = params[2 * self.blocksize:3 * self.blocksize].reshape(self.width, 1, self.in_out_dim)
        # parameter U is element product of second block and simgoid of third block G [width, 1, in_out_dim]
        U = U * torch.sigmoid(G)
        # fourth block for B [width, 1, 1]
        B = params[3 * self.blocksize:].reshape(self.width, 1, 1)
        return [W, B, U]


