import sys
sys.path.append('../')
import os
import pickle
import tensorflow as tf
import numpy as np
import gym
import joblib
import argparse
from aril_utils import *
from norm_env import * 


def main():
    parser = argparse.ArgumentParser()

    parser.add_argument("--load_path", type=str, default="swimmer_20M_ppo2_norm")
    parser.add_argument("--render", type=bool, default=False)
    parser.add_argument("--num_rollouts", type=int, default=50)
    parser.add_argument("--batch_size", type=int, default=32)
    parser.add_argument("--env", type=str, default='Swimmer-v2')
    parser.add_argument("--num_dagger_itr", type=int, default=500)
    parser.add_argument("--train_steps", type=int, default=1000)
    parser.add_argument("--save_path", type=str, default='')
    parser.add_argument("--init_samples", type=int, default=100000)
    parser.add_argument("--plot_save_path", type=str, default='./plot/swimmer.png')
    parser.add_argument("--model_save_path", type=str, default='./student_models/swimmer_student')

    args = parser.parse_args()

    #===========================================================================
    # set the parameters and the environment
    #===========================================================================
    render = args.render
    num_rollouts = args.num_rollouts
    batch_size = args.batch_size
    load_path = os.path.join('../openai_baselines/models', args.load_path)
    t_env = NormalizeEnv(args.env, update=False)
    t_env.obs_rms.load(load_path)
    # s_env = NormalizeEnv(args.env, update=False)
    s_env = VanillaEnv(args.env)

    obs_dim = t_env.observation_space.shape[0]
    act_dim = t_env.action_space.shape[0]
    args.obs_dim = obs_dim
    args.act_dim = act_dim
    print('observation dimension:', args.obs_dim)
    print('action dimension:', args.act_dim)

    # architecture of the MLP policy function
    x_student = tf.placeholder(tf.float32, shape=[None, obs_dim])
    x_teacher = tf.placeholder(tf.float32, shape=[None, obs_dim])
    y_teacher = tf.placeholder(tf.float32, shape=[None, act_dim])
    
    # student_mean, student_std = student_pi(x_student, act_dim, 'student')
    # y_student = student_mean + student_std * tf.random_normal(tf.shape(student_mean))

    student_action_mean = student_pi(x_student, args.act_dim, 'student')
    student_action_logstd = tf.Variable(np.zeros(args.act_dim), name='student/logstd', dtype=tf.float32)
    student_action_std = tf.exp(student_action_logstd)
    y_student = student_action_mean + student_action_std * tf.random_normal(tf.shape(student_action_mean))
    
    expert = pi(x_teacher, act_dim, 'ppo2_model/pi')
    # load_path = os.path.join('../baselines/models', args.load_path)
    para = joblib.load(load_path)
    expert_pi_logstd = para['ppo2_model/pi/logstd:0']
    expert_pi_std = np.exp(expert_pi_logstd)

    loss_l2 = tf.reduce_mean(tf.square(y_student - y_teacher))
    train_step = tf.train.AdamOptimizer().minimize(loss_l2)

    with tf.Session(config=tf.ConfigProto(device_count={'GPU': 0})) as sess:
        sess.run(tf.global_variables_initializer())

        print_var = [v.name for v in tf.trainable_variables() if "student" in v.name]
        print('student variables:', print_var)

        variables = [v for v in tf.trainable_variables() if "ppo2_model/pi/" in v.name]
        # var_path = os.path.join('../baselines/models', args.load_path)
        load_variables(load_path, variables=variables, sess=sess)
        # check_var = [v for v in tf.trainable_variables() if v.name == "ppo2_model/pi/mlp_fc0/b:0"][0]
        # print('check_var:', check_var.eval())

        state, obs = t_env.reset()
        done = False

        episode_rew = 0
        expert_test_rew = []

        counter = 0
        state_data = np.empty((args.init_samples, obs_dim))
        act_data = np.empty((args.init_samples, act_dim))
        print('Generating samples from expert!')

        #===========================================================================
        # generate expert data
        #===========================================================================
        while True:
            if counter >= args.init_samples:
                break
            act = expert.eval(feed_dict={x_teacher: obs[None, :]})
            act = act + expert_pi_std * np.random.normal(size=act.shape)

            state, obs, rew, done, __ = t_env.step(act)

            state_data[counter,:] = state
            act_data[counter,:] = act
            episode_rew += rew
            if args.render:
                t_env.render()
            counter += 1

            if done:
                expert_test_rew.append(episode_rew)
                print('reward:', episode_rew)
                episode_rew = 0
                state, obs = t_env.reset()

        save_expert_mean = np.mean(expert_test_rew)
        save_expert_std = np.std(expert_test_rew)

        #===========================================================================
        # run DAgger alg
        #===========================================================================
        save_mean = []
        save_std = []
        save_train_size = []

        for i_dagger in range(args.num_dagger_itr):
            print('DAgger iteration ', i_dagger)
            # train a policy by fitting the MLP
            # student_obs_data = s_env.get_obs(state_data)

            for step in range(args.train_steps):
                batch_i = np.random.randint(0, state_data.shape[0], size=batch_size)
                # train_step.run(feed_dict={x_student: student_obs_data[batch_i, :], y_teacher: act_data[batch_i, :]})
                train_step.run(feed_dict={x_student: state_data[batch_i, :], y_teacher: act_data[batch_i, :]})

                if ((step + 1) % 200 == 0):
                    print('opmization step ', step)
                    print('obj value is ', loss_l2.eval(feed_dict={x_student:state_data, y_teacher:act_data}))
            print('Optimization Finished!')
            # use trained MLP to perform
            max_steps = s_env._max_episode_steps
    
            student_returns = []
            student_states = []
            student_actions = []
            print('rollouts!')
            for i in range(num_rollouts):
                # state, obs = s_env.reset()
                state = s_env.reset()
                done = False
                totalr = 0.
                steps = 0
                while not done:
                    # act = y_student.eval(feed_dict={x_student:obs[None, :]})
                    act = y_student.eval(feed_dict={x_student:state[None, :]})
                    # observations.append(obs)
                    student_states.append(state)
                    student_actions.append(act)
                    # state, obs, r, done, _ = s_env.step(act)
                    state, r, done, __ = s_env.step(act)
                    totalr += r
                    steps += 1   
                    if render:
                        s_env.render()
                    if steps >= max_steps:
                        break
                student_returns.append(totalr)
            print('student mean return', np.mean(student_returns))
            print('student std of return', np.std(student_returns))

            # update rms in student env
            # student_observations = s_env.obs_normalize(np.array(student_states))
    
            # expert labeling
            expert_actions = []
            print('expert labeling!')
            for i_label in range(len(student_states)):
                obs = t_env.get_obs(student_states[i_label])
                act = expert.eval(feed_dict={x_teacher: obs[None, :]})
                act = act + expert_pi_std * np.random.normal(size=act.shape)
                expert_actions.append(act)

            # record training size
            train_size = state_data.shape[0]

            # data aggregation
            state_data = np.concatenate((state_data, np.array(student_states)), axis=0)
            act_data = np.concatenate((act_data, np.squeeze(np.array(expert_actions))), axis=0)

            # record mean return & std
            save_mean = np.append(save_mean, np.mean(student_returns))
            save_std = np.append(save_std, np.std(student_returns))
            save_train_size = np.append(save_train_size, train_size)

        save_var = [v for v in tf.trainable_variables() if "student" in v.name]
        save_variables(args.model_save_path, save_var, sess)

    dagger_results = {'means': save_mean, 'stds': save_std, 'train_size': save_train_size,
                      'expert_mean':save_expert_mean, 'expert_std':save_expert_std}
    plot_reward(args, save_mean, save_std, save_expert_mean)
    
    # rms_save_path = args.model_save_path + '_rms.npz'
    # np.savez(rms_save_path, s_env.obs_rms.mean, s_env.obs_rms.var)
    
    print('expert reward mean {} std {}'.format(save_expert_mean, save_expert_std))
    print('DAgger iterations finished!')


if __name__ == '__main__':
    main()
