import nf
from nf.net import MLP
import torch
import torch.nn as nn

# All networks take input (..., dim + 1), split into t and x and return f(t, x).
# Arguments equivalent to `nf.net.MLP`

def process_input(input):
    t = input[...,:1]
    x = input[...,1:]
    return t, x

class CNFSquash(nn.Module):
    def __init__(self, in_dim, hidden_dims, out_dim, activation='Tanh', final_activation=None, **kwargs):
        super().__init__()
        self.layer = MLP(in_dim - 1, hidden_dims, out_dim, activation, final_activation)
        self.hyper = nn.Linear(1, out_dim)

    def forward(self, input, **kwargs):
        t, x = process_input(input)
        return self.layer(x) * torch.sigmoid(self.hyper(t))

class CNFConcatSquash(nn.Module):
    def __init__(self, in_dim, hidden_dims, out_dim, activation='Tanh', final_activation=None, **kwargs):
        super().__init__()
        self.layer = MLP(in_dim - 1, hidden_dims, out_dim, activation, final_activation)
        self.bias = nn.Linear(1, out_dim, bias=False)
        self.gate = nn.Linear(1, out_dim)

    def forward(self, input, **kwargs):
        t, x = process_input(input)
        return self.layer(x) * torch.sigmoid(self.gate(t)) + self.bias(t)

class CNFGated(nn.Module):
    def __init__(self, in_dim, hidden_dims, out_dim, activation='Tanh', final_activation=None, **kwargs):
        super().__init__()
        self.layer = MLP(in_dim - 1, hidden_dims, out_dim, activation, final_activation)
        self.gate = MLP(in_dim - 1, hidden_dims, out_dim, activation, final_activation)

    def forward(self, input, **kwargs):
        _, x = process_input(input)
        return self.layer(x) * torch.sigmoid(self.gate(x))

class CNFBlend(nn.Module):
    def __init__(self, in_dim, hidden_dims, out_dim, activation='Tanh', final_activation=None, **kwargs):
        super().__init__()
        self.layer1 = MLP(in_dim - 1, hidden_dims, out_dim, activation, final_activation)
        self.layer2 = MLP(in_dim - 1, hidden_dims, out_dim, activation, final_activation)

    def forward(self, input, **kwargs):
        t, x = process_input(input)
        y1 = self.layer1(x)
        y2 = self.layer2(x)
        return y1 + (y2 - y1) * t

class CNFLinearDiv(nn.Module):
    def __init__(self, in_dim, hidden_dims, out_dim, activation='Tanh', final_activation=None, **kwargs):
        super().__init__()
        self.affine = nf.Affine(in_dim)

    def forward(self, input, **kwargs):
        t, x = process_input(input)
        y, log_jac = self.affine.forward(x)
        sqjacnorm = torch.zeros(x.shape[0]).requires_grad_(True)
        return y, log_jac.exp(), sqjacnorm

class CNFFuncAndJac(nn.Module):
    def __init__(self, in_dim, hidden_dim, out_dim, num_layers, interaction,
                 induced=False, n_points=None, **kwargs):
        super().__init__()
        if interaction in ['sum', 'max']:
            self.exclusive = nf.net.ExclusiveSetNet(in_dim, hidden_dim, out_dim, num_layers, interaction=interaction)
        elif interaction == 'attention':
            self.exclusive = nf.net.ExclusiveAttentionNet(in_dim, hidden_dim, out_dim, num_layers,
                                                          n_heads=1, n_points=n_points, induced=induced)
        self.dimwise = nf.net.DimwiseMLP(out_dim, hidden_dim, num_layers)

    def forward(self, t, x, mask=None, **kwargs):
        return nf.net.f_and_jac_reg_fn(self.exclusive, self.dimwise, t, x, mask=mask)

class CNFSetWrapper(nn.Module):
    """
    A wrapper for a net in continuous NF. Handles reshaping.
    """
    def __init__(self, net, input_time=False):
        super().__init__()
        self.input_time = input_time
        self.net = net
        self.shape = None
        self.mask = None

    def forward(self, input, **kwargs):
        """
        Args:
            x: Input set with two last dims combined (B, N * D + 1)
        """
        t, x = process_input(input)
        input_shape = x.shape
        x = x.view(*self.shape)
        if self.input_time:
            out = self.net(t, x, mask=self.mask)
        else:
            out = self.net(x, mask=self.mask)
        if isinstance(out, tuple):
            y, l, r = out
            return y.reshape(*input_shape), l.reshape(*input_shape), r
        else:
            return out.view(*input_shape)
