import functools
import jax
import jax.numpy as jnp  
from typing import Dict

class UncertaintyLearner():
    
    def step(self):
        pass

    def eval(self, x: jnp.ndarray, a: jnp.ndarray, info) -> Dict:
        pass

    def uncertainty(self, x: jnp.ndarray) -> jnp.ndarray:
        pass

    def save(self):
        return self._state

    def restore(self, state):
        self._state = state


class UncertaintyWrapper(UncertaintyLearner):

    def __init__(self, uncertainty_fn, behavior_fn):
        self.unc_fn = uncertainty_fn
        self.b_fn = behavior_fn

    @functools.partial(jax.jit, static_argnums=0)
    def uncertainty(self, x: jnp.ndarray) -> jnp.ndarray:
        unc_sa = self.unc_fn(x)
        unc_s = jnp.mean(unc_sa, axis=-1, keepdims=True)

        probs = self.b_fn(x)

        unc = unc_s / (jnp.sqrt(probs) + 1e-5)
        return unc


    