import os
import numpy as np
import jax
import jax.numpy as jnp
import pdb

#os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"] = ".10"
os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"] = "false"

from implicit_q_learning.common import Model
import implicit_q_learning.value_net as value_net

def load_q(env, loadpath, hidden_dims=(256, 256), seed=42):
    observations = env.observation_space.sample()[np.newaxis]
    actions = env.action_space.sample()[np.newaxis]

    rng = jax.random.PRNGKey(seed)
    rng, key = jax.random.split(rng)

    critic_def = value_net.DoubleCritic(hidden_dims)
    critic = Model.create(critic_def,
                          inputs=[key, observations, actions])

    ## allows for relative paths
    loadpath = os.path.expanduser(loadpath)
    critic = critic.load(loadpath)
    return critic

# def to_np(*xs):
#     return [x.detach().cpu().numpy() for x in xs]

# def grad(model, state, action):

#     def fn(state, action):
#         return jnp.sum(model(state, action))

#     grad = jax.grad(fn, argnums=(0, 1))
#     dfds, dfda = grad((state, action))

class JaxWrapper:

    def __init__(self, env, loadpath, *args, **kwargs):
        self.model = load_q(env, loadpath)

    def __call__(self, *xs):
        Qs = self.model(*xs)
        Q = jnp.minimum(*Qs)
        return np.array(Q)
