import argparse
import jax.numpy as np
from jax import random
import gymnasium as gym

from nn.policy import get_nnpolicy


def run(config):
  dummy_key = random.PRNGKey(0)

  ssize = 4
  asize = 1
  rsize = 1

  seed_train = config.seed

  # Policy NN
  pi_nn_size = [4, 256, 2]

  pinetwork, _ = get_nnpolicy(pi_nn_size, np.sqrt(2.), np.sqrt(0.1))

  piparams = pinetwork.init(dummy_key, np.array([0., 0., 0., 0.]))
  piparams['nn_policy'] = np.load('data/policy_cartpole.npy', allow_pickle=True).item()

  env = gym.make('CartPole-v1')


  eps = config.epsilon

  dataset_size = config.dataset_size
  dataset_dim = ssize + asize + rsize + ssize + 1

  D = np.zeros((dataset_size, dataset_dim))
  Dptr = 0

  num_epochs = dataset_size // 200


  s, _ = env.reset()

  key = random.PRNGKey(seed_train)

  episode_lengths = []

  for epoch in range(num_epochs):
      done = False
      truncated = False
      episode_length = 0

      while not (done or truncated):
          # Epsilon greedy action
          key, subkey = random.split(key)
          eps_prob = random.uniform(subkey)
          if eps_prob > eps:
              a = pinetwork.apply(piparams, dummy_key, s)
              a = np.argmax(a, axis=-1, keepdims=True).item()
              # a = qnetwork.apply(qparams_main, dummy_key, s)
          else:
              a = env.action_space.sample()
          
          # Interaction with env
          snew, r, done, truncated, _ = env.step(a)
          episode_length += 1

          if Dptr < dataset_size:
              D = D.at[Dptr].set((*s, a, r, *snew, done))
              Dptr += 1
          
          s = snew
      s, _ = env.reset()


      episode_lengths += [episode_length]
      
      print(f'Epoch {epoch:3}   Ep.len: {episode_length:3} (avg {sum(episode_lengths)/(epoch+1):6.2f})')
      
  env.close()


  np.save('data/dataset_cartpole.npy', D)
  np.save('data/dataset_cartpole_small.npy', D[:2048, :])
  np.save('data/dataset_cartpole_small_test.npy', D[2048:4096, :])


if __name__ == '__main__':
  parser = argparse.ArgumentParser(description='Cartpole environment collect dataset from a good policy')
  parser.add_argument('--seed', default=10, type=int, help='random seed for choosing an epsilon-greedy action every step')
  parser.add_argument('--epsilon', default=0., type=float, help='epsilon value for choosing an epsilon-greedy action every step')
  parser.add_argument('--dataset_size', default=10240, type=int, help='number of transitions to collect')

  args = parser.parse_args()

  if args.dataset_size < 1024:
    raise ValueError('dataset_size must be 1024 or greater')

  run(args)