import argparse
import jax.numpy as np
from jax import random
import numpy as vnp
import gymnasium as gym
from nn.policy import get_nnpolicy
from util.print import info

def run(config):
  dataset = np.load(f'data/dataset_cartpole_small{"_test" if config.test else ""}.npy')

  ssize = 4
  asize = 1
  rsize = 1

  s, a, r, snew, done = (dataset[:, :ssize],
                         dataset[:, ssize:ssize+asize].astype(int),
                         dataset[:, ssize+asize:ssize+asize+rsize],
                         dataset[:, ssize+asize+rsize:ssize+asize+rsize+ssize],
                         dataset[:, -1:])


  dummy_key = random.PRNGKey(0)

  # Policy NN
  pi_nn_size = [4, 256, 2]

  pinetwork, pinetwork_predict = get_nnpolicy(pi_nn_size, np.sqrt(2.), np.sqrt(0.1))

  piparams = pinetwork.init(random.PRNGKey(0), np.array([0., 0., 0., 0.]))
  piparams['nn_policy'] = np.load('data/policy_cartpole.npy', allow_pickle=True).item()

  env = gym.make('CartPole-v1')

  true_returns = []
    
  for i, state in enumerate(s):
    env.reset()
    env.state = env.unwrapped.state = vnp.array(state)
    done = False
    truncated = False
    
    true_return = 0
    while not done and not truncated:
      action = pinetwork_predict(piparams, dummy_key, np.array([state]))
      action = np.argmax(action).item()
      state, r, done, truncated, _ = env.step(action)
      true_return += r
    # info(f'Obtained {int(true_return):4d} from state {i:4d}.')
    true_returns += [true_return]

  save_path = f'data/cartpole_true_returns{"_test" if config.test else ""}.npy'
  info(f'Saving returns to {save_path}...')
  np.save(save_path, true_returns)

if __name__ == '__main__':
  parser = argparse.ArgumentParser(description='Cartpole get true returns from rolling out the policy')
  parser.add_argument('--test', action='store_true', help='use the test set. Default: train set')

  args = parser.parse_args()
  run(args)