
import torch
import torch.nn as nn
import abc


class Likelihood(nn.Module, metaclass=abc.ABCMeta):
    def __init__(self, input_dim: int, latent_dim: int, observation_dim: int):
        super().__init__()
        self.input_dim = input_dim
        self.latent_dim = latent_dim
        self.observation_dim = observation_dim

    def _check_last_dims_valid(self, F, Y):

        self._check_latent_dims(F)
        self._check_data_dims(Y)


    def _check_latent_dims(self, F):

        assert F.shape[-1] == self.latent_dim

    def _check_data_dims(self, Y):

        assert Y.shape[-1] == self.observation_dim

    def log_prob(self, F, Y):

        self._check_last_dims_valid(F, Y)
        res = self._log_prob(F, Y)
        return res

    @abc.abstractmethod
    def _log_prob(self, F, Y):
        raise NotImplementedError

    def conditional_mean(self, X, F):

        self._check_latent_dims(F)
        expected_Y = self._conditional_mean(F)
        self._check_data_dims(expected_Y)
        return expected_Y

    def _conditional_mean(self, F):
        raise NotImplementedError

    def conditional_variance(self, F):

        self._check_latent_dims(F)
        var_Y = self._conditional_variance(F)
        self._check_data_dims(var_Y)
        return var_Y

    def _conditional_variance(self, F):
        raise NotImplementedError

    def predict_mean_and_var(self, Fmu, Fvar):

        self._check_latent_dims(Fmu)
        self._check_latent_dims(Fvar)
        mu, var = self._predict_mean_and_var(Fmu, Fvar)
        self._check_data_dims(mu)
        self._check_data_dims(var)
        return mu, var

    @abc.abstractmethod
    def _predict_mean_and_var(self, Fmu, Fvar):
        raise NotImplementedError

    def predict_log_density(self, Fmu, Fvar, Y):

        assert Fmu.shape == Fvar.shape

        self._check_last_dims_valid(Fmu, Y)
        res = self._predict_log_density(Fmu, Fvar, Y)
        return res

    @abc.abstractmethod
    def _predict_log_density(self, Fmu, Fvar, Y):
        raise NotImplementedError

    def variational_expectations(self, Fmu, Fvar, Y):

        assert Fmu.shape == Fvar.shape


        self._check_last_dims_valid(Fmu, Y)
        ret = self._variational_expectations(Fmu, Fvar, Y)
        return ret

    @abc.abstractmethod
    def _variational_expectations(self, Fmu, Fvar, Y):
        raise NotImplementedError


class Gaussian(nn.Module):


    def __init__(self, variance=1.0, **kwargs):

        super().__init__(**kwargs)
        self.variance = nn.Parameter(torch.as_tensor(variance))

    def predict_mean_and_var(self, Fmu, Fvar):
        variance = self.variance * (self.variance > 0)
        return torch.clone(Fmu), Fvar + variance

    def predict_log_density(self, Fmu, Fvar, Y):
        variance = self.variance * (self.variance > 0)
        var = Fvar + variance
        Y = torch.as_tensor(Y)
        feat = -0.5 * (torch.log(torch.as_tensor(2 * torch.pi)) + torch.log(var) + torch.square(Y - Fmu) / var)
        return torch.sum(feat, dim=-1)

    def variational_expectations(self, Fmu, Fvar, Y):
        variance = self.variance * (self.variance > 0)
        if Fvar.ndim == 4:
            Fvar = torch.diagonal(Fvar, dim1=1, dim2=2)
        return torch.sum(
            -0.5 * torch.log(torch.as_tensor(2 * torch.pi))
            - 0.5 * torch.log(variance)
            - 0.5 * ((Y - Fmu) ** 2 + Fvar) / variance,
            dim=-1,
        )


class BroadcastingLikelihood(Likelihood):

    def __init__(self, likelihood):
        super().__init__(likelihood.input_dim, likelihood.latent_dim, likelihood.observation_dim)
        self.likelihood = likelihood


        self.needs_broadcasting = False

    def _broadcast(self, f, vars_SND, vars_ND):
        if not self.needs_broadcasting:
            return f(vars_SND, [torch.unsqueeze(v, 0) for v in vars_ND])
        else:
            S, N, D = [vars_SND[0].shape[i] for i in range(3)]
            vars_tiled = [torch.tile(x[None, :, :], [S, 1, 1]) for x in vars_ND]

            flattened_SND = [torch.reshape(x, [S * N, D]) for x in vars_SND]
            flattened_tiled = [torch.reshape(x, [S * N, -1]) for x in vars_tiled]

            flattened_result = f(flattened_SND, flattened_tiled)
            if isinstance(flattened_result, tuple):
                return [torch.reshape(x, [S, N, -1]) for x in flattened_result]
            else:
                return torch.reshape(flattened_result, [S, N])

    def _variational_expectations(
            self, Fmu, Fvar, Y
    ):
        f = lambda vars_SND, vars_ND: self.likelihood.variational_expectations(
            vars_SND[0], vars_SND[1], vars_ND[0])
        return self._broadcast(f, [Fmu, Fvar], [Y])

    def _log_prob(self, F, Y):
        f = lambda vars_SND, vars_ND: self.likelihood.logp(vars_SND[0], vars_ND[0])
        return self._broadcast(f, [F], [Y])

    def _conditional_mean(self, F):
        f = lambda vars_SND, vars_ND: self.likelihood.conditional_mean(vars_SND[0])
        return self._broadcast(f, [F], [])

    def _conditional_variance(self, F):
        f = lambda vars_SND, vars_ND: self.likelihood.conditional_variance(vars_SND[0])
        return self._broadcast(f, [F], [])

    def _predict_mean_and_var(
            self, Fmu, Fvar
    ):
        f = lambda vars_SND, vars_ND: self.likelihood.predict_mean_and_var(vars_SND[0], vars_SND[1])
        return self._broadcast(f, [Fmu, Fvar], [])

    def _predict_log_density(
            self, Fmu, Fvar, Y
    ):
        f = lambda vars_SND, vars_ND: self.likelihood.predict_density(vars_SND[0], vars_SND[1], vars_ND[0])
        return self._broadcast(f, [Fmu, Fvar], [Y])


def reparameterize(mean, var, z, full_cov=False):

    if var is None:
        return mean

    if full_cov is False:
        return mean + z * (var + 1e-6) ** 0.5
    else:

        S, N, D = mean.shape[0], mean.shape[1], mean.shape[2]  # var is SNND
        mean = torch.permute(mean, (0, 2, 1))  # SND -> SDN
        var = torch.permute(var, (0, 3, 1, 2))  # SNND -> SDNN
        I = 1e-6 * torch.eye(N).to(mean)[None, None, :, :]  # 11NN
        chol = torch.linalg.cholesky(var + I)  # SDNN
        z_SDN1 = torch.permute(z, [0, 2, 1])[:, :, :, None]  # SND->SDN1
        f = mean + torch.matmul(chol, z_SDN1)[:, :, :, 0]  # SDN(1)
        return torch.permute(f, (0, 2, 1))  # SND

