import jax
from brax_ppo import ActorCritic
from wrappers import *
from typing import NamedTuple
import pickle 
import numpy as np

class Transition(NamedTuple):
    done: jnp.ndarray
    action: jnp.ndarray
    value: jnp.ndarray
    reward: jnp.ndarray
    log_prob: jnp.ndarray
    obs: jnp.ndarray
    unnorm_obs: jnp.ndarray
    unnorm_next_obs: jnp.ndarray
    info: jnp.ndarray

env_name = 'humanoid'
gamma = 0.99

env, env_params = ILBraxGymnaxWrapper(env_name, backend='mjx'), None
env = LogWrapper(env)
env = ClipAction(env)
env = VecEnv(env)

# load model
params = pickle.load(open(f'../experts_mjx/{env_name}/params.pkl', 'rb'))
mean = np.load(f'../experts_mjx/{env_name}/mean.npy')
var = np.load(f'../experts_mjx/{env_name}/var.npy')
network = ActorCritic(env.action_space(env_params).shape[0], jax.nn.relu)

# reset env
n_envs = 2048
rng = jax.random.PRNGKey(42)

def stagger_timesteps(rng):
    return jax.random.randint(rng, (n_envs,), 0, 1000).astype(jnp.float32)

def collect_trajectories(network, params, env, n_envs, rng):
    rng, _rng = jax.random.split(rng)
    reset_rng = jax.random.split(_rng, n_envs)
    obsv, env_state = env.reset(reset_rng, env_params)

    rng, _rng = jax.random.split(rng)
    
    def step_fn(runner_state, unused):
        env_state, last_obs, rng = runner_state

        norm_obs = (last_obs - mean) / jnp.sqrt(var)
        
        pi, _ = network.apply(params, norm_obs)
        rng, _rng = jax.random.split(rng)
        action = pi.sample(seed=_rng)
        # take random action
        action = jax.random.uniform(_rng, (n_envs, env.action_space(env_params).shape[0]), minval=-1.0, maxval=1.0)
        
        rng, _rng = jax.random.split(rng)
        rng_step = jax.random.split(_rng, n_envs)
        obsv, env_state, reward, done, info = env.step(
            rng_step, env_state, action, env_params
        )

        # where done use last_obs instead of obsv
        next_obs = jnp.where(done[...,None], last_obs, obsv)

        transition = Transition(
            done, action, None, reward, None, last_obs, last_obs, next_obs, info
        )

        runner_state = (env_state, obsv, rng)
        return runner_state, transition
    
    runner_state = (env_state, obsv, rng)
    _, transitions = jax.lax.scan(step_fn, runner_state, None, length=1000)
    return transitions

transitions = collect_trajectories(network, params, env, n_envs, rng)

print('Returns:', transitions.info['returned_episode_returns'][-1,:])
print(f'Mean: {jnp.mean(transitions.info["returned_episode_returns"][-1,:])} Std: {jnp.std(transitions.info["returned_episode_returns"][-1,:])}')
