import numpy as np

import torch
import torch.nn as nn

from DSDGP.utils import reparameterize


class Layer(nn.Module):
    def __init__(self, input_prop_dim=None, **kwargs):
        super().__init__(**kwargs)
        self.input_prop_dim = input_prop_dim

    def conditional_ND(self, X, full_cov=False):
        raise NotImplementedError

    def KL(self):
        return torch.as_tensor(0.0)

    def conditional_SND(self, X, full_cov=False):
        if full_cov:
            f = lambda a: self.conditional_ND(a, full_cov=full_cov)

            mean, var = torch.vmap(f)(X)

            return mean, var
        else:
            S, N, D = X.shape[:3]

            X_flat = torch.reshape(X, [S * N, D])

            mean, var = self.conditional_ND(X_flat)
            num_outputs = mean.shape[-1]

            return [torch.reshape(m, [S, N, num_outputs]) for m in [mean, var]]

    def sample_from_conditional(self, X, z=None, full_cov=False):

        mean, var = self.conditional_SND(X, full_cov=full_cov)

        S = X.shape[0]
        N = X.shape[1]

        D = mean.shape[-1]

        mean = torch.reshape(mean, (S, N, D))
        if full_cov:
            var = torch.reshape(var, (S, N, N, D))
        else:
            var = torch.reshape(var, (S, N, D))

        if z is None:
            z = torch.randn(mean.shape).to(mean)
        samples = reparameterize(mean, var, z, full_cov=full_cov)

        if self.input_prop_dim:
            shape = [X.shape[0], X.shape[1], self.input_prop_dim]
            X_prop = torch.reshape(X[:, :, :self.input_prop_dim], shape)

            samples = torch.concat([X_prop, samples], 2)
            mean = torch.concat([X_prop, mean], 2)

            if full_cov:
                shape = (X.shape[0], X.shape[1], X.shape[1], var.shape[3])
                zeros = torch.zeros(shape)
                var = torch.concat([zeros, var], 3)
            else:
                var = torch.concat([torch.zeros_like(X_prop), var], 2)

        return samples, mean, var


class SVGPLayer(Layer):
    def __init__(self, kern, Z, num_outputs, mean_function,
                 white=False, input_prop_dim=None, **kwargs):
        super().__init__(input_prop_dim=input_prop_dim, **kwargs)
        self.num_inducing = Z.shape[0]

        q_mu = torch.zeros((self.num_inducing, num_outputs))
        self.q_mu = nn.Parameter(q_mu)
        q_sqrt = np.tile(np.eye(self.num_inducing)[None, :, :], [num_outputs, 1, 1]) * 1e-3
        self.q_sqrt = nn.Parameter(torch.as_tensor(q_sqrt))

        self.feature = nn.Parameter(Z)

        self.kern = kern
        self.mean_function = mean_function

        self.num_outputs = num_outputs
        self.white = white

        if not self.white:  # initialize to prior
            Ku = self.kern(Z)

            Lu = torch.linalg.cholesky(Ku + torch.eye(Z.shape[0]).to(Ku) * 1e-6)
            self.q_sqrt = nn.Parameter(torch.tile(Lu[None, :, :], [num_outputs, 1, 1]))

    def get_cholesky(self):
        Ku = self.kern(self.feature) + torch.eye(self.feature.shape[0]).to(self.feature) * 1e-6
        Lu = torch.linalg.cholesky(Ku)
        Ku_tiled = torch.tile(Ku[None, :, :], [self.num_outputs, 1, 1])
        Lu_tiled = torch.tile(Lu[None, :, :], [self.num_outputs, 1, 1])

        return Ku, Ku_tiled, Lu, Lu_tiled

    def conditional_ND(self, X, full_cov=False):
        _, Ku_tiled, Lu, _ = self.get_cholesky()

        Kuf = self.kern(self.feature, X)

        alpha = torch.linalg.solve_triangular(Lu, Kuf, upper=False)
        if not self.white:
            alpha = torch.linalg.solve_triangular(torch.transpose(Lu, 0, 1), alpha, upper=True)

        f_mean = torch.matmul(alpha.transpose(0, 1), self.q_mu)

        f_mean = f_mean + self.mean_function(X)

        alpha_tiled = torch.tile(alpha[None, :, :], [self.num_outputs, 1, 1])

        if self.white:
            f_cov = -torch.eye(self.num_inducing).to(X)[None, :, :]
        else:
            f_cov = -Ku_tiled

        if self.q_sqrt is not None:
            q_sqrt = torch.tril(self.q_sqrt)
            S = torch.matmul(q_sqrt, q_sqrt.transpose(1, 2))  # Inducing points prior covariance
            f_cov += S

        f_cov = torch.matmul(f_cov, alpha_tiled)

        if full_cov:
            delta_cov = torch.matmul(alpha_tiled.transpose(1, 2), f_cov)
            Kff = self.kern(X)
        else:
            delta_cov = torch.sum(alpha_tiled * f_cov, dim=1)
            Kff = self.kern.diag(X)

        f_cov = Kff.unsqueeze(0) + delta_cov
        f_cov = torch.permute(f_cov, list(range((f_cov.ndim)))[::-1])

        return f_mean, f_cov

    def KL(self):

        _, _, Lu, Lu_tiled = self.get_cholesky()
        q_sqrt = torch.tril(self.q_sqrt)

        KL = -0.5 * self.num_outputs * self.num_inducing
        KL -= 0.5 * torch.sum(torch.log(torch.diagonal(q_sqrt, dim1=1, dim2=2) ** 2))

        if not self.white:
            KL += torch.sum(torch.log(torch.diag(Lu))) * self.num_outputs
            KL += 0.5 * torch.sum(torch.square(torch.linalg.solve_triangular(Lu_tiled, q_sqrt, upper=False)))

            Kinv_m = torch.cholesky_solve(self.q_mu, Lu)
            KL += 0.5 * torch.sum(self.q_mu * Kinv_m)
        else:
            KL += 0.5 * torch.sum(torch.square(q_sqrt))
            KL += 0.5 * torch.sum(self.q_mu ** 2)

        return KL
