{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# OmniSafe Tutorial - Environment Customization from Community\n",
    "\n",
    "OmniSafe: https://github.com/PKU-Alignment/omnisafe\n",
    "\n",
    "Documentation: https://omnisafe.readthedocs.io/en/latest/\n",
    "\n",
    "Gymnasium: https://github.com/Farama-Foundation/Gymnasium\n",
    "\n",
    "[Gymnasium](https://github.com/Farama-Foundation/Gymnasium) is an open source Python library for developing and comparing reinforcement learning algorithms by providing a standard API to communicate between learning algorithms and environments, as well as a standard set of environments compliant with that API.\n",
    "\n",
    "## Introduction\n",
    "\n",
    "In this section, we will introduce how to embed an existing environment from the community into OmniSafe. The series of tasks provided by [Gymnasium](https://github.com/Farama-Foundation/Gymnasium) have been widely applied in reinforcement learning. Specifically, this section will use [Pendulum-v1](https://gymnasium.farama.org/environments/classic_control/pendulum/) as an example to show how to embed Gymnasium's tasks into OmniSafe.\n",
    "\n",
    "## Quick Installation"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Install via pip (ignore it if you have already installed).\n",
    "!pip install omnisafe"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Install from source (ignore it if you have already installed).\n",
    "## clone the repo\n",
    "!git clone https://github.com/PKU-Alignment/omnisafe\n",
    "%cd omnisafe\n",
    "\n",
    "## install it\n",
    "!pip install -e ."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Gymnasium Task Embedding\n",
    "The core part for environment embedding is to provide sufficient static or dynamic information for SafeRL agent interaction and training. This section will detail the variables that must be defined for embedding environments and the corresponding standards. We will first present the entire embedding process in the order of code organization, giving a preliminary understanding. Then, we will review all the codes, summarize, and organize the adaptations you need to make when customizing your environment.\n",
    "\n",
    "### Quick Start\n",
    "First, import all external variables required for this tutorial."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "# import all we need\n",
    "from __future__ import annotations\n",
    "\n",
    "from typing import Any, ClassVar\n",
    "import gymnasium\n",
    "import torch\n",
    "import numpy as np\n",
    "import omnisafe\n",
    "\n",
    "from omnisafe.envs.core import CMDP, env_register, env_unregister\n",
    "from omnisafe.typing import DEVICE_CPU"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Next, create a class named `ExampleMuJoCoEnv`, which needs to inherit from `CMDP`. (This is because we want to transform the environment's interaction form into the CMDP paradigm. You can define new abstract classes as needed to implement new paradigms)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "class ExampleMuJoCoEnv(CMDP):\n",
    "    _support_envs: ClassVar[list[str]] = ['Pendulum-v1']  # Supported task names\n",
    "\n",
    "    need_auto_reset_wrapper = True  # Whether `AutoReset` Wrapper is needed\n",
    "    need_time_limit_wrapper = True  # Whether `TimeLimit` Wrapper is needed\n",
    "\n",
    "    def __init__(\n",
    "        self,\n",
    "        env_id: str,\n",
    "        num_envs: int = 1,\n",
    "        device: torch.device = DEVICE_CPU,\n",
    "        **kwargs: Any,\n",
    "    ) -> None:\n",
    "        super().__init__(env_id)\n",
    "        self._num_envs = num_envs\n",
    "        # Instantiate the environment object\n",
    "        self._env = gymnasium.make(id=env_id, autoreset=True, **kwargs)\n",
    "        # Specify the action space for initialization by the algorithm layer\n",
    "        self._action_space = self._env.action_space\n",
    "        # Specify the observation space for initialization by the algorithm layer\n",
    "        self._observation_space = self._env.observation_space\n",
    "        # Optional, for GPU acceleration. Default is CPU\n",
    "        self._device = device  # 可选项，使用GPU加速。默认为CPU\n",
    "\n",
    "    def reset(\n",
    "        self,\n",
    "        seed: int | None = None,\n",
    "        options: dict[str, Any] | None = None,\n",
    "    ) -> tuple[torch.Tensor, dict[str, Any]]:\n",
    "        # Reset the environment\n",
    "        obs, info = self._env.reset(seed=seed, options=options)\n",
    "        # Convert the reset observations to a torch tensor.\n",
    "        return (\n",
    "            torch.as_tensor(obs, dtype=torch.float32, device=self._device),\n",
    "            info,\n",
    "        )\n",
    "\n",
    "    @property\n",
    "    def max_episode_steps(self) -> int | None:\n",
    "        # Return the maximum number of interaction steps per episode in the environment\n",
    "        return self._env.env.spec.max_episode_steps\n",
    "\n",
    "    def set_seed(self, seed: int) -> None:\n",
    "        # Set the environment's random seed for reproducibility\n",
    "        self.reset(seed=seed)  # 设定环境的随机种子以实现可复现性\n",
    "\n",
    "    def render(self) -> Any:\n",
    "        # Return the image rendered by the environment\n",
    "        return self._env.render()\n",
    "\n",
    "    def close(self) -> None:\n",
    "        # Release the environment instance after training ends\n",
    "        self._env.close()\n",
    "\n",
    "    def step(\n",
    "        self,\n",
    "        action: torch.Tensor,\n",
    "    ) -> tuple[\n",
    "        torch.Tensor,\n",
    "        torch.Tensor,\n",
    "        torch.Tensor,\n",
    "        torch.Tensor,\n",
    "        torch.Tensor,\n",
    "        dict[str, Any],\n",
    "    ]:\n",
    "        # Read the dynamic information after interacting with the environment\n",
    "        obs, reward, terminated, truncated, info = self._env.step(\n",
    "            action.detach().cpu().numpy(),\n",
    "        )\n",
    "        # Gymnasium does not explicitly include safety constraints; this is just a placeholder.\n",
    "        cost = np.zeros_like(reward)\n",
    "        # Convert dynamic information into torch tensor.\n",
    "        obs, reward, cost, terminated, truncated = (\n",
    "            torch.as_tensor(x, dtype=torch.float32, device=self._device)\n",
    "            for x in (obs, reward, cost, terminated, truncated)\n",
    "        )\n",
    "        if 'final_observation' in info:\n",
    "            info['final_observation'] = np.array(\n",
    "                [\n",
    "                    array if array is not None else np.zeros(obs.shape[-1])\n",
    "                    for array in info['final_observation']\n",
    "                ],\n",
    "            )\n",
    "            # Convert the last observation recorded in info into a torch tensor.\n",
    "            info['final_observation'] = torch.as_tensor(\n",
    "                info['final_observation'],\n",
    "                dtype=torch.float32,\n",
    "                device=self._device,\n",
    "            )\n",
    "\n",
    "        return obs, reward, cost, terminated, truncated, info"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Regarding the specific meaning of the above code, we have provided detailed annotation explanations. For more detailed explanations, please refer to [Tutorial 3: Environment Customization from Scratch](./3.Environment%20Customization%20from%20Scratch.ipynb). We summarize the key points as follows:\n",
    "\n",
    "- **Static variables needed for OmniSafe initialization**\n",
    "\n",
    "| Static Information | Required | Definition | Type | Example |\n",
    "|:---:|:---:|:---:|:---:|:---:|\n",
    "| `need_auto_reset_wrapper` | Yes | Whether an `AutoReset` Wrapper is needed | `bool` variable | `True` |\n",
    "| `need_time_limit_wrapper` | Yes | Whether a `TimeLimit` Wrapper is needed | `bool` variable | `True` |\n",
    "| `_action_space` | Yes | Action space | `gymnasium.space.Box` | `Box(low=-1.0, high=1.0, shape=(2,)` |\n",
    "| `_observation_space` | Yes | Observation space | `gymnasium.space.Box` | `Box(low=-1.0, high=1.0, shape=(3,)` |\n",
    "| `max_episode_steps` | Yes | The maximum number of interaction steps per episode in the environment | Function with `@property` decorator, returning a variable of type `int` or `None` | Refer to the code block above |\n",
    "| `_num_envs` | No | Number of parallel environments | `int` variable | 5 |\n",
    "| `_device` | No | Torch computing device | `torch.device` variable | `DEVICE_CPU` |\n",
    "\n",
    "- **Dynamic variables required by the environment for OmniSafe**\n",
    "\n",
    "OmniSafe's agents mainly interact dynamically with the environment through the `reset` and `step` functions. You need to ensure that the return type, number, and order of your customized environment match the examples above, more specifically:\n",
    "\n",
    "| Dynamic Information | Type | Number | Order |\n",
    "|:---:|:---:|:---:|:---:|\n",
    "| `step` | `tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, dict[str, Any]]` | 6 | `obs`, `reward`, `cost`, `terminated`, `truncated`, `info` |\n",
    "| `reset` | `tuple[torch.Tensor, dict[str, Any]]` | 2 | `obs`, `info` |\n",
    "\n",
    "- **Precautions**\n",
    "\n",
    "1. Although `_num_envs` and `_device` are not mandatory, please retain the input interface for these two parameters in the `__init__` function.\n",
    "2. `_num_envs` is an advanced parameter for instantiating multiple environments for parallel sampling, representing the number of environments instantiated. If your customized environment also supports specifying the parallel number, please specify it through `_num_envs` instead of defining a new interface."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Subsequently, by registering the above environment into OmniSafe with the registration decorator `@env_register`, you can complete the training."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "ExampleMuJoCoEnv has not been registered yet\n",
      "Loading PPOLag.yaml from /home/safepo/dev-env/omnisafe_zjy/omnisafe/utils/../configs/on-policy/PPOLag.yaml\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">Logging data to .</span><span style=\"color: #800080; text-decoration-color: #800080; font-weight: bold\">/runs/</span><span style=\"color: #ff00ff; text-decoration-color: #ff00ff; font-weight: bold\">PPOLag-</span><span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">{Pendulum-v1}</span><span style=\"color: #800080; text-decoration-color: #800080; font-weight: bold\">/seed-000-2024-04-09-15-09-14/</span><span style=\"color: #ff00ff; text-decoration-color: #ff00ff; font-weight: bold\">progress.csv</span>\n",
       "</pre>\n"
      ],
      "text/plain": [
       "\u001b[1;36mLogging data to .\u001b[0m\u001b[1;35m/runs/\u001b[0m\u001b[1;95mPPOLag-\u001b[0m\u001b[1;36m{\u001b[0m\u001b[1;36mPendulum-v1\u001b[0m\u001b[1;36m}\u001b[0m\u001b[1;35m/seed-000-2024-04-09-15-09-14/\u001b[0m\u001b[1;95mprogress.csv\u001b[0m\n"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"color: #808000; text-decoration-color: #808000; font-weight: bold\">Save with config in config.json</span>\n",
       "</pre>\n"
      ],
      "text/plain": [
       "\u001b[1;33mSave with config in config.json\u001b[0m\n"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"color: #008000; text-decoration-color: #008000\">INFO: Start training</span>\n",
       "</pre>\n"
      ],
      "text/plain": [
       "\u001b[32mINFO: Start training\u001b[0m\n"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">/home/safepo/anaconda3/envs/dev-env/lib/python3.8/site-packages/rich/live.py:231: UserWarning: install \"ipywidgets\"\n",
       "for Jupyter support\n",
       "  warnings.warn('install \"ipywidgets\" for Jupyter support')\n",
       "</pre>\n"
      ],
      "text/plain": [
       "/home/safepo/anaconda3/envs/dev-env/lib/python3.8/site-packages/rich/live.py:231: UserWarning: install \"ipywidgets\"\n",
       "for Jupyter support\n",
       "  warnings.warn('install \"ipywidgets\" for Jupyter support')\n"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"color: #008000; text-decoration-color: #008000\">Warning: trajectory cut off when rollout by epoch at </span><span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">200.0</span><span style=\"color: #008000; text-decoration-color: #008000\"> steps.</span>\n",
       "</pre>\n"
      ],
      "text/plain": [
       "\u001b[32mWarning: trajectory cut off when rollout by epoch at \u001b[0m\u001b[1;36m200.0\u001b[0m\u001b[32m steps.\u001b[0m\n"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"></pre>\n"
      ],
      "text/plain": []
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">\n",
       "</pre>\n"
      ],
      "text/plain": [
       "\n"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"></pre>\n"
      ],
      "text/plain": []
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">\n",
       "</pre>\n"
      ],
      "text/plain": [
       "\n"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━┓\n",
       "┃<span style=\"font-weight: bold\"> Metrics                        </span>┃<span style=\"font-weight: bold\"> Value                 </span>┃\n",
       "┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━┩\n",
       "│ Metrics/EpRet                  │ -1616.242431640625    │\n",
       "│ Metrics/EpCost                 │ 0.0                   │\n",
       "│ Metrics/EpLen                  │ 200.0                 │\n",
       "│ Train/Epoch                    │ 0.0                   │\n",
       "│ Train/Entropy                  │ 1.4185898303985596    │\n",
       "│ Train/KL                       │ 0.0007516025798395276 │\n",
       "│ Train/StopIter                 │ 1.0                   │\n",
       "│ Train/PolicyRatio/Mean         │ 0.9966228604316711    │\n",
       "│ Train/PolicyRatio/Min          │ 0.9966228604316711    │\n",
       "│ Train/PolicyRatio/Max          │ 0.9966228604316711    │\n",
       "│ Train/PolicyRatio/Std          │ 0.0075334208086133    │\n",
       "│ Train/LR                       │ 0.0                   │\n",
       "│ Train/PolicyStd                │ 0.9996514320373535    │\n",
       "│ TotalEnvSteps                  │ 200.0                 │\n",
       "│ Loss/Loss_pi                   │ 0.08751548826694489   │\n",
       "│ Loss/Loss_pi/Delta             │ 0.08751548826694489   │\n",
       "│ Value/Adv                      │ -0.398242324590683    │\n",
       "│ Loss/Loss_reward_critic        │ 16605.1796875         │\n",
       "│ Loss/Loss_reward_critic/Delta  │ 16605.1796875         │\n",
       "│ Value/reward                   │ 0.0049050007946789265 │\n",
       "│ Loss/Loss_cost_critic          │ 0.052194785326719284  │\n",
       "│ Loss/Loss_cost_critic/Delta    │ 0.052194785326719284  │\n",
       "│ Value/cost                     │ 0.07966174930334091   │\n",
       "│ Time/Total                     │ 0.21084904670715332   │\n",
       "│ Time/Rollout                   │ 0.17566156387329102   │\n",
       "│ Time/Update                    │ 0.03439140319824219   │\n",
       "│ Time/Epoch                     │ 0.21008920669555664   │\n",
       "│ Time/FPS                       │ 951.9786987304688     │\n",
       "│ Metrics/LagrangeMultiplier/Mea │ 0.0                   │\n",
       "│ Metrics/LagrangeMultiplier/Min │ 0.0                   │\n",
       "│ Metrics/LagrangeMultiplier/Max │ 0.0                   │\n",
       "│ Metrics/LagrangeMultiplier/Std │ 0.0                   │\n",
       "└────────────────────────────────┴───────────────────────┘\n",
       "</pre>\n"
      ],
      "text/plain": [
       "┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━┓\n",
       "┃\u001b[1m \u001b[0m\u001b[1mMetrics                       \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1mValue                \u001b[0m\u001b[1m \u001b[0m┃\n",
       "┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━┩\n",
       "│ Metrics/EpRet                  │ -1616.242431640625    │\n",
       "│ Metrics/EpCost                 │ 0.0                   │\n",
       "│ Metrics/EpLen                  │ 200.0                 │\n",
       "│ Train/Epoch                    │ 0.0                   │\n",
       "│ Train/Entropy                  │ 1.4185898303985596    │\n",
       "│ Train/KL                       │ 0.0007516025798395276 │\n",
       "│ Train/StopIter                 │ 1.0                   │\n",
       "│ Train/PolicyRatio/Mean         │ 0.9966228604316711    │\n",
       "│ Train/PolicyRatio/Min          │ 0.9966228604316711    │\n",
       "│ Train/PolicyRatio/Max          │ 0.9966228604316711    │\n",
       "│ Train/PolicyRatio/Std          │ 0.0075334208086133    │\n",
       "│ Train/LR                       │ 0.0                   │\n",
       "│ Train/PolicyStd                │ 0.9996514320373535    │\n",
       "│ TotalEnvSteps                  │ 200.0                 │\n",
       "│ Loss/Loss_pi                   │ 0.08751548826694489   │\n",
       "│ Loss/Loss_pi/Delta             │ 0.08751548826694489   │\n",
       "│ Value/Adv                      │ -0.398242324590683    │\n",
       "│ Loss/Loss_reward_critic        │ 16605.1796875         │\n",
       "│ Loss/Loss_reward_critic/Delta  │ 16605.1796875         │\n",
       "│ Value/reward                   │ 0.0049050007946789265 │\n",
       "│ Loss/Loss_cost_critic          │ 0.052194785326719284  │\n",
       "│ Loss/Loss_cost_critic/Delta    │ 0.052194785326719284  │\n",
       "│ Value/cost                     │ 0.07966174930334091   │\n",
       "│ Time/Total                     │ 0.21084904670715332   │\n",
       "│ Time/Rollout                   │ 0.17566156387329102   │\n",
       "│ Time/Update                    │ 0.03439140319824219   │\n",
       "│ Time/Epoch                     │ 0.21008920669555664   │\n",
       "│ Time/FPS                       │ 951.9786987304688     │\n",
       "│ Metrics/LagrangeMultiplier/Mea │ 0.0                   │\n",
       "│ Metrics/LagrangeMultiplier/Min │ 0.0                   │\n",
       "│ Metrics/LagrangeMultiplier/Max │ 0.0                   │\n",
       "│ Metrics/LagrangeMultiplier/Std │ 0.0                   │\n",
       "└────────────────────────────────┴───────────────────────┘\n"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/plain": [
       "(-1616.242431640625, 0.0, 200.0)"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "@env_register\n",
    "@env_unregister  # Avoid the \"environment has been registered\" error when rerunning cells\n",
    "class ExampleMuJoCoEnv(ExampleMuJoCoEnv):\n",
    "    pass\n",
    "\n",
    "\n",
    "custom_cfgs = {\n",
    "    'train_cfgs': {\n",
    "        'total_steps': 200,\n",
    "    },\n",
    "    'algo_cfgs': {\n",
    "        'steps_per_epoch': 200,\n",
    "        'update_iters': 1,\n",
    "    },\n",
    "}\n",
    "agent = omnisafe.Agent('PPOLag', 'Pendulum-v1', custom_cfgs=custom_cfgs)\n",
    "agent.learn()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Advanced Usage\n",
    "In addition to the aforementioned methods, environments from the community can also take advantage of OmniSafe's capabilities for specifying environment-specific parameters and recording information. We will detail the specific operational methods."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Specifying Specific Parameters\n",
    "\n",
    "Taking `Pendulum-v1` as an example, according to the Gymnasium documentation, a specific parameter `g`, which stands for gravitational acceleration, can be specified when creating this task. Let's first take a look at its default value:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Loading PPOLag.yaml from /home/safepo/dev-env/omnisafe_zjy/omnisafe/utils/../configs/on-policy/PPOLag.yaml\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">Logging data to .</span><span style=\"color: #800080; text-decoration-color: #800080; font-weight: bold\">/runs/</span><span style=\"color: #ff00ff; text-decoration-color: #ff00ff; font-weight: bold\">PPOLag-</span><span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">{Pendulum-v1}</span><span style=\"color: #800080; text-decoration-color: #800080; font-weight: bold\">/seed-000-2024-04-09-15-09-17/</span><span style=\"color: #ff00ff; text-decoration-color: #ff00ff; font-weight: bold\">progress.csv</span>\n",
       "</pre>\n"
      ],
      "text/plain": [
       "\u001b[1;36mLogging data to .\u001b[0m\u001b[1;35m/runs/\u001b[0m\u001b[1;95mPPOLag-\u001b[0m\u001b[1;36m{\u001b[0m\u001b[1;36mPendulum-v1\u001b[0m\u001b[1;36m}\u001b[0m\u001b[1;35m/seed-000-2024-04-09-15-09-17/\u001b[0m\u001b[1;95mprogress.csv\u001b[0m\n"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"color: #808000; text-decoration-color: #808000; font-weight: bold\">Save with config in config.json</span>\n",
       "</pre>\n"
      ],
      "text/plain": [
       "\u001b[1;33mSave with config in config.json\u001b[0m\n"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/plain": [
       "10.0"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "@env_register\n",
    "@env_unregister  # Avoid the \"environment has been registered\" error when rerunning cells\n",
    "class ExampleMuJoCoEnv(ExampleMuJoCoEnv):\n",
    "    def __getattr__(self, name: str) -> Any:\n",
    "        \"\"\"Get the attribute of the environment.\"\"\"\n",
    "        if name.startswith('_'):\n",
    "            raise AttributeError(f'attempted to get missing private attribute {name}')\n",
    "        return getattr(self._env, name)\n",
    "\n",
    "\n",
    "custom_cfgs = {\n",
    "    'train_cfgs': {\n",
    "        'total_steps': 200,\n",
    "    },\n",
    "    'algo_cfgs': {\n",
    "        'steps_per_epoch': 200,\n",
    "        'update_iters': 1,\n",
    "    },\n",
    "}\n",
    "agent = omnisafe.Agent('PPOLag', 'Pendulum-v1', custom_cfgs=custom_cfgs)\n",
    "agent.agent._env._env.g"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We implemented a magic function named `__get_attr__` to call and view specific parameters in the currently instantiated environment. In this case, we find that the default value of the gravitational acceleration `g` is 10.0.\n",
    "\n",
    "By consulting the Gymnasium documentation, this parameter can be specified during the process of creating an environment with the `gymnasium.make` function. Does OmniSafe support the passing of specific parameters for customized environments? The answer is yes:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Loading PPOLag.yaml from /home/safepo/dev-env/omnisafe_zjy/omnisafe/utils/../configs/on-policy/PPOLag.yaml\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">Logging data to .</span><span style=\"color: #800080; text-decoration-color: #800080; font-weight: bold\">/runs/</span><span style=\"color: #ff00ff; text-decoration-color: #ff00ff; font-weight: bold\">PPOLag-</span><span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">{Pendulum-v1}</span><span style=\"color: #800080; text-decoration-color: #800080; font-weight: bold\">/seed-000-2024-04-09-15-09-20/</span><span style=\"color: #ff00ff; text-decoration-color: #ff00ff; font-weight: bold\">progress.csv</span>\n",
       "</pre>\n"
      ],
      "text/plain": [
       "\u001b[1;36mLogging data to .\u001b[0m\u001b[1;35m/runs/\u001b[0m\u001b[1;95mPPOLag-\u001b[0m\u001b[1;36m{\u001b[0m\u001b[1;36mPendulum-v1\u001b[0m\u001b[1;36m}\u001b[0m\u001b[1;35m/seed-000-2024-04-09-15-09-20/\u001b[0m\u001b[1;95mprogress.csv\u001b[0m\n"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"color: #808000; text-decoration-color: #808000; font-weight: bold\">Save with config in config.json</span>\n",
       "</pre>\n"
      ],
      "text/plain": [
       "\u001b[1;33mSave with config in config.json\u001b[0m\n"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/plain": [
       "9.8"
      ]
     },
     "execution_count": 6,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "custom_cfgs.update({'env_cfgs': {'g': 9.8}})\n",
    "agent = omnisafe.Agent('PPOLag', 'Pendulum-v1', custom_cfgs=custom_cfgs)\n",
    "agent.agent._env._env.g"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Nice! The value of gravitational acceleration has been changed to 9.8. We just need to operate on `env_cfgs`, specifying the key and value of the parameter to be customized, to achieve the passing of specific parameters for the environment."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Information Recording\n",
    "\n",
    "The `Pendulum-v1` task contains many specific dynamic pieces of information. We will introduce how to record these pieces of information using OmniSafe's `Logger`. Specifically, we will explain using the maximum and cumulative values of the angular velocity `angular_velocity` per episode as examples."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Loading PPOLag.yaml from /home/safepo/dev-env/omnisafe_zjy/omnisafe/utils/../configs/on-policy/PPOLag.yaml\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">Logging data to .</span><span style=\"color: #800080; text-decoration-color: #800080; font-weight: bold\">/runs/</span><span style=\"color: #ff00ff; text-decoration-color: #ff00ff; font-weight: bold\">PPOLag-</span><span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">{Pendulum-v1}</span><span style=\"color: #800080; text-decoration-color: #800080; font-weight: bold\">/seed-000-2024-04-09-15-09-23/</span><span style=\"color: #ff00ff; text-decoration-color: #ff00ff; font-weight: bold\">progress.csv</span>\n",
       "</pre>\n"
      ],
      "text/plain": [
       "\u001b[1;36mLogging data to .\u001b[0m\u001b[1;35m/runs/\u001b[0m\u001b[1;95mPPOLag-\u001b[0m\u001b[1;36m{\u001b[0m\u001b[1;36mPendulum-v1\u001b[0m\u001b[1;36m}\u001b[0m\u001b[1;35m/seed-000-2024-04-09-15-09-23/\u001b[0m\u001b[1;95mprogress.csv\u001b[0m\n"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"color: #808000; text-decoration-color: #808000; font-weight: bold\">Save with config in config.json</span>\n",
       "</pre>\n"
      ],
      "text/plain": [
       "\u001b[1;33mSave with config in config.json\u001b[0m\n"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"color: #008000; text-decoration-color: #008000\">INFO: Start training</span>\n",
       "</pre>\n"
      ],
      "text/plain": [
       "\u001b[32mINFO: Start training\u001b[0m\n"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"color: #008000; text-decoration-color: #008000\">Warning: trajectory cut off when rollout by epoch at </span><span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">200.0</span><span style=\"color: #008000; text-decoration-color: #008000\"> steps.</span>\n",
       "</pre>\n"
      ],
      "text/plain": [
       "\u001b[32mWarning: trajectory cut off when rollout by epoch at \u001b[0m\u001b[1;36m200.0\u001b[0m\u001b[32m steps.\u001b[0m\n"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"></pre>\n"
      ],
      "text/plain": []
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">\n",
       "</pre>\n"
      ],
      "text/plain": [
       "\n"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"></pre>\n"
      ],
      "text/plain": []
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">\n",
       "</pre>\n"
      ],
      "text/plain": [
       "\n"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━┓\n",
       "┃<span style=\"font-weight: bold\"> Metrics                         </span>┃<span style=\"font-weight: bold\"> Value                 </span>┃\n",
       "┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━┩\n",
       "│ Metrics/EpRet                   │ -1607.6717529296875   │\n",
       "│ Metrics/EpCost                  │ 0.0                   │\n",
       "│ Metrics/EpLen                   │ 200.0                 │\n",
       "│ Train/Epoch                     │ 0.0                   │\n",
       "│ Train/Entropy                   │ 1.418560266494751     │\n",
       "│ Train/KL                        │ 0.0005777678452432156 │\n",
       "│ Train/StopIter                  │ 1.0                   │\n",
       "│ Train/PolicyRatio/Mean          │ 0.9981198310852051    │\n",
       "│ Train/PolicyRatio/Min           │ 0.9981198310852051    │\n",
       "│ Train/PolicyRatio/Max           │ 0.9981198310852051    │\n",
       "│ Train/PolicyRatio/Std           │ 0.005412393249571323  │\n",
       "│ Train/LR                        │ 0.0                   │\n",
       "│ Train/PolicyStd                 │ 0.9996219277381897    │\n",
       "│ TotalEnvSteps                   │ 200.0                 │\n",
       "│ Loss/Loss_pi                    │ 0.09192709624767303   │\n",
       "│ Loss/Loss_pi/Delta              │ 0.09192709624767303   │\n",
       "│ Value/Adv                       │ -0.4177907109260559   │\n",
       "│ Loss/Loss_reward_critic         │ 16393.2265625         │\n",
       "│ Loss/Loss_reward_critic/Delta   │ 16393.2265625         │\n",
       "│ Value/reward                    │ 0.00719139538705349   │\n",
       "│ Loss/Loss_cost_critic           │ 0.05219484493136406   │\n",
       "│ Loss/Loss_cost_critic/Delta     │ 0.05219484493136406   │\n",
       "│ Value/cost                      │ 0.07949987053871155   │\n",
       "│ Time/Total                      │ 0.20513606071472168   │\n",
       "│ Time/Rollout                    │ 0.17486166954040527   │\n",
       "│ Time/Update                     │ 0.029330968856811523  │\n",
       "│ Time/Epoch                      │ 0.20422101020812988   │\n",
       "│ Time/FPS                        │ 979.3323364257812     │\n",
       "│ Env/Max_angular_velocity        │ 2.9994523525238037    │\n",
       "│ Env/Cumulative_angular_velocity │ 1.0643725395202637    │\n",
       "│ Metrics/LagrangeMultiplier/Mean │ 0.0                   │\n",
       "│ Metrics/LagrangeMultiplier/Min  │ 0.0                   │\n",
       "│ Metrics/LagrangeMultiplier/Max  │ 0.0                   │\n",
       "│ Metrics/LagrangeMultiplier/Std  │ 0.0                   │\n",
       "└─────────────────────────────────┴───────────────────────┘\n",
       "</pre>\n"
      ],
      "text/plain": [
       "┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━┓\n",
       "┃\u001b[1m \u001b[0m\u001b[1mMetrics                        \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1mValue                \u001b[0m\u001b[1m \u001b[0m┃\n",
       "┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━┩\n",
       "│ Metrics/EpRet                   │ -1607.6717529296875   │\n",
       "│ Metrics/EpCost                  │ 0.0                   │\n",
       "│ Metrics/EpLen                   │ 200.0                 │\n",
       "│ Train/Epoch                     │ 0.0                   │\n",
       "│ Train/Entropy                   │ 1.418560266494751     │\n",
       "│ Train/KL                        │ 0.0005777678452432156 │\n",
       "│ Train/StopIter                  │ 1.0                   │\n",
       "│ Train/PolicyRatio/Mean          │ 0.9981198310852051    │\n",
       "│ Train/PolicyRatio/Min           │ 0.9981198310852051    │\n",
       "│ Train/PolicyRatio/Max           │ 0.9981198310852051    │\n",
       "│ Train/PolicyRatio/Std           │ 0.005412393249571323  │\n",
       "│ Train/LR                        │ 0.0                   │\n",
       "│ Train/PolicyStd                 │ 0.9996219277381897    │\n",
       "│ TotalEnvSteps                   │ 200.0                 │\n",
       "│ Loss/Loss_pi                    │ 0.09192709624767303   │\n",
       "│ Loss/Loss_pi/Delta              │ 0.09192709624767303   │\n",
       "│ Value/Adv                       │ -0.4177907109260559   │\n",
       "│ Loss/Loss_reward_critic         │ 16393.2265625         │\n",
       "│ Loss/Loss_reward_critic/Delta   │ 16393.2265625         │\n",
       "│ Value/reward                    │ 0.00719139538705349   │\n",
       "│ Loss/Loss_cost_critic           │ 0.05219484493136406   │\n",
       "│ Loss/Loss_cost_critic/Delta     │ 0.05219484493136406   │\n",
       "│ Value/cost                      │ 0.07949987053871155   │\n",
       "│ Time/Total                      │ 0.20513606071472168   │\n",
       "│ Time/Rollout                    │ 0.17486166954040527   │\n",
       "│ Time/Update                     │ 0.029330968856811523  │\n",
       "│ Time/Epoch                      │ 0.20422101020812988   │\n",
       "│ Time/FPS                        │ 979.3323364257812     │\n",
       "│ Env/Max_angular_velocity        │ 2.9994523525238037    │\n",
       "│ Env/Cumulative_angular_velocity │ 1.0643725395202637    │\n",
       "│ Metrics/LagrangeMultiplier/Mean │ 0.0                   │\n",
       "│ Metrics/LagrangeMultiplier/Min  │ 0.0                   │\n",
       "│ Metrics/LagrangeMultiplier/Max  │ 0.0                   │\n",
       "│ Metrics/LagrangeMultiplier/Std  │ 0.0                   │\n",
       "└─────────────────────────────────┴───────────────────────┘\n"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/plain": [
       "(-1607.6717529296875, 0.0, 200.0)"
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "from omnisafe.common.logger import Logger\n",
    "\n",
    "\n",
    "@env_register\n",
    "@env_unregister  # Avoid the \"environment has been registered\" error when rerunning cells\n",
    "class ExampleMuJoCoEnv(ExampleMuJoCoEnv):\n",
    "\n",
    "    def __init__(self, env_id, num_envs, device, **kwargs):\n",
    "        super().__init__(env_id, num_envs, device, **kwargs)\n",
    "        self.env_spec_log = {\n",
    "            'Env/Max_angular_velocity': 0.0,\n",
    "            'Env/Cumulative_angular_velocity': 0.0,\n",
    "        }  # Reiterate and specify in the constructor\n",
    "\n",
    "    def spec_log(self, logger: Logger) -> None:\n",
    "        for key, value in self.env_spec_log.items():\n",
    "            logger.store({key: value})\n",
    "            self.env_spec_log[key] = 0.0\n",
    "\n",
    "    def step(self, action):\n",
    "        obs, reward, cost, terminated, truncated, info = super().step(action=action)\n",
    "        angle = obs[-1].item()\n",
    "        self.env_spec_log['Env/Max_angular_velocity'] = max(\n",
    "            self.env_spec_log['Env/Max_angular_velocity'], angle\n",
    "        )\n",
    "        self.env_spec_log['Env/Cumulative_angular_velocity'] += angle\n",
    "        return obs, reward, cost, terminated, truncated, info\n",
    "\n",
    "\n",
    "agent = omnisafe.Agent('PPOLag', 'Pendulum-v1', custom_cfgs=custom_cfgs)\n",
    "agent.learn()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Great! We successfully recorded the required environment-specific information in the `Logger`. It is worth noting that, in this process, we did not modify any of OmniSafe's source code."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Summary\n",
    "In this section, using Gymnasium's classic environment `Pendulum-v1`, we introduced the necessary interface adaptation and information provision required to embed an existing community environment into OmniSafe. We hope this tutorial is helpful for embedding your customized environment. If you wish to have your environment supported as one of the official OmniSafe environments, or if you encounter difficulties in customizing environments, you are welcome to communicate with us through the [Issues](https://github.com/PKU-Alignment/omnisafe/issues), [Pull Requests](https://github.com/PKU-Alignment/omnisafe/pulls), and [Discussions](https://github.com/PKU-Alignment/omnisafe/discussions) modules."
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "omnisafe",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.19"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
