{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# OmniSafe Tutorial - Environment Customization From Scratch\n",
    "\n",
    "OmniSafe: https://github.com/PKU-Alignment/omnisafe\n",
    "\n",
    "Documentation: https://omnisafe.readthedocs.io/en/latest/\n",
    "\n",
    "Safety-Gymnasium: https://www.safety-gymnasium.com/\n",
    "\n",
    "[Safety-Gymnasium](https://www.safety-gymnasium.com/) is a highly scalable and customizable Safe Reinforcement Learning library, aiming to deliver a good view of benchmarking Safe Reinforcement Learning (Safe RL) algorithms and a more standardized setting of environments. \n",
    "\n",
    "## 引言\n",
    "\n",
    "本节与[Tutorial 4: Environment Customization from Community](./4.Environment%20Customization%20from%20Community.ipynb)共同介绍了如何令定制化环境享受OmniSafe提供的全套训练、记录与保存框架。本节侧重于面向安全强化学习初学者介绍如何从零开始创建环境；而[Tutorial 4: Environment Customization from Community](./4.Environment%20Customization%20from%20Community.ipynb)关注如何将社区已有的环境，例如[Gymnasium](https://github.com/Farama-Foundation/Gymnasium)，作出最小适配，以嵌入OmniSafe中。\n",
    "\n",
    "具体而言，本节提供了一个用于定制化环境的最简单模版。通过该模版，您将了解：\n",
    "\n",
    "- 如何在OmniSafe中创建并注册一个环境。\n",
    "- 如何指定创建环境时的定制化参数。\n",
    "- 如何记录环境特定的信息。\n",
    "\n",
    "## 快速安装"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 通过pip安装（如果您已经安装，请忽略此段代码）\n",
    "!pip install omnisafe"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 通过源代码安装（如果您已经安装，请忽略此段代码）\n",
    "## 克隆仓库\n",
    "!git clone https://github.com/PKU-Alignment/omnisafe\n",
    "%cd omnisafe\n",
    "\n",
    "## 完成安装\n",
    "!pip install -e ."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 定制化环境最简模版\n",
    "OmniSafe的定制化环境可以仅通过单个文件实现。我们将为您介绍一个最简的定制化环境模版，它将作为您入门的起点。\n",
    "\n",
    "### 定制化环境设计\n",
    "我们将在此细致地介绍一个简易随机环境的设计过程。如果您是强化学习领域的专家或有经验的研究者，可以跳过该模块至[定制化环境嵌入](#定制化环境嵌入)或[Tutorial 4: Environment Customization from Community](./4.Gymnasium%20Customization.ipynb)。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 导入必要的包\n",
    "from __future__ import annotations\n",
    "\n",
    "import random\n",
    "import omnisafe\n",
    "from typing import Any, ClassVar\n",
    "\n",
    "import torch\n",
    "from gymnasium import spaces\n",
    "\n",
    "from omnisafe.envs.core import CMDP, env_register, env_unregister"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 定义环境类\n",
    "class ExampleEnv(CMDP):\n",
    "    _support_envs: ClassVar[list[str]] = ['Example-v0']  # 受支持的任务名称\n",
    "\n",
    "    need_auto_reset_wrapper = True  # 是否需要 `AutoReset` Wrapper\n",
    "    need_time_limit_wrapper = True  # 是否需要 `TimeLimit` Wrapper"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "您需要关注上面这段代码的如下细节：\n",
    "\n",
    "- **任务名称定义** 在 `_support_envs`中提供环境受支持的任务名称。\n",
    "- **Wrapper配置** 通过设定 `need_auto_reset_wrapper`和 `need_time_limit_wrapper` 来定义自动重置和限制时间。\n",
    "- **并行环境数量** 如果您的环境支持向量化并行，请通过 `_num_envs` 参数进行设定。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "class ExampleEnv(CMDP):\n",
    "    _support_envs: ClassVar[list[str]] = ['Example-v0', 'Example-v1']  # 受支持的任务名称\n",
    "\n",
    "    need_auto_reset_wrapper = True  # 是否需要 `AutoReset` Wrapper\n",
    "    need_time_limit_wrapper = True  # 是否需要 `TimeLimit` Wrapper\n",
    "\n",
    "    def __init__(self, env_id: str, **kwargs) -> None:\n",
    "        self._count = 0\n",
    "        self._num_envs = 1\n",
    "        self._observation_space = spaces.Box(low=-1.0, high=1.0, shape=(3,))\n",
    "        self._action_space = spaces.Box(low=-1.0, high=1.0, shape=(2,))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "完成 `__init__`函数定义。此处需要给出环境的动作空间与观测空间。您需要根据您当前在设计的具体任务来定义。例如：\n",
    "```python\n",
    "if env_id == 'Example-v0':\n",
    "    self._observation_space = spaces.Box(low=-1.0, high=1.0, shape=(3,))\n",
    "    self._action_space = spaces.Box(low=-1.0, high=1.0, shape=(2,))\n",
    "elif env_id == 'Example-v1':\n",
    "    self._observation_space = spaces.Box(low=-1.0, high=1.0, shape=(4,))\n",
    "    self._action_space = spaces.Box(low=-1.0, high=1.0, shape=(3,))\n",
    "else:\n",
    "    raise NotImplementedError\n",
    "```\n",
    "**请注意：** 由于需要为上层模块提供标准的接口，因此在设计环境时请遵循 `self._observation_space` 以及 `self._action_space` 这两个变量名**"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "完成环境初始化相关函数的定义。`reset` 和 `set_seed` 是OmniSafe环境初始化的标准接口。其中 `reset` 重置环境状态与计步器。 `set_seed` 通过设定随机种子确保实验的可复现性。而带有`@property`装饰器的`max_episode_steps`函数用于为`TimeLimit` Wrapper传递需要限制的每幕最大步数。实现参考如下："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "class ExampleEnv(CMDP):\n",
    "    _support_envs: ClassVar[list[str]] = ['Example-v0', 'Example-v1']  # 受支持的任务名称\n",
    "\n",
    "    need_auto_reset_wrapper = True  # 是否需要 `AutoReset` Wrapper\n",
    "    need_time_limit_wrapper = True  # 是否需要 `TimeLimit` Wrapper\n",
    "\n",
    "    def __init__(self, env_id: str, **kwargs) -> None:\n",
    "        self._count = 0\n",
    "        self._num_envs = 1\n",
    "        self._observation_space = spaces.Box(low=-1.0, high=1.0, shape=(3,))\n",
    "        self._action_space = spaces.Box(low=-1.0, high=1.0, shape=(2,))\n",
    "\n",
    "    def set_seed(self, seed: int) -> None:\n",
    "        random.seed(seed)\n",
    "\n",
    "    def reset(\n",
    "        self,\n",
    "        seed: int | None = None,\n",
    "        options: dict[str, Any] | None = None,\n",
    "    ) -> tuple[torch.Tensor, dict]:\n",
    "        if seed is not None:\n",
    "            self.set_seed(seed)\n",
    "        obs = torch.as_tensor(self._observation_space.sample())\n",
    "        self._count = 0\n",
    "        return obs, {}\n",
    "\n",
    "    @property\n",
    "    def max_episode_steps(self) -> None:\n",
    "        \"\"\"The max steps per episode.\"\"\"\n",
    "        return 10"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "完成功能性函数的定义。`render` 函数用于渲染环境；`close` 函数用于训练结束后的清理。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "class ExampleEnv(CMDP):\n",
    "    _support_envs: ClassVar[list[str]] = ['Example-v0', 'Example-v1']  # 受支持的任务名称\n",
    "\n",
    "    need_auto_reset_wrapper = True  # 是否需要 `AutoReset` Wrapper\n",
    "    need_time_limit_wrapper = True  # 是否需要 `TimeLimit` Wrapper\n",
    "\n",
    "    def __init__(self, env_id: str, **kwargs) -> None:\n",
    "        self._count = 0\n",
    "        self._num_envs = 1\n",
    "        self._observation_space = spaces.Box(low=-1.0, high=1.0, shape=(3,))\n",
    "        self._action_space = spaces.Box(low=-1.0, high=1.0, shape=(2,))\n",
    "\n",
    "    def set_seed(self, seed: int) -> None:\n",
    "        random.seed(seed)\n",
    "\n",
    "    def reset(\n",
    "        self,\n",
    "        seed: int | None = None,\n",
    "        options: dict[str, Any] | None = None,\n",
    "    ) -> tuple[torch.Tensor, dict]:\n",
    "        if seed is not None:\n",
    "            self.set_seed(seed)\n",
    "        obs = torch.as_tensor(self._observation_space.sample())\n",
    "        self._count = 0\n",
    "        return obs, {}\n",
    "\n",
    "    @property\n",
    "    def max_episode_steps(self) -> None:\n",
    "        \"\"\"The max steps per episode.\"\"\"\n",
    "        return 10\n",
    "\n",
    "    def render(self) -> Any:\n",
    "        pass\n",
    "\n",
    "    def close(self) -> None:\n",
    "        pass"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "完成 `step` 函数定义。此处是您定制化环境的核心交互逻辑。您只需按照本例中的数据输入与输出格式进行调整即可。您也可以直接将本例中的随机交互动态更改为您的环境动态。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "class ExampleEnv(CMDP):\n",
    "    _support_envs: ClassVar[list[str]] = ['Example-v0', 'Example-v1']  # 受支持的任务名称\n",
    "    metadata: ClassVar[dict[str, int]] = {}\n",
    "\n",
    "    need_auto_reset_wrapper = True  # 是否需要 `AutoReset` Wrapper\n",
    "    need_time_limit_wrapper = True  # 是否需要 `TimeLimit` Wrapper\n",
    "\n",
    "    def __init__(self, env_id: str, **kwargs) -> None:\n",
    "        self._count = 0\n",
    "        self._num_envs = 1\n",
    "        self._observation_space = spaces.Box(low=-1.0, high=1.0, shape=(3,))\n",
    "        self._action_space = spaces.Box(low=-1.0, high=1.0, shape=(2,))\n",
    "\n",
    "    def set_seed(self, seed: int) -> None:\n",
    "        random.seed(seed)\n",
    "\n",
    "    def reset(\n",
    "        self,\n",
    "        seed: int | None = None,\n",
    "        options: dict[str, Any] | None = None,\n",
    "    ) -> tuple[torch.Tensor, dict]:\n",
    "        if seed is not None:\n",
    "            self.set_seed(seed)\n",
    "        obs = torch.as_tensor(self._observation_space.sample())\n",
    "        self._count = 0\n",
    "        return obs, {}\n",
    "\n",
    "    @property\n",
    "    def max_episode_steps(self) -> None:\n",
    "        \"\"\"The max steps per episode.\"\"\"\n",
    "        return 10\n",
    "\n",
    "    def render(self) -> Any:\n",
    "        pass\n",
    "\n",
    "    def close(self) -> None:\n",
    "        pass\n",
    "\n",
    "    def step(\n",
    "        self,\n",
    "        action: torch.Tensor,\n",
    "    ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, dict]:\n",
    "        self._count += 1\n",
    "        obs = torch.as_tensor(self._observation_space.sample())\n",
    "        reward = 2 * torch.as_tensor(random.random())\n",
    "        cost = 2 * torch.as_tensor(random.random())\n",
    "        terminated = torch.as_tensor(random.random() > 0.9)\n",
    "        truncated = torch.as_tensor(self._count > 10)\n",
    "        return obs, reward, cost, terminated, truncated, {'final_observation': obs}"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "接下来，我们试着运行该环境10个时间步，观察交互信息。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "--------------------\n",
      "obs: tensor([-0.5552,  0.2905,  0.0094])\n",
      "reward: 1.6888437271118164\n",
      "cost: 1.5159088373184204\n",
      "terminated: False\n",
      "truncated: False\n",
      "********************\n",
      "--------------------\n",
      "obs: tensor([-0.0635, -0.9966, -0.4681])\n",
      "reward: 0.5178334712982178\n",
      "cost: 1.0225493907928467\n",
      "terminated: False\n",
      "truncated: False\n",
      "********************\n",
      "--------------------\n",
      "obs: tensor([ 0.4385,  0.0678, -0.3470])\n",
      "reward: 1.5675971508026123\n",
      "cost: 0.6066254377365112\n",
      "terminated: False\n",
      "truncated: False\n",
      "********************\n",
      "--------------------\n",
      "obs: tensor([ 0.8278, -0.5252, -0.1799])\n",
      "reward: 1.1667640209197998\n",
      "cost: 1.8162257671356201\n",
      "terminated: False\n",
      "truncated: False\n",
      "********************\n",
      "--------------------\n",
      "obs: tensor([ 0.1086, -0.5711,  0.7751])\n",
      "reward: 0.5636757016181946\n",
      "cost: 1.511608362197876\n",
      "terminated: False\n",
      "truncated: False\n",
      "********************\n",
      "--------------------\n",
      "obs: tensor([-0.3585,  0.8011,  0.2172])\n",
      "reward: 0.5010126829147339\n",
      "cost: 1.8194924592971802\n",
      "terminated: True\n",
      "truncated: False\n",
      "********************\n"
     ]
    }
   ],
   "source": [
    "env = ExampleEnv(env_id='Example-v0')\n",
    "env.reset(seed=0)\n",
    "while True:\n",
    "    action = env.action_space.sample()\n",
    "    obs, reward, cost, terminated, truncated, info = env.step(action)\n",
    "    print('-' * 20)\n",
    "    print(f'obs: {obs}')\n",
    "    print(f'reward: {reward}')\n",
    "    print(f'cost: {cost}')\n",
    "    print(f'terminated: {terminated}')\n",
    "    print(f'truncated: {truncated}')\n",
    "    print('*' * 20)\n",
    "    if terminated or truncated:\n",
    "        break\n",
    "env.close()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "恭喜您！已经成功完成了基础的环境定义，接下来，我们将介绍如何将该环境注册入OmniSafe中，并实现环境参数传递、交互信息记录、算法训练以及结果保存等步骤。"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 定制化环境嵌入\n",
    "\n",
    "### 快速训练\n",
    "\n",
    "得益于OmniSafe精心设计的注册机制，我们只需一个装饰器即可将这个环境注册到OmniSafe的环境列表中。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "@env_register\n",
    "class ExampleEnv(ExampleEnv):\n",
    "    pass"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "注册同名环境将会报错，这是由于**环境名称冲突**。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "1"
      ]
     },
     "execution_count": 9,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "@env_register\n",
    "class CustomExampleEnv(ExampleEnv):\n",
    "    example_configs = 1\n",
    "\n",
    "\n",
    "env = CustomExampleEnv('Example-v0')\n",
    "env.example_configs"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "这时，您需要先对环境手动取消注册。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "@env_unregister\n",
    "class CustomExampleEnv(ExampleEnv):\n",
    "    pass"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "之后，您就可以重新注册该环境了。在本教程中，我们会同时嵌套 `env_register` 和 `env_unregister` 装饰器，这是为了避免环境重复注册造成报错，即确保该环境只被注册一次，以便用户在阅读本教程时多次修改与运行代码。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "CustomExampleEnv has not been registered yet\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "2"
      ]
     },
     "execution_count": 11,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "@env_register\n",
    "@env_unregister\n",
    "class CustomExampleEnv(ExampleEnv):\n",
    "    example_configs = 2\n",
    "\n",
    "\n",
    "env = CustomExampleEnv('Example-v0')\n",
    "env.example_configs"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "随后，您可以使用OmniSafe中的算法来训练这个自定义环境。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Loading PPOLag.yaml from /home/safepo/dev-env/omnisafe_zjy/omnisafe/utils/../configs/on-policy/PPOLag.yaml\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">Logging data to .</span><span style=\"color: #800080; text-decoration-color: #800080; font-weight: bold\">/runs/</span><span style=\"color: #ff00ff; text-decoration-color: #ff00ff; font-weight: bold\">PPOLag-</span><span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">{Example-v0}</span><span style=\"color: #800080; text-decoration-color: #800080; font-weight: bold\">/seed-000-2024-04-09-15-04-56/</span><span style=\"color: #ff00ff; text-decoration-color: #ff00ff; font-weight: bold\">progress.csv</span>\n",
       "</pre>\n"
      ],
      "text/plain": [
       "\u001b[1;36mLogging data to .\u001b[0m\u001b[1;35m/runs/\u001b[0m\u001b[1;95mPPOLag-\u001b[0m\u001b[1;36m{\u001b[0m\u001b[1;36mExample-v0\u001b[0m\u001b[1;36m}\u001b[0m\u001b[1;35m/seed-000-2024-04-09-15-04-56/\u001b[0m\u001b[1;95mprogress.csv\u001b[0m\n"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"color: #808000; text-decoration-color: #808000; font-weight: bold\">Save with config in config.json</span>\n",
       "</pre>\n"
      ],
      "text/plain": [
       "\u001b[1;33mSave with config in config.json\u001b[0m\n"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"color: #008000; text-decoration-color: #008000\">INFO: Start training</span>\n",
       "</pre>\n"
      ],
      "text/plain": [
       "\u001b[32mINFO: Start training\u001b[0m\n"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">/home/safepo/anaconda3/envs/dev-env/lib/python3.8/site-packages/rich/live.py:231: UserWarning: install \"ipywidgets\"\n",
       "for Jupyter support\n",
       "  warnings.warn('install \"ipywidgets\" for Jupyter support')\n",
       "</pre>\n"
      ],
      "text/plain": [
       "/home/safepo/anaconda3/envs/dev-env/lib/python3.8/site-packages/rich/live.py:231: UserWarning: install \"ipywidgets\"\n",
       "for Jupyter support\n",
       "  warnings.warn('install \"ipywidgets\" for Jupyter support')\n"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"></pre>\n"
      ],
      "text/plain": []
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">\n",
       "</pre>\n"
      ],
      "text/plain": [
       "\n"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"></pre>\n"
      ],
      "text/plain": []
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">\n",
       "</pre>\n"
      ],
      "text/plain": [
       "\n"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━┓\n",
       "┃<span style=\"font-weight: bold\"> Metrics                        </span>┃<span style=\"font-weight: bold\"> Value                   </span>┃\n",
       "┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━┩\n",
       "│ Metrics/EpRet                  │ 5.625942230224609       │\n",
       "│ Metrics/EpCost                 │ 6.960921287536621       │\n",
       "│ Metrics/EpLen                  │ 5.0                     │\n",
       "│ Train/Epoch                    │ 0.0                     │\n",
       "│ Train/Entropy                  │ 1.4189385175704956      │\n",
       "│ Train/KL                       │ 0.00020748490351252258  │\n",
       "│ Train/StopIter                 │ 1.0                     │\n",
       "│ Train/PolicyRatio/Mean         │ 1.0                     │\n",
       "│ Train/PolicyRatio/Min          │ 1.0                     │\n",
       "│ Train/PolicyRatio/Max          │ 1.0                     │\n",
       "│ Train/PolicyRatio/Std          │ 0.0                     │\n",
       "│ Train/LR                       │ 0.00019999999494757503  │\n",
       "│ Train/PolicyStd                │ 1.0                     │\n",
       "│ TotalEnvSteps                  │ 10.0                    │\n",
       "│ Loss/Loss_pi                   │ -1.4901161193847656e-08 │\n",
       "│ Loss/Loss_pi/Delta             │ -1.4901161193847656e-08 │\n",
       "│ Value/Adv                      │ 1.4901161193847656e-08  │\n",
       "│ Loss/Loss_reward_critic        │ 10.458966255187988      │\n",
       "│ Loss/Loss_reward_critic/Delta  │ 10.458966255187988      │\n",
       "│ Value/reward                   │ -0.015489530749619007   │\n",
       "│ Loss/Loss_cost_critic          │ 19.141571044921875      │\n",
       "│ Loss/Loss_cost_critic/Delta    │ 19.141571044921875      │\n",
       "│ Value/cost                     │ 0.05426764488220215     │\n",
       "│ Time/Total                     │ 0.034796953201293945    │\n",
       "│ Time/Rollout                   │ 0.01762533187866211     │\n",
       "│ Time/Update                    │ 0.01616811752319336     │\n",
       "│ Time/Epoch                     │ 0.03383183479309082     │\n",
       "│ Time/FPS                       │ 295.5858459472656       │\n",
       "│ Metrics/LagrangeMultiplier/Mea │ 0.0                     │\n",
       "│ Metrics/LagrangeMultiplier/Min │ 0.0                     │\n",
       "│ Metrics/LagrangeMultiplier/Max │ 0.0                     │\n",
       "│ Metrics/LagrangeMultiplier/Std │ 0.0                     │\n",
       "└────────────────────────────────┴─────────────────────────┘\n",
       "</pre>\n"
      ],
      "text/plain": [
       "┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━┓\n",
       "┃\u001b[1m \u001b[0m\u001b[1mMetrics                       \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1mValue                  \u001b[0m\u001b[1m \u001b[0m┃\n",
       "┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━┩\n",
       "│ Metrics/EpRet                  │ 5.625942230224609       │\n",
       "│ Metrics/EpCost                 │ 6.960921287536621       │\n",
       "│ Metrics/EpLen                  │ 5.0                     │\n",
       "│ Train/Epoch                    │ 0.0                     │\n",
       "│ Train/Entropy                  │ 1.4189385175704956      │\n",
       "│ Train/KL                       │ 0.00020748490351252258  │\n",
       "│ Train/StopIter                 │ 1.0                     │\n",
       "│ Train/PolicyRatio/Mean         │ 1.0                     │\n",
       "│ Train/PolicyRatio/Min          │ 1.0                     │\n",
       "│ Train/PolicyRatio/Max          │ 1.0                     │\n",
       "│ Train/PolicyRatio/Std          │ 0.0                     │\n",
       "│ Train/LR                       │ 0.00019999999494757503  │\n",
       "│ Train/PolicyStd                │ 1.0                     │\n",
       "│ TotalEnvSteps                  │ 10.0                    │\n",
       "│ Loss/Loss_pi                   │ -1.4901161193847656e-08 │\n",
       "│ Loss/Loss_pi/Delta             │ -1.4901161193847656e-08 │\n",
       "│ Value/Adv                      │ 1.4901161193847656e-08  │\n",
       "│ Loss/Loss_reward_critic        │ 10.458966255187988      │\n",
       "│ Loss/Loss_reward_critic/Delta  │ 10.458966255187988      │\n",
       "│ Value/reward                   │ -0.015489530749619007   │\n",
       "│ Loss/Loss_cost_critic          │ 19.141571044921875      │\n",
       "│ Loss/Loss_cost_critic/Delta    │ 19.141571044921875      │\n",
       "│ Value/cost                     │ 0.05426764488220215     │\n",
       "│ Time/Total                     │ 0.034796953201293945    │\n",
       "│ Time/Rollout                   │ 0.01762533187866211     │\n",
       "│ Time/Update                    │ 0.01616811752319336     │\n",
       "│ Time/Epoch                     │ 0.03383183479309082     │\n",
       "│ Time/FPS                       │ 295.5858459472656       │\n",
       "│ Metrics/LagrangeMultiplier/Mea │ 0.0                     │\n",
       "│ Metrics/LagrangeMultiplier/Min │ 0.0                     │\n",
       "│ Metrics/LagrangeMultiplier/Max │ 0.0                     │\n",
       "│ Metrics/LagrangeMultiplier/Std │ 0.0                     │\n",
       "└────────────────────────────────┴─────────────────────────┘\n"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"color: #008000; text-decoration-color: #008000\">Warning: trajectory cut off when rollout by epoch at </span><span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">10.0</span><span style=\"color: #008000; text-decoration-color: #008000\"> steps.</span>\n",
       "</pre>\n"
      ],
      "text/plain": [
       "\u001b[32mWarning: trajectory cut off when rollout by epoch at \u001b[0m\u001b[1;36m10.0\u001b[0m\u001b[32m steps.\u001b[0m\n"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"></pre>\n"
      ],
      "text/plain": []
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">\n",
       "</pre>\n"
      ],
      "text/plain": [
       "\n"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"></pre>\n"
      ],
      "text/plain": []
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">\n",
       "</pre>\n"
      ],
      "text/plain": [
       "\n"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━┓\n",
       "┃<span style=\"font-weight: bold\"> Metrics                        </span>┃<span style=\"font-weight: bold\"> Value                  </span>┃\n",
       "┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━┩\n",
       "│ Metrics/EpRet                  │ 7.8531365394592285     │\n",
       "│ Metrics/EpCost                 │ 7.931504726409912      │\n",
       "│ Metrics/EpLen                  │ 6.666666507720947      │\n",
       "│ Train/Epoch                    │ 1.0                    │\n",
       "│ Train/Entropy                  │ 1.4192386865615845     │\n",
       "│ Train/KL                       │ 6.416345422621816e-05  │\n",
       "│ Train/StopIter                 │ 1.0                    │\n",
       "│ Train/PolicyRatio/Mean         │ 1.0                    │\n",
       "│ Train/PolicyRatio/Min          │ 1.0                    │\n",
       "│ Train/PolicyRatio/Max          │ 1.0                    │\n",
       "│ Train/PolicyRatio/Std          │ 0.0                    │\n",
       "│ Train/LR                       │ 9.999999747378752e-05  │\n",
       "│ Train/PolicyStd                │ 1.0003000497817993     │\n",
       "│ TotalEnvSteps                  │ 20.0                   │\n",
       "│ Loss/Loss_pi                   │ -6.258487417198921e-08 │\n",
       "│ Loss/Loss_pi/Delta             │ -4.768371297814156e-08 │\n",
       "│ Value/Adv                      │ 1.341104507446289e-07  │\n",
       "│ Loss/Loss_reward_critic        │ 38.05686950683594      │\n",
       "│ Loss/Loss_reward_critic/Delta  │ 27.59790325164795      │\n",
       "│ Value/reward                   │ -0.008213319815695286  │\n",
       "│ Loss/Loss_cost_critic          │ 23.737285614013672     │\n",
       "│ Loss/Loss_cost_critic/Delta    │ 4.595714569091797      │\n",
       "│ Value/cost                     │ 0.17113244533538818    │\n",
       "│ Time/Total                     │ 0.0776519775390625     │\n",
       "│ Time/Rollout                   │ 0.015673398971557617   │\n",
       "│ Time/Update                    │ 0.011301994323730469   │\n",
       "│ Time/Epoch                     │ 0.027007579803466797   │\n",
       "│ Time/FPS                       │ 370.27294921875        │\n",
       "│ Metrics/LagrangeMultiplier/Mea │ 0.0                    │\n",
       "│ Metrics/LagrangeMultiplier/Min │ 0.0                    │\n",
       "│ Metrics/LagrangeMultiplier/Max │ 0.0                    │\n",
       "│ Metrics/LagrangeMultiplier/Std │ 0.0                    │\n",
       "└────────────────────────────────┴────────────────────────┘\n",
       "</pre>\n"
      ],
      "text/plain": [
       "┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━┓\n",
       "┃\u001b[1m \u001b[0m\u001b[1mMetrics                       \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1mValue                 \u001b[0m\u001b[1m \u001b[0m┃\n",
       "┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━┩\n",
       "│ Metrics/EpRet                  │ 7.8531365394592285     │\n",
       "│ Metrics/EpCost                 │ 7.931504726409912      │\n",
       "│ Metrics/EpLen                  │ 6.666666507720947      │\n",
       "│ Train/Epoch                    │ 1.0                    │\n",
       "│ Train/Entropy                  │ 1.4192386865615845     │\n",
       "│ Train/KL                       │ 6.416345422621816e-05  │\n",
       "│ Train/StopIter                 │ 1.0                    │\n",
       "│ Train/PolicyRatio/Mean         │ 1.0                    │\n",
       "│ Train/PolicyRatio/Min          │ 1.0                    │\n",
       "│ Train/PolicyRatio/Max          │ 1.0                    │\n",
       "│ Train/PolicyRatio/Std          │ 0.0                    │\n",
       "│ Train/LR                       │ 9.999999747378752e-05  │\n",
       "│ Train/PolicyStd                │ 1.0003000497817993     │\n",
       "│ TotalEnvSteps                  │ 20.0                   │\n",
       "│ Loss/Loss_pi                   │ -6.258487417198921e-08 │\n",
       "│ Loss/Loss_pi/Delta             │ -4.768371297814156e-08 │\n",
       "│ Value/Adv                      │ 1.341104507446289e-07  │\n",
       "│ Loss/Loss_reward_critic        │ 38.05686950683594      │\n",
       "│ Loss/Loss_reward_critic/Delta  │ 27.59790325164795      │\n",
       "│ Value/reward                   │ -0.008213319815695286  │\n",
       "│ Loss/Loss_cost_critic          │ 23.737285614013672     │\n",
       "│ Loss/Loss_cost_critic/Delta    │ 4.595714569091797      │\n",
       "│ Value/cost                     │ 0.17113244533538818    │\n",
       "│ Time/Total                     │ 0.0776519775390625     │\n",
       "│ Time/Rollout                   │ 0.015673398971557617   │\n",
       "│ Time/Update                    │ 0.011301994323730469   │\n",
       "│ Time/Epoch                     │ 0.027007579803466797   │\n",
       "│ Time/FPS                       │ 370.27294921875        │\n",
       "│ Metrics/LagrangeMultiplier/Mea │ 0.0                    │\n",
       "│ Metrics/LagrangeMultiplier/Min │ 0.0                    │\n",
       "│ Metrics/LagrangeMultiplier/Max │ 0.0                    │\n",
       "│ Metrics/LagrangeMultiplier/Std │ 0.0                    │\n",
       "└────────────────────────────────┴────────────────────────┘\n"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"color: #008000; text-decoration-color: #008000\">Warning: trajectory cut off when rollout by epoch at </span><span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">9.0</span><span style=\"color: #008000; text-decoration-color: #008000\"> steps.</span>\n",
       "</pre>\n"
      ],
      "text/plain": [
       "\u001b[32mWarning: trajectory cut off when rollout by epoch at \u001b[0m\u001b[1;36m9.0\u001b[0m\u001b[32m steps.\u001b[0m\n"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"></pre>\n"
      ],
      "text/plain": []
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">\n",
       "</pre>\n"
      ],
      "text/plain": [
       "\n"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"></pre>\n"
      ],
      "text/plain": []
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">\n",
       "</pre>\n"
      ],
      "text/plain": [
       "\n"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━┓\n",
       "┃<span style=\"font-weight: bold\"> Metrics                        </span>┃<span style=\"font-weight: bold\"> Value                   </span>┃\n",
       "┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━┩\n",
       "│ Metrics/EpRet                  │ 6.297085762023926       │\n",
       "│ Metrics/EpCost                 │ 6.2187700271606445      │\n",
       "│ Metrics/EpLen                  │ 5.25                    │\n",
       "│ Train/Epoch                    │ 2.0                     │\n",
       "│ Train/Entropy                  │ 1.419387698173523       │\n",
       "│ Train/KL                       │ 5.490810053743189e-06   │\n",
       "│ Train/StopIter                 │ 1.0                     │\n",
       "│ Train/PolicyRatio/Mean         │ 1.0                     │\n",
       "│ Train/PolicyRatio/Min          │ 1.0                     │\n",
       "│ Train/PolicyRatio/Max          │ 1.0                     │\n",
       "│ Train/PolicyRatio/Std          │ 0.0                     │\n",
       "│ Train/LR                       │ 0.0                     │\n",
       "│ Train/PolicyStd                │ 1.0004494190216064      │\n",
       "│ TotalEnvSteps                  │ 30.0                    │\n",
       "│ Loss/Loss_pi                   │ -1.9073486612342094e-07 │\n",
       "│ Loss/Loss_pi/Delta             │ -1.2814999195143173e-07 │\n",
       "│ Value/Adv                      │ 1.0728835775353218e-07  │\n",
       "│ Loss/Loss_reward_critic        │ 34.77037811279297       │\n",
       "│ Loss/Loss_reward_critic/Delta  │ -3.2864913940429688     │\n",
       "│ Value/reward                   │ 0.014150517992675304    │\n",
       "│ Loss/Loss_cost_critic          │ 27.43436050415039       │\n",
       "│ Loss/Loss_cost_critic/Delta    │ 3.6970748901367188      │\n",
       "│ Value/cost                     │ 0.24021005630493164     │\n",
       "│ Time/Total                     │ 0.12173724174499512     │\n",
       "│ Time/Rollout                   │ 0.01879405975341797     │\n",
       "│ Time/Update                    │ 0.011112689971923828    │\n",
       "│ Time/Epoch                     │ 0.0299375057220459      │\n",
       "│ Time/FPS                       │ 334.039794921875        │\n",
       "│ Metrics/LagrangeMultiplier/Mea │ 0.0                     │\n",
       "│ Metrics/LagrangeMultiplier/Min │ 0.0                     │\n",
       "│ Metrics/LagrangeMultiplier/Max │ 0.0                     │\n",
       "│ Metrics/LagrangeMultiplier/Std │ 0.0                     │\n",
       "└────────────────────────────────┴─────────────────────────┘\n",
       "</pre>\n"
      ],
      "text/plain": [
       "┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━┓\n",
       "┃\u001b[1m \u001b[0m\u001b[1mMetrics                       \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1mValue                  \u001b[0m\u001b[1m \u001b[0m┃\n",
       "┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━┩\n",
       "│ Metrics/EpRet                  │ 6.297085762023926       │\n",
       "│ Metrics/EpCost                 │ 6.2187700271606445      │\n",
       "│ Metrics/EpLen                  │ 5.25                    │\n",
       "│ Train/Epoch                    │ 2.0                     │\n",
       "│ Train/Entropy                  │ 1.419387698173523       │\n",
       "│ Train/KL                       │ 5.490810053743189e-06   │\n",
       "│ Train/StopIter                 │ 1.0                     │\n",
       "│ Train/PolicyRatio/Mean         │ 1.0                     │\n",
       "│ Train/PolicyRatio/Min          │ 1.0                     │\n",
       "│ Train/PolicyRatio/Max          │ 1.0                     │\n",
       "│ Train/PolicyRatio/Std          │ 0.0                     │\n",
       "│ Train/LR                       │ 0.0                     │\n",
       "│ Train/PolicyStd                │ 1.0004494190216064      │\n",
       "│ TotalEnvSteps                  │ 30.0                    │\n",
       "│ Loss/Loss_pi                   │ -1.9073486612342094e-07 │\n",
       "│ Loss/Loss_pi/Delta             │ -1.2814999195143173e-07 │\n",
       "│ Value/Adv                      │ 1.0728835775353218e-07  │\n",
       "│ Loss/Loss_reward_critic        │ 34.77037811279297       │\n",
       "│ Loss/Loss_reward_critic/Delta  │ -3.2864913940429688     │\n",
       "│ Value/reward                   │ 0.014150517992675304    │\n",
       "│ Loss/Loss_cost_critic          │ 27.43436050415039       │\n",
       "│ Loss/Loss_cost_critic/Delta    │ 3.6970748901367188      │\n",
       "│ Value/cost                     │ 0.24021005630493164     │\n",
       "│ Time/Total                     │ 0.12173724174499512     │\n",
       "│ Time/Rollout                   │ 0.01879405975341797     │\n",
       "│ Time/Update                    │ 0.011112689971923828    │\n",
       "│ Time/Epoch                     │ 0.0299375057220459      │\n",
       "│ Time/FPS                       │ 334.039794921875        │\n",
       "│ Metrics/LagrangeMultiplier/Mea │ 0.0                     │\n",
       "│ Metrics/LagrangeMultiplier/Min │ 0.0                     │\n",
       "│ Metrics/LagrangeMultiplier/Max │ 0.0                     │\n",
       "│ Metrics/LagrangeMultiplier/Std │ 0.0                     │\n",
       "└────────────────────────────────┴─────────────────────────┘\n"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/plain": [
       "(6.297085762023926, 6.2187700271606445, 5.25)"
      ]
     },
     "execution_count": 12,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "custom_cfgs = {\n",
    "    'train_cfgs': {\n",
    "        'total_steps': 30,\n",
    "    },\n",
    "    'algo_cfgs': {\n",
    "        'steps_per_epoch': 10,\n",
    "        'update_iters': 1,\n",
    "    },\n",
    "}\n",
    "agent = omnisafe.Agent('PPOLag', 'Example-v0', custom_cfgs=custom_cfgs)\n",
    "agent.learn()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "干得不错！我们已经完成了这个定制化环境的嵌入和训练。接下来，我们将进一步研究如何为环境指定超参数。\n",
    "\n",
    "### 参数设定\n",
    "\n",
    "我们从一个新的示例环境出发，假设这个环境需要传入一个名为 `num_agents` 的参数。我们将展示如何不修改OmniSafe的代码来完成参数设定。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "NewExampleEnv has not been registered yet\n"
     ]
    }
   ],
   "source": [
    "@env_register\n",
    "@env_unregister\n",
    "class NewExampleEnv(ExampleEnv):  # 创造一个新环境\n",
    "    _support_envs: ClassVar[list[str]] = ['NewExample-v0', 'NewExample-v1']\n",
    "    num_agents: ClassVar[int] = 1\n",
    "\n",
    "    def __init__(self, env_id: str, **kwargs) -> None:\n",
    "        super(NewExampleEnv, self).__init__(env_id, **kwargs)\n",
    "        self.num_agents = kwargs.get('num_agents', 1)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "此时，`num_agents` 参数为预设值：`1`。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "1"
      ]
     },
     "execution_count": 14,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "new_env = NewExampleEnv('NewExample-v0')\n",
    "new_env.num_agents"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "下面我们将展示如何通过 OmniSafe 的接口对该参数进行修改并训练："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Loading PPOLag.yaml from /home/safepo/dev-env/omnisafe_zjy/omnisafe/utils/../configs/on-policy/PPOLag.yaml\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">Logging data to .</span><span style=\"color: #800080; text-decoration-color: #800080; font-weight: bold\">/runs/</span><span style=\"color: #ff00ff; text-decoration-color: #ff00ff; font-weight: bold\">PPOLag-</span><span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">{NewExample-v0}</span><span style=\"color: #800080; text-decoration-color: #800080; font-weight: bold\">/seed-000-2024-04-09-15-05-09/</span><span style=\"color: #ff00ff; text-decoration-color: #ff00ff; font-weight: bold\">progress.csv</span>\n",
       "</pre>\n"
      ],
      "text/plain": [
       "\u001b[1;36mLogging data to .\u001b[0m\u001b[1;35m/runs/\u001b[0m\u001b[1;95mPPOLag-\u001b[0m\u001b[1;36m{\u001b[0m\u001b[1;36mNewExample-v0\u001b[0m\u001b[1;36m}\u001b[0m\u001b[1;35m/seed-000-2024-04-09-15-05-09/\u001b[0m\u001b[1;95mprogress.csv\u001b[0m\n"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"color: #808000; text-decoration-color: #808000; font-weight: bold\">Save with config in config.json</span>\n",
       "</pre>\n"
      ],
      "text/plain": [
       "\u001b[1;33mSave with config in config.json\u001b[0m\n"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/plain": [
       "2"
      ]
     },
     "execution_count": 15,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "custom_cfgs.update({'env_cfgs': {'num_agents': 2}})\n",
    "agent = omnisafe.Agent('PPOLag', 'NewExample-v0', custom_cfgs=custom_cfgs)\n",
    "agent.agent._env._env.num_agents"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "非常好！我们将 `num_agents` 设置为了2。这表示我们在未修改代码的情形下成功实现了超参数设定。\n",
    "\n",
    "### 训练信息记录\n",
    "\n",
    "在运行训练代码时，您可能已经发现 OmniSafe 通过 `Logger` 记录了训练信息，例如：\n",
    "\n",
    "```bash\n",
    "┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━┓\n",
    "┃ Metrics                        ┃ Value                   ┃\n",
    "┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━┩\n",
    "│ Metrics/EpRet                  │ 2.046875                │\n",
    "│ Metrics/EpCost                 │ 2.89453125              │\n",
    "│ Metrics/EpLen                  │ 3.25                    │\n",
    "│ Train/Epoch                    │ 3.0                     │\n",
    "...\n",
    "```\n",
    "那么我们可否将环境之中的信息输出到日志中呢？答案是肯定的，而且这个过程同样不需要修改OmniSafe的代码。只需要实现两个标准接口：\n",
    "1. 在 `__init__` 函数中，将需要输出的信息添加到`self.env_spec_log`中。\n",
    "2. 实例化 `spec_log` 函数，记录所需的信息。\n",
    "\n",
    "**请注意：** 目前OmniSafe仅支持在每一个epoch结束时记录这些信息，而不支持在每一个step结束时记录。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [],
   "source": [
    "@env_register\n",
    "@env_unregister\n",
    "class NewExampleEnv(ExampleEnv):\n",
    "    _support_envs: ClassVar[list[str]] = ['NewExample-v0', 'NewExample-v1']\n",
    "\n",
    "    # 定义需要记录的信息\n",
    "    def __init__(self, env_id: str, **kwargs) -> None:\n",
    "        super(NewExampleEnv, self).__init__(env_id, **kwargs)\n",
    "        self.env_spec_log = {'Env/Success_counts': 0}\n",
    "\n",
    "    # 通过step函数，与环境进行交互\n",
    "    def step(self, action):\n",
    "        obs, reward, cost, terminated, truncated, info = super().step(action)\n",
    "        success = int(reward > cost)\n",
    "        self.env_spec_log['Env/Success_counts'] += success\n",
    "        return obs, reward, cost, terminated, truncated, info\n",
    "\n",
    "    # 在logger中记录环境信息\n",
    "    def spec_log(self, logger) -> dict[str, Any]:\n",
    "        logger.store({'Env/Success_counts': self.env_spec_log['Env/Success_counts']})\n",
    "        self.env_spec_log['Env/Success_counts'] = 0"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "接下来，我们简单训练观察该信息是否被成功记录。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Loading PPOLag.yaml from /home/safepo/dev-env/omnisafe_zjy/omnisafe/utils/../configs/on-policy/PPOLag.yaml\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">Logging data to .</span><span style=\"color: #800080; text-decoration-color: #800080; font-weight: bold\">/runs/</span><span style=\"color: #ff00ff; text-decoration-color: #ff00ff; font-weight: bold\">PPOLag-</span><span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">{NewExample-v0}</span><span style=\"color: #800080; text-decoration-color: #800080; font-weight: bold\">/seed-000-2024-04-09-15-05-14/</span><span style=\"color: #ff00ff; text-decoration-color: #ff00ff; font-weight: bold\">progress.csv</span>\n",
       "</pre>\n"
      ],
      "text/plain": [
       "\u001b[1;36mLogging data to .\u001b[0m\u001b[1;35m/runs/\u001b[0m\u001b[1;95mPPOLag-\u001b[0m\u001b[1;36m{\u001b[0m\u001b[1;36mNewExample-v0\u001b[0m\u001b[1;36m}\u001b[0m\u001b[1;35m/seed-000-2024-04-09-15-05-14/\u001b[0m\u001b[1;95mprogress.csv\u001b[0m\n"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"color: #808000; text-decoration-color: #808000; font-weight: bold\">Save with config in config.json</span>\n",
       "</pre>\n"
      ],
      "text/plain": [
       "\u001b[1;33mSave with config in config.json\u001b[0m\n"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"color: #008000; text-decoration-color: #008000\">INFO: Start training</span>\n",
       "</pre>\n"
      ],
      "text/plain": [
       "\u001b[32mINFO: Start training\u001b[0m\n"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"></pre>\n"
      ],
      "text/plain": []
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">\n",
       "</pre>\n"
      ],
      "text/plain": [
       "\n"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"></pre>\n"
      ],
      "text/plain": []
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">\n",
       "</pre>\n"
      ],
      "text/plain": [
       "\n"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━┓\n",
       "┃<span style=\"font-weight: bold\"> Metrics                        </span>┃<span style=\"font-weight: bold\"> Value                  </span>┃\n",
       "┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━┩\n",
       "│ Metrics/EpRet                  │ 5.625942230224609      │\n",
       "│ Metrics/EpCost                 │ 6.960921287536621      │\n",
       "│ Metrics/EpLen                  │ 5.0                    │\n",
       "│ Train/Epoch                    │ 0.0                    │\n",
       "│ Train/Entropy                  │ 1.4189385175704956     │\n",
       "│ Train/KL                       │ 0.00024281258811242878 │\n",
       "│ Train/StopIter                 │ 1.0                    │\n",
       "│ Train/PolicyRatio/Mean         │ 1.0                    │\n",
       "│ Train/PolicyRatio/Min          │ 1.0                    │\n",
       "│ Train/PolicyRatio/Max          │ 1.0                    │\n",
       "│ Train/PolicyRatio/Std          │ 0.0                    │\n",
       "│ Train/LR                       │ 0.0                    │\n",
       "│ Train/PolicyStd                │ 1.0                    │\n",
       "│ TotalEnvSteps                  │ 10.0                   │\n",
       "│ Loss/Loss_pi                   │ -5.662441182607836e-08 │\n",
       "│ Loss/Loss_pi/Delta             │ -5.662441182607836e-08 │\n",
       "│ Value/Adv                      │ 1.2814999195143173e-07 │\n",
       "│ Loss/Loss_reward_critic        │ 10.477845191955566     │\n",
       "│ Loss/Loss_reward_critic/Delta  │ 10.477845191955566     │\n",
       "│ Value/reward                   │ -0.0091781010851264    │\n",
       "│ Loss/Loss_cost_critic          │ 18.525999069213867     │\n",
       "│ Loss/Loss_cost_critic/Delta    │ 18.525999069213867     │\n",
       "│ Value/cost                     │ 0.14141643047332764    │\n",
       "│ Time/Total                     │ 0.030597209930419922   │\n",
       "│ Time/Rollout                   │ 0.017596960067749023   │\n",
       "│ Time/Update                    │ 0.012219905853271484   │\n",
       "│ Time/Epoch                     │ 0.02985072135925293    │\n",
       "│ Time/FPS                       │ 335.00830078125        │\n",
       "│ Env/Success_counts             │ 1.5                    │\n",
       "│ Metrics/LagrangeMultiplier/Mea │ 0.0                    │\n",
       "│ Metrics/LagrangeMultiplier/Min │ 0.0                    │\n",
       "│ Metrics/LagrangeMultiplier/Max │ 0.0                    │\n",
       "│ Metrics/LagrangeMultiplier/Std │ 0.0                    │\n",
       "└────────────────────────────────┴────────────────────────┘\n",
       "</pre>\n"
      ],
      "text/plain": [
       "┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━┓\n",
       "┃\u001b[1m \u001b[0m\u001b[1mMetrics                       \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1mValue                 \u001b[0m\u001b[1m \u001b[0m┃\n",
       "┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━┩\n",
       "│ Metrics/EpRet                  │ 5.625942230224609      │\n",
       "│ Metrics/EpCost                 │ 6.960921287536621      │\n",
       "│ Metrics/EpLen                  │ 5.0                    │\n",
       "│ Train/Epoch                    │ 0.0                    │\n",
       "│ Train/Entropy                  │ 1.4189385175704956     │\n",
       "│ Train/KL                       │ 0.00024281258811242878 │\n",
       "│ Train/StopIter                 │ 1.0                    │\n",
       "│ Train/PolicyRatio/Mean         │ 1.0                    │\n",
       "│ Train/PolicyRatio/Min          │ 1.0                    │\n",
       "│ Train/PolicyRatio/Max          │ 1.0                    │\n",
       "│ Train/PolicyRatio/Std          │ 0.0                    │\n",
       "│ Train/LR                       │ 0.0                    │\n",
       "│ Train/PolicyStd                │ 1.0                    │\n",
       "│ TotalEnvSteps                  │ 10.0                   │\n",
       "│ Loss/Loss_pi                   │ -5.662441182607836e-08 │\n",
       "│ Loss/Loss_pi/Delta             │ -5.662441182607836e-08 │\n",
       "│ Value/Adv                      │ 1.2814999195143173e-07 │\n",
       "│ Loss/Loss_reward_critic        │ 10.477845191955566     │\n",
       "│ Loss/Loss_reward_critic/Delta  │ 10.477845191955566     │\n",
       "│ Value/reward                   │ -0.0091781010851264    │\n",
       "│ Loss/Loss_cost_critic          │ 18.525999069213867     │\n",
       "│ Loss/Loss_cost_critic/Delta    │ 18.525999069213867     │\n",
       "│ Value/cost                     │ 0.14141643047332764    │\n",
       "│ Time/Total                     │ 0.030597209930419922   │\n",
       "│ Time/Rollout                   │ 0.017596960067749023   │\n",
       "│ Time/Update                    │ 0.012219905853271484   │\n",
       "│ Time/Epoch                     │ 0.02985072135925293    │\n",
       "│ Time/FPS                       │ 335.00830078125        │\n",
       "│ Env/Success_counts             │ 1.5                    │\n",
       "│ Metrics/LagrangeMultiplier/Mea │ 0.0                    │\n",
       "│ Metrics/LagrangeMultiplier/Min │ 0.0                    │\n",
       "│ Metrics/LagrangeMultiplier/Max │ 0.0                    │\n",
       "│ Metrics/LagrangeMultiplier/Std │ 0.0                    │\n",
       "└────────────────────────────────┴────────────────────────┘\n"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/plain": [
       "(5.625942230224609, 6.960921287536621, 5.0)"
      ]
     },
     "execution_count": 17,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "custom_cfgs.update({'train_cfgs': {'total_steps': 10}})\n",
    "agent = omnisafe.Agent('PPOLag', 'NewExample-v0', custom_cfgs=custom_cfgs)\n",
    "agent.learn()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "漂亮！上述代码将在终端输出了环境特化的信息 `Env/Success_counts`。这一过程并不需要对原代码作出改动。\n",
    "\n",
    "## 总结\n",
    "OmniSafe旨在成为安全强化学习的基础软件。我们将持续完善OmniSafe的环境接口标准，使OmniSafe能够适应各种安全强化学习任务，赋能多元安全场景。"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "omnisafe",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.19"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
