import bnn
import torch.nn as nn


def net(ap_spec, inducing_data, in_features, num_layers, kwargs_lower, kwargs_top, channels=50):
    layers_list = []
    if num_layers == 1:
        layers_list.append(ap_spec.top_linear(in_features, 1, **kwargs_top))
    elif 1 < num_layers:
        layers_list.append(ap_spec.lower_linear(in_features, channels, **kwargs_lower))
        layers_list.append(nn.ReLU())
        for _ in range(num_layers - 1):
            layers_list.append(ap_spec.lower_linear(channels, channels, **kwargs_lower))
            layers_list.append(nn.ReLU())
        layers_list.append(ap_spec.top_linear(channels, 1, **kwargs_top))
    else:
        raise Exception()

    net = nn.Sequential(*layers_list)
    if ap_spec.top == 'gi':
        net = nn.Sequential(
            bnn.InducingAdd(inducing_data.shape[0], inducing_data=inducing_data),
            net,
            bnn.InducingRemove(inducing_data.shape[0])
        )

    return net

