import math
from .lop import PositiveDefiniteMatrix
import torch as t
import torch.nn as nn
from torch.distributions import Normal, Gamma
from .wishart_dist import InverseWishart

"""
To give valid, comparable values ELBOs, we need to ensure that there are no optimized model parameters --- 
all parameters need to be given priors and approximate posteriors.

While we could just use factorised distributions, they work very poorly in combination with correlated priors.  
Naively, one key reason is that the approximate posterior can't capture the prior, and in general the approximate 
posterior can't capture prior-induced correlations.  This can give rise to poor evidence-lower bounds, and even slow 
learning.

Instead, we define a small library of "probabilistic programming" primitives, which define approximate posteriors by 
conditioning on pseudo-data.
In particular, each class's forward takes latent variables describing the prior.
These could either be fixed, or they could be generated by another one of these modules, thus capturing prior-induced 
correlations.
Then, forward computes the approximate posterior by conditioning on pseudo data, or a minimal description thereof.
For instance, for a Gaussian, we only need to center and precision of the likelihood.

Note that all of these classes will sum out all but the first dimension, which is assumed to be the sample size, S.
"""


class VI_Scalar(nn.Module):
    def forward(self, S, *args):
        #check all input args have same rank
        sample_shape = t.Size([S])
        for arg in args:
            if isinstance(arg, t.Tensor):
                assert len(arg.shape) <= len(self.shape)+1
                if len(arg.shape) == len(self.shape)+1:
                    assert args.shape[0] == S
                    sample_shape = t.Size([])

        P, Q = self._forward(*args)

        x = Q.rsample(sample_shape=sample_shape)
        assert len(x.shape) == len(self.shape)+1
        assert S == x.shape[0]
        logPQ = (P.log_prob(x) - Q.log_prob(x))

        sum_idxs = [*range(1, len(logPQ.shape))]
        if 0 < len(sum_idxs):
            logPQ = logPQ.sum(sum_idxs)

        assert logPQ.shape==t.Size([S])
        self.logpq = logPQ
        return x


class VI_Normal(VI_Scalar):
    def __init__(self, shape, init_log_prec=-2., init_mean=0.):
        super().__init__()
        #precision of the likelihood 
        self.shape = shape
        self.log_L_like = nn.Parameter(init_log_prec*t.ones(shape))
        #mean/center of the likelihood
        self.m_like = nn.Parameter(init_mean*t.ones(shape))

    def _forward(self, mean, prec):
        L_like = self.log_L_like.exp()
        L_post = prec + L_like
        one = t.ones((), device=L_like.device)

        m_post = (prec*mean + L_like*self.m_like) / L_post

        return Normal(mean, t.rsqrt(one*prec)), Normal(m_post, t.rsqrt(L_post))


class VI_Gamma(VI_Scalar):
    def __init__(self, shape, init_log_shape=-2., init_log_rate=-2.):
        super().__init__()
        #precision of the likelihood 
        self.shape = shape
        self.log_shape_like = nn.Parameter(init_log_shape*t.ones(shape))
        #mean/center of the likelihood
        self.log_rate_like = nn.Parameter(init_log_rate*t.ones(shape))

    def _forward(self, shape, rate):
        shape_like = self.log_shape_like.exp()
        rate_like = self.log_rate_like.exp()

        post_shape = shape + shape_like
        post_rate  = rate  + rate_like

        return Gamma(shape, rate), Gamma(post_shape, post_rate)


class VI_Scale(nn.Module):
    """
    A scale,
    scale**2=1/prec => scale=sqrt(1/prec)
    """
    def __init__(self, shape, init_log_shape=2., init_scale=1.):
        super().__init__()
        self.shape = shape
        self.gamma = VI_Gamma(shape, init_log_shape=init_log_shape, init_log_rate=init_log_shape+2*math.log(init_scale))

    def forward(self, S, shape, scale):
        one = t.ones((), device=self.gamma.log_shape_like.device)
        x = self.gamma(S, one * shape, one * shape*scale**2)
        return t.rsqrt(x)


class VI_InverseWishart(nn.Module):
    def __init__(self, p):
        super().__init__()
        self.p = p

        self.y_features = p

        self.y = nn.Parameter(t.randn(p, self.y_features))
        #if weight == 1, then behaves as standard 
        self.log_weight = nn.Parameter(-5*t.ones(self.y_features))

    def dd_kwargs():
        return {'device': self.y.device, 'dtype': self.y.dtype}

    def forward(self, Psi, nu, S=1):
        sample_shape = t.Size([S])

        weight = self.log_weight.exp()
        post_nu = nu + weight.sum()
        
        A = (self.y * weight) @ self.y.transpose(-1, -2)
        post_Psi = (nu*Psi.full() + A)/post_nu
        post_Psi = PositiveDefiniteMatrix(post_Psi)

        P = InverseWishart(Psi, nu)
        Q = InverseWishart(post_Psi, post_nu)
        x, logQ = Q.rsample_log_prob(sample_shape=sample_shape)
        logP = P.log_prob(x)

        assert logP.shape == t.Size([S])
        assert logQ.shape == t.Size([S])

        self.logpq = logP-logQ

        return x


if __name__ == "__main__":
    p = 100
    iw1 = VI_InverseWishart(p)
    iw2 = VI_InverseWishart2(p)

    nu = 120.
    A = t.randn(p, p)
    Psi = A @ A.t() / math.sqrt(p)



