import numpy as np
import torch.nn as nn
from EDGP.sampling import *


class Layer(nn.Module):
    def __init__(self, input_prop_dim=None, **kwargs):
        super().__init__(**kwargs)
        self.input_prop_dim = input_prop_dim

    def conditional_ND(self, X, full_cov=False):
        raise NotImplementedError

    def KL(self):
        return torch.as_tensor(0.0)

    def conditional_SND(self, X, full_cov=False):

        if full_cov:
            f = lambda a: self.conditional_ND(a, full_cov=full_cov)

            mean, var = torch.vmap(f)(X)

            return mean, var
        else:
            S, N, D = X.shape[:3]

            X_flat = torch.reshape(X, [S * N, D])

            mean, var = self.conditional_ND(X_flat)
            num_outputs = mean.shape[-1]

            return [torch.reshape(m, [S, N, num_outputs]) for m in [mean, var]]

    def sample_from_conditional(self, X, z=None, full_cov=False, last_layer=False):

        if last_layer:
            mean, var = self.conditional_SND(X, full_cov=full_cov)
            return None, mean, var

        lengthscales = self.kern.lengthscales
        variance = self.kern.variance
        feature = self.feature
        q_mu = self.q_mu.transpose(0, 1)

        q_sqrt = torch.tril(self.q_sqrt)

        K_zz = self.kern(feature) + torch.eye(self.feature.shape[0]).to(self.feature) * 1e-6
        Lu = torch.linalg.cholesky(K_zz)
        K_xz = self.kern(X, feature)

        alpha = torch.linalg.solve_triangular(Lu, K_xz.transpose(1, 2), upper=False)
        alpha = torch.linalg.solve_triangular(Lu.transpose(0, 1), alpha, upper=True)

        self.prior.set_parms(lengthscales, variance)
        mean_func = self.mean_function
        f_p = self.prior(X, mean_func=mean_func)
        u_p = self.prior(feature)
        self.prior.re_init()
        u_q = sample_from_gaussian(q_sqrt, num_samples=self.num_prio, mean=q_mu)  # [sample seqz hidden]

        f_q = f_p + contract('sxz,szd->sxd', alpha.transpose(1, 2), (u_q - u_p))  # [sample seqx hidden]

        samples = f_q

        return samples, None, None


class EDGPLayer(Layer):
    def __init__(self, num_prio, kern, Z, num_outputs, mean_function,
                 white=False, input_prop_dim=None, **kwargs):

        super().__init__(input_prop_dim=input_prop_dim, **kwargs)
        self.num_inducing = Z.shape[0]
        self.num_prio = num_prio


        q_mu = torch.zeros((self.num_inducing, num_outputs))
        self.q_mu = nn.Parameter(q_mu)
        q_sqrt = np.tile(np.eye(self.num_inducing)[None, :, :], [num_outputs, 1, 1]) * 1e-3
        self.q_sqrt = nn.Parameter(torch.as_tensor(q_sqrt))

        self.feature = nn.Parameter(Z)

        self.kern = kern
        self.mean_function = mean_function

        self.num_outputs = num_outputs
        self.white = white

        if not self.white:  # initialize to prior
            Ku = self.kern(Z)


            Lu = torch.linalg.cholesky(Ku + torch.eye(Z.shape[0]).to(Ku) * 1e-6)
            self.q_sqrt = nn.Parameter(torch.tile(Lu[None, :, :], [num_outputs, 1, 1]))

        self.prior = Prior_random_fourier(
            sample_shape=[num_prio],
            num_bases=2048 * 1,
            in_dims=Z.shape[-1],
        )

    def get_cholesky(self):

        Ku = self.kern(self.feature) + torch.eye(self.feature.shape[0]).to(self.feature) * 1e-6
        Lu = torch.linalg.cholesky(Ku)
        Ku_tiled = torch.tile(Ku[None, :, :], [self.num_outputs, 1, 1])
        Lu_tiled = torch.tile(Lu[None, :, :], [self.num_outputs, 1, 1])

        return Ku, Ku_tiled, Lu, Lu_tiled

    def conditional_ND(self, X, full_cov=False):

        _, Ku_tiled, Lu, _ = self.get_cholesky()

        Kuf = self.kern(self.feature, X)

        alpha = torch.linalg.solve_triangular(Lu, Kuf, upper=False)
        if not self.white:
            alpha = torch.linalg.solve_triangular(torch.transpose(Lu, 0, 1), alpha, upper=True)

        f_mean = torch.matmul(alpha.transpose(0, 1), self.q_mu)

        f_mean = f_mean + self.mean_function(X)

        alpha_tiled = torch.tile(alpha[None, :, :], [self.num_outputs, 1, 1])

        if self.white:
            f_cov = -torch.eye(self.num_inducing).to(X)[None, :, :]
        else:
            f_cov = -Ku_tiled

        if self.q_sqrt is not None:
            q_sqrt = torch.tril(self.q_sqrt)
            S = torch.matmul(q_sqrt, q_sqrt.transpose(1, 2))  # Inducing points prior covariance
            f_cov += S

        f_cov = torch.matmul(f_cov, alpha_tiled)

        if full_cov:
            delta_cov = torch.matmul(alpha_tiled.transpose(1, 2), f_cov)
            Kff = self.kern(X)
        else:
            delta_cov = torch.sum(alpha_tiled * f_cov, dim=1)
            Kff = self.kern.diag(X)

        f_cov = Kff.unsqueeze(0) + delta_cov
        f_cov = torch.permute(f_cov, list(range((f_cov.ndim)))[::-1])

        return f_mean, f_cov

    def KL(self):

        _, _, Lu, Lu_tiled = self.get_cholesky()
        q_sqrt = torch.tril(self.q_sqrt)

        KL = -0.5 * self.num_outputs * self.num_inducing

        KL -= 0.5 * torch.sum(torch.log(torch.diagonal(q_sqrt, dim1=1, dim2=2) ** 2))

        if not self.white:

            KL += torch.sum(torch.log(torch.diag(Lu))) * self.num_outputs
            KL += 0.5 * torch.sum(torch.square(torch.linalg.solve_triangular(Lu_tiled, q_sqrt, upper=False)))


            Kinv_m = torch.cholesky_solve(self.q_mu, Lu)
            KL += 0.5 * torch.sum(self.q_mu * Kinv_m)
        else:
            KL += 0.5 * torch.sum(torch.square(q_sqrt))
            KL += 0.5 * torch.sum(self.q_mu ** 2)

        return KL
