from functools import partial

from jax import Array
from jax import jit
import jax.numpy as jnp
import jax.random as random

from .gaussian_density_ratios import IsotropicGaussianDensityRatio
from .util import trunc_gumbel


def slow_gauss_a_star(seed: int,
                      gauss_dr: IsotropicGaussianDensityRatio,
                      max_iter: int = 1_000):

  key = random.PRNGKey(seed)

  log_neg_time = jnp.inf
  best_lower_bound = -jnp.inf
  best_x = None
  best_k = None

  for k in range(max_iter):
    log_time_key, x_key = random.split(random.fold_in(key, k))

    log_neg_time = trunc_gumbel(log_time_key, shape=(), loc=0., bound=log_neg_time)
    x = gauss_dr.sample_p(x_key, shape=())
    log_r = jnp.log(gauss_dr.ratio(x))


    lower_bound = log_neg_time + log_r

    if best_lower_bound < lower_bound:
      best_lower_bound = lower_bound
      best_x = x
      best_k = k

    if best_lower_bound >= log_neg_time + gauss_dr.inf_div:
      return best_x, best_k

  else:
    raise ValueError("did not terminate!")
