# from __future__ import division
import tensorflow as tf
import numpy as np
from collections import deque
import random
import gym
from gym import wrappers
from core import *
from vlm_h import *
import os
import tensorflow_probability as tfp
import multiprocessing as mp
import os
import json
import pandas as pd
import collections
import datetime
from scipy.stats import spearmanr
os.environ["CUDA_VISIBLE_DEVICES"] = "0"

slim = tf.contrib.slim
rnn = tf.contrib.rnn
tfd = tfp.distributions
config=tf.ConfigProto(log_device_placement=False)
config.gpu_options.allow_growth = True

def main():

	################# Hyper-parameters #########################################

    GAMMA = .995
    BUFFER_SIZE_SAC = 2*10**6
    MINIBATCH_SIZE_SAC = 256
    MINIBATCH_SIZE_OPE = 64
    RANDOM_SEED = 1234
    MAX_EPISODES = 3000
    
    NUM_OPE_MODELS = 1
    CODE_SIZE = 64
    EXPLORATION = .3
    REPEAT = 1
    BUFFER_SIZE_OPE = 10250
    
    OPE_LR = .001
    OPE_DS = 1000
    OPE_DR = .95

    beta = .01

    #############################################################################
    
    BEST_RANK = 0.
    BEST_ELBO = -999999.
    
    rl_params = {'env_name' : "low"}

    file_appendix = (
        "VLM-H_" + rl_params['env_name'] + "_" + str(MAX_EPISODES)
        + "epi_repeat"+ str(REPEAT) + "_"
        + str(OPE_LR) + "_"
        + str(OPE_DS) + "_"
        + str(OPE_DR) + "_"
        + str(CODE_SIZE) + "_"
        + str(beta) + "_"
        + str(RANDOM_SEED)
    )
    
    # Load offline trajectories
    data = np.load("./{}.npy".format(rl_params['env_name']), allow_pickle=True).item()
    
    MAX_EPISODE_LEN = np.max([len(i) for i in data['observations']])

    env_state_dim = 768
    env_action_dim = 768 * np.max([j.shape[0] for i in data['actions'] for j in i])
    env_state_bound = None
    
    replay_buffer = ReplayBuffer_Trajectory(
        env_state_dim, env_action_dim, 
        np.max([len(i) for i in data['observations']]), BUFFER_SIZE_OPE)


    for i in range(len(data['observations'])):
        if len(data['observations'][i]) >= 3:
            states = np.asarray(data['observations'][i])
            states_next = np.asarray(data['next_observations'][i])
            a = np.asarray([np.concatenate([j.reshape(-1), np.zeros(env_action_dim-j.reshape(-1).shape)]) for j in data['actions'][i]])
            rewards = np.asarray(data['rewards'][i])
            final_rewards = np.asarray(data['final_rewards'][i])
            done = np.asarray(np.zeros_like(rewards))
            done[-1] += 1.
            last_idxs = np.asarray(data['last_idxs'][i])
            replay_buffer.add_seq(len(states), states, a, 
                                  rewards, done, states_next, 
                                  final_rewards, last_idxs)
    
    
    epis_already_passed = 0
    
    if os.path.exists("./rl_stats/"+file_appendix+".txt"):
        stats_df = pd.read_csv("./rl_stats/"+file_appendix+".txt", header=None, delimiter=" | ")
        epis_already_passed = len(stats_df.index.values)
        if epis_already_passed >= MAX_EPISODES:
            return

    np.random.seed(RANDOM_SEED)
    tf.set_random_seed(RANDOM_SEED)



    graph_ope_models = tf.Graph()

    graph_ac = tf.Graph()
    
    graph_ope_models_eval = tf.Graph()
    
    graph_target_policies = tf.Graph()


    with tf.Session(config=config, graph=graph_ac) as sess_ac:
        with tf.Session(config=config, graph=graph_ope_models) as sess_ope_models:
            with tf.Session(config=config, graph=graph_ope_models_eval) as sess_ope_models_eval:
                with tf.Session(config=config, graph=graph_target_policies) as sess_target_policies:
                
                    
                    obs_mean = replay_buffer.obs1_buf.mean().astype(np.float32)
                    obs_std = replay_buffer.obs1_buf.std().astype(np.float32)

                    rew_mean = replay_buffer.rews_buf.mean()
                    rew_std = replay_buffer.rews_buf.std()

                    with graph_ope_models.as_default():

                        ope_model = OPE_Model(
                            graph_ope_models, sess_ope_models, OPE_LR, OPE_DS, OPE_DR, CODE_SIZE,
                            env_state_dim, env_state_bound, env_action_dim, file_appendix,
                            BUFFER_SIZE_OPE, RANDOM_SEED, MINIBATCH_SIZE_OPE, 
                            np.max([len(i) for i in data['observations']]), beta,
                            lstm_hidden=128
                        )

                        ope_saver = ope_model.saver

                        ope_model.replay_buffer = replay_buffer

                        sess_ope_models.run(tf.global_variables_initializer())

                        if os.path.exists("./rl_stats/"+file_appendix+".txt"):
                            ope_model.saver.restore(
                                sess_ope_models, 
                                os.path.join(
                                    "./saved_model/",
                                    file_appendix,
                                    "ope.ckpt"
                                )
                            )

                            for _k in range(epis_already_passed):

                                sess_ope_models.run(ope_model.global_step_increment)

                    with graph_ope_models_eval.as_default():

                        ope_model_eval = OPE_Model(
                            graph_ope_models_eval, sess_ope_models_eval, OPE_LR, OPE_DS, OPE_DR, CODE_SIZE,
                            env_state_dim, env_state_bound, env_action_dim, file_appendix,
                            BUFFER_SIZE_OPE, RANDOM_SEED, MINIBATCH_SIZE_OPE, MAX_EPISODE_LEN, 
                            beta, is_training=False, lstm_hidden=128
                        )



                    for i in range(epis_already_passed, MAX_EPISODES):



                        ep_reward = 0
                        ep_ave_max_q = 0
                        ep_elbo = []
                        ep_likelihood_s = []
                        ep_likelihood_r = []
                        ep_likelihood_t = []
                        ep_divergence1 = []
                        ep_divergence2 = []
                        ep_divergence3 = []
                        ep_mse = []


                        if ope_model.replay_buffer.size > MINIBATCH_SIZE_OPE:


                            batch = ope_model.replay_buffer.sample_batch(MINIBATCH_SIZE_OPE)

                            ope_model.train(batch)
                            ep_elbo += [np.mean([ope_model.elbo_evaluated for k in range(NUM_OPE_MODELS)])]
                            ep_likelihood_s += [np.mean([ope_model.likelihood_s_evaluated for k in range(NUM_OPE_MODELS)])]
                            ep_likelihood_r += [np.mean([ope_model.likelihood_r_evaluated for k in range(NUM_OPE_MODELS)])]
                            ep_likelihood_t += [np.mean([ope_model.likelihood_t_evaluated for k in range(NUM_OPE_MODELS)])]
                            ep_divergence1 += [np.mean([ope_model.divergence1_evaluated for k in range(NUM_OPE_MODELS)])]
                            ep_divergence2 += [np.mean([ope_model.divergence2_evaluated for k in range(NUM_OPE_MODELS)])]
                            ep_divergence3 += [np.mean([ope_model.divergence3_evaluated for k in range(NUM_OPE_MODELS)])]
                            ep_mse += [np.mean([ope_model.encoder_decoder_lstm_states_mse_evaluated for k in range(NUM_OPE_MODELS)])]

                            if np.isnan(ep_elbo[-1]):
                                return
                            
                            if np.mean(ep_elbo) > BEST_ELBO:
                                ope_model.saver.save(ope_model.sess, ope_model.save_appendix.replace("ope.ckpt", "ope_best.ckpt"))
                                BEST_ELBO = np.mean(ep_elbo)


                        with open("./rl_stats/"+file_appendix+".txt", "a") as myfile:
                            myfile.write(
                                '| Reward: {:d} | Episode: {:d}  | ELBO: {:.4f} | DIV1: {:.4f} | DIV2: {:.4f} | DIV3: {:.4f} | P_ns: {:.4f} | P_r: {:.4f} | P_t: {:.4f} | MSE: {:.4f} \n'
                                .format(
                                    int(ep_reward), 
                                    i, 
                                    np.mean(ep_elbo),
                                    np.mean(ep_divergence1),
                                    np.mean(ep_divergence2),
                                    np.mean(ep_divergence3),
                                    np.mean(ep_likelihood_s),
                                    np.mean(ep_likelihood_r),
                                    np.mean(ep_likelihood_t),
                                    np.mean(ep_mse)
                                )
                            )


                        print(
                            '| Reward: {:d} | Episode: {:d}  | ELBO: {:.4f} | DIV1: {:.4f} | DIV2: {:.4f} | DIV3: {:.4f} | P_ns: {:.4f} | P_r: {:.4f} | P_t: {:.4f} | MSE: {:.4f} \n'
                            .format(
                                int(ep_reward), 
                                i, 
                                np.mean(ep_elbo),
                                np.mean(ep_divergence1),
                                np.mean(ep_divergence2),
                                np.mean(ep_divergence3),
                                np.mean(ep_likelihood_s),
                                np.mean(ep_likelihood_r),
                                np.mean(ep_likelihood_t),
                                np.mean(ep_mse)
                            )
                        )
                                
if __name__ == "__main__":
	main()
