{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 30,
   "metadata": {
    "pycharm": {
     "name": "#%% Pendulum-v0 environment\n"
    }
   },
   "outputs": [],
   "source": [
    "import sys\n",
    "import os\n",
    "\n",
    "path = os.path.dirname(os.path.abspath(\"__file__\"))\n",
    "sys.path.insert(0, path + '/..')\n",
    "\n",
    "import base64\n",
    "import IPython\n",
    "import importlib\n",
    "import logging\n",
    "logging.getLogger().setLevel(logging.ERROR)\n",
    "import random\n",
    "import time\n",
    "from collections import namedtuple\n",
    "\n",
    "from tf_agents.environments import suite_gym, suite_dm_control, parallel_py_environment\n",
    "from tf_agents.environments import tf_py_environment, FlattenObservationsWrapper\n",
    "from tf_agents.metrics import tf_metrics\n",
    "from tf_agents.replay_buffers import tf_uniform_replay_buffer, episodic_replay_buffer\n",
    "from tf_agents.drivers import dynamic_episode_driver, dynamic_step_driver\n",
    "from tf_agents.trajectories import time_step as ts, policy_step, trajectory\n",
    "from tf_agents.utils import common\n",
    "\n",
    "import tensorflow as tf\n",
    "tf.get_logger().setLevel('ERROR')\n",
    "tf.autograph.set_verbosity(3)\n",
    "import tensorflow_probability as tfp\n",
    "tfd = tfp.distributions\n",
    "\n",
    "import numpy as np\n",
    "import json\n",
    "\n",
    "from reinforcement_learning import labeling_functions\n",
    "import reinforcement_learning.environments\n",
    "from reinforcement_learning.environments import EnvironmentLoader, perturbed_env\n",
    "from reinforcement_learning.metrics import AverageDiscountedReturnMetric\n",
    "from policies.saved_policy import SavedTFPolicy\n",
    "from policies.epsilon_mimic import EpsilonMimicPolicy\n",
    "from policies.latent_policy import LatentPolicyOverRealStateAndActionSpaces\n",
    "\n",
    "from verification import model, local_losses, binary_latent_space\n",
    "from verification.local_losses import compute_values_from_initial_distribution\n",
    "from util.io.dataset_generator import ergodic_batched_labeling_function, is_reset_state\n",
    "\n",
    "from util.io import video\n",
    "import wasserstein_mdp\n",
    "\n",
    "# set seed\n",
    "seed = 42\n",
    "os.environ['PYTHONHASHSEED'] = str(seed)\n",
    "random.seed(seed)\n",
    "np.random.seed(seed)\n",
    "tf.random.set_seed(seed)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "def embed_mp4(filename):\n",
    "  \"\"\"Embeds an mp4 file in the notebook.\"\"\"\n",
    "  video = open(filename,'rb').read()\n",
    "  b64 = base64.b64encode(video)\n",
    "  tag = '''\n",
    "  <video width=\"640\" height=\"480\" controls>\n",
    "    <source src=\"data:video/mp4;base64,{0}\" type=\"video/mp4\">\n",
    "  Your browser does not support the video tag.\n",
    "  </video>'''.format(b64.decode())\n",
    "\n",
    "  return IPython.display.HTML(tag)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [],
   "source": [
    "def display_state_space(py_env):\n",
    "    print(\"state space shape:\", py_env.observation_spec().shape)\n",
    "    try:\n",
    "        print(\"state space max values:\", py_env.observation_spec().maximum)\n",
    "        print(\"state space min values:\", py_env.observation_spec().minimum)\n",
    "    except AttributeError as e:\n",
    "        pass\n",
    "\n",
    "def display_action_space(py_env):\n",
    "    if py_env.action_spec().dtype in [np.int64, np.int32]:\n",
    "        print(\"discrete action space\")\n",
    "        print(\"number of discrete actions:\", py_env.action_spec().maximum + 1)\n",
    "    else:\n",
    "        print(\"continuous action space\")\n",
    "        print(\"action space shape:\", py_env.action_spec().shape)\n",
    "        print(\"action space max values:\", py_env.action_spec().maximum)\n",
    "        print(\"action space min values:\", py_env.action_spec().minimum)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# CartPole"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## RL policy (DQN)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "video_path = 'policy_videos/cartpole_dqn'\n",
    "\n",
    "with suite_gym.load('CartPole-v0') as py_env:\n",
    "    py_env.reset()\n",
    "    tf_env = tf_py_environment.TFPyEnvironment(py_env)\n",
    "    \n",
    "    display_state_space(py_env)\n",
    "    display_action_space(py_env)\n",
    "\n",
    "    policy_dir = '../reinforcement_learning/saves/CartPole-v0/policy'\n",
    "    policy = SavedTFPolicy(policy_dir)\n",
    "    num_episodes=30\n",
    "\n",
    "    reward_metric = tf_metrics.AverageReturnMetric()\n",
    "    \n",
    "    video_observer = video.VideoEmbeddingObserver(\n",
    "        py_env, video_path, num_episodes=num_episodes)\n",
    "    dynamic_episode_driver.DynamicEpisodeDriver(\n",
    "        tf_env,\n",
    "        policy,\n",
    "        num_episodes=num_episodes,\n",
    "        observers=[\n",
    "            reward_metric,\n",
    "            video_observer,\n",
    "        ]).run()\n",
    "\n",
    "    tf.print('avg. episode return:', reward_metric.result())\n",
    "\n",
    "embed_mp4(video_observer.file_name)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Distilled policy"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 50,
   "metadata": {
    "collapsed": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "{'self': '<wasserstein_mdp.WassersteinMarkovDecisionProcess object at 0x2addb46fe040>', 'state_shape': '(4,)', 'action_shape': '(2,)', 'reward_shape': '(1,)', 'label_shape': '(2,)', 'discretize_action_space': 'False', 'state_encoder_network': \"ModelArchitecture(hidden_units=[64, 64, 64], activation='tanh', name='state_encoder_network_base')\", 'action_decoder_network': \"ModelArchitecture(hidden_units=[64, 64, 64], activation='tanh', name='action_decoder_network_base')\", 'transition_network': \"ModelArchitecture(hidden_units=[64, 64, 64], activation='tanh', name='transition_network_base')\", 'reward_network': \"ModelArchitecture(hidden_units=[64, 64, 64], activation='tanh', name='reward_network_base')\", 'decoder_network': \"ModelArchitecture(hidden_units=[64, 64, 64], activation='tanh', name='state_decoder_network_base')\", 'latent_policy_network': \"ModelArchitecture(hidden_units=[64, 64, 64], activation='tanh', name='discrete_policy_network_base')\", 'steady_state_lipschitz_network': \"ModelArchitecture(hidden_units=[64, 64, 64], activation='tanh', name='steady_state_network_base')\", 'transition_loss_lipschitz_network': \"ModelArchitecture(hidden_units=[64, 64, 64], activation='tanh', name='transition_loss_network_base')\", 'latent_state_size': '9', 'number_of_discrete_actions': '16', 'action_encoder_network': \"ModelArchitecture(hidden_units=[64, 64, 64], activation='tanh', name='action_encoder_network_base')\", 'state_encoder_pre_processing_network': 'None', 'state_decoder_pre_processing_network': 'None', 'time_stacked_states': 'False', 'state_encoder_temperature': '0.3333333333333333', 'state_prior_temperature': '0.3333333333333333', 'action_encoder_temperature': '0.99', 'latent_policy_temperature': '0.6666666666666666', 'wasserstein_regularizer_scale_factor': 'WassersteinRegularizerScaleFactor(global_scaling=10.0, global_gradient_penalty_multiplier=20.0, steady_state_scaling=75.0, steady_state_gradient_penalty_multiplier=None, local_transition_loss_scaling=75.0, local_transition_loss_gradient_penalty_multiplier=None)', 'encoder_temperature_decay_rate': '1e-06', 'prior_temperature_decay_rate': '2e-06', 'reset_state_label': 'True', 'autoencoder_optimizer': 'None', 'wasserstein_regularizer_optimizer': 'None', 'entropy_regularizer_scale_factor': '0.0', 'entropy_regularizer_decay_rate': '0.0', 'entropy_regularizer_scale_factor_min_value': '0.0', 'importance_sampling_exponent': '0.4', 'importance_sampling_exponent_growth_rate': '1e-05', 'time_stacked_lstm_units': '128', 'reward_bounds': 'None', 'latent_stationary_network': 'None', 'action_entropy_regularizer_scaling': '0.0', 'enforce_upper_bound': 'False', 'squared_wasserstein': 'False', 'n_critic': '5', 'trainable_prior': 'False', 'state_encoder_type': 'EncodingType.DETERMINISTIC', 'policy_based_decoding': 'False', 'deterministic_state_embedding': 'True', 'state_encoder_softclipping': 'False', 'args': '()', 'kwargs': \"{'evaluation_window_size': 5}\", '__class__': \"<class 'wasserstein_mdp.WassersteinMarkovDecisionProcess'>\", 'eval_policy': '200.0', 'local_reward_loss': '0.0037750935', 'local_transition_loss': '0.40456417', 'training_step': '120000'}\n",
      "Model: \"state_encoder\"\n",
      "_________________________________________________________________\n",
      " Layer (type)                Output Shape              Param #   \n",
      "=================================================================\n",
      " state (InputLayer)          [(None, 4)]               0         \n",
      "                                                                 \n",
      " state_encoder_body (Sequent  (None, 60)               8380      \n",
      " ial)                                                            \n",
      "                                                                 \n",
      " dense_28 (Dense)            (None, 6)                 366       \n",
      "                                                                 \n",
      "=================================================================\n",
      "Total params: 8,746\n",
      "Trainable params: 8,746\n",
      "Non-trainable params: 0\n",
      "_________________________________________________________________\n",
      "No action encoder\n",
      "Model: \"autoregressive_transition_network\"\n",
      "__________________________________________________________________________________________________\n",
      " Layer (type)                   Output Shape         Param #     Connected to                     \n",
      "==================================================================================================\n",
      " logistic_layer_input (InputLay  [(None, 11)]        0           []                               \n",
      " er)                                                                                              \n",
      "                                                                                                  \n",
      " sequential_logistic_distributi  (None, 9)           0           ['logistic_layer_input[0][0]']   \n",
      " on_layer (Sequential)                                                                            \n",
      "                                                                                                  \n",
      " autoregressive_transform (Auto  ((None, 9),         11757       ['sequential_logistic_distributio\n",
      " regressiveTransform)            (None, 9))                      n_layer[0][0]',                  \n",
      "                                                                  'logistic_layer_input[0][0]']   \n",
      "                                                                                                  \n",
      " autoregressive_network_14 (Aut  (None, 9, 1)        11756       []                               \n",
      " oregressiveNetwork)                                                                              \n",
      "                                                                                                  \n",
      "==================================================================================================\n",
      "Total params: 11,757\n",
      "Trainable params: 11,756\n",
      "Non-trainable params: 1\n",
      "__________________________________________________________________________________________________\n",
      "Model: \"latent_stationary_network\"\n",
      "_________________________________________________________________\n",
      " Layer (type)                Output Shape              Param #   \n",
      "=================================================================\n",
      " logistic_layer_input (Input  [(None, 0)]              0         \n",
      " Layer)                                                          \n",
      "                                                                 \n",
      " sequential_logistic_distrib  (None, 3)                0         \n",
      " ution_layer (Sequential)                                        \n",
      "                                                                 \n",
      " autoregressive_transform (A  ((None, 3),              8772      \n",
      " utoregressiveTransform)      (None, 3))                         \n",
      "                                                                 \n",
      " autoregressive_network_15 (  (None, 3, 1)             8771      \n",
      " AutoregressiveNetwork)                                          \n",
      "                                                                 \n",
      "=================================================================\n",
      "Total params: 8,778\n",
      "Trainable params: 8,771\n",
      "Non-trainable params: 7\n",
      "_________________________________________________________________\n",
      "Model: \"latent_policy_network\"\n",
      "_________________________________________________________________\n",
      " Layer (type)                Output Shape              Param #   \n",
      "=================================================================\n",
      " latent_state (InputLayer)   [(None, 9)]               0         \n",
      "                                                                 \n",
      " discrete_policy_network_bas  (None, 64)               8960      \n",
      " e (Sequential)                                                  \n",
      "                                                                 \n",
      " latent_policy_categorical_l  (None, 2)                130       \n",
      " ogits (Dense)                                                   \n",
      "                                                                 \n",
      "=================================================================\n",
      "Total params: 9,090\n",
      "Trainable params: 9,090\n",
      "Non-trainable params: 0\n",
      "_________________________________________________________________\n",
      "Model: \"reward_network\"\n",
      "__________________________________________________________________________________________________\n",
      " Layer (type)                   Output Shape         Param #     Connected to                     \n",
      "==================================================================================================\n",
      " latent_state (InputLayer)      [(None, 9)]          0           []                               \n",
      "                                                                                                  \n",
      " latent_action (InputLayer)     [(None, 2)]          0           []                               \n",
      "                                                                                                  \n",
      " next_latent_state (InputLayer)  [(None, 9)]         0           []                               \n",
      "                                                                                                  \n",
      " reward_function_input (Concate  (None, 20)          0           ['latent_state[0][0]',           \n",
      " nate)                                                            'latent_action[0][0]',          \n",
      "                                                                  'next_latent_state[0][0]']      \n",
      "                                                                                                  \n",
      " reward_network_base (Sequentia  (None, 64)          9664        ['reward_function_input[0][0]']  \n",
      " l)                                                                                               \n",
      "                                                                                                  \n",
      " reward_network_raw_output (Den  (None, 1)           65          ['reward_network_base[0][0]']    \n",
      " se)                                                                                              \n",
      "                                                                                                  \n",
      " reward (Reshape)               (None, 1)            0           ['reward_network_raw_output[0][0]\n",
      "                                                                 ']                               \n",
      "                                                                                                  \n",
      "==================================================================================================\n",
      "Total params: 9,729\n",
      "Trainable params: 9,729\n",
      "Non-trainable params: 0\n",
      "__________________________________________________________________________________________________\n",
      "Model: \"state_reconstruction_network\"\n",
      "_________________________________________________________________\n",
      " Layer (type)                Output Shape              Param #   \n",
      "=================================================================\n",
      " next_latent_state (InputLay  [(None, 9)]              0         \n",
      " er)                                                             \n",
      "                                                                 \n",
      " state_decoder_network_base   (None, 64)               8960      \n",
      " (Sequential)                                                    \n",
      "                                                                 \n",
      " state_decoder (Sequential)  (None, 4)                 260       \n",
      "                                                                 \n",
      "=================================================================\n",
      "Total params: 9,220\n",
      "Trainable params: 9,220\n",
      "Non-trainable params: 0\n",
      "_________________________________________________________________\n",
      "Model: \"steady_state_lipschitz_network\"\n",
      "__________________________________________________________________________________________________\n",
      " Layer (type)                   Output Shape         Param #     Connected to                     \n",
      "==================================================================================================\n",
      " latent_state (InputLayer)      [(None, 9)]          0           []                               \n",
      "                                                                                                  \n",
      " latent_action (InputLayer)     [(None, 2)]          0           []                               \n",
      "                                                                                                  \n",
      " next_latent_state (InputLayer)  [(None, 9)]         0           []                               \n",
      "                                                                                                  \n",
      " concatenate_14 (Concatenate)   (None, 20)           0           ['latent_state[0][0]',           \n",
      "                                                                  'latent_action[0][0]',          \n",
      "                                                                  'next_latent_state[0][0]']      \n",
      "                                                                                                  \n",
      " steady_state_network_base (Seq  (None, 64)          9664        ['concatenate_14[0][0]']         \n",
      " uential)                                                                                         \n",
      "                                                                                                  \n",
      " steady_state_lipschitz_network  (None, 1)           65          ['steady_state_network_base[0][0]\n",
      " _output (Dense)                                                 ']                               \n",
      "                                                                                                  \n",
      "==================================================================================================\n",
      "Total params: 9,729\n",
      "Trainable params: 9,729\n",
      "Non-trainable params: 0\n",
      "__________________________________________________________________________________________________\n",
      "Model: \"transition_loss_lipschitz_network\"\n",
      "__________________________________________________________________________________________________\n",
      " Layer (type)                   Output Shape         Param #     Connected to                     \n",
      "==================================================================================================\n",
      " state (InputLayer)             [(None, 4)]          0           []                               \n",
      "                                                                                                  \n",
      " action (InputLayer)            [(None, 2)]          0           []                               \n",
      "                                                                                                  \n",
      " latent_state (InputLayer)      [(None, 9)]          0           []                               \n",
      "                                                                                                  \n",
      " next_latent_state (InputLayer)  [(None, 9)]         0           []                               \n",
      "                                                                                                  \n",
      " concatenate_15 (Concatenate)   (None, 24)           0           ['state[0][0]',                  \n",
      "                                                                  'action[0][0]',                 \n",
      "                                                                  'latent_state[0][0]',           \n",
      "                                                                  'next_latent_state[0][0]']      \n",
      "                                                                                                  \n",
      " transition_loss_network_base (  (None, 64)          9920        ['concatenate_15[0][0]']         \n",
      " Sequential)                                                                                      \n",
      "                                                                                                  \n",
      " transition_loss_lipschitz_netw  (None, 1)           65          ['transition_loss_network_base[0]\n",
      " ork_output (Dense)                                              [0]']                            \n",
      "                                                                                                  \n",
      "==================================================================================================\n",
      "Total params: 9,985\n",
      "Trainable params: 9,985\n",
      "Non-trainable params: 0\n",
      "__________________________________________________________________________________________________\n",
      "WAE-MDP loaded\n"
     ]
    }
   ],
   "source": [
    "wae_model_path = 'saved_models/experiments/CartPole-v0/model/'\n",
    "\n",
    "with open(os.path.join(wae_model_path, 'model_infos.json'), 'r') as f:\n",
    "    wae_data = json.load(f)\n",
    "    print(wae_data)\n",
    "\n",
    "wae_mdp = wasserstein_mdp.load(wae_model_path)\n",
    "print(\"WAE-MDP loaded\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 96,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "WAE-MDP at training step 120000\n",
      "Size of the latent state space: 512\n"
     ]
    }
   ],
   "source": [
    "print(\"WAE-MDP at training step {:d}\".format(eval(wae_data['training_step'])))\n",
    "print(\"Size of the latent state space: {:d}\".format(2 ** wae_mdp.latent_state_size))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "video_path = 'policy_videos/cartpole_wae_distillation'\n",
    "with suite_gym.load('CartPole-v0') as py_env:\n",
    "    py_env.reset()\n",
    "    tf_env = tf_py_environment.TFPyEnvironment(py_env)\n",
    "    original_state = tf_env.current_time_step().observation\n",
    "    \n",
    "    tf_env = wae_mdp.wrap_tf_environment(tf_env, labeling_functions['CartPole-v0'])\n",
    "    policy =tf_env.wrap_latent_policy(wae_mdp.get_latent_policy(action_dtype=tf.int64))\n",
    "    \n",
    "    num_episodes=30\n",
    "    reward_metric = tf_metrics.AverageReturnMetric()\n",
    "    discounted_reward_metric = AverageDiscountedReturnMetric(\n",
    "        gamma=.99, reward_scale=wae_mdp._dynamic_reward_scaling)\n",
    "    video_observer = video.VideoEmbeddingObserver(\n",
    "        py_env, video_path, num_episodes=num_episodes)\n",
    "    \n",
    "    dynamic_episode_driver.DynamicEpisodeDriver(\n",
    "        tf_env, policy, num_episodes=num_episodes,\n",
    "        observers=[\n",
    "            reward_metric,\n",
    "            discounted_reward_metric,\n",
    "            video_observer,\n",
    "        ]).run()\n",
    "    \n",
    "\n",
    "tf.print('avg. episode return:', reward_metric.result())\n",
    "tf.print('avg. discounted (scaled) return:', discounted_reward_metric.result())\n",
    "embed_mp4(video_observer.file_name)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 53,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Local reward loss: 0.00499653\n",
      "Local transition loss: 0.399636\n",
      "Local transition loss (freq. estimation): 0.421809\n",
      "Time metrics:\n",
      "    Fill in the Replay Buffer (100000 frames): 119.436\n",
      "    Estimate the local reward loss function (from 33424 transitions): 1.631\n",
      "    Transition model generation (empirical frequency estimation, from 33424 transitions): 4.192\n",
      "    Estimate the local transition loss function (from 33424 transitions): 0.065\n",
      "    Estimate the local transition loss function via the frequency-estimated transition function:: 27.531\n"
     ]
    }
   ],
   "source": [
    "# PAC bounds for local losses\n",
    "# the bound computed during training can already be found in the log file (wae_data)\n",
    "epsilon = 1e-2\n",
    "delta = 5e-3\n",
    "T = int(np.ceil(-np.log(delta / 4) / (2 * epsilon**2)))\n",
    "\n",
    "with suite_gym.load(\n",
    "    'CartPole-v0',\n",
    "    env_wrappers=[lambda env: perturbed_env.PerturbedEnvironment(env, .75)]\n",
    ") as py_env:\n",
    "    py_env.reset()\n",
    "    tf_env = tf_py_environment.TFPyEnvironment(py_env)\n",
    "    local_losses_metrics = wae_mdp.estimate_local_losses_from_samples(\n",
    "        tf_env,\n",
    "        steps=T,\n",
    "        labeling_function=labeling_functions['CartPole-v0'],\n",
    "        estimate_transition_function_from_samples=True,\n",
    "        reward_scaling=wae_mdp._dynamic_reward_scaling,\n",
    "        estimate_value_difference=False)\n",
    "\n",
    "tf.print('Local reward loss: {:.6g}'.format(local_losses_metrics.local_reward_loss))\n",
    "tf.print('Local transition loss: {:.6g}'.format(local_losses_metrics.local_transition_loss))\n",
    "tf.print('Local transition loss (freq. estimation): {:.6g}'.format(\n",
    "    local_losses_metrics.local_transition_loss_transition_function_estimation))\n",
    "local_losses_metrics.print_time_metrics()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 82,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Local reward loss: 0.0038\n",
      "Local transition loss: 0.4\n",
      "Transition/reward model generation\n",
      "Time to generate the model: 1.5 sec\n"
     ]
    }
   ],
   "source": [
    "_latent_reward_fn = lambda latent_state, latent_action, next_latent_state: \\\n",
    "    wae_mdp._dynamic_reward_scaling * wae_mdp.reward_distribution(\n",
    "        latent_state=tf.cast(latent_state, dtype=tf.float32),\n",
    "        latent_action=tf.cast(latent_action, dtype=tf.float32),\n",
    "        next_latent_state=tf.cast(next_latent_state, dtype=tf.float32),\n",
    "    ).mode()\n",
    "# as the distribution is deterministic, taking the mode\n",
    "# allows to retrieve the Dirac impulsion \n",
    "\n",
    "_latent_transition_fn = lambda latent_state, latent_action: \\\n",
    "        wae_mdp.discrete_latent_transition(\n",
    "            tf.cast(latent_state, tf.float32),\n",
    "            tf.cast(latent_action, tf.float32))\n",
    "\n",
    "print('Local reward loss: {:.2g}'.format(eval(wae_data['local_reward_loss'])))\n",
    "print('Local transition loss: {:.2g}'.format(eval(wae_data['local_transition_loss'])))\n",
    "\n",
    "print('Transition/reward model generation')\n",
    "start = time.time()\n",
    "\n",
    "#  write the transition/reward functions to tensors,\n",
    "#  to formally check the values in an efficient way\n",
    "latent_transition_fn = model.TransitionFunctionCopy(\n",
    "    num_states=tf.cast(tf.pow(2, wae_mdp.latent_state_size), dtype=tf.int32),\n",
    "    num_actions=wae_mdp.number_of_discrete_actions,\n",
    "    transition_function=_latent_transition_fn,\n",
    "    epsilon=1e-6)\n",
    "\n",
    "latent_reward_fn = model.RewardFunctionCopy(\n",
    "    num_states=tf.cast(tf.pow(2, wae_mdp.latent_state_size), dtype=tf.int32),\n",
    "    num_actions=wae_mdp.number_of_discrete_actions,\n",
    "    reward_function=_latent_reward_fn,\n",
    "    transition_function=_latent_transition_fn,\n",
    "    epsilon=1e-6)\n",
    "\n",
    "end = time.time() - start\n",
    "\n",
    "print(\"Time to generate the model: {:.2g} sec\".format(end))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 80,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Value difference: 3.71213\n",
      "Time to compute the value difference: 0.861585 sec\n"
     ]
    }
   ],
   "source": [
    "start = time.time()\n",
    "\n",
    "latent_mdp_values = compute_values_from_initial_distribution(\n",
    "    latent_state_size=wae_mdp.latent_state_size,\n",
    "    atomic_prop_dims=wae_mdp.atomic_prop_dims,\n",
    "    original_state=original_state,\n",
    "    number_of_discrete_actions=wae_mdp.number_of_discrete_actions,\n",
    "    latent_policy=wae_mdp.get_latent_policy(action_dtype=tf.int64),\n",
    "    latent_transition_fn=latent_transition_fn,\n",
    "    latent_reward_function=latent_reward_fn,\n",
    "    epsilon=1e-6,\n",
    "    gamma=.99,\n",
    "    stochastic_state_embedding=lambda original_state: tfd.Independent(\n",
    "        tfd.Deterministic(loc=wae_mdp.state_embedding_function(\n",
    "            original_state,\n",
    "            ergodic_batched_labeling_function(\n",
    "                labeling_functions['CartPole-v0']\n",
    "            )(original_state))),\n",
    "        reinterpreted_batch_ndims=1)\n",
    ")\n",
    "\n",
    "value_difference = tf.abs(discounted_reward_metric.result() - latent_mdp_values)\n",
    "\n",
    "tf.print(\"Value difference: {:.6g}\".format(value_difference))\n",
    "\n",
    "end = time.time() - start\n",
    "\n",
    "print(\"Time to compute the value difference: {:2g} sec\".format(end))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Time-to-failure property:\n",
    "$\\neg\\mathsf{Reset} \\, \\mathcal{U} \\, \\mathsf{Unsafe}$\n",
    "where $\\mathsf{Unsafe} \\in \\ell\\left(s\\right)$ iff the cart position is greather than 1.5 or the pole angle is greather than 9 degrees"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 81,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "property values: 0.0316655\n",
      "Time to compute the values of the property: 0.730968 sec\n"
     ]
    }
   ],
   "source": [
    "reset_state_test_fn = lambda latent_state: is_reset_state(latent_state, wae_mdp.atomic_prop_dims)\n",
    "unsafe_state_test_fn = lambda latent_state: tf.logical_or(\n",
    "    tf.cast(1. - latent_state[..., 0], tf.bool),\n",
    "    tf.cast(1. - latent_state[..., 1], tf.bool))\n",
    "\n",
    "absorbing_states = lambda latent_state: tf.logical_or(\n",
    "    reset_state_test_fn(latent_state),\n",
    "    unsafe_state_test_fn(latent_state))\n",
    "\n",
    "state_space = binary_latent_space(wae_mdp.latent_state_size, dtype=tf.float32)\n",
    "# gives a reward of 1 when transitioning to an unsafe state\n",
    "reward_objective = tf.ones(\n",
    "    shape=(tf.pow(2, wae_mdp.latent_state_size),\n",
    "           wae_mdp.number_of_discrete_actions,\n",
    "           tf.pow(2, wae_mdp.latent_state_size))\n",
    ") * tf.cast(unsafe_state_test_fn(state_space), tf.float32)\n",
    "\n",
    "reward_objective_fn = namedtuple('reward_fn', ['to_dense'])(lambda: reward_objective)\n",
    "\n",
    "start = time.time()\n",
    "\n",
    "latent_mdp_values = compute_values_from_initial_distribution(\n",
    "    latent_state_size=wae_mdp.latent_state_size,\n",
    "    atomic_prop_dims=wae_mdp.atomic_prop_dims,\n",
    "    original_state=original_state,\n",
    "    number_of_discrete_actions=wae_mdp.number_of_discrete_actions,\n",
    "    latent_policy=wae_mdp.get_latent_policy(action_dtype=tf.int64),\n",
    "    latent_transition_fn=latent_transition_fn,\n",
    "    latent_reward_function=reward_objective_fn,\n",
    "    epsilon=1e-6,\n",
    "    gamma=.99,\n",
    "    stochastic_state_embedding=lambda original_state: tfd.Independent(\n",
    "        tfd.Deterministic(\n",
    "            loc=wae_mdp.state_embedding_function(\n",
    "                original_state,\n",
    "                ergodic_batched_labeling_function(\n",
    "                    labeling_functions['CartPole-v0']\n",
    "                )(original_state))),\n",
    "        reinterpreted_batch_ndims=1),\n",
    "    absorbing_states=absorbing_states)\n",
    "\n",
    "tf.print(\"property values: {:.6g}\".format(latent_mdp_values))\n",
    "\n",
    "end = time.time() - start\n",
    "print(\"Time to compute the values of the property: {:2g} sec\".format(end))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# MountainCar"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## RL policy (DQN)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "video_path = 'policy_videos/mountain_car_dqn'\n",
    "\n",
    "with suite_gym.load('MountainCar-v0') as py_env:\n",
    "    py_env.reset()\n",
    "    tf_env = tf_py_environment.TFPyEnvironment(py_env)\n",
    "    \n",
    "    display_state_space(py_env)\n",
    "    display_action_space(py_env)\n",
    "\n",
    "    policy_dir = '../reinforcement_learning/saves/MountainCar-v0/dqn_policy'\n",
    "    policy = SavedTFPolicy(policy_dir)\n",
    "    num_episodes=30\n",
    "\n",
    "    reward_metric = tf_metrics.AverageReturnMetric()\n",
    "    \n",
    "    video_observer = video.VideoEmbeddingObserver(\n",
    "        py_env, video_path, num_episodes=num_episodes)\n",
    "    dynamic_episode_driver.DynamicEpisodeDriver(\n",
    "        tf_env,\n",
    "        policy,\n",
    "        num_episodes=num_episodes,\n",
    "        observers=[\n",
    "            reward_metric,\n",
    "            video_observer,\n",
    "        ]).run()\n",
    "\n",
    "    tf.print('avg. episode return:', reward_metric.result())\n",
    "\n",
    "embed_mp4(video_observer.file_name)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Distilled policy"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 54,
   "metadata": {
    "collapsed": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "{'self': '<wasserstein_mdp.WassersteinMarkovDecisionProcess object at 0x2ad85f220190>', 'state_shape': '(2,)', 'action_shape': '(3,)', 'reward_shape': '(1,)', 'label_shape': '(3,)', 'discretize_action_space': 'False', 'state_encoder_network': \"ModelArchitecture(hidden_units=[64, 64], activation='relu', name='state_encoder_network_base')\", 'action_decoder_network': \"ModelArchitecture(hidden_units=[64, 64], activation='relu', name='action_decoder_network_base')\", 'transition_network': \"ModelArchitecture(hidden_units=[64, 64], activation='relu', name='transition_network_base')\", 'reward_network': \"ModelArchitecture(hidden_units=[64, 64], activation='relu', name='reward_network_base')\", 'decoder_network': \"ModelArchitecture(hidden_units=[64, 64], activation='relu', name='state_decoder_network_base')\", 'latent_policy_network': \"ModelArchitecture(hidden_units=[64, 64], activation='relu', name='discrete_policy_network_base')\", 'steady_state_lipschitz_network': \"ModelArchitecture(hidden_units=[64, 64], activation='relu', name='steady_state_network_base')\", 'transition_loss_lipschitz_network': \"ModelArchitecture(hidden_units=[64, 64], activation='relu', name='transition_loss_network_base')\", 'latent_state_size': '10', 'number_of_discrete_actions': '5', 'action_encoder_network': \"ModelArchitecture(hidden_units=[64, 64], activation='relu', name='action_encoder_network_base')\", 'state_encoder_pre_processing_network': 'None', 'state_decoder_pre_processing_network': 'None', 'time_stacked_states': 'False', 'state_encoder_temperature': '0.1', 'state_prior_temperature': '0.3333333333333333', 'action_encoder_temperature': '-1.0', 'latent_policy_temperature': '0.3333333333333333', 'wasserstein_regularizer_scale_factor': 'WassersteinRegularizerScaleFactor(global_scaling=None, global_gradient_penalty_multiplier=10, steady_state_scaling=100.0, steady_state_gradient_penalty_multiplier=None, local_transition_loss_scaling=25.0, local_transition_loss_gradient_penalty_multiplier=None)', 'encoder_temperature_decay_rate': '0.0', 'prior_temperature_decay_rate': '0.0', 'reset_state_label': 'True', 'autoencoder_optimizer': 'None', 'wasserstein_regularizer_optimizer': 'None', 'entropy_regularizer_scale_factor': '0.0', 'entropy_regularizer_decay_rate': '0.0', 'entropy_regularizer_scale_factor_min_value': '0.0', 'importance_sampling_exponent': '1.0', 'importance_sampling_exponent_growth_rate': '1.0', 'time_stacked_lstm_units': '128', 'reward_bounds': 'None', 'latent_stationary_network': 'None', 'action_entropy_regularizer_scaling': '0.0', 'enforce_upper_bound': 'False', 'squared_wasserstein': 'False', 'n_critic': '20', 'trainable_prior': 'False', 'state_encoder_type': 'EncodingType.DETERMINISTIC', 'policy_based_decoding': 'False', 'deterministic_state_embedding': 'True', 'state_encoder_softclipping': 'True', 'args': '()', 'kwargs': '{}', '__class__': \"<class 'wasserstein_mdp.WassersteinMarkovDecisionProcess'>\", 'eval_policy': '-92.1', 'local_reward_loss': '0.014176323', 'local_transition_loss': '0.38232273', 'training_step': '232000'}\n",
      "Model: \"state_encoder\"\n",
      "_________________________________________________________________\n",
      " Layer (type)                Output Shape              Param #   \n",
      "=================================================================\n",
      " state (InputLayer)          [(None, 2)]               0         \n",
      "                                                                 \n",
      " state_encoder_body (Sequent  (None, 60)               4092      \n",
      " ial)                                                            \n",
      "                                                                 \n",
      " dense_31 (Dense)            (None, 6)                 366       \n",
      "                                                                 \n",
      "=================================================================\n",
      "Total params: 4,458\n",
      "Trainable params: 4,458\n",
      "Non-trainable params: 0\n",
      "_________________________________________________________________\n",
      "No action encoder\n",
      "Model: \"autoregressive_transition_network\"\n",
      "__________________________________________________________________________________________________\n",
      " Layer (type)                   Output Shape         Param #     Connected to                     \n",
      "==================================================================================================\n",
      " logistic_layer_input (InputLay  [(None, 13)]        0           []                               \n",
      " er)                                                                                              \n",
      "                                                                                                  \n",
      " sequential_logistic_distributi  (None, 10)          0           ['logistic_layer_input[0][0]']   \n",
      " on_layer (Sequential)                                                                            \n",
      "                                                                                                  \n",
      " autoregressive_transform (Auto  ((None, 10),        7309        ['sequential_logistic_distributio\n",
      " regressiveTransform)            (None, 10))                     n_layer[0][0]',                  \n",
      "                                                                  'logistic_layer_input[0][0]']   \n",
      "                                                                                                  \n",
      " autoregressive_network_16 (Aut  (None, 10, 1)       7308        []                               \n",
      " oregressiveNetwork)                                                                              \n",
      "                                                                                                  \n",
      "==================================================================================================\n",
      "Total params: 7,309\n",
      "Trainable params: 7,308\n",
      "Non-trainable params: 1\n",
      "__________________________________________________________________________________________________\n",
      "Model: \"latent_stationary_network\"\n",
      "_________________________________________________________________\n",
      " Layer (type)                Output Shape              Param #   \n",
      "=================================================================\n",
      " logistic_layer_input (Input  [(None, 0)]              0         \n",
      " Layer)                                                          \n",
      "                                                                 \n",
      " sequential_logistic_distrib  (None, 4)                0         \n",
      " ution_layer (Sequential)                                        \n",
      "                                                                 \n",
      " autoregressive_transform (A  ((None, 4),              4741      \n",
      " utoregressiveTransform)      (None, 4))                         \n",
      "                                                                 \n",
      " autoregressive_network_17 (  (None, 4, 1)             4740      \n",
      " AutoregressiveNetwork)                                          \n",
      "                                                                 \n",
      "=================================================================\n",
      "Total params: 4,747\n",
      "Trainable params: 4,740\n",
      "Non-trainable params: 7\n",
      "_________________________________________________________________\n",
      "Model: \"latent_policy_network\"\n",
      "_________________________________________________________________\n",
      " Layer (type)                Output Shape              Param #   \n",
      "=================================================================\n",
      " latent_state (InputLayer)   [(None, 10)]              0         \n",
      "                                                                 \n",
      " discrete_policy_network_bas  (None, 64)               4864      \n",
      " e (Sequential)                                                  \n",
      "                                                                 \n",
      " latent_policy_categorical_l  (None, 3)                195       \n",
      " ogits (Dense)                                                   \n",
      "                                                                 \n",
      "=================================================================\n",
      "Total params: 5,059\n",
      "Trainable params: 5,059\n",
      "Non-trainable params: 0\n",
      "_________________________________________________________________\n",
      "Model: \"reward_network\"\n",
      "__________________________________________________________________________________________________\n",
      " Layer (type)                   Output Shape         Param #     Connected to                     \n",
      "==================================================================================================\n",
      " latent_state (InputLayer)      [(None, 10)]         0           []                               \n",
      "                                                                                                  \n",
      " latent_action (InputLayer)     [(None, 3)]          0           []                               \n",
      "                                                                                                  \n",
      " next_latent_state (InputLayer)  [(None, 10)]        0           []                               \n",
      "                                                                                                  \n",
      " reward_function_input (Concate  (None, 23)          0           ['latent_state[0][0]',           \n",
      " nate)                                                            'latent_action[0][0]',          \n",
      "                                                                  'next_latent_state[0][0]']      \n",
      "                                                                                                  \n",
      " reward_network_base (Sequentia  (None, 64)          5696        ['reward_function_input[0][0]']  \n",
      " l)                                                                                               \n",
      "                                                                                                  \n",
      " reward_network_raw_output (Den  (None, 1)           65          ['reward_network_base[0][0]']    \n",
      " se)                                                                                              \n",
      "                                                                                                  \n",
      " reward (Reshape)               (None, 1)            0           ['reward_network_raw_output[0][0]\n",
      "                                                                 ']                               \n",
      "                                                                                                  \n",
      "==================================================================================================\n",
      "Total params: 5,761\n",
      "Trainable params: 5,761\n",
      "Non-trainable params: 0\n",
      "__________________________________________________________________________________________________\n",
      "Model: \"state_reconstruction_network\"\n",
      "_________________________________________________________________\n",
      " Layer (type)                Output Shape              Param #   \n",
      "=================================================================\n",
      " next_latent_state (InputLay  [(None, 10)]             0         \n",
      " er)                                                             \n",
      "                                                                 \n",
      " state_decoder_network_base   (None, 64)               4864      \n",
      " (Sequential)                                                    \n",
      "                                                                 \n",
      " state_decoder (Sequential)  (None, 2)                 130       \n",
      "                                                                 \n",
      "=================================================================\n",
      "Total params: 4,994\n",
      "Trainable params: 4,994\n",
      "Non-trainable params: 0\n",
      "_________________________________________________________________\n",
      "Model: \"steady_state_lipschitz_network\"\n",
      "__________________________________________________________________________________________________\n",
      " Layer (type)                   Output Shape         Param #     Connected to                     \n",
      "==================================================================================================\n",
      " latent_state (InputLayer)      [(None, 10)]         0           []                               \n",
      "                                                                                                  \n",
      " latent_action (InputLayer)     [(None, 3)]          0           []                               \n",
      "                                                                                                  \n",
      " next_latent_state (InputLayer)  [(None, 10)]        0           []                               \n",
      "                                                                                                  \n",
      " concatenate_16 (Concatenate)   (None, 23)           0           ['latent_state[0][0]',           \n",
      "                                                                  'latent_action[0][0]',          \n",
      "                                                                  'next_latent_state[0][0]']      \n",
      "                                                                                                  \n",
      " steady_state_network_base (Seq  (None, 64)          5696        ['concatenate_16[0][0]']         \n",
      " uential)                                                                                         \n",
      "                                                                                                  \n",
      " steady_state_lipschitz_network  (None, 1)           65          ['steady_state_network_base[0][0]\n",
      " _output (Dense)                                                 ']                               \n",
      "                                                                                                  \n",
      "==================================================================================================\n",
      "Total params: 5,761\n",
      "Trainable params: 5,761\n",
      "Non-trainable params: 0\n",
      "__________________________________________________________________________________________________\n",
      "Model: \"transition_loss_lipschitz_network\"\n",
      "__________________________________________________________________________________________________\n",
      " Layer (type)                   Output Shape         Param #     Connected to                     \n",
      "==================================================================================================\n",
      " state (InputLayer)             [(None, 2)]          0           []                               \n",
      "                                                                                                  \n",
      " action (InputLayer)            [(None, 3)]          0           []                               \n",
      "                                                                                                  \n",
      " latent_state (InputLayer)      [(None, 10)]         0           []                               \n",
      "                                                                                                  \n",
      " next_latent_state (InputLayer)  [(None, 10)]        0           []                               \n",
      "                                                                                                  \n",
      " concatenate_17 (Concatenate)   (None, 25)           0           ['state[0][0]',                  \n",
      "                                                                  'action[0][0]',                 \n",
      "                                                                  'latent_state[0][0]',           \n",
      "                                                                  'next_latent_state[0][0]']      \n",
      "                                                                                                  \n",
      " transition_loss_network_base (  (None, 64)          5824        ['concatenate_17[0][0]']         \n",
      " Sequential)                                                                                      \n",
      "                                                                                                  \n",
      " transition_loss_lipschitz_netw  (None, 1)           65          ['transition_loss_network_base[0]\n",
      " ork_output (Dense)                                              [0]']                            \n",
      "                                                                                                  \n",
      "==================================================================================================\n",
      "Total params: 5,889\n",
      "Trainable params: 5,889\n",
      "Non-trainable params: 0\n",
      "__________________________________________________________________________________________________\n",
      "WAE-MDP loaded\n"
     ]
    }
   ],
   "source": [
    "wae_model_path = 'saved_models/experiments/MountainCar-v0/model/'\n",
    "\n",
    "with open(os.path.join(wae_model_path, 'model_infos.json'), 'r') as f:\n",
    "    wae_data = json.load(f)\n",
    "    print(wae_data)\n",
    "\n",
    "wae_mdp = wasserstein_mdp.load(wae_model_path)\n",
    "\n",
    "print(\"WAE-MDP loaded\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 57,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "WAE-MDP at training step 232000\n",
      "Size of the latent state space: 1024\n",
      "Local reward loss: 0.0141763\n",
      "Local transition loss: 0.382323\n"
     ]
    }
   ],
   "source": [
    "print(\"WAE-MDP at training step {:d}\".format(eval(wae_data['training_step'])))\n",
    "print(\"Size of the latent state space: {:d}\".format(2 ** wae_mdp.latent_state_size))\n",
    "print('Local reward loss: {:.6g}'.format(eval(wae_data['local_reward_loss'])))\n",
    "print('Local transition loss: {:.6g}'.format(eval(wae_data['local_transition_loss'])))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "video_path = 'policy_videos/moutain_car_wae_distillation'\n",
    "\n",
    "def latent_labeling_fn(time_step):\n",
    "    latent_state = time_step.observation['latent_state']\n",
    "    return {\n",
    "        'goal': latent_state.numpy()[..., 0],\n",
    "    }\n",
    "\n",
    "with suite_gym.load('MountainCar-v0') as py_env:\n",
    "    py_env.reset()\n",
    "    tf_env = tf_py_environment.TFPyEnvironment(py_env)\n",
    "    original_state = tf_env.current_time_step().observation\n",
    "    \n",
    "    tf_env = wae_mdp.wrap_tf_environment(tf_env, labeling_functions['MountainCar-v0'])\n",
    "    policy =tf_env.wrap_latent_policy(wae_mdp.get_latent_policy(action_dtype=tf.int64))\n",
    "    \n",
    "    num_episodes=30\n",
    "    reward_metric = tf_metrics.AverageReturnMetric()\n",
    "    discounted_reward_metric = AverageDiscountedReturnMetric(\n",
    "        gamma=.99, reward_scale=wae_mdp._dynamic_reward_scaling)\n",
    "    video_observer = video.VideoEmbeddingObserver(\n",
    "        py_env, video_path, num_episodes=num_episodes, labeling_function=latent_labeling_fn)\n",
    "    \n",
    "    dynamic_episode_driver.DynamicEpisodeDriver(\n",
    "        tf_env, policy, num_episodes=num_episodes,\n",
    "        observers=[\n",
    "            reward_metric,\n",
    "            discounted_reward_metric,\n",
    "            video_observer,\n",
    "        ]).run()\n",
    "    \n",
    "\n",
    "tf.print('avg. episode return:', reward_metric.result())\n",
    "tf.print('avg. discounted (scaled) return:', discounted_reward_metric.result())\n",
    "embed_mp4(video_observer.file_name)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 181,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Local reward loss: 0.014\n",
      "Local transition loss: 0.38\n",
      "Transition/reward model generation\n",
      "Time to generate the model: 1.2e+02 sec\n"
     ]
    }
   ],
   "source": [
    "frequency_estimation = True\n",
    "\n",
    "_latent_reward_fn = lambda latent_state, latent_action, next_latent_state: \\\n",
    "    wae_mdp._dynamic_reward_scaling * wae_mdp.reward_distribution(\n",
    "        latent_state=tf.cast(latent_state, dtype=tf.float32),\n",
    "        latent_action=tf.cast(latent_action, dtype=tf.float32),\n",
    "        next_latent_state=tf.cast(next_latent_state, dtype=tf.float32),\n",
    "    ).mode() \n",
    "# as the distribution is deterministic, taking the mode\n",
    "# allows to retrieve the Dirac impulsion \n",
    "\n",
    "_latent_transition_fn = lambda latent_state, latent_action: \\\n",
    "        wae_mdp.discrete_latent_transition(\n",
    "            tf.cast(latent_state, tf.float32),\n",
    "            tf.cast(latent_action, tf.float32))\n",
    "\n",
    "#  write the transition/reward functions to tensors,\n",
    "#  to formally check the values in an efficient way\n",
    "print('Transition/reward model generation')\n",
    "start = time.time()\n",
    "\n",
    "if frequency_estimation:\n",
    "    #  compute the transition tensor by frequency estimation and use the\n",
    "    #  latent transition function learned during the WAE optimization as backup function\n",
    "    with suite_gym.load(\n",
    "        'MountainCar-v0',\n",
    "        env_wrappers=[lambda env: perturbed_env.PerturbedEnvironment(env, .75)]\n",
    "    ) as py_env:\n",
    "        py_env.reset()\n",
    "        tf_env = tf_py_environment.TFPyEnvironment(py_env)\n",
    "        latent_transition_fn = model.estimate_latent_transition_function_from_samples(\n",
    "            environment=tf_env,\n",
    "            n_steps=100000,\n",
    "            state_embedding_function=wae_mdp.state_embedding_function,\n",
    "            action_embedding_function=wae_mdp.action_embedding_function,\n",
    "            labeling_function=labeling_functions['MountainCar-v0'],\n",
    "            latent_state_size=wae_mdp.latent_state_size,\n",
    "            number_of_discrete_actions=wae_mdp.number_of_discrete_actions,\n",
    "            latent_policy=wae_mdp.get_latent_policy(action_dtype=tf.int64),\n",
    "            backup_transition_fn=_latent_transition_fn)\n",
    "else:\n",
    "    latent_transition_fn = model.TransitionFunctionCopy(\n",
    "        num_states=tf.cast(tf.pow(2, wae_mdp.latent_state_size), dtype=tf.int32),\n",
    "        num_actions=wae_mdp.number_of_discrete_actions,\n",
    "        transition_function=_latent_transition_fn,\n",
    "        epsilon=0.)\n",
    "\n",
    "latent_reward_fn = model.RewardFunctionCopy(\n",
    "    num_states=tf.cast(tf.pow(2, wae_mdp.latent_state_size), dtype=tf.int32),\n",
    "    num_actions=wae_mdp.number_of_discrete_actions,\n",
    "    reward_function=_latent_reward_fn,\n",
    "    transition_function=_latent_transition_fn,\n",
    "    epsilon=1e-6)\n",
    "\n",
    "end = time.time() - start\n",
    "\n",
    "print(\"Time to generate the model: {:.2g} sec\".format(end))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 182,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Value difference: 2.83714\n",
      "Time to compute the value difference: 1.27584 sec\n"
     ]
    }
   ],
   "source": [
    "start = time.time()\n",
    "\n",
    "latent_mdp_values = compute_values_from_initial_distribution(\n",
    "    latent_state_size=wae_mdp.latent_state_size,\n",
    "    atomic_prop_dims=wae_mdp.atomic_prop_dims,\n",
    "    original_state=original_state,\n",
    "    number_of_discrete_actions=wae_mdp.number_of_discrete_actions,\n",
    "    latent_policy=wae_mdp.get_latent_policy(action_dtype=tf.int64),\n",
    "    latent_transition_fn=latent_transition_fn,\n",
    "    latent_reward_function=latent_reward_fn,\n",
    "    epsilon=1e-6,\n",
    "    gamma=.99,\n",
    "    stochastic_state_embedding=lambda original_state: tfd.Independent(\n",
    "        tfd.Deterministic(loc=wae_mdp.state_embedding_function(\n",
    "            original_state,\n",
    "            ergodic_batched_labeling_function(\n",
    "                labeling_functions['MountainCar-v0']\n",
    "            )(original_state))),\n",
    "        reinterpreted_batch_ndims=1)\n",
    ")\n",
    "\n",
    "value_difference = tf.abs(discounted_reward_metric.result() - latent_mdp_values)\n",
    "\n",
    "tf.print(\"Value difference: {:.6g}\".format(value_difference))\n",
    "\n",
    "end = time.time() - start\n",
    "\n",
    "print(\"Time to compute the value difference: {:2g} sec\".format(end))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Time-to-failure property:\n",
    "$\\neg\\mathsf{Goal} \\, \\mathcal{U} \\, \\mathsf{Reset}$\n",
    "where $\\mathsf{Goal} \\in \\ell\\left(s\\right)$ iff the car reaches the top of the mountain, at yellow flag position."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 183,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "property values: 0\n",
      "Time to compute the values of the property: 0.528304 sec\n"
     ]
    }
   ],
   "source": [
    "goal_test_fn = lambda latent_state: tf.cast(latent_state[..., 0], tf.bool)\n",
    "reset_state_test_fn = lambda latent_state: is_reset_state(latent_state, wae_mdp.atomic_prop_dims)\n",
    "\n",
    "absorbing_states = lambda latent_state: tf.logical_or(\n",
    "    goal_test_fn(latent_state),\n",
    "    reset_state_test_fn(latent_state))\n",
    "\n",
    "state_space = binary_latent_space(wae_mdp.latent_state_size, dtype=tf.float32)\n",
    "# give a reward of 1 at the end of an episode (i.e., when transitioning to the reset state)\n",
    "reward_objective = tf.ones(\n",
    "    shape=(tf.pow(2, wae_mdp.latent_state_size),\n",
    "           wae_mdp.number_of_discrete_actions,\n",
    "           tf.pow(2, wae_mdp.latent_state_size))\n",
    ") * tf.cast(reset_state_test_fn(state_space), tf.float32)\n",
    "\n",
    "reward_objective_fn = namedtuple('reward_fn', ['to_dense'])(lambda: reward_objective)\n",
    "\n",
    "start = time.time()\n",
    "\n",
    "latent_mdp_values = compute_values_from_initial_distribution(\n",
    "    latent_state_size=wae_mdp.latent_state_size,\n",
    "    atomic_prop_dims=wae_mdp.atomic_prop_dims,\n",
    "    original_state=original_state,\n",
    "    number_of_discrete_actions=wae_mdp.number_of_discrete_actions,\n",
    "    latent_policy=wae_mdp.get_latent_policy(action_dtype=tf.int64),\n",
    "    latent_transition_fn=latent_transition_fn,\n",
    "    latent_reward_function=reward_objective_fn,\n",
    "    epsilon=1e-6,\n",
    "    gamma=.99,\n",
    "    stochastic_state_embedding=lambda original_state: tfd.Independent(\n",
    "        tfd.Deterministic(\n",
    "            loc=wae_mdp.state_embedding_function(\n",
    "                original_state,\n",
    "                ergodic_batched_labeling_function(\n",
    "                    labeling_functions['MountainCar-v0']\n",
    "                )(original_state))),\n",
    "        reinterpreted_batch_ndims=1),\n",
    "    absorbing_states=absorbing_states)\n",
    "\n",
    "tf.print(\"property values: {:.6g}\".format(latent_mdp_values))\n",
    "\n",
    "end = time.time() - start\n",
    "print(\"Time to compute the values of the property: {:2g} sec\".format(end))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Episodic reachability:\n",
    "$\\neg\\mathsf{Reset} \\, \\mathcal{U} \\, \\mathsf{Goal}$\n",
    "where $\\mathsf{Goal} \\in \\ell\\left(s\\right)$ iff the car reaches the top of the mountain, at yellow flag position."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 185,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "property values: 0.37516\n",
      "Time to compute the values of the property: 0.614547 sec\n"
     ]
    }
   ],
   "source": [
    "reset_state_test_fn = lambda latent_state: is_reset_state(latent_state, wae_mdp.atomic_prop_dims)\n",
    "goal_test_fn = lambda latent_state: tf.cast(latent_state[..., 0], tf.bool)\n",
    "\n",
    "absorbing_states = lambda latent_state: tf.logical_or(\n",
    "    reset_state_test_fn(latent_state),\n",
    "    goal_test_fn(latent_state))\n",
    "\n",
    "state_space = binary_latent_space(wae_mdp.latent_state_size, dtype=tf.float32)\n",
    "# give a reward of 1 when the goal is reached\n",
    "reward_objective = tf.ones(\n",
    "    shape=(tf.pow(2, wae_mdp.latent_state_size),\n",
    "           wae_mdp.number_of_discrete_actions,\n",
    "           tf.pow(2, wae_mdp.latent_state_size))\n",
    ") * tf.cast(goal_test_fn(state_space), tf.float32)\n",
    "\n",
    "reward_objective_fn = namedtuple('reward_fn', ['to_dense'])(lambda: reward_objective)\n",
    "\n",
    "start = time.time()\n",
    "\n",
    "latent_mdp_values = compute_values_from_initial_distribution(\n",
    "    latent_state_size=wae_mdp.latent_state_size,\n",
    "    atomic_prop_dims=wae_mdp.atomic_prop_dims,\n",
    "    original_state=original_state,\n",
    "    number_of_discrete_actions=wae_mdp.number_of_discrete_actions,\n",
    "    latent_policy=wae_mdp.get_latent_policy(action_dtype=tf.int64),\n",
    "    latent_transition_fn=latent_transition_fn,\n",
    "    latent_reward_function=reward_objective_fn,\n",
    "    epsilon=1e-6,\n",
    "    gamma=.99,\n",
    "    stochastic_state_embedding=lambda original_state: tfd.Independent(\n",
    "        tfd.Deterministic(\n",
    "            loc=wae_mdp.state_embedding_function(\n",
    "                original_state,\n",
    "                ergodic_batched_labeling_function(\n",
    "                    labeling_functions['MountainCar-v0']\n",
    "                )(original_state))),\n",
    "        reinterpreted_batch_ndims=1),\n",
    "    absorbing_states=absorbing_states)\n",
    "\n",
    "tf.print(\"property values: {:.6g}\".format(latent_mdp_values))\n",
    "\n",
    "end = time.time() - start\n",
    "print(\"Time to compute the values of the property: {:2g} sec\".format(end))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Acrobot"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## RL Policy (DQN, trained in an environment with random initial states)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "video_path = 'policy_videos/acrobot_dqn'\n",
    "\n",
    "with suite_gym.load('Acrobot-v1') as py_env:\n",
    "    py_env.reset()\n",
    "    tf_env = tf_py_environment.TFPyEnvironment(py_env)\n",
    "    \n",
    "    display_state_space(py_env)\n",
    "    display_action_space(py_env)\n",
    "\n",
    "    policy_dir = '../reinforcement_learning/saves/AcrobotRandomInit-v1/dqn_policy'\n",
    "    policy = SavedTFPolicy(policy_dir)\n",
    "    num_episodes=30\n",
    "\n",
    "    reward_metric = tf_metrics.AverageReturnMetric()\n",
    "    \n",
    "    video_observer = video.VideoEmbeddingObserver(\n",
    "        py_env, video_path, num_episodes=num_episodes)\n",
    "    dynamic_episode_driver.DynamicEpisodeDriver(\n",
    "        tf_env,\n",
    "        policy,\n",
    "        num_episodes=num_episodes,\n",
    "        observers=[\n",
    "            reward_metric,\n",
    "            video_observer,\n",
    "        ]).run()\n",
    "\n",
    "    tf.print('avg. episode return:', reward_metric.result())\n",
    "\n",
    "embed_mp4(video_observer.file_name)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Distilled policy"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 58,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "{'self': '<wasserstein_mdp.WassersteinMarkovDecisionProcess object at 0x2b65a1058fd0>', 'state_shape': '(6,)', 'action_shape': '(3,)', 'reward_shape': '(1,)', 'label_shape': '(7,)', 'discretize_action_space': 'False', 'state_encoder_network': \"ModelArchitecture(hidden_units=[512, 512], activation='leaky_relu', name='state_encoder_network_base')\", 'action_decoder_network': \"ModelArchitecture(hidden_units=[512, 512], activation='leaky_relu', name='action_decoder_network_base')\", 'transition_network': \"ModelArchitecture(hidden_units=[512, 512], activation='leaky_relu', name='transition_network_base')\", 'reward_network': \"ModelArchitecture(hidden_units=[512, 512], activation='leaky_relu', name='reward_network_base')\", 'decoder_network': \"ModelArchitecture(hidden_units=[512, 512], activation='leaky_relu', name='state_decoder_network_base')\", 'latent_policy_network': \"ModelArchitecture(hidden_units=[512, 512], activation='leaky_relu', name='discrete_policy_network_base')\", 'steady_state_lipschitz_network': \"ModelArchitecture(hidden_units=[512, 512], activation='leaky_relu', name='steady_state_network_base')\", 'transition_loss_lipschitz_network': \"ModelArchitecture(hidden_units=[512, 512], activation='leaky_relu', name='transition_loss_network_base')\", 'latent_state_size': '13', 'number_of_discrete_actions': '16', 'action_encoder_network': \"ModelArchitecture(hidden_units=[512, 512], activation='leaky_relu', name='action_encoder_network_base')\", 'state_encoder_pre_processing_network': 'None', 'state_decoder_pre_processing_network': 'None', 'time_stacked_states': 'False', 'state_encoder_temperature': '0.6666666666666666', 'state_prior_temperature': '0.1', 'action_encoder_temperature': '0.99', 'latent_policy_temperature': '0.5', 'wasserstein_regularizer_scale_factor': 'WassersteinRegularizerScaleFactor(global_scaling=10.0, global_gradient_penalty_multiplier=20.0, steady_state_scaling=10.0, steady_state_gradient_penalty_multiplier=None, local_transition_loss_scaling=10.0, local_transition_loss_gradient_penalty_multiplier=None)', 'encoder_temperature_decay_rate': '0.0', 'prior_temperature_decay_rate': '0.0', 'reset_state_label': 'True', 'autoencoder_optimizer': 'None', 'wasserstein_regularizer_optimizer': 'None', 'entropy_regularizer_scale_factor': '0.0', 'entropy_regularizer_decay_rate': '0.0', 'entropy_regularizer_scale_factor_min_value': '0.0', 'importance_sampling_exponent': '0.4', 'importance_sampling_exponent_growth_rate': '1e-05', 'time_stacked_lstm_units': '128', 'reward_bounds': 'None', 'latent_stationary_network': 'None', 'action_entropy_regularizer_scaling': '0.0', 'enforce_upper_bound': 'False', 'squared_wasserstein': 'True', 'n_critic': '20', 'trainable_prior': 'False', 'state_encoder_type': 'EncodingType.DETERMINISTIC', 'policy_based_decoding': 'False', 'deterministic_state_embedding': 'True', 'state_encoder_softclipping': 'True', 'args': '()', 'kwargs': \"{'evaluation_window_size': 0}\", '__class__': \"<class 'wasserstein_mdp.WassersteinMarkovDecisionProcess'>\", 'eval_policy': '-70.2', 'local_reward_loss': '0.034769785', 'local_transition_loss': '0.64947826', 'training_step': '430000'}\n",
      "Model: \"state_encoder\"\n",
      "_________________________________________________________________\n",
      " Layer (type)                Output Shape              Param #   \n",
      "=================================================================\n",
      " state (InputLayer)          [(None, 6)]               0         \n",
      "                                                                 \n",
      " state_encoder_body (Sequent  (None, 510)              265214    \n",
      " ial)                                                            \n",
      "                                                                 \n",
      " dense_34 (Dense)            (None, 5)                 2555      \n",
      "                                                                 \n",
      "=================================================================\n",
      "Total params: 267,769\n",
      "Trainable params: 267,769\n",
      "Non-trainable params: 0\n",
      "_________________________________________________________________\n",
      "No action encoder\n",
      "Model: \"autoregressive_transition_network\"\n",
      "__________________________________________________________________________________________________\n",
      " Layer (type)                   Output Shape         Param #     Connected to                     \n",
      "==================================================================================================\n",
      " logistic_layer_input (InputLay  [(None, 16)]        0           []                               \n",
      " er)                                                                                              \n",
      "                                                                                                  \n",
      " sequential_logistic_distributi  (None, 13)          0           ['logistic_layer_input[0][0]']   \n",
      " on_layer (Sequential)                                                                            \n",
      "                                                                                                  \n",
      " autoregressive_transform (Auto  ((None, 13),        293086      ['sequential_logistic_distributio\n",
      " regressiveTransform)            (None, 13))                     n_layer[0][0]',                  \n",
      "                                                                  'logistic_layer_input[0][0]']   \n",
      "                                                                                                  \n",
      " autoregressive_network_18 (Aut  (None, 13, 1)       293085      []                               \n",
      " oregressiveNetwork)                                                                              \n",
      "                                                                                                  \n",
      "==================================================================================================\n",
      "Total params: 293,086\n",
      "Trainable params: 293,085\n",
      "Non-trainable params: 1\n",
      "__________________________________________________________________________________________________\n",
      "Model: \"latent_stationary_network\"\n",
      "_________________________________________________________________\n",
      " Layer (type)                Output Shape              Param #   \n",
      "=================================================================\n",
      " logistic_layer_input (Input  [(None, 0)]              0         \n",
      " Layer)                                                          \n",
      "                                                                 \n",
      " sequential_logistic_distrib  (None, 8)                0         \n",
      " ution_layer (Sequential)                                        \n",
      "                                                                 \n",
      " autoregressive_transform (A  ((None, 8),              271369    \n",
      " utoregressiveTransform)      (None, 8))                         \n",
      "                                                                 \n",
      " autoregressive_network_19 (  (None, 8, 1)             271368    \n",
      " AutoregressiveNetwork)                                          \n",
      "                                                                 \n",
      "=================================================================\n",
      "Total params: 271,374\n",
      "Trainable params: 271,368\n",
      "Non-trainable params: 6\n",
      "_________________________________________________________________\n",
      "Model: \"latent_policy_network\"\n",
      "_________________________________________________________________\n",
      " Layer (type)                Output Shape              Param #   \n",
      "=================================================================\n",
      " latent_state (InputLayer)   [(None, 13)]              0         \n",
      "                                                                 \n",
      " discrete_policy_network_bas  (None, 512)              269824    \n",
      " e (Sequential)                                                  \n",
      "                                                                 \n",
      " latent_policy_categorical_l  (None, 3)                1539      \n",
      " ogits (Dense)                                                   \n",
      "                                                                 \n",
      "=================================================================\n",
      "Total params: 271,363\n",
      "Trainable params: 271,363\n",
      "Non-trainable params: 0\n",
      "_________________________________________________________________\n",
      "Model: \"reward_network\"\n",
      "__________________________________________________________________________________________________\n",
      " Layer (type)                   Output Shape         Param #     Connected to                     \n",
      "==================================================================================================\n",
      " latent_state (InputLayer)      [(None, 13)]         0           []                               \n",
      "                                                                                                  \n",
      " latent_action (InputLayer)     [(None, 3)]          0           []                               \n",
      "                                                                                                  \n",
      " next_latent_state (InputLayer)  [(None, 13)]        0           []                               \n",
      "                                                                                                  \n",
      " reward_function_input (Concate  (None, 29)          0           ['latent_state[0][0]',           \n",
      " nate)                                                            'latent_action[0][0]',          \n",
      "                                                                  'next_latent_state[0][0]']      \n",
      "                                                                                                  \n",
      " reward_network_base (Sequentia  (None, 512)         278016      ['reward_function_input[0][0]']  \n",
      " l)                                                                                               \n",
      "                                                                                                  \n",
      " reward_network_raw_output (Den  (None, 1)           513         ['reward_network_base[0][0]']    \n",
      " se)                                                                                              \n",
      "                                                                                                  \n",
      " reward (Reshape)               (None, 1)            0           ['reward_network_raw_output[0][0]\n",
      "                                                                 ']                               \n",
      "                                                                                                  \n",
      "==================================================================================================\n",
      "Total params: 278,529\n",
      "Trainable params: 278,529\n",
      "Non-trainable params: 0\n",
      "__________________________________________________________________________________________________\n",
      "Model: \"state_reconstruction_network\"\n",
      "_________________________________________________________________\n",
      " Layer (type)                Output Shape              Param #   \n",
      "=================================================================\n",
      " next_latent_state (InputLay  [(None, 13)]             0         \n",
      " er)                                                             \n",
      "                                                                 \n",
      " state_decoder_network_base   (None, 512)              269824    \n",
      " (Sequential)                                                    \n",
      "                                                                 \n",
      " state_decoder (Sequential)  (None, 6)                 3078      \n",
      "                                                                 \n",
      "=================================================================\n",
      "Total params: 272,902\n",
      "Trainable params: 272,902\n",
      "Non-trainable params: 0\n",
      "_________________________________________________________________\n",
      "Model: \"steady_state_lipschitz_network\"\n",
      "__________________________________________________________________________________________________\n",
      " Layer (type)                   Output Shape         Param #     Connected to                     \n",
      "==================================================================================================\n",
      " latent_state (InputLayer)      [(None, 13)]         0           []                               \n",
      "                                                                                                  \n",
      " latent_action (InputLayer)     [(None, 3)]          0           []                               \n",
      "                                                                                                  \n",
      " next_latent_state (InputLayer)  [(None, 13)]        0           []                               \n",
      "                                                                                                  \n",
      " concatenate_18 (Concatenate)   (None, 29)           0           ['latent_state[0][0]',           \n",
      "                                                                  'latent_action[0][0]',          \n",
      "                                                                  'next_latent_state[0][0]']      \n",
      "                                                                                                  \n",
      " steady_state_network_base (Seq  (None, 512)         278016      ['concatenate_18[0][0]']         \n",
      " uential)                                                                                         \n",
      "                                                                                                  \n",
      " steady_state_lipschitz_network  (None, 1)           513         ['steady_state_network_base[0][0]\n",
      " _output (Dense)                                                 ']                               \n",
      "                                                                                                  \n",
      "==================================================================================================\n",
      "Total params: 278,529\n",
      "Trainable params: 278,529\n",
      "Non-trainable params: 0\n",
      "__________________________________________________________________________________________________\n",
      "Model: \"transition_loss_lipschitz_network\"\n",
      "__________________________________________________________________________________________________\n",
      " Layer (type)                   Output Shape         Param #     Connected to                     \n",
      "==================================================================================================\n",
      " state (InputLayer)             [(None, 6)]          0           []                               \n",
      "                                                                                                  \n",
      " action (InputLayer)            [(None, 3)]          0           []                               \n",
      "                                                                                                  \n",
      " latent_state (InputLayer)      [(None, 13)]         0           []                               \n",
      "                                                                                                  \n",
      " next_latent_state (InputLayer)  [(None, 13)]        0           []                               \n",
      "                                                                                                  \n",
      " concatenate_19 (Concatenate)   (None, 35)           0           ['state[0][0]',                  \n",
      "                                                                  'action[0][0]',                 \n",
      "                                                                  'latent_state[0][0]',           \n",
      "                                                                  'next_latent_state[0][0]']      \n",
      "                                                                                                  \n",
      " transition_loss_network_base (  (None, 512)         281088      ['concatenate_19[0][0]']         \n",
      " Sequential)                                                                                      \n",
      "                                                                                                  \n",
      " transition_loss_lipschitz_netw  (None, 1)           513         ['transition_loss_network_base[0]\n",
      " ork_output (Dense)                                              [0]']                            \n",
      "                                                                                                  \n",
      "==================================================================================================\n",
      "Total params: 281,601\n",
      "Trainable params: 281,601\n",
      "Non-trainable params: 0\n",
      "__________________________________________________________________________________________________\n",
      "WAE-MDP loaded\n"
     ]
    }
   ],
   "source": [
    "wae_model_path = 'saved_models/experiments/Acrobot-v1/model/'\n",
    "\n",
    "with open(os.path.join(wae_model_path, 'model_infos.json'), 'r') as f:\n",
    "    wae_data = json.load(f)\n",
    "    print(wae_data)\n",
    "\n",
    "wae_mdp = wasserstein_mdp.load(wae_model_path)\n",
    "\n",
    "print(\"WAE-MDP loaded\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 59,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "WAE-MDP at training step 430000\n",
      "Size of the latent state space: 8192\n",
      "Local reward loss: 0.0347698\n",
      "Local transition loss: 0.649478\n"
     ]
    }
   ],
   "source": [
    "print(\"WAE-MDP at training step {:d}\".format(eval(wae_data['training_step'])))\n",
    "print(\"Size of the latent state space: {:d}\".format(2 ** wae_mdp.latent_state_size))\n",
    "print('Local reward loss: {:.6g}'.format(eval(wae_data['local_reward_loss'])))\n",
    "print('Local transition loss: {:.6g}'.format(eval(wae_data['local_transition_loss'])))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "video_path = 'policy_videos/acrobot_wae_distillation'\n",
    "\n",
    "def latent_labeling_fn(time_step):\n",
    "    latent_state = time_step.observation['latent_state']\n",
    "    return {\n",
    "        'goal': latent_state.numpy()[..., 0],\n",
    "    }\n",
    "\n",
    "with suite_gym.load('Acrobot-v1') as py_env:\n",
    "    py_env.reset()\n",
    "    tf_env = tf_py_environment.TFPyEnvironment(py_env)\n",
    "    original_state = tf_env.current_time_step().observation\n",
    "    \n",
    "    tf_env = wae_mdp.wrap_tf_environment(tf_env, labeling_functions['Acrobot-v1'])\n",
    "    policy =tf_env.wrap_latent_policy(wae_mdp.get_latent_policy(action_dtype=tf.int64))\n",
    "    \n",
    "    num_episodes=30\n",
    "    reward_metric = tf_metrics.AverageReturnMetric()\n",
    "    discounted_reward_metric = AverageDiscountedReturnMetric(\n",
    "        gamma=.99, reward_scale=wae_mdp._dynamic_reward_scaling)\n",
    "    video_observer = video.VideoEmbeddingObserver(\n",
    "        py_env, video_path, num_episodes=num_episodes, labeling_function=latent_labeling_fn)\n",
    "    \n",
    "    dynamic_episode_driver.DynamicEpisodeDriver(\n",
    "        tf_env, policy, num_episodes=num_episodes,\n",
    "        observers=[\n",
    "            reward_metric,\n",
    "            discounted_reward_metric,\n",
    "            video_observer,\n",
    "        ]).run()\n",
    "    \n",
    "\n",
    "tf.print('avg. episode return:', reward_metric.result())\n",
    "tf.print('avg. discounted (scaled) return:', discounted_reward_metric.result())\n",
    "embed_mp4(video_observer.file_name)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 194,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Local reward loss: 0.035\n",
      "Local transition loss: 0.65\n",
      "Transition/reward model generation\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "2022-05-11 13:39:14.808659: W tensorflow/core/framework/cpu_allocator_impl.cc:82] Allocation of 4831838208 exceeds 10% of free system memory.\n",
      "2022-05-11 13:39:16.185816: W tensorflow/core/framework/cpu_allocator_impl.cc:82] Allocation of 4831838208 exceeds 10% of free system memory.\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Time to generate the model: 1.9e+02 sec\n"
     ]
    }
   ],
   "source": [
    "frequency_estimation = False\n",
    "\n",
    "_latent_reward_fn = lambda latent_state, latent_action, next_latent_state: \\\n",
    "    wae_mdp._dynamic_reward_scaling * wae_mdp.reward_distribution(\n",
    "        latent_state=tf.cast(latent_state, dtype=tf.float32),\n",
    "        latent_action=tf.cast(latent_action, dtype=tf.float32),\n",
    "        next_latent_state=tf.cast(next_latent_state, dtype=tf.float32),\n",
    "    ).mode() \n",
    "# as the distribution is deterministic, taking the mode\n",
    "# allows to retrieve the Dirac impulsion \n",
    "\n",
    "_latent_transition_fn = lambda latent_state, latent_action: \\\n",
    "        wae_mdp.discrete_latent_transition(\n",
    "            tf.cast(latent_state, tf.float32),\n",
    "            tf.cast(latent_action, tf.float32))\n",
    "\n",
    "print('Local reward loss: {:.2g}'.format(eval(wae_data['local_reward_loss'])))\n",
    "print('Local transition loss: {:.2g}'.format(eval(wae_data['local_transition_loss'])))\n",
    "\n",
    "#  write the transition/reward functions to tensors,\n",
    "#  to formally check the values in an efficient way\n",
    "print('Transition/reward model generation')\n",
    "\n",
    "start = time.time()\n",
    "\n",
    "if frequency_estimation:\n",
    "    #  compute the transition tensor by frequency estimation and use the\n",
    "    #  latent transition function learned during the WAE optimization as backup function\n",
    "    with suite_gym.load(\n",
    "        'Acrobot-v1',\n",
    "        env_wrappers=[lambda env: perturbed_env.PerturbedEnvironment(env, .75)]\n",
    "    ) as py_env:\n",
    "        py_env.reset()\n",
    "        tf_env = tf_py_environment.TFPyEnvironment(py_env)\n",
    "        latent_transition_fn = model.estimate_latent_transition_function_from_samples(\n",
    "            environment=tf_env,\n",
    "            n_steps=100000,\n",
    "            state_embedding_function=wae_mdp.state_embedding_function,\n",
    "            action_embedding_function=wae_mdp.action_embedding_function,\n",
    "            labeling_function=labeling_functions['Acrobot-v1'],\n",
    "            latent_state_size=wae_mdp.latent_state_size,\n",
    "            number_of_discrete_actions=wae_mdp.number_of_discrete_actions,\n",
    "            latent_policy=wae_mdp.get_latent_policy(action_dtype=tf.int64),\n",
    "            backup_transition_fn=_latent_transition_fn)\n",
    "else:\n",
    "    latent_transition_fn = model.TransitionFunctionCopy(\n",
    "        num_states=tf.cast(tf.pow(2, wae_mdp.latent_state_size), dtype=tf.int32),\n",
    "        num_actions=wae_mdp.number_of_discrete_actions,\n",
    "        transition_function=_latent_transition_fn,\n",
    "        epsilon=0.)\n",
    "\n",
    "latent_reward_fn = model.RewardFunctionCopy(\n",
    "    num_states=tf.cast(tf.pow(2, wae_mdp.latent_state_size), dtype=tf.int32),\n",
    "    num_actions=wae_mdp.number_of_discrete_actions,\n",
    "    reward_function=_latent_reward_fn,\n",
    "    transition_function=_latent_transition_fn,\n",
    "    epsilon=1e-6)\n",
    "\n",
    "end = time.time() - start\n",
    "\n",
    "print(\"Time to generate the model: {:.2f} sec\".format(end))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 197,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Value difference: 2.22006\n",
      "Time to compute the value difference: 207.473 sec\n"
     ]
    }
   ],
   "source": [
    "start = time.time()\n",
    "\n",
    "with tf.device('/CPU:0'):\n",
    "    latent_mdp_values = compute_values_from_initial_distribution(\n",
    "        latent_state_size=wae_mdp.latent_state_size,\n",
    "        atomic_prop_dims=wae_mdp.atomic_prop_dims,\n",
    "        original_state=original_state,\n",
    "        number_of_discrete_actions=wae_mdp.number_of_discrete_actions,\n",
    "        latent_policy=wae_mdp.get_latent_policy(action_dtype=tf.int64),\n",
    "        latent_transition_fn=latent_transition_fn,\n",
    "        latent_reward_function=latent_reward_fn,\n",
    "        epsilon=1e-6,\n",
    "        gamma=.99,\n",
    "        stochastic_state_embedding=lambda original_state: tfd.Independent(\n",
    "            tfd.Deterministic(loc=wae_mdp.state_embedding_function(\n",
    "                original_state,\n",
    "                ergodic_batched_labeling_function(\n",
    "                    labeling_functions['Acrobot-v1']\n",
    "                )(original_state))),\n",
    "            reinterpreted_batch_ndims=1)\n",
    "    )\n",
    "\n",
    "value_difference = tf.abs(discounted_reward_metric.result() - latent_mdp_values)\n",
    "\n",
    "tf.print(\"Value difference: {:.6g}\".format(value_difference))\n",
    "\n",
    "end = time.time() - start\n",
    "\n",
    "print(\"Time to compute the value difference: {:2f} sec\".format(end))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Time-to-failure property:\n",
    "$\\neg\\mathsf{Goal} \\, \\mathcal{U} \\, \\mathsf{Reset}$"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 203,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "property values: 0.0021911\n",
      "Time to compute the values of the property: 5.192409 sec\n"
     ]
    }
   ],
   "source": [
    "goal_test_fn = lambda latent_state: tf.cast(latent_state[..., 0], tf.bool)\n",
    "reset_state_test_fn = lambda latent_state: is_reset_state(latent_state, wae_mdp.atomic_prop_dims)\n",
    "\n",
    "absorbing_states = lambda latent_state: tf.logical_or(\n",
    "    goal_test_fn(latent_state),\n",
    "    reset_state_test_fn(latent_state))\n",
    "\n",
    "state_space = binary_latent_space(wae_mdp.latent_state_size, dtype=tf.float32)\n",
    "with tf.device('/CPU:0'):\n",
    "    # give a reward of 1 at the end of an episode (i.e., when transitioning to the reset state)\n",
    "    reward_objective = tf.ones(\n",
    "        shape=(tf.pow(2, wae_mdp.latent_state_size),\n",
    "               wae_mdp.number_of_discrete_actions,\n",
    "               tf.pow(2, wae_mdp.latent_state_size))\n",
    "    ) * tf.cast(reset_state_test_fn(state_space), tf.float32)\n",
    "\n",
    "reward_objective_fn = namedtuple('reward_fn', ['to_dense'])(lambda: reward_objective)\n",
    "\n",
    "start = time.time()\n",
    "\n",
    "with tf.device('/CPU:0'):\n",
    "    latent_mdp_values = compute_values_from_initial_distribution(\n",
    "        latent_state_size=wae_mdp.latent_state_size,\n",
    "        atomic_prop_dims=wae_mdp.atomic_prop_dims,\n",
    "        original_state=original_state,\n",
    "        number_of_discrete_actions=wae_mdp.number_of_discrete_actions,\n",
    "        latent_policy=wae_mdp.get_latent_policy(action_dtype=tf.int64),\n",
    "        latent_transition_fn=latent_transition_fn,\n",
    "        latent_reward_function=reward_objective_fn,\n",
    "        epsilon=1e-6,\n",
    "        gamma=.99,\n",
    "        stochastic_state_embedding=lambda original_state: tfd.Independent(\n",
    "            tfd.Deterministic(\n",
    "                loc=wae_mdp.state_embedding_function(\n",
    "                    original_state,\n",
    "                    ergodic_batched_labeling_function(\n",
    "                        labeling_functions['Acrobot-v1']\n",
    "                    )(original_state))),\n",
    "            reinterpreted_batch_ndims=1),\n",
    "        absorbing_states=absorbing_states)\n",
    "\n",
    "tf.print(\"property values: {:.6g}\".format(latent_mdp_values))\n",
    "\n",
    "end = time.time() - start\n",
    "print(\"Time to compute the values of the property: {:2f} sec\".format(end))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 207,
   "metadata": {},
   "outputs": [],
   "source": [
    "del latent_transition_fn\n",
    "del latent_reward_fn\n",
    "del reward_objective_fn"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Pendulum"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## RL policy (SAC, trained in an environment with random initial states)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "video_path = 'policy_videos/pendulum_sac'\n",
    "\n",
    "with suite_gym.load('Pendulum-v1') as py_env:\n",
    "    py_env.reset()\n",
    "    tf_env = tf_py_environment.TFPyEnvironment(py_env)\n",
    "    \n",
    "    display_state_space(py_env)\n",
    "    display_action_space(py_env)\n",
    "\n",
    "    policy_dir = '../reinforcement_learning/saves/PendulumRandomInit-v0/sac_policy'\n",
    "    policy = SavedTFPolicy(policy_dir)\n",
    "    num_episodes=30\n",
    "\n",
    "    reward_metric = tf_metrics.AverageReturnMetric()\n",
    "    \n",
    "    video_observer = video.VideoEmbeddingObserver(\n",
    "        py_env, video_path, num_episodes=num_episodes)\n",
    "    dynamic_episode_driver.DynamicEpisodeDriver(\n",
    "        tf_env,\n",
    "        policy,\n",
    "        num_episodes=num_episodes,\n",
    "        observers=[\n",
    "            reward_metric,\n",
    "            video_observer,\n",
    "        ]).run()\n",
    "\n",
    "    tf.print('avg. episode return:', reward_metric.result())\n",
    "\n",
    "embed_mp4(video_observer.file_name)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Distilled policy"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 60,
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "{'self': '<wasserstein_mdp.WassersteinMarkovDecisionProcess object at 0x2ba6945ccf10>', 'state_shape': '(3,)', 'action_shape': '(1,)', 'reward_shape': '(1,)', 'label_shape': '(4,)', 'discretize_action_space': 'True', 'state_encoder_network': \"ModelArchitecture(hidden_units=[256, 256, 256], activation='relu', name='state_encoder_network_base')\", 'action_decoder_network': \"ModelArchitecture(hidden_units=[256, 256, 256], activation='relu', name='action_decoder_network_base')\", 'transition_network': \"ModelArchitecture(hidden_units=[256, 256, 256], activation='relu', name='transition_network_base')\", 'reward_network': \"ModelArchitecture(hidden_units=[256, 256, 256], activation='relu', name='reward_network_base')\", 'decoder_network': \"ModelArchitecture(hidden_units=[256, 256, 256], activation='relu', name='state_decoder_network_base')\", 'latent_policy_network': \"ModelArchitecture(hidden_units=[256, 256, 256], activation='relu', name='discrete_policy_network_base')\", 'steady_state_lipschitz_network': \"ModelArchitecture(hidden_units=[256, 256, 256], activation='relu', name='steady_state_network_base')\", 'transition_loss_lipschitz_network': \"ModelArchitecture(hidden_units=[256, 256, 256], activation='relu', name='transition_loss_network_base')\", 'latent_state_size': '13', 'number_of_discrete_actions': '3', 'action_encoder_network': \"ModelArchitecture(hidden_units=[256, 256, 256], activation='relu', name='action_encoder_network_base')\", 'state_encoder_pre_processing_network': 'None', 'state_decoder_pre_processing_network': 'None', 'time_stacked_states': 'False', 'state_encoder_temperature': '0.6666666666666666', 'state_prior_temperature': '0.6666666666666666', 'action_encoder_temperature': '0.3333333333333333', 'latent_policy_temperature': '0.5', 'wasserstein_regularizer_scale_factor': 'WassersteinRegularizerScaleFactor(global_scaling=10.0, global_gradient_penalty_multiplier=10.0, steady_state_scaling=25.0, steady_state_gradient_penalty_multiplier=None, local_transition_loss_scaling=25.0, local_transition_loss_gradient_penalty_multiplier=None)', 'encoder_temperature_decay_rate': '0.0', 'prior_temperature_decay_rate': '0.0', 'reset_state_label': 'True', 'autoencoder_optimizer': 'None', 'wasserstein_regularizer_optimizer': 'None', 'entropy_regularizer_scale_factor': '0.0', 'entropy_regularizer_decay_rate': '0.0', 'entropy_regularizer_scale_factor_min_value': '0.0', 'importance_sampling_exponent': '0.4', 'importance_sampling_exponent_growth_rate': '7e-05', 'time_stacked_lstm_units': '128', 'reward_bounds': 'None', 'latent_stationary_network': 'None', 'action_entropy_regularizer_scaling': '0.0', 'enforce_upper_bound': 'False', 'squared_wasserstein': 'True', 'n_critic': '5', 'trainable_prior': 'False', 'state_encoder_type': 'EncodingType.DETERMINISTIC', 'policy_based_decoding': 'False', 'deterministic_state_embedding': 'True', 'state_encoder_softclipping': 'True', 'args': '()', 'kwargs': \"{'evaluation_window_size': 0}\", '__class__': \"<class 'wasserstein_mdp.WassersteinMarkovDecisionProcess'>\", 'eval_policy': '-107.5308', 'local_reward_loss': '0.026674531', 'local_transition_loss': '0.5395084', 'training_step': '370000'}\n",
      "Model: \"state_encoder\"\n",
      "_________________________________________________________________\n",
      " Layer (type)                Output Shape              Param #   \n",
      "=================================================================\n",
      " state (InputLayer)          [(None, 3)]               0         \n",
      "                                                                 \n",
      " state_encoder_body (Sequent  (None, 256)              132608    \n",
      " ial)                                                            \n",
      "                                                                 \n",
      " dense_38 (Dense)            (None, 8)                 2056      \n",
      "                                                                 \n",
      "=================================================================\n",
      "Total params: 134,664\n",
      "Trainable params: 134,664\n",
      "Non-trainable params: 0\n",
      "_________________________________________________________________\n",
      "Model: \"action_encoder\"\n",
      "__________________________________________________________________________________________________\n",
      " Layer (type)                   Output Shape         Param #     Connected to                     \n",
      "==================================================================================================\n",
      " latent_state (InputLayer)      [(None, 13)]         0           []                               \n",
      "                                                                                                  \n",
      " action (InputLayer)            [(None, 1)]          0           []                               \n",
      "                                                                                                  \n",
      " action_encoder_input (Concaten  (None, 14)          0           ['latent_state[0][0]',           \n",
      " ate)                                                             'action[0][0]']                 \n",
      "                                                                                                  \n",
      " action_encoder_network_base (S  (None, 256)         135424      ['action_encoder_input[0][0]']   \n",
      " equential)                                                                                       \n",
      "                                                                                                  \n",
      " action_encoder_categorical_log  (None, 3)           771         ['action_encoder_network_base[0][\n",
      " its (Dense)                                                     0]']                             \n",
      "                                                                                                  \n",
      "==================================================================================================\n",
      "Total params: 136,195\n",
      "Trainable params: 136,195\n",
      "Non-trainable params: 0\n",
      "__________________________________________________________________________________________________\n",
      "Model: \"autoregressive_transition_network\"\n",
      "__________________________________________________________________________________________________\n",
      " Layer (type)                   Output Shape         Param #     Connected to                     \n",
      "==================================================================================================\n",
      " logistic_layer_input (InputLay  [(None, 16)]        0           []                               \n",
      " er)                                                                                              \n",
      "                                                                                                  \n",
      " sequential_logistic_distributi  (None, 13)          0           ['logistic_layer_input[0][0]']   \n",
      " on_layer (Sequential)                                                                            \n",
      "                                                                                                  \n",
      " autoregressive_transform (Auto  ((None, 13),        151006      ['sequential_logistic_distributio\n",
      " regressiveTransform)            (None, 13))                     n_layer[0][0]',                  \n",
      "                                                                  'logistic_layer_input[0][0]']   \n",
      "                                                                                                  \n",
      " autoregressive_network_20 (Aut  (None, 13, 1)       151005      []                               \n",
      " oregressiveNetwork)                                                                              \n",
      "                                                                                                  \n",
      "==================================================================================================\n",
      "Total params: 151,006\n",
      "Trainable params: 151,005\n",
      "Non-trainable params: 1\n",
      "__________________________________________________________________________________________________\n",
      "Model: \"latent_stationary_network\"\n",
      "_________________________________________________________________\n",
      " Layer (type)                Output Shape              Param #   \n",
      "=================================================================\n",
      " logistic_layer_input (Input  [(None, 0)]              0         \n",
      " Layer)                                                          \n",
      "                                                                 \n",
      " sequential_logistic_distrib  (None, 5)                0         \n",
      " ution_layer (Sequential)                                        \n",
      "                                                                 \n",
      " autoregressive_transform (A  ((None, 5),              134406    \n",
      " utoregressiveTransform)      (None, 5))                         \n",
      "                                                                 \n",
      " autoregressive_network_21 (  (None, 5, 1)             134405    \n",
      " AutoregressiveNetwork)                                          \n",
      "                                                                 \n",
      "=================================================================\n",
      "Total params: 134,414\n",
      "Trainable params: 134,405\n",
      "Non-trainable params: 9\n",
      "_________________________________________________________________\n",
      "Model: \"latent_policy_network\"\n",
      "_________________________________________________________________\n",
      " Layer (type)                Output Shape              Param #   \n",
      "=================================================================\n",
      " latent_state (InputLayer)   [(None, 13)]              0         \n",
      "                                                                 \n",
      " discrete_policy_network_bas  (None, 256)              135168    \n",
      " e (Sequential)                                                  \n",
      "                                                                 \n",
      " latent_policy_categorical_l  (None, 3)                771       \n",
      " ogits (Dense)                                                   \n",
      "                                                                 \n",
      "=================================================================\n",
      "Total params: 135,939\n",
      "Trainable params: 135,939\n",
      "Non-trainable params: 0\n",
      "_________________________________________________________________\n",
      "Model: \"reward_network\"\n",
      "__________________________________________________________________________________________________\n",
      " Layer (type)                   Output Shape         Param #     Connected to                     \n",
      "==================================================================================================\n",
      " latent_state (InputLayer)      [(None, 13)]         0           []                               \n",
      "                                                                                                  \n",
      " latent_action (InputLayer)     [(None, 3)]          0           []                               \n",
      "                                                                                                  \n",
      " next_latent_state (InputLayer)  [(None, 13)]        0           []                               \n",
      "                                                                                                  \n",
      " reward_function_input (Concate  (None, 29)          0           ['latent_state[0][0]',           \n",
      " nate)                                                            'latent_action[0][0]',          \n",
      "                                                                  'next_latent_state[0][0]']      \n",
      "                                                                                                  \n",
      " reward_network_base (Sequentia  (None, 256)         139264      ['reward_function_input[0][0]']  \n",
      " l)                                                                                               \n",
      "                                                                                                  \n",
      " reward_network_raw_output (Den  (None, 1)           257         ['reward_network_base[0][0]']    \n",
      " se)                                                                                              \n",
      "                                                                                                  \n",
      " reward (Reshape)               (None, 1)            0           ['reward_network_raw_output[0][0]\n",
      "                                                                 ']                               \n",
      "                                                                                                  \n",
      "==================================================================================================\n",
      "Total params: 139,521\n",
      "Trainable params: 139,521\n",
      "Non-trainable params: 0\n",
      "__________________________________________________________________________________________________\n",
      "Model: \"state_reconstruction_network\"\n",
      "_________________________________________________________________\n",
      " Layer (type)                Output Shape              Param #   \n",
      "=================================================================\n",
      " next_latent_state (InputLay  [(None, 13)]             0         \n",
      " er)                                                             \n",
      "                                                                 \n",
      " state_decoder_network_base   (None, 256)              135168    \n",
      " (Sequential)                                                    \n",
      "                                                                 \n",
      " state_decoder (Sequential)  (None, 3)                 771       \n",
      "                                                                 \n",
      "=================================================================\n",
      "Total params: 135,939\n",
      "Trainable params: 135,939\n",
      "Non-trainable params: 0\n",
      "_________________________________________________________________\n",
      "Model: \"action_reconstruction_network\"\n",
      "__________________________________________________________________________________________________\n",
      " Layer (type)                   Output Shape         Param #     Connected to                     \n",
      "==================================================================================================\n",
      " latent_state (InputLayer)      [(None, 13)]         0           []                               \n",
      "                                                                                                  \n",
      " latent_action (InputLayer)     [(None, 3)]          0           []                               \n",
      "                                                                                                  \n",
      " action_reconstruction_input (C  (None, 16)          0           ['latent_state[0][0]',           \n",
      " oncatenate)                                                      'latent_action[0][0]']          \n",
      "                                                                                                  \n",
      " action_decoder_network_base (S  (None, 256)         135936      ['action_reconstruction_input[0][\n",
      " equential)                                                      0]']                             \n",
      "                                                                                                  \n",
      " action_reconstruction_network_  (None, 1)           257         ['action_decoder_network_base[0][\n",
      " raw_output (Dense)                                              0]']                             \n",
      "                                                                                                  \n",
      " action_reconstruction_network_  (None, 1)           0           ['action_reconstruction_network_r\n",
      " output (Reshape)                                                aw_output[0][0]']                \n",
      "                                                                                                  \n",
      "==================================================================================================\n",
      "Total params: 136,193\n",
      "Trainable params: 136,193\n",
      "Non-trainable params: 0\n",
      "__________________________________________________________________________________________________\n",
      "Model: \"steady_state_lipschitz_network\"\n",
      "__________________________________________________________________________________________________\n",
      " Layer (type)                   Output Shape         Param #     Connected to                     \n",
      "==================================================================================================\n",
      " latent_state (InputLayer)      [(None, 13)]         0           []                               \n",
      "                                                                                                  \n",
      " latent_action (InputLayer)     [(None, 3)]          0           []                               \n",
      "                                                                                                  \n",
      " next_latent_state (InputLayer)  [(None, 13)]        0           []                               \n",
      "                                                                                                  \n",
      " concatenate_20 (Concatenate)   (None, 29)           0           ['latent_state[0][0]',           \n",
      "                                                                  'latent_action[0][0]',          \n",
      "                                                                  'next_latent_state[0][0]']      \n",
      "                                                                                                  \n",
      " steady_state_network_base (Seq  (None, 256)         139264      ['concatenate_20[0][0]']         \n",
      " uential)                                                                                         \n",
      "                                                                                                  \n",
      " steady_state_lipschitz_network  (None, 1)           257         ['steady_state_network_base[0][0]\n",
      " _output (Dense)                                                 ']                               \n",
      "                                                                                                  \n",
      "==================================================================================================\n",
      "Total params: 139,521\n",
      "Trainable params: 139,521\n",
      "Non-trainable params: 0\n",
      "__________________________________________________________________________________________________\n",
      "Model: \"transition_loss_lipschitz_network\"\n",
      "__________________________________________________________________________________________________\n",
      " Layer (type)                   Output Shape         Param #     Connected to                     \n",
      "==================================================================================================\n",
      " state (InputLayer)             [(None, 3)]          0           []                               \n",
      "                                                                                                  \n",
      " action (InputLayer)            [(None, 1)]          0           []                               \n",
      "                                                                                                  \n",
      " latent_state (InputLayer)      [(None, 13)]         0           []                               \n",
      "                                                                                                  \n",
      " latent_action (InputLayer)     [(None, 3)]          0           []                               \n",
      "                                                                                                  \n",
      " next_latent_state (InputLayer)  [(None, 13)]        0           []                               \n",
      "                                                                                                  \n",
      " concatenate_21 (Concatenate)   (None, 33)           0           ['state[0][0]',                  \n",
      "                                                                  'action[0][0]',                 \n",
      "                                                                  'latent_state[0][0]',           \n",
      "                                                                  'latent_action[0][0]',          \n",
      "                                                                  'next_latent_state[0][0]']      \n",
      "                                                                                                  \n",
      " transition_loss_network_base (  (None, 256)         140288      ['concatenate_21[0][0]']         \n",
      " Sequential)                                                                                      \n",
      "                                                                                                  \n",
      " transition_loss_lipschitz_netw  (None, 1)           257         ['transition_loss_network_base[0]\n",
      " ork_output (Dense)                                              [0]']                            \n",
      "                                                                                                  \n",
      "==================================================================================================\n",
      "Total params: 140,545\n",
      "Trainable params: 140,545\n",
      "Non-trainable params: 0\n",
      "__________________________________________________________________________________________________\n",
      "WAE-MDP loaded\n"
     ]
    }
   ],
   "source": [
    "wae_model_path = 'saved_models/experiments/PendulumRandomInit-v1/model/'\n",
    "\n",
    "with open(os.path.join(wae_model_path, 'model_infos.json'), 'r') as f:\n",
    "    wae_data = json.load(f)\n",
    "    print(wae_data)\n",
    "\n",
    "wae_mdp = wasserstein_mdp.load(wae_model_path)\n",
    "\n",
    "print(\"WAE-MDP loaded\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 62,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "WAE-MDP at training step 370000\n",
      "Size of the latent state space: 8192\n",
      "Size of the latent action space: 3\n",
      "Local reward loss: 0.0266745\n",
      "Local transition loss: 0.539508\n"
     ]
    }
   ],
   "source": [
    "print(\"WAE-MDP at training step {:d}\".format(eval(wae_data['training_step'])))\n",
    "print(\"Size of the latent state space: {:d}\".format(2 ** wae_mdp.latent_state_size))\n",
    "print(\"Size of the latent action space: {:d}\".format(wae_mdp.number_of_discrete_actions))\n",
    "print('Local reward loss: {:.6g}'.format(eval(wae_data['local_reward_loss'])))\n",
    "print('Local transition loss: {:.6g}'.format(eval(wae_data['local_transition_loss'])))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "video_path = 'policy_videos/pendulum_wae_distillation'\n",
    "\n",
    "def latent_labeling_fn(time_step):\n",
    "    latent_state = time_step.observation['latent_state']\n",
    "    return {\n",
    "        'safe_region': latent_state.numpy()[..., 0],\n",
    "    }\n",
    "\n",
    "with suite_gym.load('Pendulum-v1') as py_env:\n",
    "    py_env.reset()\n",
    "    tf_env = tf_py_environment.TFPyEnvironment(py_env)\n",
    "    original_state = tf_env.current_time_step().observation\n",
    "    \n",
    "    tf_env = wae_mdp.wrap_tf_environment(tf_env, labeling_functions['Pendulum-v1'])\n",
    "    policy =tf_env.wrap_latent_policy(wae_mdp.get_latent_policy(action_dtype=tf.int64))\n",
    "    \n",
    "    num_episodes=30\n",
    "    reward_metric = tf_metrics.AverageReturnMetric()\n",
    "    discounted_reward_metric = AverageDiscountedReturnMetric(\n",
    "        gamma=.99, reward_scale=wae_mdp._dynamic_reward_scaling)\n",
    "    video_observer = video.VideoEmbeddingObserver(\n",
    "        py_env, video_path, num_episodes=num_episodes, labeling_function=latent_labeling_fn)\n",
    "    \n",
    "    dynamic_episode_driver.DynamicEpisodeDriver(\n",
    "        tf_env, policy, num_episodes=num_episodes,\n",
    "        observers=[\n",
    "            reward_metric,\n",
    "            discounted_reward_metric,\n",
    "            video_observer,\n",
    "        ]).run()\n",
    "    \n",
    "\n",
    "tf.print('avg. episode return:', reward_metric.result())\n",
    "tf.print('avg. discounted (scaled) return:', discounted_reward_metric.result())\n",
    "embed_mp4(video_observer.file_name)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Local reward loss: 0.027\n",
      "Local transition loss: 0.54\n",
      "Transition/reward model generation\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "2022-05-11 21:15:48.355988: W tensorflow/core/framework/cpu_allocator_impl.cc:82] Allocation of 4784408256 exceeds 10% of free system memory.\n",
      "2022-05-11 21:15:49.709997: W tensorflow/core/framework/cpu_allocator_impl.cc:82] Allocation of 4784408256 exceeds 10% of free system memory.\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Time to generate the model: 144.71 sec\n"
     ]
    }
   ],
   "source": [
    "frequency_estimation = False\n",
    "\n",
    "_latent_reward_fn = lambda latent_state, latent_action, next_latent_state: \\\n",
    "    wae_mdp._dynamic_reward_scaling * wae_mdp.reward_distribution(\n",
    "        latent_state=tf.cast(latent_state, dtype=tf.float32),\n",
    "        latent_action=tf.cast(latent_action, dtype=tf.float32),\n",
    "        next_latent_state=tf.cast(next_latent_state, dtype=tf.float32),\n",
    "    ).mode() \n",
    "# as the distribution is deterministic, taking the mode\n",
    "# allows to retrieve the Dirac impulsion \n",
    "\n",
    "_latent_transition_fn = lambda latent_state, latent_action: \\\n",
    "        wae_mdp.discrete_latent_transition(\n",
    "            tf.cast(latent_state, tf.float32),\n",
    "            tf.cast(latent_action, tf.float32))\n",
    "\n",
    "print('Local reward loss: {:.2g}'.format(eval(wae_data['local_reward_loss'])))\n",
    "print('Local transition loss: {:.2g}'.format(eval(wae_data['local_transition_loss'])))\n",
    "\n",
    "print('Transition/reward model generation')\n",
    "#  write the transition/reward functions to tensors,\n",
    "#  to formally check the values in an efficient way\n",
    "start = time.time()\n",
    "\n",
    "if frequency_estimation:\n",
    "    # compute the latent transition function by frequency estimation and use the\n",
    "    # latent transition function learned during the WAE optimization as backup function\n",
    "    with suite_gym.load(\n",
    "        'PendulumRandomInit-v1',\n",
    "        env_wrappers=[lambda env: perturbed_env.PerturbedEnvironment(env, .75)]\n",
    "    ) as py_env:\n",
    "        py_env.reset()\n",
    "        tf_env = tf_py_environment.TFPyEnvironment(py_env)\n",
    "        latent_transition_fn = model.estimate_latent_transition_function_from_samples(\n",
    "            environment=tf_env,\n",
    "            n_steps=100000,\n",
    "            state_embedding_function=wae_mdp.state_embedding_function,\n",
    "            action_embedding_function=wae_mdp.action_embedding_function,\n",
    "            labeling_function=labeling_functions['Pendulum-v1'],\n",
    "            latent_state_size=wae_mdp.latent_state_size,\n",
    "            number_of_discrete_actions=wae_mdp.number_of_discrete_actions,\n",
    "            latent_policy=wae_mdp.get_latent_policy(action_dtype=tf.int64),\n",
    "            backup_transition_fn=_latent_transition_fn)\n",
    "else:\n",
    "    latent_transition_fn = model.TransitionFunctionCopy(\n",
    "        num_states=tf.cast(tf.pow(2, wae_mdp.latent_state_size), dtype=tf.int32),\n",
    "        num_actions=wae_mdp.number_of_discrete_actions,\n",
    "        transition_function=_latent_transition_fn,\n",
    "        epsilon=0.)\n",
    "\n",
    "latent_reward_fn = model.RewardFunctionCopy(\n",
    "    num_states=tf.cast(tf.pow(2, wae_mdp.latent_state_size), dtype=tf.int32),\n",
    "    num_actions=wae_mdp.number_of_discrete_actions,\n",
    "    reward_function=_latent_reward_fn,\n",
    "    transition_function=_latent_transition_fn,\n",
    "    epsilon=1e-6)\n",
    "\n",
    "end = time.time() - start\n",
    "\n",
    "print(\"Time to generate the model: {:.2f} sec\".format(end))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Value difference: 4.33006\n",
      "Time to compute the value difference: 168.700369 sec\n"
     ]
    }
   ],
   "source": [
    "start = time.time()\n",
    "\n",
    "with tf.device('/CPU:0'):\n",
    "    latent_mdp_values = compute_values_from_initial_distribution(\n",
    "        latent_state_size=wae_mdp.latent_state_size,\n",
    "        atomic_prop_dims=wae_mdp.atomic_prop_dims,\n",
    "        original_state=original_state,\n",
    "        number_of_discrete_actions=wae_mdp.number_of_discrete_actions,\n",
    "        latent_policy=wae_mdp.get_latent_policy(action_dtype=tf.int64),\n",
    "        latent_transition_fn=latent_transition_fn,\n",
    "        latent_reward_function=latent_reward_fn,\n",
    "        epsilon=1e-6,\n",
    "        gamma=.99,\n",
    "        stochastic_state_embedding=lambda original_state: tfd.Independent(\n",
    "            tfd.Deterministic(loc=wae_mdp.state_embedding_function(\n",
    "                original_state,\n",
    "                ergodic_batched_labeling_function(\n",
    "                    labeling_functions['Pendulum-v1']\n",
    "                )(original_state))),\n",
    "            reinterpreted_batch_ndims=1)\n",
    "    )\n",
    "\n",
    "value_difference = tf.abs(discounted_reward_metric.result() - latent_mdp_values)\n",
    "\n",
    "tf.print(\"Value difference: {:.6g}\".format(value_difference))\n",
    "\n",
    "end = time.time() - start\n",
    "\n",
    "print(\"Time to compute the value difference: {:2f} sec\".format(end))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Target behavior:\n",
    "$\\varphi = \\mathsf{true} \\, \\mathcal{U} \\, \\left[\\mathsf{Upright} \\, \\mathcal{U} \\, \\mathsf{Reset} \\right] $, i.e., *reaching a safe region of the system*.\n",
    "\n",
    "The system reaches a safe region iff the pendulum eventually remains upright during the episode, i.e., its angle compared to y axis remains in a tight range: $60^{\\circ} = {\\pi}/{3}$ rad.\n",
    "A clear *failure* is when the pendulum never reaches this safe region: $\\neg \\varphi$"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "property values: 0.0348492\n",
      "Time to compute the values of the property: 81.342587 sec\n"
     ]
    }
   ],
   "source": [
    "upright_test_fn = lambda latent_state: tf.cast(latent_state[..., 0], tf.bool)\n",
    "reset_state_test_fn = lambda latent_state: is_reset_state(latent_state, wae_mdp.atomic_prop_dims)\n",
    "\n",
    "absorbing_states = reset_state_test_fn\n",
    "\n",
    "state_space = binary_latent_space(wae_mdp.latent_state_size, dtype=tf.float32)\n",
    "\n",
    "# retrieve a reward of 1 when the system transitions from unsafe states to the reset state \n",
    "\n",
    "with tf.device('/CPU:0'):\n",
    "    all_states = tf.ones(\n",
    "        shape=(tf.pow(2, wae_mdp.latent_state_size),\n",
    "               wae_mdp.number_of_discrete_actions,\n",
    "               tf.pow(2, wae_mdp.latent_state_size)))\n",
    "    # set unsafe states rows to 1 \n",
    "    reward_from_unsafe_states = tf.transpose(\n",
    "        all_states * tf.cast(\n",
    "            tf.logical_not(upright_test_fn(state_space)),\n",
    "            tf.float32))\n",
    "    # set reset states columns to 1 \n",
    "    reward_to_reset_states = all_states * tf.cast(\n",
    "        reset_state_test_fn(state_space),\n",
    "        tf.float32)\n",
    "    reward_objective = reward_from_unsafe_states * reward_to_reset_states\n",
    "\n",
    "reward_objective_fn = namedtuple('reward_fn', ['to_dense'])(lambda: reward_objective)\n",
    "\n",
    "start = time.time()\n",
    "\n",
    "with tf.device('/CPU:0'):\n",
    "    latent_mdp_values = compute_values_from_initial_distribution(\n",
    "        latent_state_size=wae_mdp.latent_state_size,\n",
    "        atomic_prop_dims=wae_mdp.atomic_prop_dims,\n",
    "        original_state=original_state,\n",
    "        number_of_discrete_actions=wae_mdp.number_of_discrete_actions,\n",
    "        latent_policy=wae_mdp.get_latent_policy(action_dtype=tf.int64),\n",
    "        latent_transition_fn=latent_transition_fn,\n",
    "        latent_reward_function=reward_objective_fn,\n",
    "        epsilon=1e-6,\n",
    "        gamma=.99,\n",
    "        stochastic_state_embedding=lambda original_state: tfd.Independent(\n",
    "            tfd.Deterministic(\n",
    "                loc=wae_mdp.state_embedding_function(\n",
    "                    original_state,\n",
    "                    ergodic_batched_labeling_function(\n",
    "                        labeling_functions['Pendulum-v1']\n",
    "                    )(original_state))),\n",
    "            reinterpreted_batch_ndims=1),\n",
    "        absorbing_states=absorbing_states)\n",
    "\n",
    "tf.print(\"property values: {:.6g}\".format(latent_mdp_values))\n",
    "\n",
    "end = time.time() - start\n",
    "print(\"Time to compute the values of the property: {:2f} sec\".format(end))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# LunarLander Continuous"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## RL policy (SAC)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "video_path = 'policy_videos/lunar_lander_sac'\n",
    "\n",
    "with suite_gym.load('LunarLanderContinuous-v2') as py_env:\n",
    "    py_env.reset()\n",
    "    tf_env = tf_py_environment.TFPyEnvironment(py_env)\n",
    "    \n",
    "    display_state_space(py_env)\n",
    "    display_action_space(py_env)\n",
    "\n",
    "    policy_dir = '../reinforcement_learning/saves/LunarLanderContinuous-v2/sac_policy'\n",
    "    policy = SavedTFPolicy(policy_dir)\n",
    "    num_episodes=30\n",
    "\n",
    "    reward_metric = tf_metrics.AverageReturnMetric()\n",
    "    \n",
    "    video_observer = video.VideoEmbeddingObserver(\n",
    "        py_env, video_path, num_episodes=num_episodes)\n",
    "    dynamic_episode_driver.DynamicEpisodeDriver(\n",
    "        tf_env,\n",
    "        policy,\n",
    "        num_episodes=num_episodes,\n",
    "        observers=[\n",
    "            reward_metric,\n",
    "            video_observer,\n",
    "        ]).run()\n",
    "\n",
    "    tf.print('avg. episode return:', reward_metric.result())\n",
    "\n",
    "embed_mp4(video_observer.file_name)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Distilled policy"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 63,
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "{'self': '<wasserstein_mdp.WassersteinMarkovDecisionProcess object at 0x2b74bebf5040>', 'state_shape': '(8,)', 'action_shape': '(2,)', 'reward_shape': '(1,)', 'label_shape': '(6,)', 'discretize_action_space': 'True', 'state_encoder_network': \"ModelArchitecture(hidden_units=[256], activation='relu', name='state_encoder_network_base')\", 'action_decoder_network': \"ModelArchitecture(hidden_units=[256], activation='relu', name='action_decoder_network_base')\", 'transition_network': \"ModelArchitecture(hidden_units=[256], activation='relu', name='transition_network_base')\", 'reward_network': \"ModelArchitecture(hidden_units=[256], activation='relu', name='reward_network_base')\", 'decoder_network': \"ModelArchitecture(hidden_units=[256], activation='relu', name='state_decoder_network_base')\", 'latent_policy_network': \"ModelArchitecture(hidden_units=[256], activation='relu', name='discrete_policy_network_base')\", 'steady_state_lipschitz_network': \"ModelArchitecture(hidden_units=[256], activation='relu', name='steady_state_network_base')\", 'transition_loss_lipschitz_network': \"ModelArchitecture(hidden_units=[256], activation='relu', name='transition_loss_network_base')\", 'latent_state_size': '14', 'number_of_discrete_actions': '3', 'action_encoder_network': \"ModelArchitecture(hidden_units=[256], activation='relu', name='action_encoder_network_base')\", 'state_encoder_pre_processing_network': 'None', 'state_decoder_pre_processing_network': 'None', 'time_stacked_states': 'False', 'state_encoder_temperature': '0.6', 'state_prior_temperature': '0.75', 'action_encoder_temperature': '0.3333333333333333', 'latent_policy_temperature': '0.5', 'wasserstein_regularizer_scale_factor': 'WassersteinRegularizerScaleFactor(global_scaling=10.0, global_gradient_penalty_multiplier=20.0, steady_state_scaling=100.0, steady_state_gradient_penalty_multiplier=None, local_transition_loss_scaling=50.0, local_transition_loss_gradient_penalty_multiplier=None)', 'encoder_temperature_decay_rate': '0.0', 'prior_temperature_decay_rate': '0.0', 'reset_state_label': 'True', 'autoencoder_optimizer': 'None', 'wasserstein_regularizer_optimizer': 'None', 'entropy_regularizer_scale_factor': '0.0', 'entropy_regularizer_decay_rate': '0.0', 'entropy_regularizer_scale_factor_min_value': '0.0', 'importance_sampling_exponent': '0.4', 'importance_sampling_exponent_growth_rate': '1e-05', 'time_stacked_lstm_units': '128', 'reward_bounds': 'None', 'latent_stationary_network': 'None', 'action_entropy_regularizer_scaling': '0.0', 'enforce_upper_bound': 'False', 'squared_wasserstein': 'True', 'n_critic': '15', 'trainable_prior': 'False', 'state_encoder_type': 'EncodingType.DETERMINISTIC', 'policy_based_decoding': 'False', 'deterministic_state_embedding': 'True', 'state_encoder_softclipping': 'False', 'args': '()', 'kwargs': \"{'evaluation_window_size': 0}\", '__class__': \"<class 'wasserstein_mdp.WassersteinMarkovDecisionProcess'>\", 'eval_policy': '282.56876', 'local_reward_loss': '0.020720486', 'local_transition_loss': '0.13135736', 'training_step': '320000'}\n",
      "Model: \"state_encoder\"\n",
      "_________________________________________________________________\n",
      " Layer (type)                Output Shape              Param #   \n",
      "=================================================================\n",
      " state (InputLayer)          [(None, 8)]               0         \n",
      "                                                                 \n",
      " state_encoder_body (Sequent  (None, 252)              2268      \n",
      " ial)                                                            \n",
      "                                                                 \n",
      " dense_40 (Dense)            (None, 7)                 1771      \n",
      "                                                                 \n",
      "=================================================================\n",
      "Total params: 4,039\n",
      "Trainable params: 4,039\n",
      "Non-trainable params: 0\n",
      "_________________________________________________________________\n",
      "Model: \"action_encoder\"\n",
      "__________________________________________________________________________________________________\n",
      " Layer (type)                   Output Shape         Param #     Connected to                     \n",
      "==================================================================================================\n",
      " latent_state (InputLayer)      [(None, 14)]         0           []                               \n",
      "                                                                                                  \n",
      " action (InputLayer)            [(None, 2)]          0           []                               \n",
      "                                                                                                  \n",
      " action_encoder_input (Concaten  (None, 16)          0           ['latent_state[0][0]',           \n",
      " ate)                                                             'action[0][0]']                 \n",
      "                                                                                                  \n",
      " action_encoder_network_base (S  (None, 256)         4352        ['action_encoder_input[0][0]']   \n",
      " equential)                                                                                       \n",
      "                                                                                                  \n",
      " action_encoder_categorical_log  (None, 3)           771         ['action_encoder_network_base[0][\n",
      " its (Dense)                                                     0]']                             \n",
      "                                                                                                  \n",
      "==================================================================================================\n",
      "Total params: 5,123\n",
      "Trainable params: 5,123\n",
      "Non-trainable params: 0\n",
      "__________________________________________________________________________________________________\n",
      "Model: \"autoregressive_transition_network\"\n",
      "__________________________________________________________________________________________________\n",
      " Layer (type)                   Output Shape         Param #     Connected to                     \n",
      "==================================================================================================\n",
      " logistic_layer_input (InputLay  [(None, 17)]        0           []                               \n",
      " er)                                                                                              \n",
      "                                                                                                  \n",
      " sequential_logistic_distributi  (None, 14)          0           ['logistic_layer_input[0][0]']   \n",
      " on_layer (Sequential)                                                                            \n",
      "                                                                                                  \n",
      " autoregressive_transform (Auto  ((None, 14),        12029       ['sequential_logistic_distributio\n",
      " regressiveTransform)            (None, 14))                     n_layer[0][0]',                  \n",
      "                                                                  'logistic_layer_input[0][0]']   \n",
      "                                                                                                  \n",
      " autoregressive_network_22 (Aut  (None, 14, 1)       12028       []                               \n",
      " oregressiveNetwork)                                                                              \n",
      "                                                                                                  \n",
      "==================================================================================================\n",
      "Total params: 12,029\n",
      "Trainable params: 12,028\n",
      "Non-trainable params: 1\n",
      "__________________________________________________________________________________________________\n",
      "Model: \"latent_stationary_network\"\n",
      "_________________________________________________________________\n",
      " Layer (type)                Output Shape              Param #   \n",
      "=================================================================\n",
      " logistic_layer_input (Input  [(None, 0)]              0         \n",
      " Layer)                                                          \n",
      "                                                                 \n",
      " sequential_logistic_distrib  (None, 7)                0         \n",
      " ution_layer (Sequential)                                        \n",
      "                                                                 \n",
      " autoregressive_transform (A  ((None, 7),              3848      \n",
      " utoregressiveTransform)      (None, 7))                         \n",
      "                                                                 \n",
      " autoregressive_network_23 (  (None, 7, 1)             3847      \n",
      " AutoregressiveNetwork)                                          \n",
      "                                                                 \n",
      "=================================================================\n",
      "Total params: 3,855\n",
      "Trainable params: 3,847\n",
      "Non-trainable params: 8\n",
      "_________________________________________________________________\n",
      "Model: \"latent_policy_network\"\n",
      "_________________________________________________________________\n",
      " Layer (type)                Output Shape              Param #   \n",
      "=================================================================\n",
      " latent_state (InputLayer)   [(None, 14)]              0         \n",
      "                                                                 \n",
      " discrete_policy_network_bas  (None, 256)              3840      \n",
      " e (Sequential)                                                  \n",
      "                                                                 \n",
      " latent_policy_categorical_l  (None, 3)                771       \n",
      " ogits (Dense)                                                   \n",
      "                                                                 \n",
      "=================================================================\n",
      "Total params: 4,611\n",
      "Trainable params: 4,611\n",
      "Non-trainable params: 0\n",
      "_________________________________________________________________\n",
      "Model: \"reward_network\"\n",
      "__________________________________________________________________________________________________\n",
      " Layer (type)                   Output Shape         Param #     Connected to                     \n",
      "==================================================================================================\n",
      " latent_state (InputLayer)      [(None, 14)]         0           []                               \n",
      "                                                                                                  \n",
      " latent_action (InputLayer)     [(None, 3)]          0           []                               \n",
      "                                                                                                  \n",
      " next_latent_state (InputLayer)  [(None, 14)]        0           []                               \n",
      "                                                                                                  \n",
      " reward_function_input (Concate  (None, 31)          0           ['latent_state[0][0]',           \n",
      " nate)                                                            'latent_action[0][0]',          \n",
      "                                                                  'next_latent_state[0][0]']      \n",
      "                                                                                                  \n",
      " reward_network_base (Sequentia  (None, 256)         8192        ['reward_function_input[0][0]']  \n",
      " l)                                                                                               \n",
      "                                                                                                  \n",
      " reward_network_raw_output (Den  (None, 1)           257         ['reward_network_base[0][0]']    \n",
      " se)                                                                                              \n",
      "                                                                                                  \n",
      " reward (Reshape)               (None, 1)            0           ['reward_network_raw_output[0][0]\n",
      "                                                                 ']                               \n",
      "                                                                                                  \n",
      "==================================================================================================\n",
      "Total params: 8,449\n",
      "Trainable params: 8,449\n",
      "Non-trainable params: 0\n",
      "__________________________________________________________________________________________________\n",
      "Model: \"state_reconstruction_network\"\n",
      "_________________________________________________________________\n",
      " Layer (type)                Output Shape              Param #   \n",
      "=================================================================\n",
      " next_latent_state (InputLay  [(None, 14)]             0         \n",
      " er)                                                             \n",
      "                                                                 \n",
      " state_decoder_network_base   (None, 256)              3840      \n",
      " (Sequential)                                                    \n",
      "                                                                 \n",
      " state_decoder (Sequential)  (None, 8)                 2056      \n",
      "                                                                 \n",
      "=================================================================\n",
      "Total params: 5,896\n",
      "Trainable params: 5,896\n",
      "Non-trainable params: 0\n",
      "_________________________________________________________________\n",
      "Model: \"action_reconstruction_network\"\n",
      "__________________________________________________________________________________________________\n",
      " Layer (type)                   Output Shape         Param #     Connected to                     \n",
      "==================================================================================================\n",
      " latent_state (InputLayer)      [(None, 14)]         0           []                               \n",
      "                                                                                                  \n",
      " latent_action (InputLayer)     [(None, 3)]          0           []                               \n",
      "                                                                                                  \n",
      " action_reconstruction_input (C  (None, 17)          0           ['latent_state[0][0]',           \n",
      " oncatenate)                                                      'latent_action[0][0]']          \n",
      "                                                                                                  \n",
      " action_decoder_network_base (S  (None, 256)         4608        ['action_reconstruction_input[0][\n",
      " equential)                                                      0]']                             \n",
      "                                                                                                  \n",
      " action_reconstruction_network_  (None, 2)           514         ['action_decoder_network_base[0][\n",
      " raw_output (Dense)                                              0]']                             \n",
      "                                                                                                  \n",
      " action_reconstruction_network_  (None, 2)           0           ['action_reconstruction_network_r\n",
      " output (Reshape)                                                aw_output[0][0]']                \n",
      "                                                                                                  \n",
      "==================================================================================================\n",
      "Total params: 5,122\n",
      "Trainable params: 5,122\n",
      "Non-trainable params: 0\n",
      "__________________________________________________________________________________________________\n",
      "Model: \"steady_state_lipschitz_network\"\n",
      "__________________________________________________________________________________________________\n",
      " Layer (type)                   Output Shape         Param #     Connected to                     \n",
      "==================================================================================================\n",
      " latent_state (InputLayer)      [(None, 14)]         0           []                               \n",
      "                                                                                                  \n",
      " latent_action (InputLayer)     [(None, 3)]          0           []                               \n",
      "                                                                                                  \n",
      " next_latent_state (InputLayer)  [(None, 14)]        0           []                               \n",
      "                                                                                                  \n",
      " concatenate_22 (Concatenate)   (None, 31)           0           ['latent_state[0][0]',           \n",
      "                                                                  'latent_action[0][0]',          \n",
      "                                                                  'next_latent_state[0][0]']      \n",
      "                                                                                                  \n",
      " steady_state_network_base (Seq  (None, 256)         8192        ['concatenate_22[0][0]']         \n",
      " uential)                                                                                         \n",
      "                                                                                                  \n",
      " steady_state_lipschitz_network  (None, 1)           257         ['steady_state_network_base[0][0]\n",
      " _output (Dense)                                                 ']                               \n",
      "                                                                                                  \n",
      "==================================================================================================\n",
      "Total params: 8,449\n",
      "Trainable params: 8,449\n",
      "Non-trainable params: 0\n",
      "__________________________________________________________________________________________________\n",
      "Model: \"transition_loss_lipschitz_network\"\n",
      "__________________________________________________________________________________________________\n",
      " Layer (type)                   Output Shape         Param #     Connected to                     \n",
      "==================================================================================================\n",
      " state (InputLayer)             [(None, 8)]          0           []                               \n",
      "                                                                                                  \n",
      " action (InputLayer)            [(None, 2)]          0           []                               \n",
      "                                                                                                  \n",
      " latent_state (InputLayer)      [(None, 14)]         0           []                               \n",
      "                                                                                                  \n",
      " latent_action (InputLayer)     [(None, 3)]          0           []                               \n",
      "                                                                                                  \n",
      " next_latent_state (InputLayer)  [(None, 14)]        0           []                               \n",
      "                                                                                                  \n",
      " concatenate_23 (Concatenate)   (None, 41)           0           ['state[0][0]',                  \n",
      "                                                                  'action[0][0]',                 \n",
      "                                                                  'latent_state[0][0]',           \n",
      "                                                                  'latent_action[0][0]',          \n",
      "                                                                  'next_latent_state[0][0]']      \n",
      "                                                                                                  \n",
      " transition_loss_network_base (  (None, 256)         10752       ['concatenate_23[0][0]']         \n",
      " Sequential)                                                                                      \n",
      "                                                                                                  \n",
      " transition_loss_lipschitz_netw  (None, 1)           257         ['transition_loss_network_base[0]\n",
      " ork_output (Dense)                                              [0]']                            \n",
      "                                                                                                  \n",
      "==================================================================================================\n",
      "Total params: 11,009\n",
      "Trainable params: 11,009\n",
      "Non-trainable params: 0\n",
      "__________________________________________________________________________________________________\n",
      "WAE-MDP loaded\n"
     ]
    }
   ],
   "source": [
    "wae_model_path = 'saved_models/experiments/LunarLanderContinuous-v2/model/'\n",
    "\n",
    "with open(os.path.join(wae_model_path, 'model_infos.json'), 'r') as f:\n",
    "    wae_data = json.load(f)\n",
    "    print(wae_data)\n",
    "\n",
    "wae_mdp = wasserstein_mdp.load(wae_model_path)\n",
    "\n",
    "print(\"WAE-MDP loaded\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 64,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "WAE-MDP at training step 320000\n",
      "Size of the latent state space: 16384\n",
      "Size of the latent action space: 3\n",
      "Local reward loss: 0.0207205\n",
      "Local transition loss: 0.131357\n"
     ]
    }
   ],
   "source": [
    "print(\"WAE-MDP at training step {:d}\".format(eval(wae_data['training_step'])))\n",
    "print(\"Size of the latent state space: {:d}\".format(2 ** wae_mdp.latent_state_size))\n",
    "print(\"Size of the latent action space: {:d}\".format(wae_mdp.number_of_discrete_actions))\n",
    "print('Local reward loss: {:.6g}'.format(eval(wae_data['local_reward_loss'])))\n",
    "print('Local transition loss: {:.6g}'.format(eval(wae_data['local_transition_loss'])))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "video_path = 'policy_videos/lunar_lander_wae_distillation'\n",
    "\n",
    "def latent_labeling_fn(time_step):\n",
    "    latent_state = time_step.observation['latent_state']\n",
    "    return {\n",
    "        'safe_angle': tf.logical_not(tf.cast(latent_state[..., 0], tf.bool)).numpy(),\n",
    "        'safe_landing': tf.logical_and(\n",
    "            tf.cast(latent_state[..., 1], tf.bool),\n",
    "            tf.cast(latent_state[..., 5], tf.bool)\n",
    "        ).numpy()\n",
    "    }\n",
    "\n",
    "with suite_gym.load('LunarLanderContinuous-v2') as py_env:\n",
    "    py_env.reset()\n",
    "    tf_env = tf_py_environment.TFPyEnvironment(py_env)\n",
    "    original_state = tf_env.current_time_step().observation\n",
    "    \n",
    "    tf_env = wae_mdp.wrap_tf_environment(tf_env, labeling_functions['LunarLanderContinuous-v2'])\n",
    "    policy =tf_env.wrap_latent_policy(wae_mdp.get_latent_policy(action_dtype=tf.int64))\n",
    "    \n",
    "    num_episodes=30\n",
    "    reward_metric = tf_metrics.AverageReturnMetric()\n",
    "    discounted_reward_metric = AverageDiscountedReturnMetric(\n",
    "        gamma=.99, reward_scale=wae_mdp._dynamic_reward_scaling)\n",
    "    video_observer = video.VideoEmbeddingObserver(\n",
    "        py_env, video_path, num_episodes=num_episodes,\n",
    "        labeling_function=latent_labeling_fn, font_color='white')\n",
    "    \n",
    "    dynamic_episode_driver.DynamicEpisodeDriver(\n",
    "        tf_env, policy, num_episodes=num_episodes,\n",
    "        observers=[\n",
    "            reward_metric,\n",
    "            discounted_reward_metric,\n",
    "            video_observer,\n",
    "        ]).run()\n",
    "    \n",
    "\n",
    "tf.print('avg. episode return:', reward_metric.result())\n",
    "tf.print('avg. discounted (scaled) return:', discounted_reward_metric.result())\n",
    "embed_mp4(video_observer.file_name)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Local reward loss: 0.021\n",
      "Local transition loss: 0.13\n",
      "Transition/reward model generation\n",
      "Time to generate the model: 904.06 sec\n"
     ]
    }
   ],
   "source": [
    "frequency_estimation = False\n",
    "\n",
    "_latent_reward_fn = lambda latent_state, latent_action, next_latent_state: \\\n",
    "    wae_mdp._dynamic_reward_scaling * wae_mdp.reward_distribution(\n",
    "        latent_state=tf.cast(latent_state, dtype=tf.float32),\n",
    "        latent_action=tf.cast(latent_action, dtype=tf.float32),\n",
    "        next_latent_state=tf.cast(next_latent_state, dtype=tf.float32),\n",
    "    ).mode() \n",
    "# as the distribution is deterministic, taking the mode\n",
    "# allows to retrieve the Dirac impulsion \n",
    "\n",
    "_latent_transition_fn = lambda latent_state, latent_action: \\\n",
    "        wae_mdp.discrete_latent_transition(\n",
    "            tf.cast(latent_state, tf.float32),\n",
    "            tf.cast(latent_action, tf.float32))\n",
    "\n",
    "print('Local reward loss: {:.2g}'.format(eval(wae_data['local_reward_loss'])))\n",
    "print('Local transition loss: {:.2g}'.format(eval(wae_data['local_transition_loss'])))\n",
    "\n",
    "print('Transition/reward model generation')\n",
    "#  write the transition/reward functions to tensors,\n",
    "#  to formally check the values in an efficient way\n",
    "start = time.time()\n",
    "\n",
    "if frequency_estimation:\n",
    "    # compute the latent transition function by frequency estimation and use the\n",
    "    # latent transition function learned during the WAE optimization as backup function\n",
    "    with suite_gym.load(\n",
    "        'LunarLanderContinuous-v2',\n",
    "        env_wrappers=[lambda env: perturbed_env.PerturbedEnvironment(env, .75)]\n",
    "    ) as py_env:\n",
    "        py_env.reset()\n",
    "        tf_env = tf_py_environment.TFPyEnvironment(py_env)\n",
    "        latent_transition_fn = model.estimate_latent_transition_function_from_samples(\n",
    "            environment=tf_env,\n",
    "            n_steps=100000,\n",
    "            state_embedding_function=wae_mdp.state_embedding_function,\n",
    "            action_embedding_function=wae_mdp.action_embedding_function,\n",
    "            labeling_function=labeling_functions['LunarLanderContinuous-v2'],\n",
    "            latent_state_size=wae_mdp.latent_state_size,\n",
    "            number_of_discrete_actions=wae_mdp.number_of_discrete_actions,\n",
    "            latent_policy=wae_mdp.get_latent_policy(action_dtype=tf.int64),\n",
    "            backup_transition_fn=_latent_transition_fn)\n",
    "else:\n",
    "    with tf.device('/CPU:0'):\n",
    "        latent_transition_fn = model.TransitionFunctionCopy(\n",
    "            num_states=tf.cast(tf.pow(2, wae_mdp.latent_state_size), dtype=tf.int32),\n",
    "            num_actions=wae_mdp.number_of_discrete_actions,\n",
    "            transition_function=_latent_transition_fn,\n",
    "            epsilon=1e-6)\n",
    "with tf.device('/CPU:0'):\n",
    "    latent_reward_fn = model.RewardFunctionCopy(\n",
    "        num_states=tf.cast(tf.pow(2, wae_mdp.latent_state_size), dtype=tf.int32),\n",
    "        num_actions=wae_mdp.number_of_discrete_actions,\n",
    "        reward_function=_latent_reward_fn,\n",
    "        transition_function=_latent_transition_fn,\n",
    "        epsilon=1e-6)\n",
    "\n",
    "end = time.time() - start\n",
    "\n",
    "print(\"Time to generate the model: {:.2f} sec\".format(end))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "2022-05-11 22:18:56.386137: W tensorflow/core/framework/cpu_allocator_impl.cc:82] Allocation of 3221225472 exceeds 10% of free system memory.\n",
      "2022-05-11 22:18:56.386206: W tensorflow/core/framework/cpu_allocator_impl.cc:82] Allocation of 3221225472 exceeds 10% of free system memory.\n",
      "2022-05-11 22:18:57.350524: W tensorflow/core/framework/cpu_allocator_impl.cc:82] Allocation of 3221225472 exceeds 10% of free system memory.\n",
      "2022-05-11 22:18:58.029953: W tensorflow/core/framework/cpu_allocator_impl.cc:82] Allocation of 3221225472 exceeds 10% of free system memory.\n",
      "2022-05-11 22:18:58.705602: W tensorflow/core/framework/cpu_allocator_impl.cc:82] Allocation of 3221225472 exceeds 10% of free system memory.\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Value difference: 0.0372883\n",
      "Time to compute the value difference: 480.655583 sec\n"
     ]
    }
   ],
   "source": [
    "start = time.time()\n",
    "\n",
    "with tf.device('/CPU:0'):\n",
    "    latent_mdp_values = compute_values_from_initial_distribution(\n",
    "        latent_state_size=wae_mdp.latent_state_size,\n",
    "        atomic_prop_dims=wae_mdp.atomic_prop_dims,\n",
    "        original_state=original_state,\n",
    "        number_of_discrete_actions=wae_mdp.number_of_discrete_actions,\n",
    "        latent_policy=wae_mdp.get_latent_policy(action_dtype=tf.int64),\n",
    "        latent_transition_fn=latent_transition_fn,\n",
    "        latent_reward_function=latent_reward_fn,\n",
    "        epsilon=1e-6,\n",
    "        gamma=.99,\n",
    "        stochastic_state_embedding=lambda original_state: tfd.Independent(\n",
    "            tfd.Deterministic(loc=wae_mdp.state_embedding_function(\n",
    "                original_state,\n",
    "                ergodic_batched_labeling_function(\n",
    "                    labeling_functions['LunarLanderContinuous-v2']\n",
    "                )(original_state))),\n",
    "            reinterpreted_batch_ndims=1)\n",
    "    )\n",
    "\n",
    "value_difference = tf.abs(discounted_reward_metric.result() - latent_mdp_values)\n",
    "\n",
    "tf.print(\"Value difference: {:.6g}\".format(value_difference))\n",
    "\n",
    "end = time.time() - start\n",
    "\n",
    "print(\"Time to compute the value difference: {:2f} sec\".format(end))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Time to failure property:\n",
    "$\\neg \\mathsf{SafeLanding} \\, \\mathcal{U} \\, \\mathsf{Reset}$"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "property values: 0.0702039\n",
      "Time to compute the values of the property: 386.048782 sec\n"
     ]
    }
   ],
   "source": [
    "safe_landing_test_fn = lambda latent_state: tf.logical_and(\n",
    "        tf.cast(latent_state[..., 1], tf.bool),\n",
    "        tf.cast(latent_state[..., 5], tf.bool))\n",
    "reset_state_test_fn = lambda latent_state: is_reset_state(latent_state, wae_mdp.atomic_prop_dims)\n",
    "\n",
    "absorbing_states = lambda latent_state: tf.logical_or(\n",
    "    safe_landing_test_fn(latent_state),\n",
    "    reset_state_test_fn(latent_state))\n",
    "\n",
    "state_space = binary_latent_space(wae_mdp.latent_state_size, dtype=tf.float32)\n",
    "with tf.device('/CPU:0'):\n",
    "    # give a reward of 1 at the end of an episode (i.e., when transitioning to the reset state)\n",
    "    reward_objective = tf.ones(\n",
    "        shape=(tf.pow(2, wae_mdp.latent_state_size),\n",
    "               wae_mdp.number_of_discrete_actions,\n",
    "               tf.pow(2, wae_mdp.latent_state_size))\n",
    "    ) * tf.cast(reset_state_test_fn(state_space), tf.float32)\n",
    "\n",
    "reward_objective_fn = namedtuple('reward_fn', ['to_dense'])(lambda: reward_objective)\n",
    "\n",
    "start = time.time()\n",
    "\n",
    "with tf.device('/CPU:0'):\n",
    "    latent_mdp_values = compute_values_from_initial_distribution(\n",
    "        latent_state_size=wae_mdp.latent_state_size,\n",
    "        atomic_prop_dims=wae_mdp.atomic_prop_dims,\n",
    "        original_state=original_state,\n",
    "        number_of_discrete_actions=wae_mdp.number_of_discrete_actions,\n",
    "        latent_policy=wae_mdp.get_latent_policy(action_dtype=tf.int64),\n",
    "        latent_transition_fn=latent_transition_fn,\n",
    "        latent_reward_function=reward_objective_fn,\n",
    "        epsilon=1e-6,\n",
    "        gamma=.99,\n",
    "        stochastic_state_embedding=lambda original_state: tfd.Independent(\n",
    "            tfd.Deterministic(\n",
    "                loc=wae_mdp.state_embedding_function(\n",
    "                    original_state,\n",
    "                    ergodic_batched_labeling_function(\n",
    "                        labeling_functions['LunarLanderContinuous-v2']\n",
    "                    )(original_state))),\n",
    "            reinterpreted_batch_ndims=1),\n",
    "        absorbing_states=absorbing_states)\n",
    "\n",
    "tf.print(\"property values: {:.6g}\".format(latent_mdp_values))\n",
    "\n",
    "end = time.time() - start\n",
    "print(\"Time to compute the values of the property: {:2f} sec\".format(end))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Target behavior:\n",
    "$\\varphi = \\mathsf{true} \\, \\mathcal{U} \\, \\left[\\mathsf{SafeAngle} \\, \\mathcal{U} \\, \\mathsf{Reset} \\right] $, i.e., *reaching a safe region of the system*.\n",
    "\n",
    "Failure: never reach the safe region of the system: $\\neg \\varphi$"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "property values: 0.0155664\n",
      "Time to compute the values of the property: 322.558286 sec\n"
     ]
    }
   ],
   "source": [
    "unsafe_angle_test = lambda latent_state: tf.cast(latent_state[..., 0], tf.bool)\n",
    "reset_state_test_fn = lambda latent_state: is_reset_state(latent_state, wae_mdp.atomic_prop_dims)\n",
    "\n",
    "absorbing_states = reset_state_test_fn\n",
    "\n",
    "state_space = binary_latent_space(wae_mdp.latent_state_size, dtype=tf.float32)\n",
    "\n",
    "# retrieve a reward of 1 when the system transitions from unsafe states to the reset state \n",
    "with tf.device('/CPU:0'):\n",
    "    all_states = tf.ones(\n",
    "        shape=(tf.pow(2, wae_mdp.latent_state_size),\n",
    "               wae_mdp.number_of_discrete_actions,\n",
    "               tf.pow(2, wae_mdp.latent_state_size)))\n",
    "    # set unsafe states rows to 1 \n",
    "    reward_from_unsafe_states = tf.transpose(\n",
    "        all_states * tf.cast(\n",
    "            unsafe_angle_test(state_space),\n",
    "            tf.float32))\n",
    "    # set reset states columns to 1 \n",
    "    reward_to_reset_states = all_states * tf.cast(\n",
    "        reset_state_test_fn(state_space),\n",
    "        tf.float32)\n",
    "    reward_objective = reward_from_unsafe_states * reward_to_reset_states\n",
    "\n",
    "reward_objective_fn = namedtuple('reward_fn', ['to_dense'])(lambda: reward_objective)\n",
    "\n",
    "start = time.time()\n",
    "\n",
    "with tf.device('/CPU:0'):\n",
    "    latent_mdp_values = compute_values_from_initial_distribution(\n",
    "        latent_state_size=wae_mdp.latent_state_size,\n",
    "        atomic_prop_dims=wae_mdp.atomic_prop_dims,\n",
    "        original_state=original_state,\n",
    "        number_of_discrete_actions=wae_mdp.number_of_discrete_actions,\n",
    "        latent_policy=wae_mdp.get_latent_policy(action_dtype=tf.int64),\n",
    "        latent_transition_fn=latent_transition_fn,\n",
    "        latent_reward_function=reward_objective_fn,\n",
    "        epsilon=1e-6,\n",
    "        gamma=.99,\n",
    "        stochastic_state_embedding=lambda original_state: tfd.Independent(\n",
    "            tfd.Deterministic(\n",
    "                loc=wae_mdp.state_embedding_function(\n",
    "                    original_state,\n",
    "                    ergodic_batched_labeling_function(\n",
    "                        labeling_functions['LunarLanderContinuous-v2']\n",
    "                    )(original_state))),\n",
    "            reinterpreted_batch_ndims=1),\n",
    "        absorbing_states=absorbing_states)\n",
    "\n",
    "tf.print(\"property values: {:.6g}\".format(latent_mdp_values))\n",
    "\n",
    "end = time.time() - start\n",
    "print(\"Time to compute the values of the property: {:2f} sec\".format(end))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.0"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 1
}
