import os
# run jax on cpu
os.environ['JAX_PLATFORM_NAME'] = 'cpu'
import sys
sys.path.append("..")
import jax
from minatar_ppo import ActorCritic
import gymnax
from typing import NamedTuple
import jax.numpy as jnp
import jax.nn as nn
from wrappers import FlattenObservationWrapper, LogWrapper, VecEnv
# jax.config.update('jax_platform_name', 'cpu')

class Transition(NamedTuple):
    done: jnp.ndarray
    action: jnp.ndarray
    value: jnp.ndarray
    reward: jnp.ndarray
    log_prob: jnp.ndarray
    obs: jnp.ndarray
    unnorm_obs: jnp.ndarray
    unnorm_next_obs: jnp.ndarray
    info: jnp.ndarray

env_name = 'Asterix-MinAtar'
time_limit = 2500 if env_name == 'Freeway-MinAtar' else 1000
gamma = 0.99

env, env_params = gymnax.make(env_name)
# env = AbsorbAfterDoneWrapper(env) 
# env = TimeLimitAutoResetWrapper(env, time_limit)
env = FlattenObservationWrapper(env)
env = LogWrapper(env) 
env = VecEnv(env)

import pickle 

num_actions = env.action_space(env_params).n
# load model
params = pickle.load(open(f'/workspace/il_discovery/src/experts_new/{env_name}/params.pkl', 'rb'))
network = ActorCritic(
    num_actions, activation=nn.relu
)

n_envs = 20
rng = jax.random.PRNGKey(0)

def stagger_timesteps(rng):
    return jax.random.randint(rng, (n_envs,), 0, 1000).astype(jnp.float32)

def collect_trajectories(network, params, env, n_envs, rng):
    rng, _rng = jax.random.split(rng)
    reset_rng = jax.random.split(_rng, n_envs)
    obsv, env_state = env.reset(reset_rng, env_params)
    rng, _rng = jax.random.split(rng)

    def step_fn(runner_state, unused):
        env_state, last_obs, rng = runner_state

        # SELECT ACTION
        rng, _rng = jax.random.split(rng)
        pi, value = network.apply(params, last_obs)
        action = pi.sample(seed=_rng)
        # random action
        # action = jax.random.randint(rng, (n_envs,), 0, num_actions)
        log_prob = pi.log_prob(action)

        # STEP ENV
        rng, _rng = jax.random.split(rng)
        rng_step = jax.random.split(_rng, n_envs)
        obsv, env_state, reward, done, info = env.step(rng_step, env_state, action, env_params)

        # where done use last_obs instead of obsv
        next_obs = jnp.where(done[...,None], last_obs, obsv)

        transition = Transition(
            done, action, value, reward, log_prob, last_obs, last_obs, next_obs, info
        )

        runner_state = (env_state, obsv, rng)
        return runner_state, transition
    
    runner_state = (env_state, obsv, rng)
    if env_name == 'Freeway-MinAtar':
        n_steps = 2500
    else:
        n_steps = 1000
    _, transitions = jax.lax.scan(step_fn, runner_state, None, length=n_steps)
    return transitions

transitions = collect_trajectories(network, params, env, n_envs, rng)

print('Returns:', transitions.info['returned_episode_returns'][-1,:])
print(f'Mean: {jnp.mean(transitions.info["returned_episode_returns"][-1,:])} Std: {jnp.std(transitions.info["returned_episode_returns"][-1,:])}')

# sort the transitions based on final reward (in descending order)
sorted_idx = jnp.argsort(transitions.info['returned_episode_returns'][-1, :])[::-1]

# For each array of shape (1000, n_envs), sort along the last axis.
def sort_leaf(x):
    return x[:, sorted_idx][:, :10]

transitions_sorted = jax.tree_util.tree_map(sort_leaf, transitions)

print('Returns:', transitions_sorted.info['returned_episode_returns'][-1, :])
print(f'Mean: {jnp.mean(transitions_sorted.info["returned_episode_returns"][-1, :])} Std: {jnp.std(transitions_sorted.info["returned_episode_returns"][-1, :])}')

pickle.dump(transitions_sorted, open(f'/workspace/il_discovery/src/experts_new/{env_name}/transitions_sorted.pkl', 'wb'))