from DSDGP.layer_initializations import init_layers_linear
from DSDGP.utils import Gaussian
from EDGP.Kernels import *
from utils import *


class DeepGPBase(OptimModule):

    def __init__(self, likelihood, layers, num_samples=1, **kwargs):
        super().__init__()
        self.num_samples = num_samples
        self.likelihood = likelihood
        self.layers = nn.ModuleList(layers)


    def propagate(self, X, full_cov=False, num_samples=1, zs=None):
        sX = X.unsqueeze(0).repeat(num_samples, 1, 1)

        Fs, Fmeans, Fvars = [], [], []
        F = sX
        zs = zs or [None, ] * len(self.layers)
        for layer, z in zip(self.layers, zs):
            F, Fmean, Fvar = layer.sample_from_conditional(F, z=z, full_cov=full_cov)

            Fs.append(F)
            Fmeans.append(Fmean)
            Fvars.append(Fvar)
        return Fs, Fmeans, Fvars

    def forward(self, X):
        if X.ndim == 3:
            X = X.reshape(X.shape[0], -1)
        Fmeans, Fvars = self.predict_f_samples(X, num_samples=self.num_samples)
        return torch.mean(Fmeans, dim=0), torch.mean(Fvars, dim=0)

    def predict_f_samples(self, predict_at, num_samples, full_cov=False):
        Fs, Fmeans, Fvars = self.propagate(predict_at, full_cov=full_cov,
                                           num_samples=num_samples)
        return Fmeans[-1], Fvars[-1]

    def predict_all_layers(self, predict_at, num_samples, full_cov=False):
        return self.propagate(predict_at, full_cov=full_cov,
                              num_samples=num_samples)

    def predict_y(self, predict_at, num_samples):
        Fmean, Fvar = self.predict_f_samples(predict_at, num_samples=num_samples,
                                             full_cov=False)
        return self.likelihood.predict_mean_and_var(Fmean, Fvar)

    def predict_log_density(self, data, num_samples):
        Fmean, Fvar = self.predict_f_samples(data[0], num_samples=num_samples,
                                             full_cov=False)
        l = self.likelihood.predict_log_density(Fmean, Fvar, data[1])

        log_num_samples = torch.log(torch.as_tensor(self.num_samples).to(data[0]))

        return torch.logsumexp(l - log_num_samples, dim=0)

    def expected_data_log_likelihood(self, X, Y):
        if X.ndim == 3:
            X = X.reshape(X.shape[0], -1)
            Y = Y[:, -1]
        F_mean, F_var = self.predict_f_samples(X, num_samples=self.num_samples,
                                               full_cov=False)
        var_exp = self.likelihood.variational_expectations(F_mean, F_var, Y)  # Shape [S, N, D]

        return torch.mean(var_exp, dim=0)  # Shape [N, D]

    def fit(self, X, Y):
        if X.ndim == 3:
            X = X.reshape(X.shape[0], -1)
            Y = Y[:, -1]
        likelihood = torch.sum(self.expected_data_log_likelihood(X, Y))
        scale = 1.
        KL = sum([layer.KL() for layer in self.layers])

        return -(scale * likelihood - KL)


class DeepGP(DeepGPBase):
    def __init__(self, num_inducing, kernels, layer_sizes, likelihood,
                 num_outputs=1, mean_function=None, whiten=False,
                 num_samples=1):
        layers = init_layers_linear(num_inducing, kernels, layer_sizes,
                                    mean_function=mean_function,
                                    num_outputs=num_outputs,
                                    whiten=whiten)
        super().__init__(likelihood, layers, num_samples)


def DSDGP(in_dims, n_layers, num_inducing, num_prio, hidden_dims):
    layer_sizes = [in_dims]
    kernels = [SquaredExponential(in_list=in_dims)]
    for l in range(n_layers - 1):
        kernels.append(SquaredExponential(in_list=hidden_dims))
        layer_sizes.append(hidden_dims)
    model = DeepGP(num_inducing, kernels, layer_sizes, Gaussian(),
                   num_outputs=1, num_samples=num_prio)

    return model
