import os
import numpy as np
import jax
import jax.numpy as jnp
import functools
import pdb

from data_augmentation.diffuser.iql.common import Model
from data_augmentation.diffuser.iql.value_net import DoubleCritic

def load_q(env, loadpath, hidden_dims=(256, 256), seed=42):
    print(f'[ utils/iql ] Loading Q: {loadpath}')
    observations = env.observation_space.sample()[np.newaxis]
    actions = env.action_space.sample()[np.newaxis]

    rng = jax.random.PRNGKey(seed)
    rng, key = jax.random.split(rng)

    critic_def = DoubleCritic(hidden_dims)
    critic = Model.create(critic_def,
                          inputs=[key, observations, actions])

    ## allows for relative paths
    loadpath = os.path.expanduser(loadpath)
    critic = critic.load(loadpath)
    return critic

class JaxWrapper:

    def __init__(self, env, loadpath, *args, **kwargs):
        self.model = load_q(env, loadpath)

    @functools.partial(jax.jit, static_argnames=('self'), device=jax.devices('cpu')[0])
    def forward(self, xs):
        Qs = self.model(*xs)
        Q = jnp.minimum(*Qs)
        return Q

    def __call__(self, *xs):
        Q = self.forward(xs)
        return np.array(Q)
