{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Build your hierarchical agent\n",
    "\n",
    "In this tutorial we'll build our first 2-level hierarchical agent with TAG.\n",
    "\n",
    "We'll see how to configure agents, how to configure levels and how to stack them. Finally we'll also train the resulting hierarchy and test it."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Environment\n",
    "\n",
    "We'll use the [MPE Simple Spread environment](https://pettingzoo.farama.org/environments/mpe/simple_spread/) from Pettingzoo with 2 agents.\n",
    "\n",
    "Note that we also have our own wrapper for the environment, that colors the agents to help with visualization: `src/tame/envs/mpe_simple_spread/env.py`. But in this example we'll use the standard one."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "from pettingzoo.mpe import simple_spread_v3\n",
    "\n",
    "TOTAL_AGENTS = 2\n",
    "# Note that we use the ParallelEnv here, as the library only supports that one for now.\n",
    "env = simple_spread_v3.parallel_env(N=TOTAL_AGENTS,max_cycles=100,continuous_actions=False, render_mode='rgb_array')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/mnt/data/gpaolo/miniforge3/envs/tame/lib/python3.10/site-packages/pettingzoo/utils/conversions.py:144: UserWarning: The `observation_spaces` dictionary is deprecated. Use the `observation_space` function instead.\n",
      "  warnings.warn(\n"
     ]
    }
   ],
   "source": [
    "env_obs_sizes = []\n",
    "for agent in env.observation_spaces:\n",
    "    obs_space = env.observation_spaces[agent]\n",
    "    env_obs_sizes.append(obs_space.shape[0])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Hierarchy structure\n",
    "\n",
    "We'll define our 2 level hierarchy with PPO-based subagents.\n",
    "\n",
    "We start by defining the agent names and how they are connected:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "# TODO rename agent to ppo...\n",
    "bottom_agent_names = [f\"agent_{i}\" for i in range(len(env.possible_agents))]\n",
    "top_agent_name = \"top_ppo\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Bottom to env connections: {'agent_0': ['agent_0'], 'agent_1': ['agent_1']}\n",
      "Top to bottom connections: {'top_ppo': ['agent_0', 'agent_1']}\n"
     ]
    }
   ],
   "source": [
    "# Each bottom agent is connected to one agent in the environment\n",
    "bottom_env_links = {agent_name: [env_agent] for agent_name, env_agent in zip(bottom_agent_names, env.possible_agents)}\n",
    "print(f\"Bottom to env connections: {bottom_env_links}\")\n",
    "\n",
    "# The top agent is connected to the two bottome agents\n",
    "top_bottom_links = {top_agent_name: bottom_agent_names}\n",
    "print(f\"Top to bottom connections: {top_bottom_links}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We now instantiate the hierarchy itself, that will help us organize the levels and the communications:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "from tame.hierarchy import Hierarchy\n",
    "\n",
    "hierarchy = Hierarchy()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Bottom level\n",
    "### Agents\n",
    "\n",
    "Now we can parametrize the agents in the bottom level.\n",
    "\n",
    "For this we use `AgentConfig`, which allows us to provide the configuration for each of the agents in the level."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "from pprint import pprint\n",
    "\n",
    "from gymnasium.spaces import Dict as GymDict, Box, Discrete\n",
    "from tame.hierarchy import AgentConfig\n",
    "# We use our PPO implementation as agent in the level. Note that this is a LevelAgent.\n",
    "from tame.agents.monolithic_ppo import Agent as PPO \n",
    "from tame.agents.monolithic_ppo import Args as PPOArgs "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "An important aspect of the agents definition is the communication space.\n",
    "There are multiple ways agents can communicate to the level above in TAG (no communication, just pass the observations from below, learn to communicate.)\n",
    "\n",
    "In this simple agent, we just pass the observations from the level below.\n",
    "This means that the communication space of each agent in the bottome level is equal to the observation space for each agent in the environment."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Agent agent_0 is connected to ['agent_0']\n",
      "Observation shapes of agents connected to agent_0: [(12,)]\n",
      "Agent agent_0 communication space: Dict('agent_0': Box(inf, inf, (12,), float32))\n",
      "Agent agent_0 config:\n",
      "AgentConfig(name='agent_0',\n",
      "            communication_space=Dict('agent_0': Box(inf, inf, (12,), float32)),\n",
      "            agent_class=<class 'tame.agents.monolithic_ppo.Agent'>,\n",
      "            agent_kwargs={'args': Args(subargs=None,\n",
      "                                       seed=1,\n",
      "                                       cuda=0,\n",
      "                                       exp_name='monolithic_ppo',\n",
      "                                       torch_deterministic=True,\n",
      "                                       save_model=True,\n",
      "                                       total_timesteps=500000,\n",
      "                                       learning_rate=0.00025,\n",
      "                                       gamma=0.99,\n",
      "                                       anneal_lr=True,\n",
      "                                       gae_lambda=0.95,\n",
      "                                       batch_size=2048,\n",
      "                                       num_minibatches=4,\n",
      "                                       update_epochs=4,\n",
      "                                       norm_adv=True,\n",
      "                                       clip_coef=0.2,\n",
      "                                       clip_vloss=True,\n",
      "                                       ent_coef=0.0,\n",
      "                                       vf_coef=0.5,\n",
      "                                       save_all_trace=False,\n",
      "                                       max_grad_norm=0.5,\n",
      "                                       target_kl=None,\n",
      "                                       verbose=True,\n",
      "                                       learn_comm=False,\n",
      "                                       ae_epochs=50)},\n",
      "            device=device(type='cpu'))\n",
      "===============================================\n",
      "Agent agent_1 is connected to ['agent_1']\n",
      "Observation shapes of agents connected to agent_1: [(12,)]\n",
      "Agent agent_1 communication space: Dict('agent_1': Box(inf, inf, (12,), float32))\n",
      "Agent agent_1 config:\n",
      "AgentConfig(name='agent_1',\n",
      "            communication_space=Dict('agent_1': Box(inf, inf, (12,), float32)),\n",
      "            agent_class=<class 'tame.agents.monolithic_ppo.Agent'>,\n",
      "            agent_kwargs={'args': Args(subargs=None,\n",
      "                                       seed=1,\n",
      "                                       cuda=0,\n",
      "                                       exp_name='monolithic_ppo',\n",
      "                                       torch_deterministic=True,\n",
      "                                       save_model=True,\n",
      "                                       total_timesteps=500000,\n",
      "                                       learning_rate=0.00025,\n",
      "                                       gamma=0.99,\n",
      "                                       anneal_lr=True,\n",
      "                                       gae_lambda=0.95,\n",
      "                                       batch_size=2048,\n",
      "                                       num_minibatches=4,\n",
      "                                       update_epochs=4,\n",
      "                                       norm_adv=True,\n",
      "                                       clip_coef=0.2,\n",
      "                                       clip_vloss=True,\n",
      "                                       ent_coef=0.0,\n",
      "                                       vf_coef=0.5,\n",
      "                                       save_all_trace=False,\n",
      "                                       max_grad_norm=0.5,\n",
      "                                       target_kl=None,\n",
      "                                       verbose=True,\n",
      "                                       learn_comm=False,\n",
      "                                       ae_epochs=50)},\n",
      "            device=device(type='cpu'))\n",
      "===============================================\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/mnt/data/gpaolo/miniforge3/envs/tame/lib/python3.10/site-packages/pettingzoo/utils/conversions.py:144: UserWarning: The `observation_spaces` dictionary is deprecated. Use the `observation_space` function instead.\n",
      "  warnings.warn(\n"
     ]
    }
   ],
   "source": [
    "\n",
    "bottom_agents = []\n",
    "for agent_name in bottom_agent_names:\n",
    "\n",
    "    # We first define the communication space that our agent uses.\n",
    "    # In this case we don't learn the communication space, but just have the agent concatenate the observations\n",
    "    # it receives from the level below (that is the environment) from the agents it's connected with.\n",
    "    # ---------------------------\n",
    "    # Get list of agents from env connected to this agent\n",
    "    env_agents = bottom_env_links[agent_name]\n",
    "    print(f\"Agent {agent_name} is connected to {env_agents}\")\n",
    "\n",
    "    # Get observation shapes of the connected agents\n",
    "    env_observation_shapes = [env.observation_spaces[env_agent].shape for env_agent in env_agents]\n",
    "    print(f\"Observation shapes of agents connected to {agent_name}: {env_observation_shapes}\")\n",
    "\n",
    "    # This is the size of the message that the agent sends to the level above\n",
    "    concatenated_obs_size = np.sum(env_observation_shapes)\n",
    "\n",
    "    # Define agent communication space\n",
    "    agent_comm_space = GymDict({agent_name: Box(np.inf, np.inf, shape=[concatenated_obs_size])})\n",
    "    print(f\"Agent {agent_name} communication space: {agent_comm_space}\")\n",
    "    # ---------------------------\n",
    "\n",
    "    # Make agent config\n",
    "    agent_config = AgentConfig(\n",
    "        name=agent_name,\n",
    "        agent_class=PPO,\n",
    "        agent_kwargs={'args': PPOArgs()},\n",
    "        communication_space=agent_comm_space\n",
    "        )\n",
    "    \n",
    "    print(f\"Agent {agent_name} config:\")\n",
    "    pprint(agent_config)\n",
    "    print(\"===============================================\")\n",
    "    bottom_agents.append(agent_config)\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Level\n",
    "\n",
    "Once we have the agents configurations, we can configure the level itself.\n",
    "\n",
    "As for the agents it's important to define the communication space, for the level it's important to define the action space, that is which actions the level expects from the agents in the level above.\n",
    "\n",
    "The action space is a nested Dict gym space the defines what the agents from top sends to which agent in the level.\n",
    "It will have form:\n",
    "{top_agent_name: {bottom_agent: agent_action_space}}\n",
    "\n",
    "In this case, we'll have a discrete action space of size 5 for each connection."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Dict('top_ppo': Dict('agent_0': Discrete(5), 'agent_1': Discrete(5)))\n"
     ]
    }
   ],
   "source": [
    "bottom_lev_action_space = {}\n",
    "for hl_agent, agent_names in top_bottom_links.items():\n",
    "    agent_spaces = {agent_name: Discrete(5) for agent_name in agent_names}\n",
    "    bottom_lev_action_space[hl_agent] = GymDict(agent_spaces)\n",
    "bottom_lev_action_space = GymDict(bottom_lev_action_space)\n",
    "\n",
    "print(bottom_lev_action_space)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "This means that the `top_ppo` will send a separate `Discrete(5)` action to each of the agents in the `bottom` level.\n",
    "\n",
    "Now, we have all the ingredients finally parametrize the whole level.\n",
    "To do so we'll use the `LevelConfig` class to which we pass all the parameters."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "from tame.hierarchy import LevelConfig\n",
    "\n",
    "bottom_level_config = LevelConfig(\n",
    "                name=\"bottom\",\n",
    "                agents=bottom_agents,\n",
    "                uplinks=top_bottom_links, # How the agents in this level are connected to the level above\n",
    "                downlinks=bottom_env_links, # How the agents in this level are connected to the level below\n",
    "                action_frequency=1, # How often the agents in this level can act wrt to the level below\n",
    "                trace_type=\"full\", # 'full' or 'reward'. Full saves everything, reward saves only the rewards\n",
    "                concat_obs=True, # If true, the actions from the level above are concatenated to the observations from the level below before being sent to the agents in this level\n",
    "                action_space=bottom_lev_action_space, # type: ignore\n",
    "                env=env, # The environment that the agents in this level interact with\n",
    "            )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Created level bottom with agents: ['agent_0', 'agent_1']\n"
     ]
    }
   ],
   "source": [
    "# Finally, we add the level to the hierarchy\n",
    "hierarchy.add_level_config(bottom_level_config)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Note\n",
    "If you want to instantiate the level directly, you can use the `LevelEnv` from `src/tame/hierarchy/level_env.py` and then add it to the hierarchy through the `add_level()` function.\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Top level\n",
    "\n",
    "Now we can define the top level of the hierarchy.\n",
    "\n",
    "### Agent\n",
    "As before, we define the top agent as PPO through the `AgentConfig`.\n",
    "Note that here we don't have to make any communication space, as there is no level above!"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Agent top_ppo config:\n",
      "AgentConfig(name='top_ppo',\n",
      "            communication_space=None,\n",
      "            agent_class=<class 'tame.agents.monolithic_ppo.Agent'>,\n",
      "            agent_kwargs={'args': Args(subargs=None,\n",
      "                                       seed=1,\n",
      "                                       cuda=0,\n",
      "                                       exp_name='monolithic_ppo',\n",
      "                                       torch_deterministic=True,\n",
      "                                       save_model=True,\n",
      "                                       total_timesteps=500000,\n",
      "                                       learning_rate=0.00025,\n",
      "                                       gamma=0.99,\n",
      "                                       anneal_lr=True,\n",
      "                                       gae_lambda=0.95,\n",
      "                                       batch_size=2048,\n",
      "                                       num_minibatches=4,\n",
      "                                       update_epochs=4,\n",
      "                                       norm_adv=True,\n",
      "                                       clip_coef=0.2,\n",
      "                                       clip_vloss=True,\n",
      "                                       ent_coef=0.0,\n",
      "                                       vf_coef=0.5,\n",
      "                                       save_all_trace=False,\n",
      "                                       max_grad_norm=0.5,\n",
      "                                       target_kl=None,\n",
      "                                       verbose=True,\n",
      "                                       learn_comm=False,\n",
      "                                       ae_epochs=50)},\n",
      "            device=device(type='cpu'))\n"
     ]
    }
   ],
   "source": [
    "top_agent_config = AgentConfig(\n",
    "            name=top_agent_name,\n",
    "            communication_space=None, # No communication space for the top agent\n",
    "            agent_class=PPO,\n",
    "            agent_kwargs={\"args\": PPOArgs()},\n",
    "        )\n",
    "print(f\"Agent {top_agent_name} config:\")\n",
    "pprint(top_agent_config)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Level\n",
    "\n",
    "Finally we define the top level.\n",
    "Here we won't have to define an action space, as this level does not expects actions from above.\n",
    "\n",
    "Moreover, the environment of this level, is the level below."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [],
   "source": [
    "top_level_config = LevelConfig(\n",
    "                name=\"top\",\n",
    "                agents=[top_agent_config],\n",
    "                uplinks=None, # No level above, so no uplinks\n",
    "                downlinks=top_bottom_links,\n",
    "                action_frequency=1,\n",
    "                trace_type=\"full\",\n",
    "                env=hierarchy.levels[-1], # The environment that the agents in this level interact with. In this case the bottom level we defined before!\n",
    "            )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Created level top with agents: ['top_ppo']\n"
     ]
    }
   ],
   "source": [
    "# Finally, we add the level to the hierarchy\n",
    "hierarchy.add_level_config(top_level_config)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Whole hierarchy\n",
    "\n",
    "We can see the structure of the hierarchy through the `print_tree()` function, and all the details through `print_hierarchy_details()`."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "└── top_ppo\n",
      "    ├── agent_0\n",
      "    │   └── agent_0 (env)\n",
      "    └── agent_1\n",
      "        └── agent_1 (env)\n"
     ]
    }
   ],
   "source": [
    "hierarchy.print_tree()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "=== Hierarchy Details ===\n",
      "\n",
      "Level 0: bottom\n",
      "────────────────────────────────────────────────────────────\n",
      "\n",
      "Level Spaces:\n",
      "  Observation Space: Dict('top_ppo': Dict('agent_0': Box(inf, inf, (12,), float32), 'agent_1': Box(inf, inf, (12,), float32)))\n",
      "  Action Space: Dict('top_ppo': Dict('agent_0': Discrete(5), 'agent_1': Discrete(5)))\n",
      "\n",
      "Uplinks:\n",
      "  top_ppo → ['agent_0', 'agent_1']\n",
      "\n",
      "Downlinks:\n",
      "  agent_0 → ['agent_0']\n",
      "  agent_1 → ['agent_1']\n",
      "\n",
      "Agents:\n",
      "  agent_0:\n",
      "    Observation space: Dict('agent_0': Box(-inf, inf, (13,), float32))\n",
      "    Action space: Dict('agent_0': Discrete(5))\n",
      "    Communication space: Dict('agent_0': Box(inf, inf, (12,), float32))\n",
      "  agent_1:\n",
      "    Observation space: Dict('agent_1': Box(-inf, inf, (13,), float32))\n",
      "    Action space: Dict('agent_1': Discrete(5))\n",
      "    Communication space: Dict('agent_1': Box(inf, inf, (12,), float32))\n",
      "\n",
      "============================================================\n",
      "\n",
      "Level 1: top\n",
      "────────────────────────────────────────────────────────────\n",
      "\n",
      "Level Spaces:\n",
      "  Observation Space: Dict()\n",
      "  Action Space: Dict()\n",
      "\n",
      "Downlinks:\n",
      "  top_ppo → ['agent_0', 'agent_1']\n",
      "\n",
      "Agents:\n",
      "  top_ppo:\n",
      "    Observation space: Dict('agent_0': Box(inf, inf, (12,), float32), 'agent_1': Box(inf, inf, (12,), float32))\n",
      "    Action space: Dict('agent_0': Discrete(5), 'agent_1': Discrete(5))\n",
      "    Communication space: None\n",
      "    Receives from below:\n",
      "      agent_0: Dict('agent_0': Box(inf, inf, (12,), float32))\n",
      "      agent_1: Dict('agent_1': Box(inf, inf, (12,), float32))\n",
      "\n",
      "============================================================\n",
      "\n"
     ]
    }
   ],
   "source": [
    "hierarchy.print_hierarchy_details()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Training and Evaluation"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Training the hierarchy is simple, and is enough to call the `step()` function for as many `training_steps` desired.\n",
    "This function will take care of stepping all the levels in the hierarchy, environment included, and training the agents by calling their `update_step` function."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Eval episode: 0 - TS: 100 - Total Reward:\n",
      "agent_0   -109.396377\n",
      "agent_1   -109.396377\n",
      "dtype: float64\n"
     ]
    }
   ],
   "source": [
    "# We evaluate the untrained agent first\n",
    "from tame.utils.utils import evaluate\n",
    "\n",
    "initial_reward = evaluate(agent=hierarchy, env=env, eval_runs=1)\n",
    "initial_reward = initial_reward.sum()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Training step::  20%|█▉        | 1995/10000 [00:04<00:19, 421.21it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "agent_0: Updating models...\n",
      "agent_1: Updating models...\n",
      "top_ppo: Updating models...\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Training step::  40%|████      | 4010/10000 [00:11<00:13, 453.91it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "agent_0: Updating models...\n",
      "agent_1: Updating models...\n",
      "top_ppo: Updating models...\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Training step::  61%|██████    | 6051/10000 [00:16<00:08, 459.42it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "agent_0: Updating models...\n",
      "agent_1: Updating models...\n",
      "top_ppo: Updating models...\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Training step::  81%|████████  | 8069/10000 [00:22<00:04, 434.99it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "agent_0: Updating models...\n",
      "agent_1: Updating models...\n",
      "top_ppo: Updating models...\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Training step:: 100%|██████████| 10000/10000 [00:27<00:00, 365.25it/s]\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "(None, {'top_ppo': {'agent_0': {}, 'agent_1': {}}})"
      ]
     },
     "execution_count": 17,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "from tqdm import tqdm\n",
    "\n",
    "# If you want to train with a new instance of the environment, you can pass it to the hierarchy like this.\n",
    "# Otherwise the training will be done on the env that you passed to the bottom_level\n",
    "hierarchy.connect(env)\n",
    "\n",
    "# Training loop\n",
    "done = False\n",
    "episode = 0\n",
    "hierarchy.reset()\n",
    "total_timesteps = 10000\n",
    "for global_step in tqdm(range(total_timesteps), desc=\"Training step:\"):\n",
    "    if done:\n",
    "        done = False\n",
    "        hierarchy.reset()\n",
    "        episode += 1\n",
    "\n",
    "    _, reward, terminated, truncated, _ = hierarchy.step(action=None)\n",
    "\n",
    "    if any(terminated.values()) or any(truncated.values()):\n",
    "        done = True\n",
    "\n",
    "# Final cleanup\n",
    "hierarchy.reset()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Eval episode: 0 - TS: 100 - Total Reward:\n",
      "agent_0   -85.299738\n",
      "agent_1   -85.299738\n",
      "dtype: float64\n"
     ]
    }
   ],
   "source": [
    "# Reevaluate the agent after training\n",
    "final_reward = evaluate(agent=hierarchy, env=env, eval_runs=1)\n",
    "final_reward = final_reward.sum()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Hopefully the final reward is better than the initial reward, meaning that your agent has learned!"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "================ Reward Summary ===============\n",
      "Agent         Initial        Final  Improvement\n",
      "--------------------------------------------\n",
      "agent_0       -109.40       -85.30        22.0%\n",
      "agent_1       -109.40       -85.30        22.0%\n"
     ]
    }
   ],
   "source": [
    "print(\"\\n================ Reward Summary ===============\")\n",
    "print(f\"{'Agent':8} {'Initial':>12} {'Final':>12} {'Improvement':>12}\")\n",
    "print(\"-\" * 44)\n",
    "\n",
    "for agent in initial_reward.index:\n",
    "    init = initial_reward[agent]\n",
    "    fin = final_reward[agent]\n",
    "    pct = ((fin - init) / abs(init)) * 100\n",
    "    print(f\"{agent:8} {init:12.2f} {fin:12.2f} {pct:>11.1f}%\")"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "tame",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.15"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
