import bnn
import torch as t
import torch.nn as nn


class resgp(nn.Module):
    def __init__(self, lin, channels, depth, kwargs_lower, kwargs_top):
        super().__init__()
        self.lins = nn.ModuleList([lin(channels, channels, **kwargs_lower) for _ in range(depth - 1)])
        self.last_lin = lin(channels, 1, **kwargs_top)
        self.logits = nn.Parameter(t.arange(depth - 1, dtype=t.float32) - (depth - 1))

    def forward(self, x):
        # Use residual connections as in Salimbeni et al. 2017
        out = x
        for i in range(len(self.lins)):
            out = self.lins[i](out)
            out = self.logits[i].exp() * out + x

        return self.last_lin(out)


def net(ap_spec, inducing_data, in_features, num_layers, kwargs_lower, kwargs_top, channels=50):
    net = resgp(ap_spec.lower_linear, channels, num_layers, kwargs_lower, kwargs_top)

    if ap_spec.top == 'gigp':
        net = bnn.InducingWrapper(net, inducing_data.shape[0], inducing_data=inducing_data)

    net = nn.Sequential(net, bnn.NormalLearnedScale())

    return net