import torch.nn as nn
import layers as l


def dgp_layer(inducing_batch, width, trainable_noise=True, neuron_prec=False):
        return nn.Sequential(l.SqExpKernelFeatures(inducing_batch=inducing_batch, trainable_noise=trainable_noise),
                            l.GIGP(out_features=width, inducing_batch=inducing_batch, neuron_prec=neuron_prec))


def dgp_net(inducing_batch, inducing_data, inducing_targets, width, depth, in_features, neuron_prec=True):
    net = nn.Sequential(
            l.SqExpKernelFeaturesARD(inducing_batch=inducing_batch, in_features=in_features, trainable_noise=True),
            l.GIGP(out_features=width, inducing_batch=inducing_batch, neuron_prec=neuron_prec),
            *[dgp_layer(inducing_batch, width, neuron_prec=neuron_prec) for _ in range(depth-2)],
            l.SqExpKernelFeatures(inducing_batch=inducing_batch),
            l.GIGP(out_features=1, inducing_targets=inducing_targets, inducing_batch=inducing_batch, neuron_prec=neuron_prec)
        )

    net = l.InducingWrapper(net, inducing_batch=inducing_batch, inducing_data=inducing_data)

    return nn.Sequential(net, l.NormalLearnedScale())
