{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# How to make your own Agent"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "vscode": {
     "languageId": "plaintext"
    }
   },
   "source": [
    "In this tutorial you'll learn:\n",
    "- How to implement agents that work with TAG's hierarchy system\n",
    "- The difference between BaseAgent and LevelAgent interfaces\n",
    "- How to create standalone vs hierarchical agents\n",
    "- Best practices for agent implementation"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Each agent being part of a TAG hierarchy has to respect the `LevelAgent` interface from `src/tame/hierarchy/base_agent.py`.\n",
    "An example of this is our `PPO` implementation from `src/tame/agents/monolithic_ppo.py`\n",
    "\n",
    "In this tutorial we will discuss the important functions to implement for a well functioning agent.\n",
    "\n",
    "As you can see, `src/tame/hierarchy/base_agent.py` provides two type of agent interfaces:\n",
    "- `BaseAgent`: a general interface for any agent that the `run_experiment` script must run\n",
    "- `LevelAgent`: the itnerface for agents to be part of a `LevelEnv`."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## BaseAgent\n",
    "\n",
    "The `BaseAgent` class is a general interface for standalone RL agents that can be run using the `run_experiment.py` script. It requires implementing these core methods:\n",
    "\n",
    "1. `__init__(env: ParallelEnv, args: None | Any = None)`:\n",
    "   - Initializes the agent with a PettingZoo parallel environment\n",
    "   - Takes optional arguments for customization\n",
    "   - Should set up networks, memory buffers, and other agent components\n",
    "\n",
    "2. `save_agent(save_path: str | Path, name: None | str = None)`:\n",
    "   - Saves agent's state (networks, parameters) to disk\n",
    "   - Models are saved in `{save_path}/models/{name}.pth`\n",
    "   - Name parameter should not include '.pth' extension\n",
    "\n",
    "3. `load_agent(load_path: Path | str, name: str = \"trained_model\") -> bool`:\n",
    "   - Loads agent's state from disk\n",
    "   - Returns True if loading was successful\n",
    "   - Loads from `{load_path}/models/{name}.pth`\n",
    "\n",
    "4. `act(observation: Dict[str, np.ndarray]) -> dict`:\n",
    "   - Takes observations and returns actions for each agent\n",
    "   - Input is dictionary: `{agent_name: observation}`\n",
    "   - Output is dictionary: `{agent_name: action}`\n",
    "\n",
    "5. `train(env: Any, log_path: Path | str | None = None, run_name: str | None = None)`:\n",
    "   - Implements the training loop for the agent\n",
    "   - Uses provided environment for training\n",
    "   - Logs results to specified path if provided\n",
    "\n",
    "Here's a minimal example:\n",
    "\n",
    "```python\n",
    "from tame.hierarchy import BaseAgent\n",
    "from pathlib import Path\n",
    "import torch\n",
    "import numpy as np\n",
    "\n",
    "class MyBaseAgent(BaseAgent):\n",
    "    def __init__(self, env, args=None):\n",
    "        self.env = env\n",
    "        self.args = args or {}\n",
    "        \n",
    "        # Initialize networks, buffers etc.\n",
    "        self.policy = torch.nn.Linear(\n",
    "            np.prod(env.observation_spaces[env.possible_agents[0]].shape),\n",
    "            env.action_spaces[env.possible_agents[0]].n\n",
    "        )\n",
    "        self.optimizer = torch.optim.Adam(self.policy.parameters())\n",
    "        \n",
    "    def save_agent(self, save_path, name=None):\n",
    "        name = name or \"model\"\n",
    "        save_dir = Path(save_path) / \"models\"\n",
    "        save_dir.mkdir(parents=True, exist_ok=True)\n",
    "        \n",
    "        torch.save({\n",
    "            'policy_state_dict': self.policy.state_dict(),\n",
    "            'optimizer_state_dict': self.optimizer.state_dict(),\n",
    "        }, save_dir / f\"{name}.pth\")\n",
    "        \n",
    "    def load_agent(self, load_path, name=\"trained_model\"):\n",
    "        try:\n",
    "            checkpoint = torch.load(Path(load_path) / \"models\" / f\"{name}.pth\")\n",
    "            self.policy.load_state_dict(checkpoint['policy_state_dict'])\n",
    "            self.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])\n",
    "            return True\n",
    "        except:\n",
    "            return False\n",
    "            \n",
    "    def act(self, observation):\n",
    "        actions = {}\n",
    "        for agent, obs in observation.items():\n",
    "            with torch.no_grad():\n",
    "                logits = self.policy(torch.FloatTensor(obs))\n",
    "                action = torch.argmax(logits).item()\n",
    "            actions[agent] = action\n",
    "        return actions\n",
    "        \n",
    "    def train(self, env, log_path=None, run_name=None):\n",
    "        # Your training loop here\n",
    "        for episode in range(1000):\n",
    "            obs, _ = env.reset()\n",
    "            done = False\n",
    "            while not done:\n",
    "                actions = self.act(obs)\n",
    "                next_obs, rewards, terminated, truncated, _ = env.step(actions)\n",
    "                # Your update logic here\n",
    "                done = any(terminated.values()) or any(truncated.values())\n",
    "                obs = next_obs"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Note\n",
    "`BaseAgent` is primarily useful if you want to use the `run_experiment.py` script. \n",
    "If you're building an agent to be part of a hierarchy, you should implement `LevelAgent` instead."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## LevelAgent"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The `LevelAgent` interface inherits from `BaseAgent` but adds several key methods needed for hierarchical learning. When implementing your own agent, you'll need to:\n",
    "\n",
    "1. Define the following class attributes:\n",
    "   - `action_space`: Dict of action spaces for each agent the LevelAgent controls\n",
    "   - `observation_space`: Dict of observation spaces from which the agent receives input\n",
    "   - `communication_space`: Dict of spaces defining messages sent to higher levels (optional)\n",
    "\n",
    "2. Implement these required methods:\n",
    "   - `__init__(observation_space, action_space, communication_space, device, name, args, torch_compile)`:\n",
    "     - Initializes spaces, networks, and memory buffers\n",
    "     - Note: Unlike BaseAgent, LevelAgent doesn't take an environment in its constructor\n",
    "   \n",
    "   - `update_step(global_step, writer)`:\n",
    "     - Performs a training update during environment steps\n",
    "     - Called automatically by LevelEnv during training\n",
    "     - Use writer (TensorBoard) to log training metrics\n",
    "   \n",
    "   - `store(state, action, reward, done)`:\n",
    "     - Stores transitions in agent memory (e.g., replay buffer)\n",
    "     - Called automatically by LevelEnv after each action\n",
    "   \n",
    "   - `act_train(observation, global_step)`:\n",
    "     - Like `act()` but used during training\n",
    "     - Typically implements exploration (e.g., epsilon-greedy)\n",
    "\n",
    "   -  `save_agent(save_path: str | Path, name: None | str = None)`:\n",
    "      - Saves agent's state (networks, parameters) to disk\n",
    "      - Models are saved in `{save_path}/models/{name}.pth`\n",
    "      - Name parameter should not include '.pth' extension\n",
    "\n",
    "   -  `load_agent(load_path: Path | str, name: str = \"trained_model\") -> bool`:\n",
    "      - Loads agent's state from disk\n",
    "      - Returns True if loading was successful\n",
    "      - Loads from `{load_path}/models/{name}.pth`\n",
    "\n",
    "   - `act(observation: Dict[str, np.ndarray]) -> dict`:\n",
    "      - Takes observations and returns actions for each agent\n",
    "      - Input is dictionary: `{agent_name: observation}`\n",
    "      - Output is dictionary: `{agent_name: action}`\n",
    "   \n",
    "   - `seed(seed)`:\n",
    "     - Sets random seeds for reproducibility\n",
    "\n",
    "3. Optionally override:\n",
    "   - `comm(observation)`:\n",
    "     - Generates messages sent to higher levels\n",
    "     - Default implementation concatenates observations\n",
    "     - Override to implement learned communication\n",
    "\n",
    "Here's a minimal example showing the key components:\n",
    "\n",
    "```python\n",
    "from tame.hierarchy import LevelAgent\n",
    "import gymnasium as gym\n",
    "import numpy as np\n",
    "\n",
    "class MyAgent(LevelAgent):\n",
    "    def __init__(self, observation_space, action_space, communication_space, device, name=\"my_agent\", args=None, torch_compile=False):\n",
    "        self.observation_space = observation_space\n",
    "        self.action_space = action_space\n",
    "        self.communication_space = communication_space\n",
    "        self.device = device\n",
    "        self.name = name\n",
    "        \n",
    "        # Initialize your networks, buffers etc.\n",
    "        self.memory = []\n",
    "        \n",
    "    def act_train(self, observation, global_step):\n",
    "        # Add exploration during training\n",
    "        if np.random.random() < 0.1:  # epsilon-greedy\n",
    "            return {agent: self.action_space[agent].sample() for agent in observation}\n",
    "        return self.act(observation)\n",
    "        \n",
    "    def act(self, observation):\n",
    "        # Your action selection logic\n",
    "        return {agent: self.action_space[agent].sample() for agent in observation}\n",
    "        \n",
    "    def store(self, state, action=None, reward=None, done=None):\n",
    "        # Store transition in memory\n",
    "        self.memory.append((state, action, reward, done))\n",
    "        \n",
    "    def update_step(self, global_step, writer):\n",
    "        # Perform one training update\n",
    "        if len(self.memory) >= batch_size:\n",
    "            # Your training logic here\n",
    "            if writer:\n",
    "                writer.add_scalar(\"loss\", loss, global_step)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "When implementing your own agent, consider:\n",
    "- Using existing implementations like `monolithic_ppo.py` as templates\n",
    "- Testing agents independently before adding them to hierarchies\n",
    "- Adding thorough logging during training for debugging\n",
    "- Implementing proper cleanup in `__del__` if needed\n",
    "\n",
    "The agent will be automatically integrated into training and evaluation flows by the `LevelEnv`, so focus on implementing the core learning algorithms and let the framework handle the hierarchical aspects."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Note\n",
    "\n",
    "1. A `LevelAgent` does not need the `train` function to be implemented, as its training is performed through the `update_step` function.\n",
    "But if you implement `train` as well, you can use a `LevelAgent` also standalone with the `run_experiment.py` script.\n",
    "\n",
    "2. In general, a `BaseAgent` is used to instantiate multiple `LevelAgent` through the hierarchy (example as in `src/tame/agents/ppo3.py`)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Examples\n",
    "You can find examples of how `BaseAgent` and `LevelAgent` are implemented in `src/tame/agents`."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Next Steps\n",
    "\n",
    "Now that you understand how to implement agents in TAG, you can:\n",
    "\n",
    "1. Study example implementations in `src/tame/agents/`:\n",
    "   - `monolithic_ppo.py` for a basic LevelAgent example\n",
    "   - `ppo3.py` for a complete hierarchical agent\n",
    "\n",
    "2. Check the full documentation for:\n",
    "   - Advanced communication mechanisms between levels\n",
    "   - Custom observation/action space configurations\n",
    "   - Performance optimization tips\n",
    "   - Debugging tools and techniques\n",
    "\n",
    "3. Try building your own agent:\n",
    "   - Start with a simple standalone BaseAgent\n",
    "   - Convert it to a LevelAgent\n",
    "   - Test it in increasingly complex hierarchies\n",
    "   - Experiment with different communication strategies\n",
    "\n",
    "4. Reference [Building your first hierarchy](tutorials/03_build_a_hierarchy.ipynb) to see how to integrate your agent into a complete hierarchical system.\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": []
  }
 ],
 "metadata": {
  "language_info": {
   "name": "python"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
