{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "cellView": "form",
    "id": "nQnmcm0oI1Q-"
   },
   "outputs": [],
   "source": [
    "#@title Licensed under the Apache License, Version 2.0 (the \"License\");\n",
    "# you may not use this file except in compliance with the License.\n",
    "# You may obtain a copy of the License at\n",
    "#\n",
    "# https://www.apache.org/licenses/LICENSE-2.0\n",
    "#\n",
    "# Unless required by applicable law or agreed to in writing, software\n",
    "# distributed under the License is distributed on an \"AS IS\" BASIS,\n",
    "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
    "# See the License for the specific language governing permissions and\n",
    "# limitations under the License."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "0vekhJpsOxLK"
   },
   "source": [
    "#SBSim: A tutorial of using Reinforcement Learning for Optimizing Energy Use and Minimizing Carbon Emission in Office Buildings\n",
    "\n",
    "___\n",
    "\n",
    "Commercial office buildings contribute 17 percent of Carbon Emissions in the US, according to the US Energy Information Administration (EIA), and improving their efficiency will reduce their environmental burden and operating cost. A major contributor of energy consumption in these buildings are the Heating, Ventilation, and Air Conditioning (HVAC) devices. HVAC devices form a complex and interconnected thermodynamic system with the building and outside weather conditions, and current setpoint control policies are not fully optimized for minimizing energy use and carbon emission. Given a suitable training environment, a Reinforcement Learning (RL) agent is able to improve upon these policies, but training such a model, especially in a way that scales to thousands of buildings, presents many practical challenges. Most existing work on applying RL to this important task either makes use of proprietary data, or focuses on expensive and proprietary simulations that may not be grounded in the real world. We present the Smart Buildings Control Suite, the first open source interactive HVAC control dataset extracted from live sensor measurements of devices in real office buildings. The dataset consists of two components: real-world historical data from two buildings, for offline RL, and a lightweight interactive simulator for each of these buildings, calibrated using the historical data, for online and model-based RL. For ease of use, our RL environments are all compatible with the OpenAI gym environment standard. We believe this benchmark will accelerate progress and collaboration on HVAC optimization.\n",
    "\n",
    "---\n",
    "\n",
    "This notebook accompanies the paper titled, **Real-World Data and Calibrated Simulation Suite for Offline Training of Reinforcement Learning Agents to Optimize Energy and Emission in Office Buildings** by Judah Goldfeder and John Sipple."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "L7w-mjPcH7u6"
   },
   "source": [
    "#Smart Buildings Simulator Twin Delayed DDPG Demo\n",
    "\n",
    "This notebook runs through training a Twin Delayed DDPG (TD3) agent on an HVAC building simulator that has been calibrated from real world data."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "cellView": "form",
    "id": "YchP7JXbSXS1"
   },
   "outputs": [],
   "source": [
    "# @title Imports\n",
    "from dataclasses import dataclass\n",
    "import datetime, pytz\n",
    "import enum\n",
    "import functools\n",
    "import os\n",
    "import time\n",
    "from typing import Final, Sequence\n",
    "from typing import Optional\n",
    "from typing import Union, cast\n",
    "os.environ['WRAPT_DISABLE_EXTENSIONS'] = 'true'\n",
    "\n",
    "from absl import logging\n",
    "import gin\n",
    "from matplotlib import patches\n",
    "import matplotlib.cm as cm\n",
    "import matplotlib.dates as mdates\n",
    "import matplotlib.pyplot as plt\n",
    "from matplotlib.ticker import MaxNLocator\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "import reverb\n",
    "import mediapy as media\n",
    "from IPython.display import clear_output\n",
    "from smart_control.environment import environment\n",
    "from smart_control.proto import smart_control_building_pb2\n",
    "from smart_control.proto import smart_control_normalization_pb2\n",
    "from smart_control.reward import electricity_energy_cost\n",
    "from smart_control.reward import natural_gas_energy_cost\n",
    "from smart_control.reward import setpoint_energy_carbon_regret\n",
    "from smart_control.reward import setpoint_energy_carbon_reward\n",
    "from smart_control.simulator import randomized_arrival_departure_occupancy\n",
    "from smart_control.simulator import rejection_simulator_building\n",
    "from smart_control.simulator import simulator_building\n",
    "from smart_control.simulator import step_function_occupancy\n",
    "from smart_control.simulator import stochastic_convection_simulator\n",
    "from smart_control.utils import bounded_action_normalizer\n",
    "from smart_control.utils import building_renderer\n",
    "from smart_control.utils import controller_reader\n",
    "from smart_control.utils import controller_writer\n",
    "from smart_control.utils import conversion_utils\n",
    "from smart_control.utils import observation_normalizer\n",
    "from smart_control.utils import reader_lib\n",
    "from smart_control.utils import writer_lib\n",
    "from smart_control.utils import histogram_reducer\n",
    "from smart_control.utils import environment_utils\n",
    "import tensorflow as tf\n",
    "from tf_agents.agents.td3 import td3_agent # TD3 import\n",
    "from tf_agents.agents.ddpg import critic_network\n",
    "from tf_agents.agents.ddpg import actor_network\n",
    "from tf_agents.drivers import py_driver\n",
    "from tf_agents.keras_layers import inner_reshape\n",
    "from tf_agents.metrics import py_metrics\n",
    "from tf_agents.networks import nest_map\n",
    "from tf_agents.networks import sequential\n",
    "from tf_agents.networks import network        # added to fix input_tensor_spec error by inheriting from networks.Network class\n",
    "from tf_agents.networks import utils as network_utils\n",
    "from tf_agents.policies import greedy_policy\n",
    "from tf_agents.policies import py_tf_eager_policy\n",
    "from tf_agents.policies import random_py_policy\n",
    "from tf_agents.policies import tf_policy\n",
    "from tf_agents.replay_buffers import reverb_replay_buffer\n",
    "from tf_agents.replay_buffers import reverb_utils\n",
    "from tf_agents.specs import tensor_spec\n",
    "from tf_agents.train import actor\n",
    "from tf_agents.train import actor\n",
    "from tf_agents.train import learner\n",
    "from tf_agents.train import triggers\n",
    "from tf_agents.train.utils import spec_utils\n",
    "from tf_agents.train.utils import train_utils\n",
    "from tf_agents.trajectories import policy_step\n",
    "from tf_agents.trajectories import time_step as ts\n",
    "from tf_agents.trajectories import trajectory as trajectory_lib\n",
    "from tf_agents.trajectories import trajectory\n",
    "from tf_agents.typing import types"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "cellView": "form",
    "id": "sDDU5FmLkYo-"
   },
   "outputs": [],
   "source": [
    "# @title Set local runtime configurations\n",
    "\n",
    "\n",
    "def logging_info(*args):\n",
    "  logging.info(*args)\n",
    "  print(*args)\n",
    "\n",
    "data_path = \"/home/ron/Projects/sbsim/smart_control/configs/resources/sb1/\" #@param {type:\"string\"}\n",
    "metrics_path = \"/home/ron/Projects/sbsim/outputs/metrics\" #@param {type:\"string\"}\n",
    "output_data_path = \"/home/ron/Projects/sbsim/smart_control/sb_colab_demo\" #@param {type:\"string\"}\n",
    "root_dir = \"/home/ron/Projects/sbsim\" #@param {type:\"string\"}\n",
    "\n",
    "\n",
    "@gin.configurable\n",
    "def get_histogram_reducer():\n",
    "\n",
    "\n",
    "    reader = controller_reader.ProtoReader(data_path)\n",
    "\n",
    "    hr = histogram_reducer.HistogramReducer(\n",
    "        histogram_parameters_tuples=histogram_parameters_tuples,\n",
    "        reader=reader,\n",
    "        normalize_reduce=True,\n",
    "        )\n",
    "    return hr\n",
    "\n",
    "!mkdir -p $root_dir\n",
    "!mkdir -p $output_data_path\n",
    "!mkdir -p $metrics_path\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "def remap_filepath(filepath) -> str:\n",
    "    return filepath\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "cellView": "form",
    "id": "JV_2oCn2uQU4"
   },
   "outputs": [],
   "source": [
    "# @title Plotting Utities\n",
    "reward_shift = 0\n",
    "reward_scale = 1.0\n",
    "person_productivity_hour = 300.0\n",
    "\n",
    "KELVIN_TO_CELSIUS = 273.15\n",
    "\n",
    "\n",
    "def render_env(env: environment.Environment):\n",
    "  \"\"\"Renders the environment.\"\"\"\n",
    "  building_layout = env.building._simulator._building._floor_plan\n",
    "\n",
    "  # create a renderer\n",
    "  renderer = building_renderer.BuildingRenderer(building_layout, 1)\n",
    "\n",
    "  # get the current temps to render\n",
    "  # this also is not ideal, since the temps are not fully exposed.\n",
    "  # V Ideally this should be a publicly accessable field\n",
    "  temps = env.building._simulator._building.temp\n",
    "\n",
    "  input_q = env.building._simulator._building.input_q\n",
    "\n",
    "  # render\n",
    "  vmin = 285\n",
    "  vmax = 305\n",
    "  image = renderer.render(\n",
    "      temps,\n",
    "      cmap='bwr',\n",
    "      vmin=vmin,\n",
    "      vmax=vmax,\n",
    "      colorbar=False,\n",
    "      input_q=input_q,\n",
    "      diff_range=0.5,\n",
    "      diff_size=1,\n",
    "  ).convert('RGB')\n",
    "  media.show_image(\n",
    "      image, title='Environment %s' % env.current_simulation_timestamp\n",
    "  )\n",
    "\n",
    "\n",
    "def get_energy_timeseries(reward_infos, time_zone: str) -> pd.DataFrame:\n",
    "  \"\"\"Returns a timeseries of energy rates.\"\"\"\n",
    "\n",
    "  start_times = []\n",
    "  end_times = []\n",
    "\n",
    "  device_ids = []\n",
    "  device_types = []\n",
    "  air_handler_blower_electrical_energy_rates = []\n",
    "  air_handler_air_conditioner_energy_rates = []\n",
    "  boiler_natural_gas_heating_energy_rates = []\n",
    "  boiler_pump_electrical_energy_rates = []\n",
    "\n",
    "  for reward_info in reward_infos:\n",
    "    end_timestamp = conversion_utils.proto_to_pandas_timestamp(\n",
    "        reward_info.end_timestamp\n",
    "    ).tz_convert(time_zone)\n",
    "    start_timestamp = end_timestamp - pd.Timedelta(300, unit='second')\n",
    "\n",
    "    for air_handler_id in reward_info.air_handler_reward_infos:\n",
    "      start_times.append(start_timestamp)\n",
    "      end_times.append(end_timestamp)\n",
    "\n",
    "      device_ids.append(air_handler_id)\n",
    "      device_types.append('air_handler')\n",
    "\n",
    "      air_handler_blower_electrical_energy_rates.append(\n",
    "          reward_info.air_handler_reward_infos[\n",
    "              air_handler_id\n",
    "          ].blower_electrical_energy_rate\n",
    "      )\n",
    "      air_handler_air_conditioner_energy_rates.append(\n",
    "          reward_info.air_handler_reward_infos[\n",
    "              air_handler_id\n",
    "          ].air_conditioning_electrical_energy_rate\n",
    "      )\n",
    "      boiler_natural_gas_heating_energy_rates.append(0)\n",
    "      boiler_pump_electrical_energy_rates.append(0)\n",
    "\n",
    "    for boiler_id in reward_info.boiler_reward_infos:\n",
    "      start_times.append(start_timestamp)\n",
    "      end_times.append(end_timestamp)\n",
    "\n",
    "      device_ids.append(boiler_id)\n",
    "      device_types.append('boiler')\n",
    "\n",
    "      air_handler_blower_electrical_energy_rates.append(0)\n",
    "      air_handler_air_conditioner_energy_rates.append(0)\n",
    "\n",
    "      boiler_natural_gas_heating_energy_rates.append(\n",
    "          reward_info.boiler_reward_infos[\n",
    "              boiler_id\n",
    "          ].natural_gas_heating_energy_rate\n",
    "      )\n",
    "      boiler_pump_electrical_energy_rates.append(\n",
    "          reward_info.boiler_reward_infos[boiler_id].pump_electrical_energy_rate\n",
    "      )\n",
    "\n",
    "  df_map = {\n",
    "      'start_time': start_times,\n",
    "      'end_time': end_times,\n",
    "      'device_id': device_ids,\n",
    "      'device_type': device_types,\n",
    "      'air_handler_blower_electrical_energy_rate': (\n",
    "          air_handler_blower_electrical_energy_rates\n",
    "      ),\n",
    "      'air_handler_air_conditioner_energy_rate': (\n",
    "          air_handler_air_conditioner_energy_rates\n",
    "      ),\n",
    "      'boiler_natural_gas_heating_energy_rate': (\n",
    "          boiler_natural_gas_heating_energy_rates\n",
    "      ),\n",
    "      'boiler_pump_electrical_energy_rate': boiler_pump_electrical_energy_rates,\n",
    "  }\n",
    "  df = pd.DataFrame(df_map).sort_values('start_time')\n",
    "  return df\n",
    "\n",
    "\n",
    "def get_outside_air_temperature_timeseries(\n",
    "    observation_responses,\n",
    "    time_zone: str,\n",
    ") -> pd.Series:\n",
    "  \"\"\"Returns a timeseries of outside air temperature.\"\"\"\n",
    "  temps = []\n",
    "  for i in range(len(observation_responses)):\n",
    "    temp = [\n",
    "        (\n",
    "            conversion_utils.proto_to_pandas_timestamp(\n",
    "                sor.timestamp\n",
    "            ).tz_convert(time_zone)\n",
    "            - pd.Timedelta(300, unit='second'),\n",
    "            sor.continuous_value,\n",
    "        )\n",
    "        for sor in observation_responses[i].single_observation_responses\n",
    "        if sor.single_observation_request.measurement_name\n",
    "        == 'outside_air_temperature_sensor'\n",
    "    ][0]\n",
    "    temps.append(temp)\n",
    "\n",
    "  res = list(zip(*temps))\n",
    "  return pd.Series(res[1], index=res[0]).sort_index()\n",
    "\n",
    "\n",
    "def get_reward_timeseries(\n",
    "    reward_infos,\n",
    "    reward_responses,\n",
    "    time_zone: str,\n",
    ") -> pd.DataFrame:\n",
    "  \"\"\"Returns a timeseries of reward values.\"\"\"\n",
    "  cols = [\n",
    "      'agent_reward_value',\n",
    "      'electricity_energy_cost',\n",
    "      'carbon_emitted',\n",
    "      'occupancy',\n",
    "  ]\n",
    "  df = pd.DataFrame(columns=cols)\n",
    "\n",
    "  for i in range(min(len(reward_responses), len(reward_infos))):\n",
    "    step_start_timestamp = conversion_utils.proto_to_pandas_timestamp(\n",
    "        reward_infos[i].start_timestamp\n",
    "    ).tz_convert(time_zone)\n",
    "    step_end_timestamp = conversion_utils.proto_to_pandas_timestamp(\n",
    "        reward_infos[i].end_timestamp\n",
    "    ).tz_convert(time_zone)\n",
    "    delta_time_sec = (step_end_timestamp - step_start_timestamp).total_seconds()\n",
    "    occupancy = np.sum([\n",
    "        reward_infos[i].zone_reward_infos[zone_id].average_occupancy\n",
    "        for zone_id in reward_infos[i].zone_reward_infos\n",
    "    ])\n",
    "\n",
    "    df.loc[\n",
    "        conversion_utils.proto_to_pandas_timestamp(\n",
    "            reward_infos[i].start_timestamp\n",
    "        ).tz_convert(time_zone)\n",
    "    ] = [\n",
    "        reward_responses[i].agent_reward_value,\n",
    "        reward_responses[i].electricity_energy_cost,\n",
    "        reward_responses[i].carbon_emitted,\n",
    "        occupancy,\n",
    "    ]\n",
    "\n",
    "  df = df.sort_index()\n",
    "  df['cumulative_reward'] = df['agent_reward_value'].cumsum()\n",
    "  logging_info('Cumulative reward: %4.2f' % df.iloc[-1]['cumulative_reward'])\n",
    "  return df\n",
    "\n",
    "\n",
    "def format_plot(\n",
    "    ax1, xlabel: str, start_time: int, end_time: int, time_zone: str\n",
    "):\n",
    "  \"\"\"Formats a plot with common attributes.\"\"\"\n",
    "  ax1.set_facecolor('black')\n",
    "  ax1.xaxis.tick_top()\n",
    "  ax1.tick_params(axis='x', labelsize=12)\n",
    "  ax1.tick_params(axis='y', labelsize=12)\n",
    "  ax1.xaxis.set_major_formatter(\n",
    "      mdates.DateFormatter('%a %m/%d %H:%M', tz=pytz.timezone(time_zone))\n",
    "  )\n",
    "  ax1.grid(color='gray', linestyle='-', linewidth=1.0)\n",
    "  ax1.set_ylabel(xlabel, color='blue', fontsize=12)\n",
    "  ax1.set_xlim(left=start_time, right=end_time)\n",
    "  ax1.yaxis.set_major_locator(MaxNLocator(integer=True))\n",
    "  ax1.legend(prop={'size': 10})\n",
    "\n",
    "\n",
    "def plot_occupancy_timeline(\n",
    "    ax1, reward_timeseries: pd.DataFrame, time_zone: str\n",
    "):\n",
    "  local_times = [ts.tz_convert(time_zone) for ts in reward_timeseries.index]\n",
    "  ax1.plot(\n",
    "      local_times,\n",
    "      reward_timeseries['occupancy'],\n",
    "      color='cyan',\n",
    "      marker=None,\n",
    "      alpha=1,\n",
    "      lw=2,\n",
    "      linestyle='-',\n",
    "      label='Num Occupants',\n",
    "  )\n",
    "  format_plot(\n",
    "      ax1,\n",
    "      'Occupancy',\n",
    "      reward_timeseries.index.min(),\n",
    "      reward_timeseries.index.max(),\n",
    "      time_zone,\n",
    "  )\n",
    "\n",
    "\n",
    "def plot_energy_cost_timeline(\n",
    "    ax1,\n",
    "    reward_timeseries: pd.DataFrame,\n",
    "    time_zone: str,\n",
    "    cumulative: bool = False,\n",
    "):\n",
    "  local_times = [ts.tz_convert(time_zone) for ts in reward_timeseries.index]\n",
    "  if cumulative:\n",
    "    feature_timeseries_cost = reward_timeseries[\n",
    "        'electricity_energy_cost'\n",
    "    ].cumsum()\n",
    "  else:\n",
    "    feature_timeseries_cost = reward_timeseries['electricity_energy_cost']\n",
    "  ax1.plot(\n",
    "      local_times,\n",
    "      feature_timeseries_cost,\n",
    "      color='magenta',\n",
    "      marker=None,\n",
    "      alpha=1,\n",
    "      lw=2,\n",
    "      linestyle='-',\n",
    "      label='Electricity',\n",
    "  )\n",
    "\n",
    "  format_plot(\n",
    "      ax1,\n",
    "      'Energy Cost [$]',\n",
    "      reward_timeseries.index.min(),\n",
    "      reward_timeseries.index.max(),\n",
    "      time_zone,\n",
    "  )\n",
    "\n",
    "\n",
    "def plot_reward_timeline(ax1, reward_timeseries, time_zone):\n",
    "\n",
    "  local_times = [ts.tz_convert(time_zone) for ts in reward_timeseries.index]\n",
    "\n",
    "  ax1.plot(\n",
    "      local_times,\n",
    "      reward_timeseries['cumulative_reward'],\n",
    "      color='royalblue',\n",
    "      marker=None,\n",
    "      alpha=1,\n",
    "      lw=6,\n",
    "      linestyle='-',\n",
    "      label='reward',\n",
    "  )\n",
    "  format_plot(\n",
    "      ax1,\n",
    "      'Agent Reward',\n",
    "      reward_timeseries.index.min(),\n",
    "      reward_timeseries.index.max(),\n",
    "      time_zone,\n",
    "  )\n",
    "\n",
    "\n",
    "def plot_energy_timeline(ax1, energy_timeseries, time_zone, cumulative=False):\n",
    "\n",
    "  def _to_kwh(\n",
    "      energy_rate: float,\n",
    "      step_interval: pd.Timedelta = pd.Timedelta(5, unit='minute'),\n",
    "  ) -> float:\n",
    "    kw_power = energy_rate / 1000.0\n",
    "    hwh_power = kw_power * step_interval / pd.Timedelta(1, unit='hour')\n",
    "    return hwh_power.cumsum()\n",
    "\n",
    "  timeseries = energy_timeseries[\n",
    "      energy_timeseries['device_type'] == 'air_handler'\n",
    "  ]\n",
    "\n",
    "  if cumulative:\n",
    "    feature_timeseries_ac = _to_kwh(\n",
    "        timeseries['air_handler_air_conditioner_energy_rate']\n",
    "    )\n",
    "    feature_timeseries_blower = _to_kwh(\n",
    "        timeseries['air_handler_blower_electrical_energy_rate']\n",
    "    )\n",
    "  else:\n",
    "    feature_timeseries_ac = (\n",
    "        timeseries['air_handler_air_conditioner_energy_rate'] / 1000.0\n",
    "    )\n",
    "    feature_timeseries_blower = (\n",
    "        timeseries['air_handler_blower_electrical_energy_rate'] / 1000.0\n",
    "    )\n",
    "\n",
    "  ax1.plot(\n",
    "      timeseries['start_time'],\n",
    "      feature_timeseries_ac,\n",
    "      color='magenta',\n",
    "      marker=None,\n",
    "      alpha=1,\n",
    "      lw=4,\n",
    "      linestyle='-',\n",
    "      label='AHU Electricity',\n",
    "  )\n",
    "  ax1.plot(\n",
    "      timeseries['start_time'],\n",
    "      feature_timeseries_blower,\n",
    "      color='magenta',\n",
    "      marker=None,\n",
    "      alpha=1,\n",
    "      lw=4,\n",
    "      linestyle='--',\n",
    "      label='FAN Electricity',\n",
    "  )\n",
    "\n",
    "  timeseries = energy_timeseries[energy_timeseries['device_type'] == 'boiler']\n",
    "  if cumulative:\n",
    "    feature_timeseries_gas = _to_kwh(\n",
    "        timeseries['boiler_natural_gas_heating_energy_rate']\n",
    "    )\n",
    "    feature_timeseries_pump = _to_kwh(\n",
    "        timeseries['boiler_pump_electrical_energy_rate']\n",
    "    )\n",
    "  else:\n",
    "    feature_timeseries_gas = (\n",
    "        timeseries['boiler_natural_gas_heating_energy_rate'] / 1000.0\n",
    "    )\n",
    "    feature_timeseries_pump = (\n",
    "        timeseries['boiler_pump_electrical_energy_rate'] / 1000.0\n",
    "    )\n",
    "\n",
    "  ax1.plot(\n",
    "      timeseries['start_time'],\n",
    "      feature_timeseries_gas,\n",
    "      color='lime',\n",
    "      marker=None,\n",
    "      alpha=1,\n",
    "      lw=4,\n",
    "      linestyle='-',\n",
    "      label='BLR Gas',\n",
    "  )\n",
    "  ax1.plot(\n",
    "      timeseries['start_time'],\n",
    "      feature_timeseries_pump,\n",
    "      color='lime',\n",
    "      marker=None,\n",
    "      alpha=1,\n",
    "      lw=4,\n",
    "      linestyle='--',\n",
    "      label='Pump Electricity',\n",
    "  )\n",
    "\n",
    "  if cumulative:\n",
    "    label = 'HVAC Energy Consumption [kWh]'\n",
    "  else:\n",
    "    label = 'HVAC Power Consumption [kW]'\n",
    "\n",
    "  format_plot(\n",
    "      ax1,\n",
    "      label,\n",
    "      timeseries['start_time'].min(),\n",
    "      timeseries['end_time'].max(),\n",
    "      time_zone,\n",
    "  )\n",
    "\n",
    "\n",
    "def plot_carbon_timeline(ax1, reward_timeseries, time_zone, cumulative=False):\n",
    "  \"\"\"Plots carbon-emission timeline.\"\"\"\n",
    "\n",
    "  if cumulative:\n",
    "    feature_timeseries_carbon = reward_timeseries['carbon_emitted'].cumsum()\n",
    "  else:\n",
    "    feature_timeseries_carbon = reward_timeseries['carbon_emitted']\n",
    "  ax1.plot(\n",
    "      reward_timeseries.index,\n",
    "      feature_timeseries_carbon,\n",
    "      color='white',\n",
    "      marker=None,\n",
    "      alpha=1,\n",
    "      lw=4,\n",
    "      linestyle='-',\n",
    "      label='Carbon',\n",
    "  )\n",
    "  format_plot(\n",
    "      ax1,\n",
    "      'Carbon emission [kg]',\n",
    "      reward_timeseries.index.min(),\n",
    "      reward_timeseries.index.max(),\n",
    "      time_zone,\n",
    "  )\n",
    "\n",
    "\n",
    "def get_zone_timeseries(reward_infos, time_zone):\n",
    "  \"\"\"Converts reward infos to a timeseries dataframe.\"\"\"\n",
    "\n",
    "  start_times = []\n",
    "  end_times = []\n",
    "  zones = []\n",
    "  heating_setpoints = []\n",
    "  cooling_setpoints = []\n",
    "  zone_air_temperatures = []\n",
    "  air_flow_rate_setpoints = []\n",
    "  air_flow_rates = []\n",
    "  average_occupancies = []\n",
    "\n",
    "  for reward_info in reward_infos:\n",
    "    start_timestamp = conversion_utils.proto_to_pandas_timestamp(\n",
    "        reward_info.end_timestamp\n",
    "    ).tz_convert(time_zone) - pd.Timedelta(300, unit='second')\n",
    "    end_timestamp = conversion_utils.proto_to_pandas_timestamp(\n",
    "        reward_info.end_timestamp\n",
    "    ).tz_convert(time_zone)\n",
    "\n",
    "    for zone_id in reward_info.zone_reward_infos:\n",
    "      zones.append(zone_id)\n",
    "      start_times.append(start_timestamp)\n",
    "      end_times.append(end_timestamp)\n",
    "\n",
    "      heating_setpoints.append(\n",
    "          reward_info.zone_reward_infos[zone_id].heating_setpoint_temperature\n",
    "      )\n",
    "      cooling_setpoints.append(\n",
    "          reward_info.zone_reward_infos[zone_id].cooling_setpoint_temperature\n",
    "      )\n",
    "\n",
    "      zone_air_temperatures.append(\n",
    "          reward_info.zone_reward_infos[zone_id].zone_air_temperature\n",
    "      )\n",
    "      air_flow_rate_setpoints.append(\n",
    "          reward_info.zone_reward_infos[zone_id].air_flow_rate_setpoint\n",
    "      )\n",
    "      air_flow_rates.append(\n",
    "          reward_info.zone_reward_infos[zone_id].air_flow_rate\n",
    "      )\n",
    "      average_occupancies.append(\n",
    "          reward_info.zone_reward_infos[zone_id].average_occupancy\n",
    "      )\n",
    "\n",
    "  df_map = {\n",
    "      'start_time': start_times,\n",
    "      'end_time': end_times,\n",
    "      'zone': zones,\n",
    "      'heating_setpoint_temperature': heating_setpoints,\n",
    "      'cooling_setpoint_temperature': cooling_setpoints,\n",
    "      'zone_air_temperature': zone_air_temperatures,\n",
    "      'air_flow_rate_setpoint': air_flow_rate_setpoints,\n",
    "      'air_flow_rate': air_flow_rates,\n",
    "      'average_occupancy': average_occupancies,\n",
    "  }\n",
    "  return pd.DataFrame(df_map).sort_values('start_time')\n",
    "\n",
    "\n",
    "def get_action_timeseries(action_responses):\n",
    "  \"\"\"Converts action responses to a dataframe.\"\"\"\n",
    "  timestamps = []\n",
    "  device_ids = []\n",
    "  setpoint_names = []\n",
    "  setpoint_values = []\n",
    "  response_types = []\n",
    "  for action_response in action_responses:\n",
    "\n",
    "    timestamp = conversion_utils.proto_to_pandas_timestamp(\n",
    "        action_response.timestamp\n",
    "    )\n",
    "    for single_action_response in action_response.single_action_responses:\n",
    "      device_id = single_action_response.request.device_id\n",
    "      setpoint_name = single_action_response.request.setpoint_name\n",
    "      setpoint_value = single_action_response.request.continuous_value\n",
    "      response_type = single_action_response.response_type\n",
    "\n",
    "      timestamps.append(timestamp)\n",
    "      device_ids.append(device_id)\n",
    "      setpoint_names.append(setpoint_name)\n",
    "      setpoint_values.append(setpoint_value)\n",
    "      response_types.append(response_type)\n",
    "\n",
    "  return pd.DataFrame({\n",
    "      'timestamp': timestamps,\n",
    "      'device_id': device_ids,\n",
    "      'setpoint_name': setpoint_names,\n",
    "      'setpoint_value': setpoint_values,\n",
    "      'response_type': response_types,\n",
    "  })\n",
    "\n",
    "\n",
    "def plot_action_timeline(ax1, action_timeseries, action_tuple, time_zone):\n",
    "  \"\"\"Plots action timeline.\"\"\"\n",
    "\n",
    "  single_action_timeseries = action_timeseries[\n",
    "      (action_timeseries['device_id'] == action_tuple[0])\n",
    "      & (action_timeseries['setpoint_name'] == action_tuple[1])\n",
    "  ]\n",
    "  single_action_timeseries = single_action_timeseries.sort_values(\n",
    "      by='timestamp'\n",
    "  )\n",
    "\n",
    "  if action_tuple[1] in [\n",
    "      'supply_water_setpoint',\n",
    "      'supply_air_heating_temperature_setpoint',\n",
    "  ]:\n",
    "    single_action_timeseries['setpoint_value'] = (\n",
    "        single_action_timeseries['setpoint_value'] - KELVIN_TO_CELSIUS\n",
    "    )\n",
    "\n",
    "  ax1.plot(\n",
    "      single_action_timeseries['timestamp'],\n",
    "      single_action_timeseries['setpoint_value'],\n",
    "      color='lime',\n",
    "      marker=None,\n",
    "      alpha=1,\n",
    "      lw=4,\n",
    "      linestyle='-',\n",
    "      label=action_tuple[1],\n",
    "  )\n",
    "  title = '%s %s' % (action_tuple[0], action_tuple[1])\n",
    "  format_plot(\n",
    "      ax1,\n",
    "      'Action',\n",
    "      single_action_timeseries['timestamp'].min(),\n",
    "      single_action_timeseries['timestamp'].max(),\n",
    "      time_zone,\n",
    "  )\n",
    "\n",
    "\n",
    "def get_outside_air_temperature_timeseries(observation_responses, time_zone):\n",
    "  temps = []\n",
    "  for i in range(len(observation_responses)):\n",
    "    temp = [\n",
    "        (\n",
    "            conversion_utils.proto_to_pandas_timestamp(\n",
    "                sor.timestamp\n",
    "            ).tz_convert(time_zone),\n",
    "            sor.continuous_value,\n",
    "        )\n",
    "        for sor in observation_responses[i].single_observation_responses\n",
    "        if sor.single_observation_request.measurement_name\n",
    "        == 'outside_air_temperature_sensor'\n",
    "    ][0]\n",
    "    temps.append(temp)\n",
    "\n",
    "  res = list(zip(*temps))\n",
    "  return pd.Series(res[1], index=res[0]).sort_index()\n",
    "\n",
    "\n",
    "def plot_temperature_timeline(\n",
    "    ax1, zone_timeseries, outside_air_temperature_timeseries, time_zone\n",
    "):\n",
    "  zone_temps = pd.pivot_table(\n",
    "      zone_timeseries,\n",
    "      index=zone_timeseries['start_time'],\n",
    "      columns='zone',\n",
    "      values='zone_air_temperature',\n",
    "  ).sort_index()\n",
    "  zone_temps.quantile(q=0.25, axis=1)\n",
    "  zone_temp_stats = pd.DataFrame({\n",
    "      'min_temp': zone_temps.min(axis=1),\n",
    "      'q25_temp': zone_temps.quantile(q=0.25, axis=1),\n",
    "      'median_temp': zone_temps.median(axis=1),\n",
    "      'q75_temp': zone_temps.quantile(q=0.75, axis=1),\n",
    "      'max_temp': zone_temps.max(axis=1),\n",
    "  })\n",
    "\n",
    "  zone_heating_setpoints = (\n",
    "      pd.pivot_table(\n",
    "          zone_timeseries,\n",
    "          index=zone_timeseries['start_time'],\n",
    "          columns='zone',\n",
    "          values='heating_setpoint_temperature',\n",
    "      )\n",
    "      .sort_index()\n",
    "      .min(axis=1)\n",
    "  )\n",
    "  zone_cooling_setpoints = (\n",
    "      pd.pivot_table(\n",
    "          zone_timeseries,\n",
    "          index=zone_timeseries['start_time'],\n",
    "          columns='zone',\n",
    "          values='cooling_setpoint_temperature',\n",
    "      )\n",
    "      .sort_index()\n",
    "      .max(axis=1)\n",
    "  )\n",
    "\n",
    "  ax1.plot(\n",
    "      zone_cooling_setpoints.index,\n",
    "      zone_cooling_setpoints - KELVIN_TO_CELSIUS,\n",
    "      color='yellow',\n",
    "      lw=1,\n",
    "  )\n",
    "  ax1.plot(\n",
    "      zone_cooling_setpoints.index,\n",
    "      zone_heating_setpoints - KELVIN_TO_CELSIUS,\n",
    "      color='yellow',\n",
    "      lw=1,\n",
    "  )\n",
    "\n",
    "  ax1.fill_between(\n",
    "      zone_temp_stats.index,\n",
    "      zone_temp_stats['min_temp'] - KELVIN_TO_CELSIUS,\n",
    "      zone_temp_stats['max_temp'] - KELVIN_TO_CELSIUS,\n",
    "      facecolor='green',\n",
    "      alpha=0.8,\n",
    "  )\n",
    "  ax1.fill_between(\n",
    "      zone_temp_stats.index,\n",
    "      zone_temp_stats['q25_temp'] - KELVIN_TO_CELSIUS,\n",
    "      zone_temp_stats['q75_temp'] - KELVIN_TO_CELSIUS,\n",
    "      facecolor='green',\n",
    "      alpha=0.8,\n",
    "  )\n",
    "  ax1.plot(\n",
    "      zone_temp_stats.index,\n",
    "      zone_temp_stats['median_temp'] - KELVIN_TO_CELSIUS,\n",
    "      color='white',\n",
    "      lw=3,\n",
    "      alpha=1.0,\n",
    "  )\n",
    "  ax1.plot(\n",
    "      outside_air_temperature_timeseries.index,\n",
    "      outside_air_temperature_timeseries - KELVIN_TO_CELSIUS,\n",
    "      color='magenta',\n",
    "      lw=3,\n",
    "      alpha=1.0,\n",
    "  )\n",
    "  format_plot(\n",
    "      ax1,\n",
    "      'Temperature [C]',\n",
    "      zone_temp_stats.index.min(),\n",
    "      zone_temp_stats.index.max(),\n",
    "      time_zone,\n",
    "  )\n",
    "\n",
    "\n",
    "def plot_timeseries_charts(reader, time_zone):\n",
    "  \"\"\"Plots timeseries charts.\"\"\"\n",
    "\n",
    "  observation_responses = reader.read_observation_responses(\n",
    "      pd.Timestamp.min, pd.Timestamp.max\n",
    "  )\n",
    "  action_responses = reader.read_action_responses(\n",
    "      pd.Timestamp.min, pd.Timestamp.max\n",
    "  )\n",
    "  reward_infos = reader.read_reward_infos(pd.Timestamp.min, pd.Timestamp.max)\n",
    "  reward_responses = reader.read_reward_responses(\n",
    "      pd.Timestamp.min, pd.Timestamp.max\n",
    "  )\n",
    "\n",
    "  if len(reward_infos) == 0 or len(reward_responses) == 0:\n",
    "    return\n",
    "\n",
    "  action_timeseries = get_action_timeseries(action_responses)\n",
    "  action_tuples = list(\n",
    "      set([\n",
    "          (row['device_id'], row['setpoint_name'])\n",
    "          for _, row in action_timeseries.iterrows()\n",
    "      ])\n",
    "  )\n",
    "\n",
    "  reward_timeseries = get_reward_timeseries(\n",
    "      reward_infos, reward_responses, time_zone\n",
    "  ).sort_index()\n",
    "  outside_air_temperature_timeseries = get_outside_air_temperature_timeseries(\n",
    "      observation_responses, time_zone\n",
    "  )\n",
    "  zone_timeseries = get_zone_timeseries(reward_infos, time_zone)\n",
    "  fig, axes = plt.subplots(\n",
    "      nrows=6 + len(action_tuples),\n",
    "      ncols=1,\n",
    "      gridspec_kw={\n",
    "          'height_ratios': [1, 1, 1, 1, 1, 1] + [1] * len(action_tuples)\n",
    "      },\n",
    "      squeeze=True,\n",
    "  )\n",
    "  fig.set_size_inches(24, 25)\n",
    "\n",
    "  energy_timeseries = get_energy_timeseries(reward_infos, time_zone)\n",
    "  plot_reward_timeline(axes[0], reward_timeseries, time_zone)\n",
    "  plot_energy_timeline(axes[1], energy_timeseries, time_zone, cumulative=True)\n",
    "  plot_energy_cost_timeline(\n",
    "      axes[2], reward_timeseries, time_zone, cumulative=True\n",
    "  )\n",
    "  plot_carbon_timeline(axes[3], reward_timeseries, time_zone, cumulative=True)\n",
    "  plot_occupancy_timeline(axes[4], reward_timeseries, time_zone)\n",
    "  plot_temperature_timeline(\n",
    "      axes[5], zone_timeseries, outside_air_temperature_timeseries, time_zone\n",
    "  )\n",
    "\n",
    "  for i, action_tuple in enumerate(action_tuples):\n",
    "    plot_action_timeline(\n",
    "        axes[6 + i], action_timeseries, action_tuple, time_zone\n",
    "    )\n",
    "\n",
    "  plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "kTtVb9wbRsKU"
   },
   "source": [
    "# Load up the environment\n",
    "\n",
    "In this section we load up the Smart Buildings simulator environment."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "cellView": "form",
    "id": "2fcYS1VBrvia"
   },
   "outputs": [],
   "source": [
    "# @title Utils for importing the environment.\n",
    "\n",
    "def load_environment(gin_config_file: str):\n",
    "  \"\"\"Returns an Environment from a config file.\"\"\"\n",
    "  # Global definition is required by Gin library to instantiate Environment.\n",
    "  global environment  # pylint: disable=global-variable-not-assigned\n",
    "  with gin.unlock_config():\n",
    "    gin.parse_config_file(gin_config_file)\n",
    "    return environment.Environment()  # pylint: disable=no-value-for-parameter\n",
    "\n",
    "\n",
    "def get_latest_episode_reader(\n",
    "    metrics_path: str,\n",
    ") -> controller_reader.ProtoReader:\n",
    "\n",
    "  episode_infos = controller_reader.get_episode_data(metrics_path).sort_index()\n",
    "  selected_episode = episode_infos.index[-1]\n",
    "  episode_path = os.path.join(metrics_path, selected_episode)\n",
    "  reader = controller_reader.ProtoReader(episode_path)\n",
    "  return reader\n",
    "\n",
    "@gin.configurable\n",
    "def get_histogram_path():\n",
    "  return data_path\n",
    "\n",
    "\n",
    "@gin.configurable\n",
    "def get_reset_temp_values():\n",
    "  reset_temps_filepath = remap_filepath(\n",
    "      os.path.join(data_path, \"reset_temps.npy\")\n",
    "  )\n",
    "\n",
    "  return np.load(reset_temps_filepath)\n",
    "\n",
    "\n",
    "@gin.configurable\n",
    "def get_zone_path():\n",
    "  return remap_filepath(\n",
    "      os.path.join(data_path, \"double_resolution_zone_1_2.npy\")\n",
    "  )\n",
    "\n",
    "\n",
    "@gin.configurable\n",
    "def get_metrics_path():\n",
    "  return os.path.join(metrics_path, \"metrics\")\n",
    "\n",
    "\n",
    "@gin.configurable\n",
    "def get_weather_path():\n",
    "  return remap_filepath(\n",
    "      os.path.join(\n",
    "          data_path, \"local_weather_moffett_field_20230701_20231122.csv\"\n",
    "      )\n",
    "  )"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "10THzl_rSgFW"
   },
   "source": [
    "In the cell below, we will load the collect and eval environments. While we are loading the same environment, below, it would be useful to load the same building over near, but non-overlapping time windows."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "cellView": "form",
    "id": "XFeGO2TLRS1o"
   },
   "outputs": [],
   "source": [
    "# @title Load the environments\n",
    "\n",
    "histogram_parameters_tuples = (\n",
    "        ('zone_air_temperature_sensor',(285., 286., 287., 288, 289., 290., 291., 292., 293., 294., 295., 296., 297., 298., 299., 300.,301,302,303)),\n",
    "        ('supply_air_damper_percentage_command',(0.0, 0.2, 0.4, 0.6, 0.8, 1.0)),\n",
    "        ('supply_air_flowrate_setpoint',( 0., 0.05, .1, .2, .3, .4, .5,  .7,  .9)),\n",
    "    )\n",
    "\n",
    "time_zone = 'US/Pacific'\n",
    "collect_scenario_config = os.path.join(data_path, \"sim_config.gin\")\n",
    "print(collect_scenario_config)\n",
    "eval_scenario_config = os.path.join(data_path, \"sim_config.gin\")\n",
    "print(eval_scenario_config)\n",
    "\n",
    "collect_env = load_environment(collect_scenario_config)\n",
    "\n",
    "# For efficency, set metrics_path to None\n",
    "collect_env._metrics_path = None\n",
    "collect_env._occupancy_normalization_constant = 125.0\n",
    "\n",
    "eval_env = load_environment(eval_scenario_config)\n",
    "# eval_env._label += \"_eval\"\n",
    "eval_env._metrics_path = metrics_path\n",
    "eval_env._occupancy_normalization_constant = 125.0\n",
    "\n",
    "initial_collect_env = load_environment(eval_scenario_config)\n",
    "\n",
    "initial_collect_env._metrics_path = metrics_path\n",
    "initial_collect_env._occupancy_normalization_constant = 125.0"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "c55CehnYR8lY"
   },
   "source": [
    "In the sectioni below, we'll define a function that accepts the envirnment and a policy, and runs a fixed number of episodes. The policy can be a rules-based policy or an RL-based policy."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "cellView": "form",
    "id": "bitzHo5_UbXy"
   },
   "outputs": [],
   "source": [
    "# @title Define a method to execute the policy on the environment.\n",
    "\n",
    "\n",
    "def get_trajectory(time_step, current_action: policy_step.PolicyStep):\n",
    "  \"\"\"Get the trajectory for the current action and time step.\"\"\"\n",
    "  observation = time_step.observation\n",
    "  action = current_action.action\n",
    "  policy_info = ()\n",
    "  reward = time_step.reward\n",
    "  discount = time_step.discount\n",
    "\n",
    "  if time_step.is_first():\n",
    "    traj = trajectory.first(observation, action, policy_info, reward, discount)\n",
    "\n",
    "  elif time_step.is_last():\n",
    "    traj = trajectory.last(observation, action, policy_info, reward, discount)\n",
    "\n",
    "  else:\n",
    "    traj = trajectory.mid(observation, action, policy_info, reward, discount)\n",
    "  return traj\n",
    "\n",
    "\n",
    "def compute_avg_return(\n",
    "    environment,\n",
    "    policy,\n",
    "    num_episodes=1,\n",
    "    time_zone: str = \"US/Pacific\",\n",
    "    render_interval_steps: int = 24,\n",
    "    trajectory_observers=None,\n",
    "):\n",
    "  \"\"\"Computes the average return of the policy on the environment.\n",
    "\n",
    "  Args:\n",
    "    environment: environment.Environment\n",
    "    policy: policy.Policy\n",
    "    num_episodes: total number of eposides to run.\n",
    "    time_zone: time zone of the environment\n",
    "    render_interval_steps: Number of steps to take between rendering.\n",
    "    trajectory_observers: list of trajectory observers for use in rendering.\n",
    "  \"\"\"\n",
    "\n",
    "  total_return = 0.0\n",
    "  for _ in range(num_episodes):\n",
    "\n",
    "    time_step = environment.reset()\n",
    "\n",
    "    episode_return = 0.0\n",
    "    t0 = time.time()\n",
    "    epoch = t0\n",
    "\n",
    "    step_id = 0\n",
    "    execution_times = []\n",
    "\n",
    "    while not time_step.is_last():\n",
    "\n",
    "      action_step = policy.action(time_step)\n",
    "      time_step = environment.step(action_step.action)\n",
    "\n",
    "      if trajectory_observers is not None:\n",
    "        traj = get_trajectory(time_step, action_step)\n",
    "        for observer in trajectory_observers:\n",
    "          observer(traj)\n",
    "\n",
    "      episode_return += time_step.reward\n",
    "      t1 = time.time()\n",
    "      dt = t1 - t0\n",
    "      episode_seconds = t1 - epoch\n",
    "      execution_times.append(dt)\n",
    "      sim_time = environment.current_simulation_timestamp.tz_convert(time_zone)\n",
    "\n",
    "      print(\n",
    "          \"Step %5d Sim Time: %s, Reward: %8.2f, Return: %8.2f, Mean Step Time:\"\n",
    "          \" %8.2f s, Episode Time: %8.2f s\"\n",
    "          % (\n",
    "              step_id,\n",
    "              sim_time.strftime(\"%Y-%m-%d %H:%M\"),\n",
    "              time_step.reward,\n",
    "              episode_return,\n",
    "              np.mean(execution_times),\n",
    "              episode_seconds,\n",
    "          )\n",
    "      )\n",
    "\n",
    "      if (step_id > 0) and (step_id % render_interval_steps == 0):\n",
    "        if environment._metrics_path:\n",
    "          clear_output(wait=True)\n",
    "          reader = get_latest_episode_reader(environment._metrics_path)\n",
    "          plot_timeseries_charts(reader, time_zone)\n",
    "        render_env(environment)\n",
    "\n",
    "      t0 = t1\n",
    "      step_id += 1\n",
    "    total_return += episode_return\n",
    "\n",
    "  avg_return = total_return / num_episodes\n",
    "  return avg_return"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "86IIF7FrfJ_2"
   },
   "source": [
    "# Rules-based Control (RBC)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "cellView": "form",
    "id": "X9JR8qze6Yvb"
   },
   "outputs": [],
   "source": [
    "# @title Utils for RBC\n",
    "\n",
    "# We're concerned with controlling Heatpumps/ACs and Hot Water Systems (HWS).\n",
    "class DeviceType(enum.Enum):\n",
    "  AC = 0\n",
    "  HWS = 1\n",
    "\n",
    "\n",
    "SetpointName = str  # Identify the setpoint\n",
    "# Setpoint value.\n",
    "SetpointValue = Union[float, int, bool]\n",
    "\n",
    "\n",
    "@dataclass\n",
    "class ScheduleEvent:\n",
    "  start_time: pd.Timedelta\n",
    "  device: DeviceType\n",
    "  setpoint_name: SetpointName\n",
    "  setpoint_value: SetpointValue\n",
    "\n",
    "\n",
    "# A schedule is a list of times and setpoints for a device.\n",
    "Schedule = list[ScheduleEvent]\n",
    "ActionSequence = list[tuple[DeviceType, SetpointName]]\n",
    "\n",
    "\n",
    "def to_rad(sin_theta: float, cos_theta: float) -> float:\n",
    "  \"\"\"Converts a sin and cos theta to radians to extract the time.\"\"\"\n",
    "\n",
    "  if sin_theta >= 0 and cos_theta >= 0:\n",
    "    return np.arccos(cos_theta)\n",
    "  elif sin_theta >= 0 and cos_theta < 0:\n",
    "    return np.pi - np.arcsin(sin_theta)\n",
    "  elif sin_theta < 0 and cos_theta < 0:\n",
    "    return np.pi - np.arcsin(sin_theta)\n",
    "  else:\n",
    "    return 2 * np.pi - np.arccos(cos_theta)\n",
    "\n",
    "  return np.arccos(cos_theta) + rad_offset\n",
    "\n",
    "\n",
    "def to_dow(sin_theta: float, cos_theta: float) -> float:\n",
    "  \"\"\"Converts a sin and cos theta to days to extract day of week.\"\"\"\n",
    "  theta = to_rad(sin_theta, cos_theta)\n",
    "  return np.floor(7 * theta / 2 / np.pi)\n",
    "\n",
    "\n",
    "def to_hod(sin_theta: float, cos_theta: float) -> float:\n",
    "  \"\"\"Converts a sin and cos theta to hours to extract hour of day.\"\"\"\n",
    "  theta = to_rad(sin_theta, cos_theta)\n",
    "  return np.floor(24 * theta / 2 / np.pi)\n",
    "\n",
    "\n",
    "def find_schedule_action(\n",
    "    schedule: Schedule,\n",
    "    device: DeviceType,\n",
    "    setpoint_name: SetpointName,\n",
    "    timestamp: pd.Timedelta,\n",
    ") -> SetpointValue:\n",
    "  \"\"\"Finds the action for a schedule event for a time and schedule.\"\"\"\n",
    "\n",
    "  # Get all the schedule events for the device and the setpoint, and turn it\n",
    "  # into a series.\n",
    "  device_schedule_dict = {}\n",
    "  for schedule_event in schedule:\n",
    "    if (\n",
    "        schedule_event.device == device\n",
    "        and schedule_event.setpoint_name == setpoint_name\n",
    "    ):\n",
    "      device_schedule_dict[schedule_event.start_time] = (\n",
    "          schedule_event.setpoint_value\n",
    "      )\n",
    "  device_schedule = pd.Series(device_schedule_dict)\n",
    "\n",
    "  # Get the indexes of the schedule events that fall before the timestamp.\n",
    "\n",
    "  device_schedule_indexes = device_schedule.index[\n",
    "      device_schedule.index <= timestamp\n",
    "  ]\n",
    "\n",
    "  # If are no events preceedding the time, then choose the last\n",
    "  # (assuming it wraps around).\n",
    "  if device_schedule_indexes.empty:\n",
    "    return device_schedule.loc[device_schedule.index[-1]]\n",
    "  else:\n",
    "    return device_schedule.loc[device_schedule_indexes[-1]]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "cellView": "form",
    "id": "QZON8t8w2KF5"
   },
   "outputs": [],
   "source": [
    "# @title Define a schedule policy\n",
    "\n",
    "class SchedulePolicy(tf_policy.TFPolicy):\n",
    "  \"\"\"TF Policy implementation of the Schedule policy.\"\"\"\n",
    "\n",
    "  def __init__(\n",
    "      self,\n",
    "      time_step_spec,\n",
    "      action_spec: types.NestedTensorSpec,\n",
    "      action_sequence: ActionSequence,\n",
    "      weekday_schedule_events: Schedule,\n",
    "      weekend_holiday_schedule_events: Schedule,\n",
    "      dow_sin_index: int,\n",
    "      dow_cos_index: int,\n",
    "      hod_sin_index: int,\n",
    "      hod_cos_index: int,\n",
    "      action_normalizers,\n",
    "      local_start_time: str = pd.Timestamp,\n",
    "      policy_state_spec: types.NestedTensorSpec = (),\n",
    "      info_spec: types.NestedTensorSpec = (),\n",
    "      training: bool = False,\n",
    "      name: Optional[str] = None,\n",
    "  ):\n",
    "    self.weekday_schedule_events = weekday_schedule_events\n",
    "    self.weekend_holiday_schedule_events = weekend_holiday_schedule_events\n",
    "    self.dow_sin_index = dow_sin_index\n",
    "    self.dow_cos_index = dow_cos_index\n",
    "    self.hod_sin_index = hod_sin_index\n",
    "    self.hod_cos_index = hod_cos_index\n",
    "    self.action_sequence = action_sequence\n",
    "    self.action_normalizers = action_normalizers\n",
    "    self.local_start_time = local_start_time\n",
    "    self.norm_mean = 0.0\n",
    "    self.norm_std = 1.0\n",
    "\n",
    "    policy_state_spec = ()\n",
    "\n",
    "    super().__init__(\n",
    "        time_step_spec=time_step_spec,\n",
    "        action_spec=action_spec,\n",
    "        policy_state_spec=policy_state_spec,\n",
    "        info_spec=info_spec,\n",
    "        clip=False,\n",
    "        observation_and_action_constraint_splitter=None,\n",
    "        name=name,\n",
    "    )\n",
    "\n",
    "  def _normalize_action_map(\n",
    "      self, action_map: dict[tuple[DeviceType, SetpointName], SetpointValue]\n",
    "  ) -> dict[tuple[DeviceType, SetpointName], SetpointValue]:\n",
    "\n",
    "    normalized_action_map = {}\n",
    "\n",
    "    for k, v in action_map.items():\n",
    "      for normalizer_k, normalizer in self.action_normalizers.items():\n",
    "        if normalizer_k.endswith(k[1]):\n",
    "\n",
    "          normed_v = normalizer.agent_value(v)\n",
    "          normalized_action_map[k] = normed_v\n",
    "\n",
    "    return normalized_action_map\n",
    "\n",
    "  def _get_action(\n",
    "      self, time_step\n",
    "  ) -> dict[tuple[DeviceType, SetpointName], SetpointValue]:\n",
    "\n",
    "    observation = time_step.observation\n",
    "    action_spec = cast(tensor_spec.BoundedTensorSpec, self.action_spec)\n",
    "    dow_sin = (observation[self.dow_sin_index] * self.norm_std) + self.norm_mean\n",
    "    dow_cos = (observation[self.dow_cos_index] * self.norm_std) + self.norm_mean\n",
    "    hod_sin = (observation[self.hod_sin_index] * self.norm_std) + self.norm_mean\n",
    "    hod_cos = (observation[self.hod_cos_index] * self.norm_std) + self.norm_mean\n",
    "\n",
    "    dow = to_dow(dow_sin, dow_cos)\n",
    "    hod = to_hod(hod_sin, hod_cos)\n",
    "\n",
    "    timestamp = (\n",
    "        pd.Timedelta(hod, unit='hour') + self.local_start_time.utcoffset()\n",
    "    )\n",
    "\n",
    "    if dow < 5:  # weekday\n",
    "\n",
    "      action_map = {\n",
    "          (tup[0], tup[1]): find_schedule_action(\n",
    "              self.weekday_schedule_events, tup[0], tup[1], timestamp\n",
    "          )\n",
    "          for tup in action_sequence\n",
    "      }\n",
    "\n",
    "      return action_map\n",
    "\n",
    "    else:  # Weekend\n",
    "\n",
    "      action_map = {\n",
    "          (tup[0], tup[1]): find_schedule_action(\n",
    "              self.weekend_holiday_schedule_events, tup[0], tup[1], timestamp\n",
    "          )\n",
    "          for tup in action_sequence\n",
    "      }\n",
    "\n",
    "      return action_map\n",
    "\n",
    "  def _action(self, time_step, policy_state, seed):\n",
    "    del seed\n",
    "    action_map = self._get_action(time_step)\n",
    "    normalized_action_map = self._normalize_action_map(action_map)\n",
    "\n",
    "    action = np.array(\n",
    "        [\n",
    "            normalized_action_map[device_setpoint]\n",
    "            for device_setpoint in action_sequence\n",
    "        ],\n",
    "        dtype=np.float32,\n",
    "    )\n",
    "\n",
    "    t_action = tf.convert_to_tensor(action)\n",
    "    return policy_step.PolicyStep(t_action, (), ())"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "UkQs64KT6qs-"
   },
   "source": [
    "Next, we parameterize the setpoint schedule.\n",
    "\n",
    "We distinguish between weekend and holiday schedules:\n",
    "\n",
    "* For **weekdays, between 6:00 am and 7:00 pm local time** we maintain occupancy conditions:\n",
    "  * AC/Heatpump supply air heating setpoint is 12 C\n",
    "  * Supply water temperarure is 77 C\n",
    "* For **weekday, before 6:00 am and after 7:00 pm locl time** we maintain efficiency conditions (setback):\n",
    "  * AC/Heatpump supply air heating setpoint is 0 C\n",
    "  * Supply water temperarure is 42 C\n",
    "\n",
    "* For **weekends and holdidays**, all day, we maintain efficiency conditions (setback):\n",
    "  * AC/Heatpump supply air heating setpoint is 0 C\n",
    "  * Supply water temperarure is 42 C\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "cellView": "form",
    "id": "SpveeGWUf5AK"
   },
   "outputs": [],
   "source": [
    "# @title Configure the schedule parameters\n",
    "\n",
    "hod_cos_index = collect_env._field_names.index('hod_cos_000')\n",
    "hod_sin_index = collect_env._field_names.index('hod_sin_000')\n",
    "dow_cos_index = collect_env._field_names.index('dow_cos_000')\n",
    "dow_sin_index = collect_env._field_names.index('dow_sin_000')\n",
    "\n",
    "\n",
    "# Note that temperatures are specified in Kelvin:\n",
    "weekday_schedule_events = [\n",
    "    ScheduleEvent(\n",
    "        pd.Timedelta(6, unit='hour'),\n",
    "        DeviceType.AC,\n",
    "        'supply_air_heating_temperature_setpoint',\n",
    "        292.0,\n",
    "    ),\n",
    "    ScheduleEvent(\n",
    "        pd.Timedelta(19, unit='hour'),\n",
    "        DeviceType.AC,\n",
    "        'supply_air_heating_temperature_setpoint',\n",
    "        285.0,\n",
    "    ),\n",
    "    ScheduleEvent(\n",
    "        pd.Timedelta(6, unit='hour'),\n",
    "        DeviceType.HWS,\n",
    "        'supply_water_setpoint',\n",
    "        350.0,\n",
    "    ),\n",
    "    ScheduleEvent(\n",
    "        pd.Timedelta(19, unit='hour'),\n",
    "        DeviceType.HWS,\n",
    "        'supply_water_setpoint',\n",
    "        315.0,\n",
    "    ),\n",
    "]\n",
    "\n",
    "\n",
    "weekend_holiday_schedule_events = [\n",
    "    ScheduleEvent(\n",
    "        pd.Timedelta(6, unit='hour'),\n",
    "        DeviceType.AC,\n",
    "        'supply_air_heating_temperature_setpoint',\n",
    "        285.0,\n",
    "    ),\n",
    "    ScheduleEvent(\n",
    "        pd.Timedelta(19, unit='hour'),\n",
    "        DeviceType.AC,\n",
    "        'supply_air_heating_temperature_setpoint',\n",
    "        285.0,\n",
    "    ),\n",
    "    ScheduleEvent(\n",
    "        pd.Timedelta(6, unit='hour'),\n",
    "        DeviceType.HWS,\n",
    "        'supply_water_setpoint',\n",
    "        315.0,\n",
    "    ),\n",
    "    ScheduleEvent(\n",
    "        pd.Timedelta(19, unit='hour'),\n",
    "        DeviceType.HWS,\n",
    "        'supply_water_setpoint',\n",
    "        315.0,\n",
    "    ),\n",
    "]\n",
    "\n",
    "action_sequence = [\n",
    "    (DeviceType.HWS, 'supply_water_setpoint'),\n",
    "    (DeviceType.AC, 'supply_air_heating_temperature_setpoint'),\n",
    "]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "xOTP9p8-0N0H"
   },
   "source": [
    "We instantiate the schedule policy below."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "cellView": "form",
    "id": "jv-1GBFTieNc"
   },
   "outputs": [],
   "source": [
    "# @title Instantiate the Schedule RBC policy\n",
    "ts = collect_env.reset()\n",
    "local_start_time = collect_env.current_simulation_timestamp.tz_convert(tz = 'US/Pacific')\n",
    "\n",
    "action_normalizers = collect_env._action_normalizers\n",
    "\n",
    "observation_spec, action_spec, time_step_spec = spec_utils.get_tensor_specs(collect_env)\n",
    "schedule_policy = SchedulePolicy(\n",
    "    time_step_spec= time_step_spec,\n",
    "    action_spec= action_spec,\n",
    "    action_sequence = action_sequence,\n",
    "    weekday_schedule_events = weekday_schedule_events,\n",
    "    weekend_holiday_schedule_events = weekend_holiday_schedule_events,\n",
    "    dow_sin_index=dow_sin_index,\n",
    "    dow_cos_index=dow_cos_index,\n",
    "    hod_sin_index=hod_sin_index,\n",
    "    hod_cos_index=hod_cos_index,\n",
    "    local_start_time=local_start_time,\n",
    "    action_normalizers=action_normalizers,\n",
    "\n",
    ")\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "pAYOf5Xtzi2u"
   },
   "source": [
    "Next, we will run the static control setpoints on the environment to establish baseline performance.\n",
    "\n",
    "**Note:** This will take some time to execute. Feel free to skip this step if you want to jump directly to the RL section below."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "cellView": "form",
    "id": "3Zv-lSiWDp50"
   },
   "outputs": [],
   "source": [
    "# @title Optionally, execute the schedule policy on the environment\n",
    "# Optional\n",
    "compute_avg_return(eval_env, schedule_policy, 1, time_zone=\"US/Pacific\", render_interval_steps=144, trajectory_observers=None)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "SDgizVLzRti1"
   },
   "source": [
    "# Reinforcement Learning Control\n",
    "In the previous section we used a simple schedule to control the HVAC setpoints, however in this section, we configure and train a Reinforcement Learning (RL) agent.\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "cellView": "form",
    "id": "zBjFBpkabFHR"
   },
   "outputs": [],
   "source": [
    "# @title Utilities to configure networks for the RL Agent.\n",
    "dense = functools.partial(\n",
    "    tf.keras.layers.Dense,\n",
    "    activation=tf.keras.activations.relu,\n",
    "    kernel_initializer='glorot_uniform',\n",
    ")\n",
    "\n",
    "\n",
    "def logging_info(*args):\n",
    "  logging.info(*args)\n",
    "  print(*args)\n",
    "\n",
    "\n",
    "def create_fc_network(layer_units):\n",
    "  return sequential.Sequential([dense(num_units) for num_units in layer_units])\n",
    "\n",
    "\n",
    "def create_identity_layer():\n",
    "  return tf.keras.layers.Lambda(lambda x: x)\n",
    "\n",
    "\n",
    "def create_sequential_critic_network(\n",
    "    obs_fc_layer_units, action_fc_layer_units, joint_fc_layer_units\n",
    "):\n",
    "  \"\"\"Create a sequential critic network.\"\"\"\n",
    "\n",
    "  # Split the inputs into observations and actions.\n",
    "  def split_inputs(inputs):\n",
    "    return {'observation': inputs[0], 'action': inputs[1]}\n",
    "\n",
    "  # Create an observation network.\n",
    "  obs_network = (\n",
    "      create_fc_network(obs_fc_layer_units)\n",
    "      if obs_fc_layer_units\n",
    "      else create_identity_layer()\n",
    "  )\n",
    "\n",
    "  # Create an action network.\n",
    "  action_network = (\n",
    "      create_fc_network(action_fc_layer_units)\n",
    "      if action_fc_layer_units\n",
    "      else create_identity_layer()\n",
    "  )\n",
    "\n",
    "  # Create a joint network.\n",
    "  joint_network = (\n",
    "      create_fc_network(joint_fc_layer_units)\n",
    "      if joint_fc_layer_units\n",
    "      else create_identity_layer()\n",
    "  )\n",
    "\n",
    "  # Final layer.\n",
    "  value_layer = tf.keras.layers.Dense(1, kernel_initializer='glorot_uniform')\n",
    "\n",
    "  return sequential.Sequential(\n",
    "      [\n",
    "          tf.keras.layers.Lambda(split_inputs),\n",
    "          nest_map.NestMap(\n",
    "              {'observation': obs_network, 'action': action_network}\n",
    "          ),\n",
    "          nest_map.NestFlatten(),\n",
    "          tf.keras.layers.Concatenate(),\n",
    "          joint_network,\n",
    "          value_layer,\n",
    "          inner_reshape.InnerReshape(current_shape=[1], new_shape=[]),\n",
    "      ],\n",
    "      name='sequential_critic',\n",
    "  )\n",
    "\n",
    "# Define the actor network\n",
    "class CustomActorNetwork(network.Network):\n",
    "    def __init__(self, input_tensor_spec, output_tensor_spec, fc_layer_params, name='CustomActorNetwork'):\n",
    "        super(CustomActorNetwork, self).__init__(\n",
    "            input_tensor_spec=input_tensor_spec,\n",
    "            state_spec=(),\n",
    "            name=name\n",
    "        )\n",
    "        self._output_tensor_spec = output_tensor_spec\n",
    "        \n",
    "        # Define the layers\n",
    "        self._layers = []\n",
    "        for num_units in fc_layer_params:\n",
    "            self._layers.append(\n",
    "                tf.keras.layers.Dense(\n",
    "                    num_units,\n",
    "                    activation=tf.keras.activations.relu,\n",
    "                    kernel_initializer='glorot_uniform'\n",
    "                )\n",
    "            )\n",
    "        \n",
    "        # Output layer\n",
    "        self._layers.append(\n",
    "            tf.keras.layers.Dense(\n",
    "                output_tensor_spec.shape.num_elements(),\n",
    "                activation=tf.keras.activations.tanh,\n",
    "                kernel_initializer='glorot_uniform'\n",
    "            )\n",
    "        )\n",
    "\n",
    "    def call(self, observations, step_type=None, network_state=(), training=False):\n",
    "        del step_type  # Unused.\n",
    "        observations = tf.cast(observations, tf.float32)\n",
    "        batch_squash = network_utils.BatchSquash(1)\n",
    "        observations = batch_squash.flatten(observations)\n",
    "        \n",
    "        output = observations\n",
    "        for layer in self._layers:\n",
    "            output = layer(output, training=training)\n",
    "        \n",
    "        output = batch_squash.unflatten(output)\n",
    "        \n",
    "        # Scale the output actions\n",
    "        action_means = (self._output_tensor_spec.maximum + self._output_tensor_spec.minimum) / 2.0\n",
    "        action_magnitudes = (self._output_tensor_spec.maximum - self._output_tensor_spec.minimum) / 2.0\n",
    "        output = action_means + action_magnitudes * output\n",
    "        \n",
    "        return output, network_state\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "9g6pE6v2bb8O"
   },
   "source": [
    "Set the configuration parameters for the TD3 Agent"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "cellView": "form",
    "id": "CeVkerwYcng2"
   },
   "outputs": [],
   "source": [
    "# @title Set the RL Agent's parameters\n",
    "\n",
    "# Actor network fully connected layers.\n",
    "actor_fc_layers = (128, 128)\n",
    "# Critic network observation fully connected layers.\n",
    "critic_obs_fc_layers = (128, 64)\n",
    "# Critic network action fully connected layers.\n",
    "critic_action_fc_layers = (128, 64)\n",
    "# Critic network joint fully connected layers.\n",
    "critic_joint_fc_layers = (128, 64)\n",
    "\n",
    "batch_size = 256\n",
    "actor_learning_rate = 3e-4\n",
    "critic_learning_rate = 3e-4\n",
    "alpha_learning_rate = 3e-4\n",
    "gamma = 0.99\n",
    "target_update_tau= 0.005\n",
    "target_update_period= 2 # do TD3 delayed updates\n",
    "reward_scale_factor = 1.0\n",
    "\n",
    "# TD3 specific params\n",
    "exploration_noise_std = 0.1\n",
    "target_policy_noise = 0.2\n",
    "target_policy_noise_clip = 0.5\n",
    "\n",
    "# Replay params\n",
    "replay_capacity = 1000000\n",
    "debug_summaries = True\n",
    "summarize_grads_and_vars = True\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "EhTPXjtebMZD"
   },
   "source": [
    "##Initialize the TD3 agent\n",
    "\n",
    "We use the [TD3](https://arxiv.org/abs/1802.09477) Reinforcement Learning algorithm to learn a building controller.\n",
    "\n",
    "This notebook illustrates the building control environment using the TD3 implementation in [TF-Agents](https://www.tensorflow.org/agents)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "cellView": "form",
    "id": "NW0pzLvjbSnP"
   },
   "outputs": [],
   "source": [
    "# @title Construct the TD3 agent\n",
    "\n",
    "\n",
    "_, action_tensor_spec, time_step_tensor_spec = spec_utils.get_tensor_specs(\n",
    "    collect_env\n",
    ")\n",
    "\n",
    "# Instantiate custom actor network\n",
    "actor_net = CustomActorNetwork(\n",
    "    input_tensor_spec=time_step_tensor_spec.observation,\n",
    "    output_tensor_spec=action_tensor_spec,\n",
    "    fc_layer_params=actor_fc_layers,\n",
    "    name='CustomActorNetwork'\n",
    ")\n",
    "\n",
    "# Define the critic network using tf_agents.networks.ActorNetwork\n",
    "critic_net = critic_network.CriticNetwork(\n",
    "    input_tensor_spec=(time_step_tensor_spec.observation, action_tensor_spec),\n",
    "    observation_fc_layer_params=critic_obs_fc_layers,\n",
    "    action_fc_layer_params=critic_action_fc_layers,\n",
    "    joint_fc_layer_params=critic_joint_fc_layers,\n",
    "    activation_fn=tf.keras.activations.relu,\n",
    "    kernel_initializer='glorot_uniform',\n",
    "    last_kernel_initializer='glorot_uniform',\n",
    "    name='CriticNetwork'\n",
    ")\n",
    "\n",
    "train_step = train_utils.create_train_step()\n",
    "agent = td3_agent.Td3Agent(\n",
    "    time_step_spec=time_step_tensor_spec,\n",
    "    action_spec=action_tensor_spec,\n",
    "    actor_network=actor_net,\n",
    "    critic_network=critic_net,\n",
    "    actor_optimizer=tf.keras.optimizers.Adam(learning_rate=actor_learning_rate),\n",
    "    critic_optimizer=tf.keras.optimizers.Adam(learning_rate=critic_learning_rate),\n",
    "    exploration_noise_std=exploration_noise_std,\n",
    "    target_update_tau=target_update_tau,\n",
    "    target_update_period=target_update_period,\n",
    "    target_policy_noise=target_policy_noise,\n",
    "    target_policy_noise_clip=target_policy_noise_clip,\n",
    "    gamma=gamma,\n",
    "    reward_scale_factor=reward_scale_factor,\n",
    "    gradient_clipping=None,\n",
    "    debug_summaries=debug_summaries,\n",
    "    summarize_grads_and_vars=summarize_grads_and_vars,\n",
    "    train_step_counter=train_step,\n",
    ")\n",
    "agent.initialize()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "J5hNdgZBG5BZ"
   },
   "source": [
    "Below we construct a replay buffer using reverb. The replay buffer is popualted with state-action-reward-state tuples during collect. Thie allows the agent to relive past experiences, and prevents the model from overfitting in the local neighborhood.\n",
    "\n",
    "During traning, the agent samples from the replay buffer. This helps decorrelate the traiing data in a way that randomization of a training set would in supervised learning. Otherwise, in most environments the experience in a window of time is highly correlated."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "cellView": "form",
    "id": "vX2zGUWJGWAl"
   },
   "outputs": [],
   "source": [
    "# @title Set up the replay buffer\n",
    "replay_capacity = 50000\n",
    "table_name = 'uniform_table'\n",
    "table = reverb.Table(\n",
    "    table_name,\n",
    "    max_size=replay_capacity,\n",
    "    sampler=reverb.selectors.Uniform(),\n",
    "    remover=reverb.selectors.Fifo(),\n",
    "    rate_limiter=reverb.rate_limiters.MinSize(1),\n",
    ")\n",
    "\n",
    "reverb_checkpoint_dir = output_data_path + \"/reverb_checkpoint\"\n",
    "\n",
    "reverb_port = None\n",
    "print('reverb_checkpoint_dir=%s' %reverb_checkpoint_dir)\n",
    "\n",
    "reverb_checkpointer = reverb.platform.checkpointers_lib.DefaultCheckpointer(\n",
    "    path=reverb_checkpoint_dir\n",
    ")\n",
    "reverb_server = reverb.Server(\n",
    "    [table], port=reverb_port, checkpointer=reverb_checkpointer\n",
    ")\n",
    "logging_info('reverb_server_port=%d' %reverb_server.port)\n",
    "reverb_replay = reverb_replay_buffer.ReverbReplayBuffer(\n",
    "    agent.collect_data_spec,\n",
    "    sequence_length=2,\n",
    "    table_name=table_name,\n",
    "    local_server=reverb_server,\n",
    ")\n",
    "rb_observer = reverb_utils.ReverbAddTrajectoryObserver(\n",
    "    reverb_replay.py_client, table_name, sequence_length=2, stride_length=1\n",
    ")\n",
    "print('num_frames in replay buffer=%d' %reverb_replay.num_frames())"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "SH7LQZ_Pd0vY"
   },
   "source": [
    "For simplicity, we'll grab eval and collect policies and give them short variable names."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "cellView": "form",
    "id": "BwY7StuMkuV4"
   },
   "outputs": [],
   "source": [
    "# @title Access the eval and collect policies\n",
    "eval_policy = agent.policy\n",
    "collect_policy = agent.collect_policy"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "6klSPQeGsPLz"
   },
   "source": [
    "In the next section we define observer classes that enable printing model and environment output as the scenario evolves to who you the percentage of the episode, the timestamp in the scenario, cumulative reward, and the execution time.\n",
    "\n",
    "We also provide a plot observer that periodically outputs the performance charts and the temperature gradient across both floors of the buidling."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "cellView": "form",
    "id": "dJ_EMQkZdw8q"
   },
   "outputs": [],
   "source": [
    "# @title Define Observers\n",
    "class RenderAndPlotObserver:\n",
    "  \"\"\"Renders and plots the environment.\"\"\"\n",
    "\n",
    "  def __init__(\n",
    "      self,\n",
    "      render_interval_steps: int = 10,\n",
    "      environment=None,\n",
    "  ):\n",
    "    self._counter = 0\n",
    "    self._render_interval_steps = render_interval_steps\n",
    "    self._environment = environment\n",
    "    self._cumulative_reward = 0.0\n",
    "\n",
    "    self._start_time = None\n",
    "    if self._environment is not None:\n",
    "      self._num_timesteps_in_episode = (\n",
    "          self._environment._num_timesteps_in_episode\n",
    "      )\n",
    "      self._environment._end_timestamp\n",
    "\n",
    "  def __call__(self, trajectory: trajectory_lib.Trajectory) -> None:\n",
    "\n",
    "    reward = trajectory.reward\n",
    "    self._cumulative_reward += reward\n",
    "    self._counter += 1\n",
    "    if self._start_time is None:\n",
    "      self._start_time = pd.Timestamp.now()\n",
    "\n",
    "    if self._counter % self._render_interval_steps == 0 and self._environment:\n",
    "\n",
    "      execution_time = pd.Timestamp.now() - self._start_time\n",
    "      mean_execution_time = execution_time.total_seconds() / self._counter\n",
    "\n",
    "      clear_output(wait=True)\n",
    "      if self._environment._metrics_path is not None:\n",
    "        reader = get_latest_episode_reader(self._environment._metrics_path)\n",
    "        plot_timeseries_charts(reader, time_zone)\n",
    "\n",
    "      render_env(self._environment)\n",
    "\n",
    "\n",
    "class PrintStatusObserver:\n",
    "  \"\"\"Prints status information.\"\"\"\n",
    "\n",
    "  def __init__(\n",
    "      self, status_interval_steps: int = 1, environment=None, replay_buffer=None\n",
    "  ):\n",
    "    self._counter = 0\n",
    "    self._status_interval_steps = status_interval_steps\n",
    "    self._environment = environment\n",
    "    self._cumulative_reward = 0.0\n",
    "    self._replay_buffer = replay_buffer\n",
    "\n",
    "    self._start_time = None\n",
    "    if self._environment is not None:\n",
    "      self._num_timesteps_in_episode = (\n",
    "          self._environment._num_timesteps_in_episode\n",
    "      )\n",
    "      self._environment._end_timestamp\n",
    "\n",
    "  def __call__(self, trajectory: trajectory_lib.Trajectory) -> None:\n",
    "\n",
    "    reward = trajectory.reward\n",
    "    self._cumulative_reward += reward\n",
    "    self._counter += 1\n",
    "    if self._start_time is None:\n",
    "      self._start_time = pd.Timestamp.now()\n",
    "\n",
    "    if self._counter % self._status_interval_steps == 0 and self._environment:\n",
    "\n",
    "      execution_time = pd.Timestamp.now() - self._start_time\n",
    "      mean_execution_time = execution_time.total_seconds() / self._counter\n",
    "\n",
    "      sim_time = self._environment.current_simulation_timestamp.tz_convert(\n",
    "          time_zone\n",
    "      )\n",
    "      percent_complete = int(\n",
    "          100.0 * (self._counter / self._num_timesteps_in_episode)\n",
    "      )\n",
    "\n",
    "      if self._replay_buffer is not None:\n",
    "        rb_size = self._replay_buffer.num_frames()\n",
    "        rb_string = \" Replay Buffer Size: %d\" % rb_size\n",
    "      else:\n",
    "        rb_string = \"\"\n",
    "\n",
    "      print(\n",
    "          \"Step %5d of %5d (%3d%%) Sim Time: %s Reward: %2.2f Cumulative\"\n",
    "          \" Reward: %8.2f Execution Time: %s Mean Execution Time: %3.2fs %s\"\n",
    "          % (\n",
    "              self._environment._step_count,\n",
    "              self._num_timesteps_in_episode,\n",
    "              percent_complete,\n",
    "              sim_time.strftime(\"%Y-%m-%d %H:%M\"),\n",
    "              reward,\n",
    "              self._cumulative_reward,\n",
    "              execution_time,\n",
    "              mean_execution_time,\n",
    "              rb_string,\n",
    "          )\n",
    "      )\n",
    "\n",
    "\n",
    "initial_collect_render_plot_observer = RenderAndPlotObserver(\n",
    "    render_interval_steps=144, environment=initial_collect_env\n",
    ")\n",
    "initial_collect_print_status_observer = PrintStatusObserver(\n",
    "    status_interval_steps=1,\n",
    "    environment=initial_collect_env,\n",
    "    replay_buffer=reverb_replay,\n",
    ")\n",
    "collect_render_plot_observer = RenderAndPlotObserver(\n",
    "    render_interval_steps=144, environment=collect_env\n",
    ")\n",
    "collect_print_status_observer = PrintStatusObserver(\n",
    "    status_interval_steps=1,\n",
    "    environment=collect_env,\n",
    "    replay_buffer=reverb_replay,\n",
    ")\n",
    "eval_render_plot_observer = RenderAndPlotObserver(\n",
    "    render_interval_steps=144, environment=eval_env\n",
    ")\n",
    "eval_print_status_observer = PrintStatusObserver(\n",
    "    status_interval_steps=1, environment=eval_env, replay_buffer=reverb_replay\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "el732oZItQjO"
   },
   "source": [
    "In the following cell, we shall run the baseline control on the scenario to populate the replay buffer. We will use the schedule policy we build above to simulate training off-policy from recorded telemetry."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "cellView": "form",
    "id": "ZGq3SY0kKwsa"
   },
   "outputs": [],
   "source": [
    "# @title Populate the replay buffer with data from baseline control\n",
    "initial_collect_actor = actor.Actor(\n",
    "  initial_collect_env,\n",
    "  schedule_policy,\n",
    "  train_step,\n",
    "  steps_per_run=initial_collect_env._num_timesteps_in_episode,\n",
    "  observers=[rb_observer, initial_collect_print_status_observer, initial_collect_render_plot_observer])\n",
    "initial_collect_actor.run()\n",
    "reverb_replay.py_client.checkpoint()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "y3ZzWxqIunCz"
   },
   "source": [
    "Next wrap the replay buffer into a TF dataset."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "cellView": "form",
    "id": "ba7bilizt_qW"
   },
   "outputs": [],
   "source": [
    "# @title Make a TF Dataset\n",
    "# Dataset generates trajectories with shape [Bx2x...]\n",
    "dataset = reverb_replay.as_dataset(\n",
    "    num_parallel_calls=3,\n",
    "    sample_batch_size=batch_size,\n",
    "    num_steps=2).prefetch(50)\n",
    "\n",
    "dataset"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "-YqfMl5FuQpf"
   },
   "source": [
    "Here, we extract the collect and evaluation policies for training."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "cellView": "form",
    "id": "TzwSaxYkeTh5"
   },
   "outputs": [],
   "source": [
    "# @title Convert the policies into TF Eager Policies\n",
    "\n",
    "tf_collect_policy = agent.collect_policy\n",
    "agent_collect_policy = py_tf_eager_policy.PyTFEagerPolicy(\n",
    "    tf_collect_policy, use_tf_function=True\n",
    ")\n",
    "\n",
    "tf_policy = agent.policy\n",
    "agent_policy = py_tf_eager_policy.PyTFEagerPolicy(\n",
    "    tf_policy, use_tf_function=True\n",
    ")\n",
    "\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "qtoqyo8Ypn0Q"
   },
   "source": [
    "We will set the interval of saving the policies and writing critic, actor, and alphs losses.\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "xums9Kxkxylw"
   },
   "outputs": [],
   "source": [
    "policy_save_interval = 1 # Save the policy after every learning step.\n",
    "learner_summary_interval = 1 # Produce a summary of the critic, actor, and alpha losses after every gradient update step."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "al5HNoiwvYO-"
   },
   "source": [
    "In the following cell we will define the agent learner, a TF-Agents wrapper around the process that performs gradiant-based updates to the actor and critic networks in the agent.\n",
    "\n",
    "You should see a statememt that shows you where the policies will be saved to."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "cellView": "form",
    "id": "Ah4oS9HLwOid"
   },
   "outputs": [],
   "source": [
    "# @title Define an Agent Learner\n",
    "\n",
    "experience_dataset_fn = lambda: dataset\n",
    "\n",
    "saved_model_dir = os.path.join(root_dir, learner.POLICY_SAVED_MODEL_DIR)\n",
    "print('Policies will be saved to saved_model_dir: %s' %saved_model_dir)\n",
    "env_step_metric = py_metrics.EnvironmentSteps()\n",
    "learning_triggers = [\n",
    "      triggers.PolicySavedModelTrigger(\n",
    "          saved_model_dir,\n",
    "          agent,\n",
    "          train_step,\n",
    "          interval=policy_save_interval,\n",
    "          metadata_metrics={triggers.ENV_STEP_METADATA_KEY: env_step_metric},\n",
    "      ),\n",
    "      triggers.StepPerSecondLogTrigger(train_step, interval=10),\n",
    "]\n",
    "\n",
    "agent_learner = learner.Learner(\n",
    "      root_dir,\n",
    "      train_step,\n",
    "      agent,\n",
    "      experience_dataset_fn,\n",
    "      triggers=learning_triggers,\n",
    "      strategy=None,\n",
    "      summary_interval=learner_summary_interval,\n",
    ")\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "wAdbomqlyqpz"
   },
   "source": [
    "Set the number of training steps in a training iteration. This is the number of collect steps between gradient updates.\n",
    "\n",
    "Here we set the number of training steps to the length of a full episode."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "6iWtSC-FKHMW"
   },
   "outputs": [],
   "source": [
    "collect_steps_per_treining_iteration = collect_env._num_timesteps_in_episode"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "BdKA4Jy4YfJM"
   },
   "source": [
    "Next, we will define a *collect actor* and an *eval actor* that wrap the policy and the environment, and can execute and collect metrics.\n",
    "\n",
    "The principal difference between the collect actor and the eval actor, is that the collect actor will choose actions by drawing off the actor network distribution, choosing actions that have a high probability over actions with lower probability. This stochastic property enables the agent explore bettwer actions and improve the policy.\n",
    "\n",
    "However, the eval actor always chooses the action associated with the highest probability."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "cellView": "form",
    "id": "LWsI9znlqLvh"
   },
   "outputs": [],
   "source": [
    "# @title Define a TF-Agents Actor for collect and eval\n",
    "tf_collect_policy = agent.collect_policy\n",
    "collect_policy = py_tf_eager_policy.PyTFEagerPolicy(\n",
    "    tf_collect_policy, use_tf_function=True\n",
    ")\n",
    "collect_actor = actor.Actor(\n",
    "    collect_env,\n",
    "    collect_policy,\n",
    "    train_step,\n",
    "    steps_per_run=collect_steps_per_treining_iteration,\n",
    "    metrics=actor.collect_metrics(1),\n",
    "    summary_dir=os.path.join(root_dir, learner.TRAIN_DIR),\n",
    "    summary_interval=1,\n",
    "    observers=[\n",
    "        rb_observer,\n",
    "        env_step_metric,\n",
    "        collect_print_status_observer,\n",
    "        collect_render_plot_observer,\n",
    "    ],\n",
    ")\n",
    "\n",
    "tf_policy = agent.policy\n",
    "eval_policy = py_tf_eager_policy.PyTFEagerPolicy(\n",
    "    tf_policy, use_tf_function=True\n",
    ")\n",
    "\n",
    "eval_actor = actor.Actor(\n",
    "    eval_env,\n",
    "    eval_policy,\n",
    "    train_step,\n",
    "    episodes_per_run=1,\n",
    "    metrics=actor.eval_metrics(1),\n",
    "    summary_dir=os.path.join(root_dir, 'eval'),\n",
    "    summary_interval=1,\n",
    "    observers=[\n",
    "        rb_observer, \n",
    "        eval_print_status_observer, \n",
    "        eval_render_plot_observer\n",
    "    ],\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "c_DN734lZAwE"
   },
   "source": [
    "Finally we're ready to execute the RL traiing loop with TD3!\n",
    "\n",
    "You can sepcify the total number of trainng iterations, and the number of gradient steps per iteration. With fewer steps, the model will train more slowly, but more steps may make the agent less stable."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "cellView": "form",
    "id": "PAlT1f6SWYxq"
   },
   "outputs": [],
   "source": [
    "# @title Execute the training loop\n",
    "\n",
    "num_training_iterations = 10\n",
    "num_gradient_updates_per_training_iteration = 100\n",
    "\n",
    "# Collect the performance results with the untrained model.\n",
    "eval_actor.run_and_log()\n",
    "\n",
    "logging_info('Training.')\n",
    "for iter in range(num_training_iterations):\n",
    "  print('Training iteration: ', iter)\n",
    "  # Let the collect actor run, using its policy.\n",
    "  collect_actor.run()\n",
    "  logging_info(\n",
    "      'Executing %d gradient updates.'\n",
    "      %num_gradient_updates_per_training_iteration\n",
    "  )\n",
    "  # Now, with the additional collectsteps in the replay buffer,\n",
    "  # allow the agent to make additional policy improvements.\n",
    "  loss_info = agent_learner.run(\n",
    "      iterations=num_gradient_updates_per_training_iteration\n",
    "  )\n",
    "  logging_info( # No alpha Loss for TD3\n",
    "      'Actor Loss: %6.2f, Critic Loss: %6.2f'\n",
    "      % (\n",
    "          loss_info.extra.actor_loss.numpy(),\n",
    "          loss_info.extra.critic_loss.numpy(),\n",
    "      )\n",
    "  )\n",
    "\n",
    "  logging_info('Evaluating.')\n",
    "\n",
    "  _ = eval_env.reset()\n",
    "  # Run the eval actor after the training iteration, and get its performance.\n",
    "  eval_actor.run_and_log()\n",
    "\n",
    "rb_observer.close()\n",
    "reverb_server.stop()"
   ]
  }
 ],
 "metadata": {
  "colab": {
   "collapsed_sections": [
    "L7w-mjPcH7u6",
    "kTtVb9wbRsKU",
    "86IIF7FrfJ_2",
    "SDgizVLzRti1"
   ],
   "last_runtime": {
    "build_target": "",
    "kind": "local"
   },
   "private_outputs": true,
   "provenance": [
    {
     "file_id": "1a2nzc-VcaGRTpsEFj3FgqRZY0Lk1dgrW",
     "timestamp": 1705074752110
    }
   ],
   "toc_visible": true
  },
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.7"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
