
import torch
import torch.nn as nn

from EDGP.layers import EDGPLayer


class Linear_mean_func:
    def __init__(self, A):
        self.A = A

    @torch.no_grad()
    def __call__(self, x):
        return x @ self.A


def init_layers_linear(num_samples, num_inducing, kernels, layer_sizes, mean_function=None,
                       num_outputs=None, Layer=EDGPLayer, whiten=False):
    assert num_outputs is not None
    layers = []
    dim_out = layer_sizes[0]
    for in_idx, kern_in in enumerate(kernels[:-1]):
        dim_in = layer_sizes[in_idx]
        dim_out = layer_sizes[in_idx + 1]

        Z_running = torch.randn(num_inducing, dim_in)
        mf = nn.Linear(dim_in, dim_out)

        layers.append(Layer(num_samples, kern_in, Z_running, dim_out, mf, white=whiten))

    Z_running = torch.randn(num_inducing, dim_out)
    layers.append(Layer(num_samples, kernels[-1], Z_running, num_outputs, nn.Linear(layer_sizes[-1], num_outputs),
                        white=whiten))
    return layers
