{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "9b058b71",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "WARNING:tensorflow:\n",
      "The TensorFlow contrib module will not be included in TensorFlow 2.0.\n",
      "For more information, please see:\n",
      "  * https://github.com/tensorflow/community/blob/master/rfcs/20180907-contrib-sunset.md\n",
      "  * https://github.com/tensorflow/addons\n",
      "  * https://github.com/tensorflow/io (for I/O related ops)\n",
      "If you depend on functionality not listed there, please file an issue.\n",
      "\n",
      "WARNING:tensorflow:From /tmp/ipykernel_29752/3647192631.py:9: The name tf.ConfigProto is deprecated. Please use tf.compat.v1.ConfigProto instead.\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Warning: Flow failed to import. Set the environment variable D4RL_SUPPRESS_IMPORT_ERROR=1 to suppress this message.\n",
      "No module named 'flow'\n",
      "/home/ggao5/anaconda3/envs/ope_py37/lib/python3.7/site-packages/glfw/__init__.py:906: GLFWError: (65544) b'X11: The DISPLAY environment variable is missing'\n",
      "  warnings.warn(message, GLFWError)\n",
      "Warning: CARLA failed to import. Set the environment variable D4RL_SUPPRESS_IMPORT_ERROR=1 to suppress this message.\n",
      "No module named 'carla'\n",
      "pybullet build time: Dec  1 2021 18:33:04\n"
     ]
    }
   ],
   "source": [
    "import tensorflow as tf\n",
    "import numpy as np\n",
    "import tensorflow_probability as tfp\n",
    "from matplotlib import pyplot as plt\n",
    "import os \n",
    "os.environ[\"CUDA_VISIBLE_DEVICES\"] = \"2\"\n",
    "slim = tf.contrib.slim\n",
    "tfd = tfp.distributions\n",
    "session_config = tf.ConfigProto(log_device_placement=False)\n",
    "session_config.gpu_options.allow_growth = True\n",
    "\n",
    "import gym\n",
    "from gym import wrappers\n",
    "from gym.envs.classic_control.pendulum import angle_normalize, PendulumEnv\n",
    "import d4rl\n",
    "\n",
    "from tensorflow.nn.rnn_cell import LSTMStateTuple\n",
    "rnn = tf.contrib.rnn"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "138951e0",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/ggao5/anaconda3/envs/ope_py37/lib/python3.7/site-packages/gym/spaces/box.py:74: UserWarning: \u001b[33mWARN: Box bound precision lowered by casting to float32\u001b[0m\n",
      "  \"Box bound precision lowered by casting to {}\".format(self.dtype)\n",
      "load datafile: 100%|██████████| 8/8 [00:00<00:00, 355.36it/s]\n"
     ]
    }
   ],
   "source": [
    "# input data\n",
    "\n",
    "env = gym.make(\"pen-human-v1\")\n",
    "# d4rl_original_data = env.get_dataset()\n",
    "d4rl_original_data = [i for i in d4rl.sequence_dataset(env)]\n",
    "\n",
    "state_dim = d4rl_original_data[0]['observations'].shape[1]\n",
    "action_dim = d4rl_original_data[0]['actions'].shape[1]\n",
    "is_training = True\n",
    "CODE_SIZE = action_dim\n",
    "horizon = d4rl_original_data[0]['observations'].shape[0]\n",
    "buffer_size = 3000\n",
    "MAX_EPISODES = 100 # 100 ~ 2000 according to data size\n",
    "MINIBATCH_SIZE = 4 # 4~64 according to data size\n",
    "\n",
    "lr = 0.001\n",
    "BEST_LOSS = 9999.\n",
    "\n",
    "\n",
    "EPS = 1e-8\n",
    "save_path = \"./saved_dist/state_action_dist.ckpt\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "id": "9353c8a8",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "45"
      ]
     },
     "execution_count": 13,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "state_dim"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "60aa015b",
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Box([-1. -1. -1. -1. -1. -1. -1. -1. -1. -1. -1. -1. -1. -1. -1. -1. -1. -1.\n",
       " -1. -1. -1. -1. -1. -1.], [1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.], (24,), float32)"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "env.action_space # check action bound"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "3dc0a6f1",
   "metadata": {},
   "outputs": [],
   "source": [
    "# get mean and std\n",
    "ob = [i for u in d4rl_original_data for j in u['observations'] for i in j]\n",
    "obs_mean = sum(ob)/len(ob)\n",
    "obs_std = np.std(ob)\n",
    "\n",
    "rw = [j for u in d4rl_original_data for j in u['rewards']]\n",
    "rew_mean = sum(rw)/len(rw)\n",
    "rew_std = np.std(rw)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "80f5e175",
   "metadata": {},
   "outputs": [],
   "source": [
    "class ReplayBuffer_Trajectory(object):\n",
    "# read by traj\n",
    "    def __init__(self, obs_dim, act_dim, horizon, size):\n",
    "        # size is in terms of num. of trajectories\n",
    "        self.obs1_buf = np.zeros([size, horizon, obs_dim], dtype=np.float32)\n",
    "#         self.obs2_buf = np.zeros([size, horizon, obs_dim], dtype=np.float32)\n",
    "        self.acts_buf = np.zeros([size, horizon, act_dim], dtype=np.float32)\n",
    "        self.rews_buf = np.zeros([size, horizon], dtype=np.float32)\n",
    "        self.done_buf = np.zeros([size, horizon], dtype=np.float32)\n",
    "        self.ptr0, self.ptr1, self.size, self.max_size, self.horizon = 0, 0, 0, size, horizon\n",
    "        self.count = 0\n",
    "        \n",
    "    def port_d4rl_data(self, d4rl_data, obs_mean, obs_std, rew_mean, rew_std):\n",
    "        \"\"\"\n",
    "        Port d4rl sequence datasets (generator format) into buffer\n",
    "        Now only support running this **before training starts**\n",
    "        \"\"\"\n",
    "        d4rl_data = [_d for _d in d4rl_data] # convert generator to list\n",
    "        \n",
    "        d4rl_size = 0\n",
    "        for i in range(len(d4rl_data)):\n",
    "#             if d4rl_data[i]['observations'].shape[0] == 1000 and d4rl_data[i]['next_observations'].shape[0]==1000 and  d4rl_data[i]['actions'].shape[0] == 1000:\n",
    "            d4rl_size += 1\n",
    "            \n",
    "        if self.max_size < d4rl_size:\n",
    "            assert False, \"Buffer size smaller than the size of d4rl data, cannot port in\"\n",
    "        \n",
    "        for i in range(len(d4rl_data)):\n",
    "#             if d4rl_data[i]['observations'].shape[0] == 1000 and d4rl_data[i]['next_observations'].shape[0] == 1000 and d4rl_data[i]['actions'].shape[0] == 1000:\n",
    "            self.obs1_buf[self.ptr0, :, :] = (d4rl_data[i]['observations'].astype(np.float32) - obs_mean) / obs_std\n",
    "#             self.obs2_buf[self.ptr0, :, :] = (d4rl_data[i]['next_observations'].astype(np.float32) - obs_mean) / obs_std\n",
    "            self.acts_buf[self.ptr0, :, :] = d4rl_data[i]['actions'].astype(np.float32)\n",
    "            self.rews_buf[self.ptr0, :] = (d4rl_data[i]['rewards'].astype(np.float32) - rew_mean) / rew_std\n",
    "            self.done_buf[self.ptr0, :] = d4rl_data[i]['terminals'].astype(np.float32)\n",
    "            self.size = min(self.size+1, self.max_size)\n",
    "            self.ptr0 = (self.ptr0+1) % self.max_size\n",
    "            self.count += 1\n",
    "\n",
    "    def add(self, obs, act, rew, done):\n",
    "        self.obs1_buf[self.ptr0, self.ptr1] = obs\n",
    "#         self.obs2_buf[self.ptr0, self.ptr1] = next_obs\n",
    "        self.acts_buf[self.ptr0, self.ptr1] = act\n",
    "        self.rews_buf[self.ptr0, self.ptr1] = rew\n",
    "        self.done_buf[self.ptr0, self.ptr1] = done\n",
    "        self.ptr1 = (self.ptr1+1) % self.horizon\n",
    "        if self.ptr1 == 0:\n",
    "            self.size = min(self.size+1, self.max_size)\n",
    "            self.ptr0 = (self.ptr0+1) % self.max_size\n",
    "            self.count += 1\n",
    "\n",
    "    def sample_batch(self, batch_size=32):\n",
    "        idxs = np.random.randint(0, self.size, size=batch_size)\n",
    "        return dict(obs1=self.obs1_buf[idxs],\n",
    "#                     obs2=self.obs2_buf[idxs],\n",
    "                    acts=self.acts_buf[idxs],\n",
    "                    rews=self.rews_buf[idxs],\n",
    "                    done=self.done_buf[idxs],\n",
    "                   )\n",
    "    \n",
    "    def save(self, path):\n",
    "        np.savez(\n",
    "            path, \n",
    "            obs1_buf=self.obs1_buf, \n",
    "#             obs2_buf=self.obs2_buf, \n",
    "            acts_buf=self.acts_buf, \n",
    "            rews_buf=self.rews_buf, \n",
    "            done_buf=self.done_buf,\n",
    "        )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "dfd8b18a",
   "metadata": {},
   "outputs": [],
   "source": [
    "def trun_normal_log_prob(x, mu, std, low, high):\n",
    "    z = tfd.Normal(0,1).cdf((high-x)/(std+EPS)) - tfd.Normal(0,1).cdf((low-x)/(std+EPS))\n",
    "    return tf.reduce_sum(-0.5*((x - mu) / (std+EPS))**2 - 0.5*tf.log(2*np.pi) - tf.log(std*z), axis=1, name=\"log_prob\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "8361b034",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "WARNING:tensorflow:From /tmp/ipykernel_29752/3236672759.py:1: The name tf.reset_default_graph is deprecated. Please use tf.compat.v1.reset_default_graph instead.\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "WARNING:tensorflow:From /tmp/ipykernel_29752/3236672759.py:1: The name tf.reset_default_graph is deprecated. Please use tf.compat.v1.reset_default_graph instead.\n",
      "\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "WARNING:tensorflow:From /tmp/ipykernel_29752/3236672759.py:3: The name tf.placeholder is deprecated. Please use tf.compat.v1.placeholder instead.\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "WARNING:tensorflow:From /tmp/ipykernel_29752/3236672759.py:3: The name tf.placeholder is deprecated. Please use tf.compat.v1.placeholder instead.\n",
      "\n"
     ]
    }
   ],
   "source": [
    "tf.reset_default_graph()\n",
    "\n",
    "state_holder = tf.placeholder(shape=[None, state_dim], dtype=tf.float32, name='state_holder') \n",
    "action_holder = tf.placeholder(shape=[None, action_dim], dtype=tf.float32, name='action_holder')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "253b9b09",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<tf.Tensor 'state_holder:0' shape=(?, 45) dtype=float32>"
      ]
     },
     "execution_count": 9,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "state_holder"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "db4e6925",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "WARNING:tensorflow:From /tmp/ipykernel_29752/1375984545.py:1: The name tf.AUTO_REUSE is deprecated. Please use tf.compat.v1.AUTO_REUSE instead.\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "WARNING:tensorflow:From /tmp/ipykernel_29752/1375984545.py:1: The name tf.AUTO_REUSE is deprecated. Please use tf.compat.v1.AUTO_REUSE instead.\n",
      "\n"
     ]
    }
   ],
   "source": [
    "def learn_dist_from_s(state, code_size, reuse=tf.AUTO_REUSE, is_training=True, var_scope=\"BC\"):\n",
    "    with tf.variable_scope(var_scope, reuse=reuse) as scope:\n",
    "        with slim.arg_scope([slim.fully_connected], \n",
    "                                activation_fn=tf.nn.relu,\n",
    "                                weights_initializer=tf.glorot_uniform_initializer,\n",
    "                                weights_regularizer=slim.l2_regularizer(0.001),\n",
    "                                biases_regularizer=slim.l2_regularizer(0.001),\n",
    "                                normalizer_fn = slim.batch_norm,\n",
    "                                normalizer_params = {\"is_training\": is_training},\n",
    "                                reuse = reuse,\n",
    "                                scope = scope):\n",
    "            # is_training = False for evaluation\n",
    "            x = slim.fully_connected(state, 128, scope=\"fc1\")\n",
    "            x = slim.fully_connected(x, 64, scope=\"fc2\")\n",
    "            loc = slim.fully_connected(x, code_size, activation_fn=None, scope=\"loc\")\n",
    "            scale =slim.fully_connected(x, code_size, activation_fn=tf.nn.softplus, scope=\"scale\")\n",
    "#             dist = tfd.MultivariateNormalDiag(loc, scale)\n",
    "            out_sample = tfd.TruncatedNormal(loc, scale, -1., 1.).sample() # -1, 1 bound\n",
    "            out_log_prob = trun_normal_log_prob(action_holder, loc, scale, -1., 1.)\n",
    "            return out_sample, out_log_prob"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "23635a15",
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "WARNING:tensorflow:From /tmp/ipykernel_29752/1375984545.py:2: The name tf.variable_scope is deprecated. Please use tf.compat.v1.variable_scope instead.\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "WARNING:tensorflow:From /tmp/ipykernel_29752/1375984545.py:2: The name tf.variable_scope is deprecated. Please use tf.compat.v1.variable_scope instead.\n",
      "\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "WARNING:tensorflow:From /home/ggao5/anaconda3/envs/ope_py37/lib/python3.7/site-packages/tensorflow_core/contrib/layers/python/layers/layers.py:1866: Layer.apply (from tensorflow.python.keras.engine.base_layer) is deprecated and will be removed in a future version.\n",
      "Instructions for updating:\n",
      "Please use `layer.__call__` method instead.\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "WARNING:tensorflow:From /home/ggao5/anaconda3/envs/ope_py37/lib/python3.7/site-packages/tensorflow_core/contrib/layers/python/layers/layers.py:1866: Layer.apply (from tensorflow.python.keras.engine.base_layer) is deprecated and will be removed in a future version.\n",
      "Instructions for updating:\n",
      "Please use `layer.__call__` method instead.\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "WARNING:tensorflow:From /tmp/ipykernel_29752/2797447731.py:3: The name tf.log is deprecated. Please use tf.math.log instead.\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "WARNING:tensorflow:From /tmp/ipykernel_29752/2797447731.py:3: The name tf.log is deprecated. Please use tf.math.log instead.\n",
      "\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "WARNING:tensorflow:From /tmp/ipykernel_29752/344484521.py:4: The name tf.train.AdamOptimizer is deprecated. Please use tf.compat.v1.train.AdamOptimizer instead.\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "WARNING:tensorflow:From /tmp/ipykernel_29752/344484521.py:4: The name tf.train.AdamOptimizer is deprecated. Please use tf.compat.v1.train.AdamOptimizer instead.\n",
      "\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "WARNING:tensorflow:From /home/ggao5/anaconda3/envs/ope_py37/lib/python3.7/site-packages/tensorflow_core/python/ops/math_grad.py:1375: where (from tensorflow.python.ops.array_ops) is deprecated and will be removed in a future version.\n",
      "Instructions for updating:\n",
      "Use tf.where in 2.0, which has the same broadcast rule as np.where\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "WARNING:tensorflow:From /home/ggao5/anaconda3/envs/ope_py37/lib/python3.7/site-packages/tensorflow_core/python/ops/math_grad.py:1375: where (from tensorflow.python.ops.array_ops) is deprecated and will be removed in a future version.\n",
      "Instructions for updating:\n",
      "Use tf.where in 2.0, which has the same broadcast rule as np.where\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "WARNING:tensorflow:From /tmp/ipykernel_29752/344484521.py:6: The name tf.train.Saver is deprecated. Please use tf.compat.v1.train.Saver instead.\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "WARNING:tensorflow:From /tmp/ipykernel_29752/344484521.py:6: The name tf.train.Saver is deprecated. Please use tf.compat.v1.train.Saver instead.\n",
      "\n"
     ]
    }
   ],
   "source": [
    "sample, log_prob = learn_dist_from_s(state_holder, CODE_SIZE)\n",
    "# loss = tf.reduce_mean(-dist.log_prob(action_holder))\n",
    "loss= tf.reduce_mean(-log_prob)\n",
    "optimize = tf.train.AdamOptimizer(lr).minimize(loss)\n",
    "\n",
    "saver = tf.train.Saver() # save all variables"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "50a4c0f2",
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "WARNING:tensorflow:From /tmp/ipykernel_29752/3606388409.py:1: The name tf.InteractiveSession is deprecated. Please use tf.compat.v1.InteractiveSession instead.\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "WARNING:tensorflow:From /tmp/ipykernel_29752/3606388409.py:1: The name tf.InteractiveSession is deprecated. Please use tf.compat.v1.InteractiveSession instead.\n",
      "\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "WARNING:tensorflow:From /tmp/ipykernel_29752/3606388409.py:2: The name tf.global_variables_initializer is deprecated. Please use tf.compat.v1.global_variables_initializer instead.\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "2022-11-14 23:06:26.166320: I tensorflow/core/platform/cpu_feature_guard.cc:142] Your CPU supports instructions that this TensorFlow binary was not compiled to use: SSE4.1 SSE4.2 AVX AVX2 FMA\n",
      "2022-11-14 23:06:26.177941: I tensorflow/core/platform/profile_utils/cpu_utils.cc:94] CPU Frequency: 2397365000 Hz\n",
      "2022-11-14 23:06:26.179378: I tensorflow/compiler/xla/service/service.cc:168] XLA service 0x55982e10a390 initialized for platform Host (this does not guarantee that XLA will be used). Devices:\n",
      "2022-11-14 23:06:26.179431: I tensorflow/compiler/xla/service/service.cc:176]   StreamExecutor device (0): Host, Default Version\n",
      "2022-11-14 23:06:26.180875: I tensorflow/stream_executor/platform/default/dso_loader.cc:44] Successfully opened dynamic library libcuda.so.1\n",
      "2022-11-14 23:06:26.199269: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1618] Found device 0 with properties: \n",
      "name: TITAN Xp major: 6 minor: 1 memoryClockRate(GHz): 1.582\n",
      "pciBusID: 0000:83:00.0\n",
      "2022-11-14 23:06:26.199531: I tensorflow/stream_executor/platform/default/dso_loader.cc:44] Successfully opened dynamic library libcudart.so.10.0\n",
      "2022-11-14 23:06:26.201234: I tensorflow/stream_executor/platform/default/dso_loader.cc:44] Successfully opened dynamic library libcublas.so.10.0\n",
      "2022-11-14 23:06:26.202908: I tensorflow/stream_executor/platform/default/dso_loader.cc:44] Successfully opened dynamic library libcufft.so.10.0\n",
      "2022-11-14 23:06:26.203222: I tensorflow/stream_executor/platform/default/dso_loader.cc:44] Successfully opened dynamic library libcurand.so.10.0\n",
      "2022-11-14 23:06:26.205281: I tensorflow/stream_executor/platform/default/dso_loader.cc:44] Successfully opened dynamic library libcusolver.so.10.0\n",
      "2022-11-14 23:06:26.206848: I tensorflow/stream_executor/platform/default/dso_loader.cc:44] Successfully opened dynamic library libcusparse.so.10.0\n",
      "2022-11-14 23:06:26.211368: I tensorflow/stream_executor/platform/default/dso_loader.cc:44] Successfully opened dynamic library libcudnn.so.7\n",
      "2022-11-14 23:06:26.213866: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1746] Adding visible gpu devices: 0\n",
      "2022-11-14 23:06:26.213923: I tensorflow/stream_executor/platform/default/dso_loader.cc:44] Successfully opened dynamic library libcudart.so.10.0\n",
      "2022-11-14 23:06:26.312848: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1159] Device interconnect StreamExecutor with strength 1 edge matrix:\n",
      "2022-11-14 23:06:26.312889: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1165]      0 \n",
      "2022-11-14 23:06:26.312899: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1178] 0:   N \n",
      "2022-11-14 23:06:26.315913: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1304] Created TensorFlow device (/job:localhost/replica:0/task:0/device:GPU:0 with 11306 MB memory) -> physical GPU (device: 0, name: TITAN Xp, pci bus id: 0000:83:00.0, compute capability: 6.1)\n",
      "2022-11-14 23:06:26.318596: I tensorflow/compiler/xla/service/service.cc:168] XLA service 0x55982ebd0ba0 initialized for platform CUDA (this does not guarantee that XLA will be used). Devices:\n",
      "2022-11-14 23:06:26.318613: I tensorflow/compiler/xla/service/service.cc:176]   StreamExecutor device (0): TITAN Xp, Compute Capability 6.1\n",
      "WARNING:tensorflow:From /tmp/ipykernel_29752/3606388409.py:2: The name tf.global_variables_initializer is deprecated. Please use tf.compat.v1.global_variables_initializer instead.\n",
      "\n",
      "2022-11-14 23:06:26.865526: I tensorflow/stream_executor/platform/default/dso_loader.cc:44] Successfully opened dynamic library libcublas.so.10.0\n",
      "2022-11-14 23:06:27.037930: I tensorflow/stream_executor/platform/default/dso_loader.cc:44] Successfully opened dynamic library libcudnn.so.7\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epi: 0, loss: 39.58692932128906\n",
      "epi: 1, loss: 25.042978286743164\n",
      "epi: 2, loss: 17.702434539794922\n",
      "epi: 3, loss: 13.543204307556152\n",
      "epi: 4, loss: 15.400638580322266\n",
      "epi: 5, loss: 12.138121604919434\n",
      "epi: 6, loss: 14.768411636352539\n",
      "epi: 7, loss: 9.277633666992188\n",
      "epi: 8, loss: 8.310900688171387\n",
      "epi: 9, loss: 8.571642875671387\n",
      "epi: 10, loss: 6.742438793182373\n",
      "epi: 11, loss: 11.217260360717773\n",
      "epi: 12, loss: 16.24464225769043\n",
      "epi: 13, loss: 9.781713485717773\n",
      "epi: 14, loss: 11.194061279296875\n",
      "epi: 15, loss: 11.077701568603516\n",
      "epi: 16, loss: 10.029929161071777\n",
      "epi: 17, loss: 10.472651481628418\n",
      "epi: 18, loss: 9.519920349121094\n",
      "epi: 19, loss: 6.719595909118652\n",
      "epi: 20, loss: 7.029800891876221\n",
      "epi: 21, loss: 13.228585243225098\n",
      "epi: 22, loss: 12.33446979522705\n",
      "epi: 23, loss: 9.70638656616211\n",
      "epi: 24, loss: 7.755185127258301\n",
      "epi: 25, loss: 9.388580322265625\n",
      "epi: 26, loss: 5.630100250244141\n",
      "epi: 27, loss: 4.08254337310791\n",
      "epi: 28, loss: 8.023987770080566\n",
      "epi: 29, loss: 11.464365005493164\n",
      "epi: 30, loss: 6.149938583374023\n",
      "epi: 31, loss: 3.870405912399292\n",
      "epi: 32, loss: 3.5955090522766113\n",
      "epi: 33, loss: 2.9530019760131836\n",
      "epi: 34, loss: 4.822900772094727\n",
      "epi: 35, loss: 6.9701666831970215\n",
      "epi: 36, loss: 2.4943501949310303\n",
      "epi: 37, loss: 1.1058002710342407\n",
      "epi: 38, loss: 1.5162769556045532\n",
      "epi: 39, loss: -0.5611898899078369\n",
      "epi: 40, loss: 2.7761919498443604\n",
      "epi: 41, loss: 9.770585060119629\n",
      "epi: 42, loss: 5.544129848480225\n",
      "epi: 43, loss: 5.064733982086182\n",
      "epi: 44, loss: 2.477437734603882\n",
      "epi: 45, loss: 8.005799293518066\n",
      "epi: 46, loss: 9.194964408874512\n",
      "epi: 47, loss: 5.050869941711426\n",
      "epi: 48, loss: 7.789857864379883\n",
      "epi: 49, loss: 3.173900842666626\n",
      "epi: 50, loss: 6.707129955291748\n",
      "epi: 51, loss: 2.2349371910095215\n",
      "epi: 52, loss: 6.700491905212402\n",
      "epi: 53, loss: 1.5279819965362549\n",
      "epi: 54, loss: 2.07889461517334\n",
      "epi: 55, loss: 3.2918028831481934\n",
      "epi: 56, loss: 3.5389652252197266\n",
      "epi: 57, loss: 3.5560202598571777\n",
      "epi: 58, loss: 3.170039176940918\n",
      "epi: 59, loss: 3.1129579544067383\n",
      "epi: 60, loss: 2.1282460689544678\n",
      "epi: 61, loss: 0.9162510633468628\n",
      "epi: 62, loss: -0.21727539598941803\n",
      "epi: 63, loss: 10.56550407409668\n",
      "epi: 64, loss: 6.9152302742004395\n",
      "epi: 65, loss: 2.399336099624634\n",
      "epi: 66, loss: 0.9894439578056335\n",
      "epi: 67, loss: -0.7547097206115723\n",
      "epi: 68, loss: 2.5901927947998047\n",
      "epi: 69, loss: 4.4673991203308105\n",
      "epi: 70, loss: -0.6391945481300354\n",
      "epi: 71, loss: 4.05164909362793\n",
      "epi: 72, loss: 3.038358211517334\n",
      "epi: 73, loss: 2.690361261367798\n",
      "epi: 74, loss: 1.0292493104934692\n",
      "epi: 75, loss: -0.8883960843086243\n",
      "epi: 76, loss: 2.424910306930542\n",
      "epi: 77, loss: -0.9165894985198975\n",
      "epi: 78, loss: -1.7329210042953491\n",
      "epi: 79, loss: 5.601339817047119\n",
      "epi: 80, loss: 3.6654608249664307\n",
      "epi: 81, loss: 7.625931262969971\n",
      "epi: 82, loss: 8.596928596496582\n",
      "epi: 83, loss: 1.640761375427246\n",
      "epi: 84, loss: 1.360178828239441\n",
      "epi: 85, loss: 1.559096336364746\n",
      "epi: 86, loss: 0.5589019656181335\n",
      "epi: 87, loss: 0.5139387249946594\n",
      "epi: 88, loss: 4.018736362457275\n",
      "epi: 89, loss: 0.13354384899139404\n",
      "epi: 90, loss: -0.2578454613685608\n",
      "epi: 91, loss: 3.781846523284912\n",
      "epi: 92, loss: 2.6530044078826904\n",
      "epi: 93, loss: 2.493283271789551\n",
      "epi: 94, loss: 13.844751358032227\n",
      "epi: 95, loss: 6.401968002319336\n",
      "epi: 96, loss: 2.0268585681915283\n",
      "epi: 97, loss: 4.376383304595947\n",
      "epi: 98, loss: 3.541147470474243\n",
      "epi: 99, loss: 4.4695634841918945\n"
     ]
    }
   ],
   "source": [
    "sess = tf.InteractiveSession(config=session_config)\n",
    "sess.run(tf.global_variables_initializer())\n",
    "\n",
    "replay_buffer = ReplayBuffer_Trajectory(state_dim, action_dim, horizon, buffer_size)\n",
    "replay_buffer.port_d4rl_data(d4rl_original_data, obs_mean, obs_std, rew_mean, rew_std, )\n",
    "\n",
    "\n",
    "for i in range(MAX_EPISODES):\n",
    "    \n",
    "    if replay_buffer.size > MINIBATCH_SIZE:\n",
    "        batch = replay_buffer.sample_batch(MINIBATCH_SIZE)    \n",
    "        (s_batch, a_batch) = (batch[\"obs1\"], batch[\"acts\"],)\n",
    "        \n",
    "        for _t in range(horizon):\n",
    "            if _t == 0:\n",
    "                loss_list = []\n",
    "            feed_dict={action_holder : a_batch[:, _t, :], state_holder : s_batch[:, _t, :],}\n",
    "            loss_val, _ = sess.run([loss, optimize], feed_dict)\n",
    "            loss_list += [loss_val]\n",
    "        \n",
    "        print('epi: {}, loss: {}'.format(i, np.mean(loss_list)))\n",
    "        \n",
    "        if np.mean(loss_list) < BEST_LOSS:\n",
    "            BEST_LOSS = np.mean(loss_list)\n",
    "            saver.save(sess, save_path)\n"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.11"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
