import os
import pickle
import tensorflow as tf
import numpy as np
import gym
import joblib
import argparse
import matplotlib.pyplot as plt
from run_dagger import *


parser = argparse.ArgumentParser()
parser.add_argument("--load_path", type=str, default="swimmer_student")
parser.add_argument("--render", type=bool, default=True)
parser.add_argument("--env", type=str, default='Swimmer-v2')
parser.add_argument("--num_test", type=int, default=50)
args = parser.parse_args()

render = args.render
num_test = args.num_test
env = gym.make(args.env)

obs_dim = env.observation_space.shape[0]
act_dim = env.action_space.shape[0]
print('observation dimension:', obs_dim)
print('action dimension:', act_dim)

# architecture of the MLP policy function
x_student = tf.placeholder(tf.float32, shape=[None, obs_dim])
policy_student = student_pi(x_student, act_dim, 'student')


with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    variables = [v for v in tf.trainable_variables() if "student" in v.name]
    var_path = os.path.join('./student_models', args.load_path)
    print(var_path)
    load_variables(var_path, variables=variables, sess=sess)
    # check_var = [v for v in tf.trainable_variables() if v.name == "ppo2_model/pi/mlp_fc0/b:0"][0]
    # print('check_var:', check_var.eval())

    obs = env.reset()
    done = False

    episode_rew = 0
    expert_test_rew = []
    counter = 0

    while True:
        act = policy_student.eval(feed_dict={x_student: obs[None, :]})
        obs, rew, done, __ = env.step(act)
        episode_rew += rew
        if render:
            env.render()

        if done:
            expert_test_rew.append(episode_rew)
            print('reward:', episode_rew)
            episode_rew = 0
            obs = env.reset()
            counter += 1
            if counter >= num_test:
                break
    rew_mean = np.mean(expert_test_rew)
    rew_std = np.std(expert_test_rew)

    print("test for {} times with reward mean {} and std {}".format(num_test, rew_mean, rew_std))