#import safety_gym
import gym
import mujoco_maze
'''
from utils_new import mini_batch_train
from utils_new import mini_batch_train_plain
from utils_new import mini_batch_train_frames
'''
from utils_pt_full import mini_batch_train
from utils_pt_full import mini_batch_train_plain
from utils_pt_full import mini_batch_train_frames


from ddpg import DDPGAgent
import matplotlib.pyplot as plt
import argparse
import time
import torch
import numpy as np

env  = gym.make("PointUMaze-v0")
max_frames=100000
#max_episodes = 10
max_steps = 50
batch_size = 32
maxsteps_total=200000
tracelim=100
tracebuffersize=500
gamma = 0.95
tau = 1e-2
buffer_maxlen = 100000
critic_lr = 1e-5
actor_lr = 1e-5
stepcount=0
Nruns=10
#retall=[]
for ig, goal in enumerate(env.env._task.goals):
	gp=goal.pos
	#time.sleep(0.5)
print("Goal pos:")
print(gp)
#time.sleep(4)
priors=0


if priors==1:
	agent.critic.load_state_dict(torch.load('criticsaved.pt'))
	agent.actor.load_state_dict(torch.load('actorsaved.pt'))

for runno in range(Nruns):
    agent = DDPGAgent(env, gamma, tau, buffer_maxlen, critic_lr, actor_lr)
    episode_rewards,stepcount,allrewards = mini_batch_train(env, agent, max_steps, batch_size,maxsteps_total,gp,stepcount,runno)
    allrewards=allrewards[0:(maxsteps_total)]
    if runno==0:
        retall=allrewards
    else:
        print(np.shape(retall))
        print(np.shape(allrewards))
        retall=np.vstack((retall,allrewards))
np.savez("pointmaze_results_final_"+str(Nruns)+".npy.npz",episode_rewards,retall)
