import torch.nn as nn
import layers as l


def dwp_layer(inducing_batch, width, trainable_noise=True):
    return nn.Sequential(l.SqExpKernelGram(trainable_noise=trainable_noise), l.ImprovedWishartLayer(inducing_batch, width))


def idwp_net(inducing_batch, inducing_data, inducing_targets, width, depth, in_features):
    net = nn.Sequential(
        l.FeaturesToKernelARD(inducing_batch=inducing_batch, in_features=in_features),
        l.SqExpKernelGram(trainable_noise=True, lengthscale=False),
        l.ImprovedWishartLayer(inducing_batch, width),
        *[dwp_layer(inducing_batch, width) for _ in range(depth-2)],
        l.SqExpKernelGram(),
        l.GIGP(out_features=1, inducing_targets=inducing_targets, inducing_batch=inducing_batch)
    )

    net = l.InducingWrapper(net, inducing_batch=inducing_batch, inducing_data=inducing_data)
    return nn.Sequential(net, l.NormalLearnedScale())


def dwp_layer_b(inducing_batch, width, trainable_noise=True):
    return nn.Sequential(l.SqExpKernelGram(trainable_noise=trainable_noise), l.ImprovedWishartLayerB(inducing_batch, width))


def ibdwp_net(inducing_batch, inducing_data, inducing_targets, width, depth, in_features):
    net = nn.Sequential(
        l.FeaturesToKernelARD(inducing_batch=inducing_batch, in_features=in_features),
        l.SqExpKernelGram(trainable_noise=True, lengthscale=False),
        l.ImprovedWishartLayerB(inducing_batch, width),
        *[dwp_layer_b(inducing_batch, width) for _ in range(depth-2)],
        l.SqExpKernelGram(),
        l.GIGP(out_features=1, inducing_targets=inducing_targets, inducing_batch=inducing_batch)
    )

    net = l.InducingWrapper(net, inducing_batch=inducing_batch, inducing_data=inducing_data)
    return nn.Sequential(net, l.NormalLearnedScale())

