import nf
import torch
import torch.nn as nn

class DimwiseMLP(nn.Module):
    def __init__(self, conditioner_dim, hidden_dim, num_layers, **kwargs):
        super().__init__()
        self.f = nf.net.MLP(conditioner_dim + 1, [hidden_dim] * num_layers, 1)

    def forward(self, t, x, c, **kwargs):
        # t = t.repeat(x.shape[0] // t.shape[0], 1)
        y = self.f(torch.cat([x, c], -1))
        return y
