import jax
import jax.scipy as jsc
from jax import random, grad, value_and_grad, jit, vmap
import jax.numpy as jnp
from objective import VIBasic
from functools import partial
import numpy as np

class PVIRejection:
    def __init__(self, model, posterior, y, s = 10, *args, **kwargs):
        self.model = model
        self.posterior = posterior
        self.y = y
        self.s = s
        self.vg_o = self.all#jit(self.all)
        self.M = self.model.M

    def all(self, key, params):
        key1, key2 = random.split(key)
        keys = random.split(key1, (self.s,))
        def elements(key, params):
            theta_sample = self.posterior.sample(key, params, number=len(self.y))
            return self.model.log_likelihoods(theta_sample, self.y), vmap(self.posterior.get_grad, in_axes=(0, None))(theta_sample, params), self.M(theta_sample)
        logl, gradq2, M = vmap(elements, in_axes=(0, None))(keys, params)
        if len(logl.shape)>2:
            logl = logl[:,0,:]
        logM = jnp.max(jnp.log(M))
        pred_obj = jnp.sum(jsc.special.logsumexp(logl, axis=0) - jnp.log(self.s))
        
        t_key, key2 = random.split(key2)
        t = random.uniform(t_key, (self.s, len(self.y)))
        rejection_map = jnp.log(t) + logM < logl
        flags = jnp.expand_dims(jnp.sum(rejection_map, axis=0,) > 0, 1)
        num = 1
        while jnp.sum(flags) < len(self.y):
            num += 1
            key1, key2 = random.split(key2)
            keys = random.split(key1, (self.s,))
            logl_new, gradq2_new, _ = vmap(elements, in_axes=(0, None))(keys, params)
            if len(logl_new.shape) > 2:
                logl_new = logl_new[:, 0, :]
            t_key, key2 = random.split(key2)
            t = random.uniform(t_key, (self.s, len(self.y)))
            rejection_map_new = jnp.log(t) + logM < logl_new
            #print(rejection_map.shape)
            flags_new = jnp.expand_dims(jnp.sum(rejection_map_new, axis=0, ) > 0, 1)
            flags_addition = flags_new & (~flags)
            #print(np.sum(flags), np.sum(flags_addition))
            flags = flags | flags_new
            #print(flags_new.shape, flags.shape)
            #print(rejection_map.shape, rejection_map_new.shape, flags_addition.shape)
            #print(gradq2.shape, gradq2_new.shape)
            rejection_map = rejection_map | (rejection_map_new & flags_addition[..., 0])
            gradq2 = gradq2 * (~flags_addition) + gradq2_new * flags_addition
            if num > 10000:
                break
        #print(num)
        #print(logM, t.shape, logl)s
        weights = jnp.expand_dims(rejection_map, 2)
        #expanded_gradq = jnp.expand_dims(gradq2, 1)
        #print(expanded_gradq.shape)
        weights = weights / jnp.maximum(jnp.sum(weights, axis=0,), 1)
        #print(weights[:,0:2,...], gradq2[:,0:2,...])
        #print(weights.shape, gradq2.shape)
        #print(weights.shape)
        #print(jnp.sum(rejection_map,axis=0))
        pred_grad = len(self.y)/jnp.maximum(jnp.sum(flags), 1) * jnp.sum(flags * jnp.sum(weights * gradq2, axis=0), axis=0)
        #print(params, pred_grad)
        #print(flags.shape, jnp.sum(weights * expanded_gradq, axis=0).shape, jnp.maximum(jnp.sum(weights, axis=0,), 1).shape)
        return pred_obj, pred_grad#, (jnp.sum(weights, axis=0,), rejection_map, jnp.sum(weights))

    def value_and_grad(self, key, params):
        return self.vg_o(key, params)