{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import sys\n",
    "sys.path.append(\"../\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": false
   },
   "outputs": [],
   "source": [
    "# from __future__ import division\n",
    "import tensorflow as tf\n",
    "import numpy as np\n",
    "from collections import deque\n",
    "import random\n",
    "import gym\n",
    "from gym import wrappers\n",
    "from gym.envs.classic_control.pendulum import angle_normalize, PendulumEnv\n",
    "from core import *\n",
    "from utils_latentPolicy_sac_lstm_zt_zt1 import *\n",
    "import os\n",
    "import tensorflow_probability as tfp\n",
    "import multiprocessing as mp\n",
    "import os\n",
    "import d4rl\n",
    "import json\n",
    "os.environ[\"CUDA_VISIBLE_DEVICES\"] = \"1\"\n",
    "\n",
    "slim = tf.contrib.slim\n",
    "rnn = tf.contrib.rnn\n",
    "tfd = tfp.distributions\n",
    "config=tf.ConfigProto(log_device_placement=False)\n",
    "config.gpu_options.allow_growth = True\n",
    "\n",
    "with open('./processed_data_mayo/train_pattern_15_15_15_15.npy', 'rb') as f:\n",
    "    DATA = np.load(f, allow_pickle=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": false
   },
   "outputs": [],
   "source": [
    "# Need to modify: dimension, normalization, call data\n",
    "\n",
    "\n",
    "def main(args):\n",
    "\n",
    "    \n",
    "    lr, ope_lr, ope_ds, ope_dr, beta = args\n",
    "    # lr, ope_lr, ope_ds, ope_dr, beta = 0.0003, 0.0001, 1000, 0.95, 1.\n",
    "\n",
    "    LR = lr\n",
    "    GAMMA = 0.995\n",
    "    BUFFER_SIZE_SAC = 2*10**6\n",
    "    MINIBATCH_SIZE_SAC = 256\n",
    "    MINIBATCH_SIZE_OPE = 64\n",
    "    RANDOM_SEED = 2599\n",
    "    MAX_EPISODES = 2000\n",
    "    MAX_EPISODE_LEN = 3\n",
    "    NUM_OPE_MODELS = 1\n",
    "    CODE_SIZE = 16\n",
    "    EXPLORATION = .4\n",
    "    REPEAT = 1\n",
    "    BUFFER_SIZE_OPE = 3000\n",
    "\n",
    "    #     EVAL_EPIs = [400, 10]\n",
    "    #     EVAL_PATHs = [\n",
    "    #         \"./saved_model/SAC_online_HalfCheetah-v2_\"+str(target_epi)+\"epi_200epiLen_0.0003_1234/sac.ckpt\"\n",
    "    #         for target_epi in EVAL_EPIs\n",
    "    #     ]\n",
    "    #     EVAL_TRUTHs = [-129.99324919041212, -1180.3803198136773]\n",
    "\n",
    "    OPE_LR = ope_lr\n",
    "    OPE_DS = ope_ds\n",
    "    OPE_DR = ope_dr\n",
    "\n",
    "    BEST_ELBO = -9999.\n",
    "\n",
    "    #     OPE_LR = 1e-03\n",
    "    #     OPE_DS = 1000\n",
    "    #     OPE_DR = .995\n",
    "\n",
    "    network_params = {\n",
    "    'hidden_sizes':[256, 256],\n",
    "    'activation':'relu',\n",
    "    'policy':mlp_gaussian_policy\n",
    "    }\n",
    "\n",
    "    rl_params = {\n",
    "        'env_name':'pyrenees',\n",
    "\n",
    "        # control params\n",
    "        'seed': RANDOM_SEED,\n",
    "        'epochs': MAX_EPISODES,\n",
    "        'actor_critic':mlp_actor_critic,\n",
    "        'steps_per_epoch': MAX_EPISODE_LEN,\n",
    "        'replay_size': BUFFER_SIZE_SAC,\n",
    "        'batch_size': MINIBATCH_SIZE_SAC,\n",
    "        'start_epis': 0,\n",
    "        'max_ep_len': MAX_EPISODE_LEN,\n",
    "        'save_freq': 10,\n",
    "        'render': False,\n",
    "\n",
    "        # rl params\n",
    "        'gamma': 0.99,\n",
    "        'polyak': 0.995,\n",
    "        'lr': LR,\n",
    "        'grad_clip_val':None,\n",
    "\n",
    "        # entropy params\n",
    "        'alpha': 'auto',\n",
    "        'target_entropy':'auto' # fixed or auto define with -act_dim\n",
    "    }\n",
    "\n",
    "    file_appendix = (\n",
    "        \"lstm_vae_\" + rl_params['env_name'] + \"_\" + str(MAX_EPISODES)\n",
    "        + \"epi_repeat\"+ str(REPEAT) + \"_\" + str(LR) + \"_\"\n",
    "        + str(OPE_LR) + \"_\"\n",
    "        + str(OPE_DS) + \"_\"\n",
    "        + str(OPE_DR) + \"_\"\n",
    "        + str(CODE_SIZE) + \"_\"\n",
    "        + str(beta) + \"_\"\n",
    "        + str(RANDOM_SEED)\n",
    "    )\n",
    "\n",
    "    #     env = gym.make(rl_params['env_name'])\n",
    "    np.random.seed(RANDOM_SEED)\n",
    "    tf.set_random_seed(RANDOM_SEED)\n",
    "    #     env.seed(RANDOM_SEED)\n",
    "\n",
    "    env_state_dim = 130 # NEED MOD\n",
    "    # state_dim = CODE_SIZE\n",
    "    env_action_dim = 3 # NEED MOD\n",
    "    env_action_bound = None # NEED MOD\n",
    "    env_state_bound = None\n",
    "    # Ensure action bound is symmetric\n",
    "    #     assert (env.action_space.high == -env.action_space.low)\n",
    "\n",
    "    graph_ope_models = tf.Graph()\n",
    "\n",
    "\n",
    "    graph_ope_models_eval = tf.Graph()\n",
    "\n",
    "\n",
    "    with tf.Session(config=config, graph=graph_ope_models) as sess_ope_models:\n",
    "        with tf.Session(config=config, graph=graph_ope_models_eval) as sess_ope_models_eval:\n",
    "\n",
    "    #             d4rl_qlearning = d4rl.qlearning_dataset(env)\n",
    "            # NEED MOD\n",
    "            obs_mean = 0.\n",
    "            obs_std = 1.\n",
    "\n",
    "            rew_mean = 0.\n",
    "            rew_std = 1.\n",
    "            # NEED MOD\n",
    "\n",
    "            with graph_ope_models.as_default():\n",
    "\n",
    "                ope_model = OPE_Model(\n",
    "                    graph_ope_models, sess_ope_models, OPE_LR, OPE_DS, OPE_DR, CODE_SIZE,\n",
    "                    env_state_dim, env_state_bound, env_action_dim, file_appendix,\n",
    "                    BUFFER_SIZE_OPE, RANDOM_SEED, MINIBATCH_SIZE_OPE, MAX_EPISODE_LEN, beta\n",
    "                )\n",
    "\n",
    "                ope_saver = ope_model.saver\n",
    "\n",
    "                sess_ope_models.run(tf.global_variables_initializer())\n",
    "\n",
    "                ope_model.replay_buffer.port_d4rl_data(\n",
    "    #                     d4rl.sequence_dataset(env), # original D4RL data format !!!!\n",
    "                    DATA,\n",
    "                    obs_mean,\n",
    "                    obs_std,\n",
    "                    rew_mean,\n",
    "                    rew_std,\n",
    "                )\n",
    "\n",
    "\n",
    "            with graph_ope_models_eval.as_default():\n",
    "\n",
    "                ope_model_eval = OPE_Model(\n",
    "                    graph_ope_models_eval, sess_ope_models_eval, OPE_LR, OPE_DS, OPE_DR, CODE_SIZE,\n",
    "                    env_state_dim, env_state_bound, env_action_dim, file_appendix,\n",
    "                    BUFFER_SIZE_OPE, RANDOM_SEED, MINIBATCH_SIZE_OPE, MAX_EPISODE_LEN, \n",
    "                    beta, is_training=False\n",
    "                )\n",
    "\n",
    "\n",
    "            actor_noise = OrnsteinUhlenbeckActionNoise(mu=np.zeros(env_action_dim))\n",
    "\n",
    "            # Initialize replay memory\n",
    "    #             replay_buffer = sac.replay_buffer\n",
    "\n",
    "    #         print \"Start\"\n",
    "\n",
    "            for i in range(MAX_EPISODES):\n",
    "\n",
    "    #                     print (\"epi_{}\".format(i))\n",
    "\n",
    "    #                         env.seed(RANDOM_SEED)\n",
    "    #                 s = env.reset()\n",
    "\n",
    "                ep_reward = 0\n",
    "                ep_ave_max_q = 0\n",
    "                ep_elbo = []\n",
    "                ep_likelihood_s = []\n",
    "                ep_likelihood_r = []\n",
    "                ep_divergence1 = []\n",
    "                ep_divergence2 = []\n",
    "                ep_divergence3 = []\n",
    "                ep_mse = []\n",
    "\n",
    "                if ope_model.replay_buffer.size > MINIBATCH_SIZE_OPE * 4:\n",
    "\n",
    "    #                             for l in range(MAX_EPISODE_LEN):\n",
    "\n",
    "    #                                 if l % 20 == 0:\n",
    "\n",
    "                    batch = ope_model.replay_buffer.sample_batch(MINIBATCH_SIZE_OPE)\n",
    "\n",
    "                    ope_model.train(batch)\n",
    "                    ep_elbo += [np.mean([ope_model.elbo_evaluated for k in range(NUM_OPE_MODELS)])]\n",
    "                    ep_likelihood_s += [np.mean([ope_model.likelihood_s_evaluated for k in range(NUM_OPE_MODELS)])]\n",
    "                    ep_likelihood_r += [np.mean([ope_model.likelihood_r_evaluated for k in range(NUM_OPE_MODELS)])]\n",
    "                    ep_divergence1 += [np.mean([ope_model.divergence1_evaluated for k in range(NUM_OPE_MODELS)])]\n",
    "                    ep_divergence2 += [np.mean([ope_model.divergence2_evaluated for k in range(NUM_OPE_MODELS)])]\n",
    "                    ep_divergence3 += [np.mean([ope_model.divergence3_evaluated for k in range(NUM_OPE_MODELS)])]\n",
    "                    ep_mse += [np.mean([ope_model.encoder_decoder_lstm_states_mse_evaluated for k in range(NUM_OPE_MODELS)])]\n",
    "                    if ep_elbo[-1] > BEST_ELBO:\n",
    "                        BEST_ELBO = ep_elbo[-1]\n",
    "                        ope_model.saver.save(ope_model.sess, ope_model.save_appendix.replace(\"ope.ckpt\", \"aug_best.ckpt\"))\n",
    "\n",
    "                    if np.isnan(ep_elbo[-1]):\n",
    "                        return\n",
    "#                         break\n",
    "\n",
    "    #                     if (i+1) % 50 == 0 and ope_model.replay_buffer.size > MINIBATCH_SIZE_OPE * 4:\n",
    "    #                         mae = evaluate(\n",
    "    #                             ope_model_eval, \n",
    "    #                             graph_ope_models_eval, \n",
    "    #                             sess_ope_models_eval, \n",
    "    #                             MAX_EPISODE_LEN,\n",
    "    #                             REPEAT,\n",
    "    #                             env_state_dim,\n",
    "    #                             env_action_dim,\n",
    "    #                             RANDOM_SEED,\n",
    "    #                             obs_mean, \n",
    "    #                             obs_std, \n",
    "    #                             rew_mean, \n",
    "    #                             rew_std,\n",
    "    #                             GAMMA\n",
    "    #                         )\n",
    "    #                         if mae < BEST_MAE:\n",
    "    #                             ope_model.saver.save(ope_model.sess, ope_model.save_appendix.replace(\"ope.ckpt\", \"ope_best.ckpt\"))\n",
    "    #                             BEST_MAE = mae\n",
    "\n",
    "\n",
    "                with open(\"./rl_stats/\"+file_appendix+\".txt\", \"a\") as myfile:\n",
    "                    myfile.write(\n",
    "                        '| Reward: {:d} | Episode: {:d}  | ELBO: {:.4f} | DIV1: {:.4f} | DIV2: {:.4f} | DIV3: {:.4f} | P_ns: {:.4f} | P_r: {:.4f} | MSE: {:.4f} \\n'\n",
    "                        .format(\n",
    "                            int(ep_reward), \n",
    "                            i, \n",
    "                            np.mean(ep_elbo),\n",
    "                            np.mean(ep_divergence1),\n",
    "                            np.mean(ep_divergence2),\n",
    "                            np.mean(ep_divergence3),\n",
    "                            np.mean(ep_likelihood_s),\n",
    "                            np.mean(ep_likelihood_r),\n",
    "                            np.mean(ep_mse)\n",
    "                        )\n",
    "                    )\n",
    "\n",
    "\n",
    "                print(\n",
    "                    '| Reward: {:d} | Episode: {:d}  | ELBO: {:.4f} | DIV1: {:.4f} | DIV2: {:.4f} | DIV3: {:.4f} | P_ns: {:.4f} | P_r: {:.4f} | MSE: {:.4f} \\n'\n",
    "                    .format(\n",
    "                        int(ep_reward), \n",
    "                        i, \n",
    "                        np.mean(ep_elbo),\n",
    "                        np.mean(ep_divergence1),\n",
    "                        np.mean(ep_divergence2),\n",
    "                        np.mean(ep_divergence3),\n",
    "                        np.mean(ep_likelihood_s),\n",
    "                        np.mean(ep_likelihood_r),\n",
    "                        np.mean(ep_mse)\n",
    "                    )\n",
    "                )\n",
    "                                \n",
    "# def evaluate(ope_eval, graph_ope_eval, sess_ope_eval, *args):\n",
    "#     MAX_EPISODE_LEN, REPEAT, env_state_dim, env_action_dim, RANDOM_SEED, obs_mean, obs_std, rew_mean, rew_std, GAMMA = args\n",
    "    \n",
    "#     with tf.io.gfile.GFile(\"../../d4rl/deep_ope/d4rl_policies.json\", 'r') as f:\n",
    "#         policy_database = json.load(f)\n",
    "\n",
    "#     policy_metadatas = [i for i in policy_database if i['task.task_names'][0].find(\"halfcheetah\")!=-1]\n",
    "    \n",
    "#     EVAL_TRUTHs = [i['return_mean'] for i in policy_metadatas]\n",
    "    \n",
    "#     pred = []\n",
    "#     class LearnedEnv(object):\n",
    "#         def __init__(self, model):\n",
    "\n",
    "# #             super(LearnedPendulum, self).__init__()\n",
    "#             self.model = model\n",
    "\n",
    "#         def reset(self):\n",
    "#             self.model.init_z0_s0()\n",
    "# #                 s0 = np.array([ 0.76898139, -0.63927117,  0.30185718])\n",
    "#             s0 = self.model.sess.run(self.model.decoder_state_sample, \n",
    "#                                feed_dict={self.model.decoder_zt_holder:self.model.zt}).reshape(-1)\n",
    "\n",
    "#             self.obs = s0\n",
    "#             return s0\n",
    "\n",
    "#         def step(self, u):\n",
    "#             new_obs, reward = self.model.get_zt1_s2_r(np.reshape(u, (1, env_action_dim)))\n",
    "#             self.obs = new_obs\n",
    "#             self.model.update_zt()\n",
    "\n",
    "#             return new_obs, reward, False, {}\n",
    "\n",
    "#     learned_env = LearnedEnv(ope_eval)\n",
    "# #     learned_env.seed(RANDOM_SEED)\n",
    "\n",
    "#     for _i in range(len(policy_metadatas)):\n",
    "#         with graph_ope_eval.as_default():\n",
    "#             ope_eval.saver.restore(sess_ope_eval, ope_eval.save_appendix)\n",
    "#         ep_rewards = []\n",
    "#         policy = D4RL_Policy(policy_metadatas[_i]['policy_path'])\n",
    "#         for i in range(5):\n",
    "\n",
    "#             terminal = 0\n",
    "\n",
    "#             s = learned_env.reset()\n",
    "#             s = s.reshape(env_state_dim)*obs_std + obs_mean\n",
    "#             ep_reward = 0\n",
    "\n",
    "#             for j in range(MAX_EPISODE_LEN):\n",
    "\n",
    "#                 if j % REPEAT == 0:\n",
    "#                     a, _ = policy.act(np.reshape(s, (env_state_dim,)), np.zeros((env_action_dim,)))\n",
    "#                 s2, r, terminal, info = learned_env.step(a)\n",
    "#                 r = r*rew_std + rew_mean\n",
    "#                 s2 = s2.reshape(env_state_dim)*obs_std + obs_mean\n",
    "\n",
    "#                 ep_reward += r*(GAMMA**j)\n",
    "\n",
    "#                 s = s2\n",
    "\n",
    "#                 if terminal or j == MAX_EPISODE_LEN-1:\n",
    "#                     ep_rewards += [ep_reward]\n",
    "\n",
    "#                     break\n",
    "#         pred += [np.mean(ep_rewards)]\n",
    "#     return np.mean(np.abs(np.asarray(EVAL_TRUTHs)-np.asarray(pred)))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "LRs = [0.0003]\n",
    "OPE_LRs = [0.003, 0.0001, 0.0003, 0.0005, 0.0007, 0.0009]\n",
    "OPE_DSs = [1000]\n",
    "OPE_DRs = [.9]\n",
    "BETAs = [1., .1, .05, .01, 5., 10.]\n",
    "# BETAs = [.5, .05, .005, 1., 5.]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "pool = mp.Pool(3)\n",
    "pool.map(main, [(lr, ope_lr, ope_ds, ope_dr, beta) for lr in LRs for ope_lr in OPE_LRs for ope_ds in OPE_DSs for ope_dr in OPE_DRs for beta in BETAs])\n",
    "pool.close()\n",
    "pool.join()"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.11"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
