import jax
import jax.scipy as jsc
from jax import random, grad, value_and_grad, jit
import jax.numpy as jnp
from objective import VIBasic

class PVIBasic:
    def __init__(self, model, posterior, y, *args, **kwargs):
        self.model = model
        self.posterior = posterior
        self.y = y
        self.vg_o = jit(value_and_grad(self.objective, argnums=1))

    def objective(self, key, params):
        mu1, scale1 = self.posterior.posterior_parameters(params)
        mu2, scale2 = self.model.likelihood_parameters(mu1)
        mu = mu1
        scale = jnp.sqrt(scale2 * scale2 + scale1 * scale1)
        pred_obj = jnp.sum(jsc.stats.norm.logpdf(self.y, mu, scale))
        return pred_obj

    def value_and_grad(self, key, params):
        return self.vg_o(key, params)