{"nbformat":4,"nbformat_minor":0,"metadata":{"colab":{"name":"LunarLanderContinuous_Public.ipynb","provenance":[],"collapsed_sections":["SyPl97Rht-E3","f75f934c-6921-43aa-8389-6df4b993eca4","MOlw1cZpnJf5","NVcUoeRUnJf6","kyKnz5OUnJf7"],"authorship_tag":"ABX9TyNpYBa8/sXaUNJgCKr614LR"},"kernelspec":{"name":"python3","display_name":"Python 3"},"language_info":{"name":"python"}},"cells":[{"cell_type":"markdown","source":["## Imports"],"metadata":{"id":"SyPl97Rht-E3"}},{"cell_type":"code","source":["!apt install python-opengl\n","!apt install ffmpeg\n","!apt install xvfb\n","!pip install pyvirtualdisplay\n","from pyvirtualdisplay import Display\n","\n","# Start virtual display\n","dis = Display(visible=0, size=(600, 400))\n","dis.start()"],"metadata":{"id":"ponykHh19KBb"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["!pip3 install gym\n","!pip3 install box2d-py\n","!pip3 install gym[Box_2D]\n","!pip3 install pygame\n","\n","import random\n","from typing import Dict, List, Tuple\n","\n","from collections import deque, namedtuple\n","import gym\n","import matplotlib.pyplot as plt\n","import numpy as np\n","import torch\n","import torch.nn as nn\n","import torch.nn.functional as F\n","import torch.optim as optim\n","from IPython.display import clear_output\n","from torch.distributions import Normal"],"metadata":{"id":"0FZXYLBxRmj-"},"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"f75f934c-6921-43aa-8389-6df4b993eca4"},"source":["## Instantiate the Environment\n","\n","Initialize the environment."]},{"cell_type":"code","source":["#@title Custom LunarLander Environment\n","import math\n","import sys\n","from typing import Optional\n","\n","import Box2D\n","import numpy as np\n","from Box2D.b2 import (\n","    circleShape,\n","    contactListener,\n","    edgeShape,\n","    fixtureDef,\n","    polygonShape,\n","    revoluteJointDef,\n",")\n","\n","import gym\n","from gym import error, spaces\n","from gym.utils import EzPickle, seeding\n","\n","FPS = 50\n","SCALE = 30.0  # affects how fast-paced the game is, forces should be adjusted as well\n","\n","MAIN_ENGINE_POWER = 13.0\n","SIDE_ENGINE_POWER = 2.0\n","\n","INITIAL_RANDOM = 500.0  # Set 1500 to make game harder\n","\n","LANDER_POLY = [(-14, +17), (-17, 0), (-17, -10), (+17, -10), (+17, 0), (+14, +17)]\n","LEG_AWAY = 20\n","LEG_DOWN = 18\n","LEG_W, LEG_H = 2, 8\n","LEG_SPRING_TORQUE = 40\n","\n","SIDE_ENGINE_HEIGHT = 14.0\n","SIDE_ENGINE_AWAY = 12.0\n","\n","VIEWPORT_W = 600\n","VIEWPORT_H = 400\n","\n","\n","class ContactDetector(contactListener):\n","    def __init__(self, env):\n","        contactListener.__init__(self)\n","        self.env = env\n","\n","    def BeginContact(self, contact):\n","        if (\n","            self.env.lander == contact.fixtureA.body\n","            or self.env.lander == contact.fixtureB.body\n","        ):\n","            self.env.game_over = True\n","        for i in range(2):\n","            if self.env.legs[i] in [contact.fixtureA.body, contact.fixtureB.body]:\n","                self.env.legs[i].ground_contact = True\n","\n","    def EndContact(self, contact):\n","        for i in range(2):\n","            if self.env.legs[i] in [contact.fixtureA.body, contact.fixtureB.body]:\n","                self.env.legs[i].ground_contact = False\n","\n","\n","class LunarLander(gym.Env, EzPickle):\n","    \"\"\"\n","    ### Description\n","    This environment is a classic rocket trajectory optimization problem.\n","    According to Pontryagin's maximum principle, it is optimal to fire the\n","    engine at full throttle or turn it off. This is the reason why this\n","    environment has discrete actions: engine on or off.\n","\n","    There are two environment versions: discrete or continuous.\n","    The landing pad is always at coordinates (0,0). The coordinates are the\n","    first two numbers in the state vector.\n","    Landing outside of the landing pad is possible. Fuel is infinite, so an agent\n","    can learn to fly and then land on its first attempt.\n","\n","    ### Action Space\n","    There are four discrete actions available: do nothing, fire left\n","    orientation engine, fire main engine, fire right orientation engine.\n","\n","    ### Observation Space\n","    There are 8 dimensions: the relative coordinates of the lander in `x` & `y`, its linear\n","    velocities in `x` & `y`, its angle, its angular velocity, two booleans\n","    that represent whether each leg is in contact with the ground or not\n","\n","    ### Rewards\n","    Reward for moving from the top of the screen to the landing pad and coming\n","    to rest is about 100-140 points.\n","    If the lander moves away from the landing pad, it loses reward.\n","    If the lander crashes, it receives an additional -100 points. If it comes\n","    to rest, it receives an additional +100 points. Each leg with ground\n","    contact is +10 points.\n","    Firing the main engine is -0.3 points each frame. Firing the side engine\n","    is -0.03 points each frame. Solved is 200 points.\n","\n","    ### Starting State\n","    The lander starts at the top center of the viewport with a random initial\n","    force applied to its center of mass.\n","\n","    ### Episode Termination\n","    The episode finishes if:\n","    1) the lander crashes (the lander body gets in contact with the moon);\n","    2) the lander gets outside of the viewport (`x` coordinate is greater than 1);\n","    3) the lander is not awake. From the [Box2D docs](https://box2d.org/documentation/md__d_1__git_hub_box2d_docs_dynamics.html#autotoc_md61),\n","        a body which is not awake is a body which doesn't move and doesn't\n","        collide with any other body:\n","    > When Box2D determines that a body (or group of bodies) has come to rest,\n","    > the body enters a sleep state which has very little CPU overhead. If a\n","    > body is awake and collides with a sleeping body, then the sleeping body\n","    > wakes up. Bodies will also wake up if a joint or contact attached to\n","    > them is destroyed.\n","\n","    ### Arguments\n","    To use to the _continuous_ environment, you need to specify the\n","    `continuous=True` argument like below:\n","    \"\"\"\n","\n","    metadata = {\"render_modes\": [\"human\", \"rgb_array\"], \"render_fps\": FPS}\n","\n","    def __init__(self,\n","                 continuous: bool = False, \n","                 main_engine_power = MAIN_ENGINE_POWER,\n","                 side_engine_power = SIDE_ENGINE_POWER):\n","        EzPickle.__init__(self)\n","        self.screen = None\n","        self.clock = None\n","        self.isopen = True\n","        self.world = Box2D.b2World()\n","        self.moon = None\n","        self.lander = None\n","        self.particles = []\n","\n","        self.prev_reward = None\n","\n","        self.continuous = continuous\n","\n","        self.main_engine_power = main_engine_power\n","        self.side_engine_power = side_engine_power\n","\n","        # useful range is -1 .. +1, but spikes can be higher\n","        self.observation_space = spaces.Box(\n","            -np.inf, np.inf, shape=(8,), dtype=np.float32\n","        )\n","\n","        if self.continuous:\n","            # Action is two floats [main engine, left-right engines].\n","            # Main engine: -1..0 off, 0..+1 throttle from 50% to 100% power. Engine can't work with less than 50% power.\n","            # Left-right:  -1.0..-0.5 fire left engine, +0.5..+1.0 fire right engine, -0.5..0.5 off\n","            self.action_space = spaces.Box(-1, +1, (2,), dtype=np.float32)\n","        else:\n","            # Nop, fire left engine, main engine, right engine\n","            self.action_space = spaces.Discrete(4)\n","\n","    def _destroy(self):\n","        if not self.moon:\n","            return\n","        self.world.contactListener = None\n","        self._clean_particles(True)\n","        self.world.DestroyBody(self.moon)\n","        self.moon = None\n","        self.world.DestroyBody(self.lander)\n","        self.lander = None\n","        self.world.DestroyBody(self.legs[0])\n","        self.world.DestroyBody(self.legs[1])\n","\n","    def reset(\n","        self,\n","        *,\n","        seed: Optional[int] = None,\n","        return_info: bool = False,\n","        options: Optional[dict] = None,\n","        helipad_x = None,\n","    ):\n","        np.random.seed(seed)\n","        self._destroy()\n","        self.world.contactListener_keepref = ContactDetector(self)\n","        self.world.contactListener = self.world.contactListener_keepref\n","        self.game_over = False\n","        self.prev_shaping = None\n","\n","        W = VIEWPORT_W / SCALE\n","        H = VIEWPORT_H / SCALE\n","\n","        if helipad_x is None:\n","            self.helipad_x = np.random.uniform(W / 4, 3*W / 4)\n","        else:\n","            self.helipad_x = int(helipad_x * W)\n","\n","        # terrain\n","        CHUNKS = 11\n","        height = np.random.uniform(0, H / 2, size=(CHUNKS + 1,))\n","        chunk_x = np.array([W / (CHUNKS - 1) * i for i in range(CHUNKS)])\n","        chunk_x_idx = np.argmin(chunk_x < self.helipad_x)\n","        self.helipad_x1 = chunk_x[chunk_x_idx - 1]\n","        self.helipad_x2 = chunk_x[chunk_x_idx]\n","        self.helipad_y = H / 4\n","        height[max(0, chunk_x_idx - 2)] = self.helipad_y\n","        height[max(0, chunk_x_idx - 1)] = self.helipad_y\n","        height[chunk_x_idx + 0] = self.helipad_y\n","        height[min(CHUNKS - 1, chunk_x_idx + 1)] = self.helipad_y\n","        height[min(CHUNKS - 1, chunk_x_idx + 2)] = self.helipad_y\n","        smooth_y = [\n","            0.33 * (height[i - 1] + height[i + 0] + height[i + 1])\n","            for i in range(CHUNKS)\n","        ]\n","\n","        self.moon = self.world.CreateStaticBody(\n","            shapes=edgeShape(vertices=[(0, 0), (W, 0)])\n","        )\n","        self.sky_polys = []\n","        for i in range(CHUNKS - 1):\n","            p1 = (chunk_x[i], smooth_y[i])\n","            p2 = (chunk_x[i + 1], smooth_y[i + 1])\n","            self.moon.CreateEdgeFixture(vertices=[p1, p2], density=0, friction=0.1)\n","            self.sky_polys.append([p1, p2, (p2[0], H), (p1[0], H)])\n","\n","        self.moon.color1 = (0.0, 0.0, 0.0)\n","        self.moon.color2 = (0.0, 0.0, 0.0)\n","\n","        initial_y = VIEWPORT_H / SCALE\n","        self.lander = self.world.CreateDynamicBody(\n","            position=(VIEWPORT_W / SCALE / 2, initial_y),\n","            angle=0.0,\n","            fixtures=fixtureDef(\n","                shape=polygonShape(\n","                    vertices=[(x / SCALE, y / SCALE) for x, y in LANDER_POLY]\n","                ),\n","                density=5.0,\n","                friction=0.1,\n","                categoryBits=0x0010,\n","                maskBits=0x001,  # collide only with ground\n","                restitution=0.0,\n","            ),  # 0.99 bouncy\n","        )\n","        self.lander.color1 = (128, 102, 230)\n","        self.lander.color2 = (77, 77, 128)\n","        self.lander.ApplyForceToCenter(\n","            (\n","                np.random.uniform(-INITIAL_RANDOM, INITIAL_RANDOM),\n","                np.random.uniform(-INITIAL_RANDOM, INITIAL_RANDOM),\n","            ),\n","            True,\n","        )\n","\n","        self.legs = []\n","        for i in [-1, +1]:\n","            leg = self.world.CreateDynamicBody(\n","                position=(VIEWPORT_W / SCALE / 2 - i * LEG_AWAY / SCALE, initial_y),\n","                angle=(i * 0.05),\n","                fixtures=fixtureDef(\n","                    shape=polygonShape(box=(LEG_W / SCALE, LEG_H / SCALE)),\n","                    density=1.0,\n","                    restitution=0.0,\n","                    categoryBits=0x0020,\n","                    maskBits=0x001,\n","                ),\n","            )\n","            leg.ground_contact = False\n","            leg.color1 = (128, 102, 230)\n","            leg.color2 = (77, 77, 128)\n","            rjd = revoluteJointDef(\n","                bodyA=self.lander,\n","                bodyB=leg,\n","                localAnchorA=(0, 0),\n","                localAnchorB=(i * LEG_AWAY / SCALE, LEG_DOWN / SCALE),\n","                enableMotor=True,\n","                enableLimit=True,\n","                maxMotorTorque=LEG_SPRING_TORQUE,\n","                motorSpeed=+0.3 * i,  # low enough not to jump back into the sky\n","            )\n","            if i == -1:\n","                rjd.lowerAngle = (\n","                    +0.9 - 0.5\n","                )  # The most esoteric numbers here, angled legs have freedom to travel within\n","                rjd.upperAngle = +0.9\n","            else:\n","                rjd.lowerAngle = -0.9\n","                rjd.upperAngle = -0.9 + 0.5\n","            leg.joint = self.world.CreateJoint(rjd)\n","            self.legs.append(leg)\n","\n","        self.drawlist = [self.lander] + self.legs\n","\n","        if not return_info:\n","            return self.step(np.array([0, 0]) if self.continuous else 0)[0]\n","        else:\n","            return self.step(np.array([0, 0]) if self.continuous else 0)[0], {}\n","\n","    def _create_particle(self, mass, x, y, ttl):\n","        p = self.world.CreateDynamicBody(\n","            position=(x, y),\n","            angle=0.0,\n","            fixtures=fixtureDef(\n","                shape=circleShape(radius=2 / SCALE, pos=(0, 0)),\n","                density=mass,\n","                friction=0.1,\n","                categoryBits=0x0100,\n","                maskBits=0x001,  # collide only with ground\n","                restitution=0.3,\n","            ),\n","        )\n","        p.ttl = ttl\n","        self.particles.append(p)\n","        self._clean_particles(False)\n","        return p\n","\n","    def _clean_particles(self, all):\n","        while self.particles and (all or self.particles[0].ttl < 0):\n","            self.world.DestroyBody(self.particles.pop(0))\n","\n","    def step(self, action):\n","        if self.continuous:\n","            action = np.clip(action, -1, +1).astype(np.float32)\n","        else:\n","            assert self.action_space.contains(\n","                action\n","            ), f\"{action!r} ({type(action)}) invalid \"\n","\n","        # Engines\n","        tip = (math.sin(self.lander.angle), math.cos(self.lander.angle))\n","        side = (-tip[1], tip[0])\n","        dispersion = [np.random.uniform(-1.0, +1.0) / SCALE for _ in range(2)]\n","\n","        m_power = 0.0\n","        if (self.continuous and action[0] > 0.0) or (\n","            not self.continuous and action == 2\n","        ):\n","            # Main engine\n","            if self.continuous:\n","                m_power = (np.clip(action[0], 0.0, 1.0) + 1.0) * 0.5  # 0.5..1.0\n","                assert m_power >= 0.5 and m_power <= 1.0\n","            else:\n","                m_power = 1.0\n","            ox = (\n","                tip[0] * (4 / SCALE + 2 * dispersion[0]) + side[0] * dispersion[1]\n","            )  # 4 is move a bit downwards, +-2 for randomness\n","            oy = -tip[1] * (4 / SCALE + 2 * dispersion[0]) - side[1] * dispersion[1]\n","            impulse_pos = (self.lander.position[0] + ox, self.lander.position[1] + oy)\n","            p = self._create_particle(\n","                3.5,  # 3.5 is here to make particle speed adequate\n","                impulse_pos[0],\n","                impulse_pos[1],\n","                m_power,\n","            )  # particles are just a decoration\n","            p.ApplyLinearImpulse(\n","                (ox * self.main_engine_power * m_power, oy * self.main_engine_power * m_power),\n","                impulse_pos,\n","                True,\n","            )\n","            self.lander.ApplyLinearImpulse(\n","                (-ox * self.main_engine_power * m_power, -oy * self.main_engine_power * m_power),\n","                impulse_pos,\n","                True,\n","            )\n","\n","        s_power = 0.0\n","        if (self.continuous and np.abs(action[1]) > 0.5) or (\n","            not self.continuous and action in [1, 3]\n","        ):\n","            # Orientation engines\n","            if self.continuous:\n","                direction = np.sign(action[1])\n","                s_power = np.clip(np.abs(action[1]), 0.5, 1.0)\n","                assert s_power >= 0.5 and s_power <= 1.0\n","            else:\n","                direction = action - 2\n","                s_power = 1.0\n","            ox = tip[0] * dispersion[0] + side[0] * (\n","                3 * dispersion[1] + direction * SIDE_ENGINE_AWAY / SCALE\n","            )\n","            oy = -tip[1] * dispersion[0] - side[1] * (\n","                3 * dispersion[1] + direction * SIDE_ENGINE_AWAY / SCALE\n","            )\n","            impulse_pos = (\n","                self.lander.position[0] + ox - tip[0] * 17 / SCALE,\n","                self.lander.position[1] + oy + tip[1] * SIDE_ENGINE_HEIGHT / SCALE,\n","            )\n","            p = self._create_particle(0.7, impulse_pos[0], impulse_pos[1], s_power)\n","            p.ApplyLinearImpulse(\n","                (ox * self.side_engine_power * s_power, oy * self.side_engine_power * s_power),\n","                impulse_pos,\n","                True,\n","            )\n","            self.lander.ApplyLinearImpulse(\n","                (-ox * self.side_engine_power * s_power, -oy * self.side_engine_power * s_power),\n","                impulse_pos,\n","                True,\n","            )\n","\n","        self.world.Step(1.0 / FPS, 6 * 30, 2 * 30)\n","\n","        pos = self.lander.position\n","        vel = self.lander.linearVelocity\n","        state = [\n","            (pos.x - self.helipad_x) / (VIEWPORT_W / SCALE / 2),\n","            (pos.y - (self.helipad_y + LEG_DOWN / SCALE)) / (VIEWPORT_H / SCALE / 2),\n","            vel.x * (VIEWPORT_W / SCALE) / FPS,\n","            vel.y * (VIEWPORT_H / SCALE) / FPS,\n","            self.lander.angle,\n","            20.0 * self.lander.angularVelocity / FPS,\n","            1.0 if self.legs[0].ground_contact else 0.0,\n","            1.0 if self.legs[1].ground_contact else 0.0,\n","        ]\n","        assert len(state) == 8\n","\n","        reward = 0\n","        shaping = (\n","            -100 * np.sqrt(state[0]**2 + state[1]**2)\n","            - 100 * np.sqrt(state[2]**2 + state[3]**2)\n","            - 100 * abs(state[4])\n","            + 10 * state[6]\n","            + 10 * state[7]\n","        )  # And ten points for legs contact, the idea is if you\n","        # lose contact again after landing, you get negative reward\n","        if self.prev_shaping is not None:\n","            reward = shaping - self.prev_shaping\n","        self.prev_shaping = shaping\n","\n","        reward -= (\n","            m_power * 0.30\n","        )  # less fuel spent is better, about -30 for heuristic landing\n","        reward -= s_power * 0.03\n","\n","        done = False\n","        if self.game_over or abs(state[0]) >= 1.0:\n","            done = True\n","            reward = -100\n","        if not self.lander.awake:\n","            done = True\n","            reward = +100\n","        return np.array(state, dtype=np.float32), reward, done, {}\n","\n","    def render(self, mode=\"human\"):\n","        import pygame\n","        from pygame import gfxdraw\n","\n","        if self.screen is None:\n","            pygame.init()\n","            pygame.display.init()\n","            self.screen = pygame.display.set_mode((VIEWPORT_W, VIEWPORT_H))\n","        if self.clock is None:\n","            self.clock = pygame.time.Clock()\n","\n","        self.surf = pygame.Surface(self.screen.get_size())\n","\n","        pygame.transform.scale(self.surf, (SCALE, SCALE))\n","        pygame.draw.rect(self.surf, (255, 255, 255), self.surf.get_rect())\n","\n","        for obj in self.particles:\n","            obj.ttl -= 0.15\n","            obj.color1 = (\n","                int(max(0.2, 0.15 + obj.ttl) * 255),\n","                int(max(0.2, 0.5 * obj.ttl) * 255),\n","                int(max(0.2, 0.5 * obj.ttl) * 255),\n","            )\n","            obj.color2 = (\n","                int(max(0.2, 0.15 + obj.ttl) * 255),\n","                int(max(0.2, 0.5 * obj.ttl) * 255),\n","                int(max(0.2, 0.5 * obj.ttl) * 255),\n","            )\n","\n","        self._clean_particles(False)\n","\n","        for p in self.sky_polys:\n","            scaled_poly = []\n","            for coord in p:\n","                scaled_poly.append((coord[0] * SCALE, coord[1] * SCALE))\n","            pygame.draw.polygon(self.surf, (0, 0, 0), scaled_poly)\n","            gfxdraw.aapolygon(self.surf, scaled_poly, (0, 0, 0))\n","\n","        for obj in self.particles + self.drawlist:\n","            for f in obj.fixtures:\n","                trans = f.body.transform\n","                if type(f.shape) is circleShape:\n","                    pygame.draw.circle(\n","                        self.surf,\n","                        color=obj.color1,\n","                        center=trans * f.shape.pos * SCALE,\n","                        radius=f.shape.radius * SCALE,\n","                    )\n","                    pygame.draw.circle(\n","                        self.surf,\n","                        color=obj.color2,\n","                        center=trans * f.shape.pos * SCALE,\n","                        radius=f.shape.radius * SCALE,\n","                    )\n","\n","                else:\n","                    path = [trans * v * SCALE for v in f.shape.vertices]\n","                    pygame.draw.polygon(self.surf, color=obj.color1, points=path)\n","                    gfxdraw.aapolygon(self.surf, path, obj.color1)\n","                    pygame.draw.aalines(\n","                        self.surf, color=obj.color2, points=path, closed=True\n","                    )\n","\n","                for x in [self.helipad_x1, self.helipad_x2]:\n","                    x = x * SCALE\n","                    flagy1 = self.helipad_y * SCALE\n","                    flagy2 = flagy1 + 50\n","                    pygame.draw.line(\n","                        self.surf,\n","                        color=(255, 255, 255),\n","                        start_pos=(x, flagy1),\n","                        end_pos=(x, flagy2),\n","                        width=1,\n","                    )\n","                    pygame.draw.polygon(\n","                        self.surf,\n","                        color=(204, 204, 0),\n","                        points=[\n","                            (x, flagy2),\n","                            (x, flagy2 - 10),\n","                            (x + 25, flagy2 - 5),\n","                        ],\n","                    )\n","                    gfxdraw.aapolygon(\n","                        self.surf,\n","                        [(x, flagy2), (x, flagy2 - 10), (x + 25, flagy2 - 5)],\n","                        (204, 204, 0),\n","                    )\n","\n","        self.surf = pygame.transform.flip(self.surf, False, True)\n","        self.screen.blit(self.surf, (0, 0))\n","\n","        if mode == \"human\":\n","            pygame.event.pump()\n","            self.clock.tick(self.metadata[\"render_fps\"])\n","            pygame.display.flip()\n","\n","        if mode == \"rgb_array\":\n","            return np.transpose(\n","                np.array(pygame.surfarray.pixels3d(self.surf)), axes=(1, 0, 2)\n","            )\n","        else:\n","            return self.isopen\n","\n","    def close(self):\n","        if self.screen is not None:\n","            import pygame\n","\n","            pygame.display.quit()\n","            pygame.quit()\n","            self.isopen = False"],"metadata":{"id":"5Qr7oBcUtbjn"},"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"GKAZtuv1nJf-"},"source":["*ActionNormalizer* is an action wrapper class to normalize the action values ranged in (-1. 1). Thanks to this class, we can make the agent simply select action values within the zero centered range (-1, 1)."]},{"cell_type":"code","execution_count":null,"metadata":{"id":"LVlqsBObnJf_"},"outputs":[],"source":["class ActionNormalizer(gym.ActionWrapper):\n","    \"\"\"Rescale and relocate the actions.\"\"\"\n","\n","    def action(self, action: np.ndarray) -> np.ndarray:\n","        \"\"\"Change the range (-1, 1) to (low, high).\"\"\"\n","        low = self.action_space.low\n","        high = self.action_space.high\n","\n","        scale_factor = (high - low) / 2\n","        reloc_factor = high - scale_factor\n","\n","        action = action * scale_factor + reloc_factor\n","        action = np.clip(action, low, high)\n","\n","        return action\n","\n","    def reverse_action(self, action: np.ndarray) -> np.ndarray:\n","        \"\"\"Change the range (low, high) to (-1, 1).\"\"\"\n","        low = self.action_space.low\n","        high = self.action_space.high\n","\n","        scale_factor = (high - low) / 2\n","        reloc_factor = high - scale_factor\n","\n","        action = (action - reloc_factor) / scale_factor\n","        action = np.clip(action, -1.0, 1.0)\n","\n","        return action"]},{"cell_type":"code","execution_count":null,"metadata":{"id":"9EWd8JNXnJgA","executionInfo":{"status":"ok","timestamp":1652928614557,"user_tz":420,"elapsed":8,"user":{"displayName":"Joey Hong","userId":"05817993003903533998"}},"outputId":"5adedc3f-97c6-4c04-d436-d5b3d04f9d87","colab":{"base_uri":"https://localhost:8080/"}},"outputs":[{"output_type":"stream","name":"stdout","text":["State shape:  (8,)\n","Action shape:  (2,)\n"]},{"output_type":"execute_result","data":{"text/plain":["array([-0.04541585,  1.4051299 ,  0.6882406 , -0.5146879 , -0.00393014,\n","       -0.07794825,  0.        ,  0.        ], dtype=float32)"]},"metadata":{},"execution_count":8}],"source":["# Test environment\n","env = LunarLander(continuous=True)\n","print('State shape: ', env.observation_space.shape)\n","print('Action shape: ', env.action_space.shape)\n","\n","env = ActionNormalizer(env)\n","env.reset(seed=0)"]},{"cell_type":"markdown","metadata":{"id":"1IX-jdErnJf2"},"source":["# Soft Actor Critic (SAC)\n","\n","1. [T. Haarnoja et al., \"Soft Actor-Critic: Off-Policy Maximum Entropy Deep Reinforcement Learning with a Stochastic Actor.\" arXiv preprint arXiv:1801.01290, 2018.](https://arxiv.org/pdf/1801.01290.pdf)\n","2. [T. Haarnoja et al., \"Soft Actor-Critic Algorithms and Applications.\" arXiv preprint arXiv:1812.05905, 2018.](https://arxiv.org/pdf/1812.05905.pdf)"]},{"cell_type":"markdown","metadata":{"id":"MOlw1cZpnJf5"},"source":["## Replay buffer"]},{"cell_type":"code","execution_count":null,"metadata":{"id":"CJ_F0xxinJf6"},"outputs":[],"source":["class ReplayBuffer:\n","    \"\"\"A simple numpy replay buffer.\"\"\"\n","\n","    def __init__(self, obs_dim: int, act_dim: int, size: int, batch_size: int = 32):\n","        \"\"\"Initialize.\"\"\"\n","        self.obs_buf = np.zeros([size, obs_dim], dtype=np.float32)\n","        self.next_obs_buf = np.zeros([size, obs_dim], dtype=np.float32)\n","        self.acts_buf = np.zeros([size, act_dim], dtype=np.float32)\n","        self.rews_buf = np.zeros([size], dtype=np.float32)\n","        self.done_buf = np.zeros([size], dtype=np.float32)\n","        self.max_size, self.batch_size = size, batch_size\n","        self.ptr, self.size, = 0, 0\n","\n","    def store(self,\n","        obs: np.ndarray,\n","        act: np.ndarray, \n","        rew: float, \n","        next_obs: np.ndarray, \n","        done: bool,\n","    ):\n","        \"\"\"Store the transition in buffer.\"\"\"\n","        self.obs_buf[self.ptr] = obs\n","        self.next_obs_buf[self.ptr] = next_obs\n","        self.acts_buf[self.ptr] = act\n","        self.rews_buf[self.ptr] = rew\n","        self.done_buf[self.ptr] = done\n","        self.ptr = (self.ptr + 1) % self.max_size\n","        self.size = min(self.size + 1, self.max_size)\n","\n","    def sample_batch(self) -> Dict[str, np.ndarray]:\n","        \"\"\"Randomly sample a batch of experiences from memory.\"\"\"\n","        idxs = np.random.choice(self.size, size=self.batch_size, replace=False)\n","        return dict(obs=self.obs_buf[idxs],\n","                    next_obs=self.next_obs_buf[idxs],\n","                    acts=self.acts_buf[idxs],\n","                    rews=self.rews_buf[idxs],\n","                    done=self.done_buf[idxs])\n","\n","    def __len__(self) -> int:\n","        return self.size"]},{"cell_type":"markdown","metadata":{"id":"NVcUoeRUnJf6"},"source":["## Network\n","We are going to use three different networks for policy, Q-function, and V-function. We use two Q-functions to mitigate positive bias and softly update V-function for stable learning. One interesting thing is that the policy network works as Tanh Normal distribution which enforces action bounds. (The details are descibed in Appendix C of [2].)"]},{"cell_type":"code","execution_count":null,"metadata":{"id":"k6t4fOA9nJf7"},"outputs":[],"source":["#@title\n","def init_layer_uniform(layer: nn.Linear, init_w: float = 3e-3) -> nn.Linear:\n","    \"\"\"Init uniform parameters on the single layer.\"\"\"\n","    layer.weight.data.uniform_(-init_w, init_w)\n","    layer.bias.data.uniform_(-init_w, init_w)\n","\n","    return layer\n","\n","\n","class Actor(nn.Module):\n","    def __init__(\n","        self, \n","        in_dim: int, \n","        out_dim: int,\n","        log_std_min: float = -20,\n","        log_std_max: float = 2,\n","    ):\n","        \"\"\"Initialize.\"\"\"\n","        super(Actor, self).__init__()\n","        \n","        # set the log std range\n","        self.log_std_min = log_std_min\n","        self.log_std_max = log_std_max\n","        \n","        # set the hidden layers\n","        self.hidden1 = nn.Linear(in_dim, 128)\n","        self.hidden2 = nn.Linear(128, 128)\n","        \n","        # set log_std layer\n","        self.log_std_layer = nn.Linear(128, out_dim)\n","        self.log_std_layer = init_layer_uniform(self.log_std_layer)\n","\n","        # set mean layer\n","        self.mu_layer = nn.Linear(128, out_dim)\n","        self.mu_layer = init_layer_uniform(self.mu_layer)\n","\n","    def distribution(self, state: torch.Tensor) -> torch.distributions.distribution.Distribution:\n","        x = F.relu(self.hidden1(state))\n","        x = F.relu(self.hidden2(x))\n","\n","        # get mean\n","        mu = self.mu_layer(x).tanh()\n","\n","        # get std\n","        log_std = self.log_std_layer(x).tanh()\n","        log_std = self.log_std_min + 0.5 * (self.log_std_max - self.log_std_min) * (log_std + 1)\n","        std = torch.exp(log_std)\n","\n","        return Normal(mu, std)\n","\n","    def forward(self, state: torch.Tensor) -> torch.Tensor:\n","        \"\"\"Forward method implementation.\"\"\"\n","        dist = self.distribution(state)\n","        # sample action\n","        z = dist.rsample()\n","        # normalize action\n","        action = z.tanh()\n","        log_prob = dist.log_prob(z) - torch.log(1 - action.pow(2) + 1e-7)\n","        log_prob = log_prob.sum(-1, keepdim=True)\n","        \n","        return action, log_prob\n","\n","    def log_prob(self, state: torch.Tensor, action: torch.Tensor) -> torch.Tensor:\n","        dist = self.distribution(state)\n","        z = torch.atanh(action)\n","        log_prob = dist.log_prob(z) - torch.log(1 - action.pow(2) + 1e-7)\n","        log_prob = log_prob.sum(-1, keepdim=True)\n","        return log_prob\n","    \n","\n","class CriticQ(nn.Module):\n","    def __init__(self, in_dim: int):\n","        \"\"\"Initialize.\"\"\"\n","        super(CriticQ, self).__init__()\n","        \n","        self.hidden1 = nn.Linear(in_dim, 128)\n","        self.hidden2 = nn.Linear(128, 128)\n","        self.out = nn.Linear(128, 1)\n","        self.out = init_layer_uniform(self.out)\n","\n","    def forward(\n","        self, state: torch.Tensor, action: torch.Tensor\n","    ) -> torch.Tensor:\n","        \"\"\"Forward method implementation.\"\"\"\n","        x = torch.cat((state, action), dim=-1)\n","        x = F.relu(self.hidden1(x))\n","        x = F.relu(self.hidden2(x))\n","        value = self.out(x)\n","\n","        return value\n","\n","\n","class CriticV(nn.Module):\n","    def __init__(self, in_dim: int):\n","        \"\"\"Initialize.\"\"\"\n","        super(CriticV, self).__init__()\n","        \n","        self.hidden1 = nn.Linear(in_dim, 128)\n","        self.hidden2 = nn.Linear(128, 128)\n","        self.out = nn.Linear(128, 1)\n","        self.out = init_layer_uniform(self.out)\n","\n","    def forward(self, state: torch.Tensor) -> torch.Tensor:\n","        \"\"\"Forward method implementation.\"\"\"\n","        x = F.relu(self.hidden1(state))\n","        x = F.relu(self.hidden2(x))\n","        value = self.out(x)\n","\n","        return value"]},{"cell_type":"markdown","metadata":{"id":"kyKnz5OUnJf7"},"source":["## SAC Agent"]},{"cell_type":"code","execution_count":null,"metadata":{"id":"Rky77rpLnJf8"},"outputs":[],"source":["#@title\n","class SACAgent:\n","    \"\"\"SAC agent interacting with environment.\n","    \n","    Attributes:\n","        actor (nn.Module): actor model to select actions\n","        actor_optimizer (Optimizer): optimizer for training actor\n","        vf (nn.Module): critic model to predict state values\n","        vf_target (nn.Module): target critic model to predict state values\n","        vf_optimizer (Optimizer): optimizer for training vf\n","        qf_1 (nn.Module): critic model to predict state-action values\n","        qf_2 (nn.Module): critic model to predict state-action values\n","        qf_1_optimizer (Optimizer): optimizer for training qf_1\n","        qf_2_optimizer (Optimizer): optimizer for training qf_2\n","        env (gym.Env): openAI Gym environment\n","        memory (ReplayBuffer): replay memory\n","        batch_size (int): batch size for sampling\n","        gamma (float): discount factor\n","        tau (float): parameter for soft target update\n","        alpha (float): weight for entropy\n","        update_freq (int): parameter update frequency\n","    \"\"\"\n","    \n","    def __init__(\n","        self,\n","        env: gym.Env,\n","        memory_size: int,\n","        batch_size: int,\n","        gamma: float = 0.99,\n","        tau: float = 1e-3,\n","        alpha: float = 1,\n","        update_freq: int = 1):\n","        \"\"\"Initialize.\"\"\"\n","        obs_dim = env.observation_space.shape[0]\n","        action_dim = env.action_space.shape[0]\n","\n","        self.env = env\n","        self.batch_size = batch_size\n","        self.gamma = gamma\n","        self.tau = tau\n","        self.alpha = alpha\n","        self.update_freq = update_freq\n","        self.memory = ReplayBuffer(obs_dim, action_dim, memory_size, batch_size)\n","\n","        # device: cpu / gpu\n","        self.device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n","\n","        # actor\n","        self.actor = Actor(obs_dim, action_dim).to(self.device)\n","        \n","        # v function\n","        self.vf = CriticV(obs_dim).to(self.device)\n","        self.vf_target = CriticV(obs_dim).to(self.device)\n","        self.vf_target.load_state_dict(self.vf.state_dict())\n","        \n","        # q function\n","        self.qf_1 = CriticQ(obs_dim + action_dim).to(self.device)\n","        self.qf_2 = CriticQ(obs_dim + action_dim).to(self.device)\n","\n","        # optimizers\n","        self.actor_optimizer = optim.Adam(self.actor.parameters(), lr=1e-4)\n","        self.vf_optimizer = optim.Adam(self.vf.parameters(), lr=1e-3)\n","        self.qf_1_optimizer = optim.Adam(self.qf_1.parameters(), lr=1e-3)\n","        self.qf_2_optimizer = optim.Adam(self.qf_2.parameters(), lr=1e-3)\n","\n","        # Initialize time step (for updating every UPDATE_EVERY steps)\n","        self.t_step = 0\n","    \n","    def act(self, state: np.ndarray, eps=0.) -> np.ndarray:\n","        \"\"\"Select an action from the input state.\"\"\"\n","        state_tensor = torch.FloatTensor(state).to(self.device)\n","        # if initial random action should be conducted\n","        if random.random() > eps:\n","            selected_action = self.env.action_space.sample()\n","        else:\n","            selected_action = self.actor(state_tensor)[0].detach().cpu().numpy()\n","        return selected_action\n","\n","    def log_prob(self, state: np.ndarray, action: np.ndarray) -> np.ndarray:\n","        \"\"\"Get log-probability of taking particular action.\"\"\"\n","        state_tensor = torch.FloatTensor(state).to(self.device)\n","        action_tensor = torch.FloatTensor(action).to(self.device)\n","        return self.actor.log_prob(state_tensor, action_tensor).detach().cpu().numpy()\n","\n","    def step(self, state: np.ndarray, action: np.ndarray, \n","             reward: float, next_state: np.ndarray, done: bool):\n","        \"\"\"Take an action and return the response of the env.\"\"\"\n","        self.memory.store(state, action, reward, next_state, done)\n","\n","        self.t_step = (self.t_step + 1) % self.update_freq\n","        if self.t_step == 0:\n","          if (self.memory.size >= self.batch_size):\n","            self.learn()\n","\n","    def learn(self) -> Tuple[torch.Tensor, ...]:\n","        \"\"\"Update the model by gradient descent.\"\"\"\n","        device = self.device  # for shortening the following lines\n","        \n","        samples = self.memory.sample_batch()\n","        state = torch.FloatTensor(samples[\"obs\"]).to(device)\n","        next_state = torch.FloatTensor(samples[\"next_obs\"]).to(device)\n","        action = torch.FloatTensor(samples[\"acts\"]).to(device)\n","        reward = torch.FloatTensor(samples[\"rews\"].reshape(-1, 1)).to(device)\n","        done = torch.FloatTensor(samples[\"done\"].reshape(-1, 1)).to(device)\n","        new_action, log_prob = self.actor(state)\n","\n","        # Q function loss\n","        mask = 1 - done\n","        q_1_pred = self.qf_1(state, action)\n","        q_2_pred = self.qf_2(state, action)\n","        v_target = self.vf_target(next_state)\n","        q_target = reward + self.gamma * v_target * mask\n","        qf_1_loss = F.mse_loss(q_1_pred, q_target.detach())\n","        qf_2_loss = F.mse_loss(q_2_pred, q_target.detach())\n","\n","        # train Q functions\n","        self.qf_1_optimizer.zero_grad()\n","        qf_1_loss.backward()\n","        self.qf_1_optimizer.step()\n","\n","        self.qf_2_optimizer.zero_grad()\n","        qf_2_loss.backward()\n","        self.qf_2_optimizer.step()\n","\n","        # V function loss\n","        v_pred = self.vf(state)\n","        q_pred = torch.min(self.qf_1(state, new_action), self.qf_2(state, new_action))\n","        v_target = q_pred - self.alpha * log_prob\n","        vf_loss = F.mse_loss(v_pred, v_target.detach())\n","\n","        # train V function\n","        self.vf_optimizer.zero_grad()\n","        vf_loss.backward()\n","        self.vf_optimizer.step()\n","\n","        # actor loss\n","        advantage = q_pred - v_pred.detach()\n","        actor_loss = (self.alpha * log_prob - advantage).mean()\n","\n","        # train actor\n","        self.actor_optimizer.zero_grad()\n","        actor_loss.backward()\n","        self.actor_optimizer.step()\n","\n","        # ------------------- update target network ------------------- #\n","        self._target_soft_update()\n","\n","    def _target_soft_update(self):\n","        \"\"\"Soft-update: target = tau*local + (1-tau)*target.\"\"\"\n","        tau = self.tau\n","        \n","        for t_param, l_param in zip(self.vf_target.parameters(), self.vf.parameters()):\n","            t_param.data.copy_(tau * l_param.data + (1.0 - tau) * t_param.data)"]},{"cell_type":"markdown","source":["Test training\n","--"],"metadata":{"id":"OFBhwMEBCxdN"}},{"cell_type":"code","execution_count":null,"metadata":{"id":"L4JEeTIInJgA"},"outputs":[],"source":["# parameters\n","memory_size = 100000\n","batch_size = 256"]},{"cell_type":"code","source":["def sac(agent, n_episodes=5000, max_t=1000, initial_random_episodes=0, eps_start=1.0, eps_end=0.01, eps_decay=0.999,\n","        save_dir=SYM_PATH):\n","    \"\"\"SAC:\n","    \n","    Params\n","    ======\n","        n_episodes (int): maximum number of training episodes\n","        max_t (int): maximum number of timesteps per episode\n","        eps_start (float): starting value of epsilon, for epsilon-greedy action selection\n","        eps_end (float): minimum value of epsilon\n","        eps_decay (float): multiplicative factor (per episode) for decreasing epsilon\n","    \"\"\"\n","    env = agent.env\n","    scores = []                        # list containing scores from each episode\n","    scores_window = deque(maxlen=100)  # last 100 scores\n","    eps = eps_start                    # initialize epsilon\n","    for i_episode in range(1, n_episodes+1):\n","        state = env.reset()\n","        score = 0\n","        eps = 1. if i_episode < initial_random_episodes else eps\n","        for t in range(max_t):\n","            action = agent.act(state, eps)\n","            next_state, reward, done, _ = env.step(action)\n","            agent.step(state, action, reward, next_state, done)\n","            state = next_state\n","            score += reward\n","            if done:\n","                break \n","        scores_window.append(score)       # save most recent score\n","        scores.append(score)              # save most recent score\n","        # decrease epsilon\n","        if i_episode >= initial_random_episodes:\n","          eps = max(eps_end, eps_decay * eps)\n","        print('\\rEpisode {}\\tAverage Score: {:.2f}'.format(i_episode, np.mean(scores_window)), end=\"\")\n","        if i_episode % 100 == 0:\n","            print('\\rEpisode {}\\tAverage Score: {:.2f}'.format(i_episode, np.mean(scores_window)))\n","            torch.save(agent.actor.state_dict(), os.path.join(save_dir, 'checkpoint_{}.pth'.format(i_episode)))\n","        if np.mean(scores_window) >= 200.0:\n","            print('\\nEnvironment solved in {:d} episodes!\\tAverage Score: {:.2f}'.format(i_episode-100, np.mean(scores_window)))\n","            torch.save(agent.actor.state_dict(), os.path.join(save_dir, 'checkpoint.pth'))\n","            break\n","    return scores"],"metadata":{"id":"ctGHzSh8GkAE"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["for side_engine_power in np.linspace(0.2, SIDE_ENGINE_POWER, num=11):\n","  env = LunarLander(continuous=True, side_engine_power=side_engine_power)\n","  env = ActionNormalizer(env)\n","\n","  agent = SACAgent(env, memory_size, batch_size)\n","  save_dir = os.path.join(SYM_PATH, 'checkpoints_{}'.format(side_engine_power))\n","  if not os.path.exists(save_dir):\n","    os.makedirs(save_dir)\n","  print(save_dir)\n","  _ = sac(agent, n_episodes=600, max_t=1000, save_dir=save_dir)"],"metadata":{"id":"BbYfHOjufl03"},"execution_count":null,"outputs":[]},{"cell_type":"markdown","source":["Collect demonstrations\n","--"],"metadata":{"id":"AKhfahVS9d7A"}},{"cell_type":"code","source":["def collect_demonstrations(agent, checkpoint_path, helipad_x=0.5, dataset_size=10000, max_t=1000):\n","    env = agent.env\n","    agent.actor.load_state_dict(\n","        torch.load(os.path.join(SYM_PATH, checkpoint_path), map_location=agent.device))\n","    dataset = []\n","    while len(dataset) < dataset_size:\n","      state = env.reset(helipad_x=helipad_x)\n","      for t in range(max_t):\n","          action = agent.act(state)\n","          dataset.append((state, action))\n","          state, reward, done, _ = env.step(action)\n","          if done:\n","            break\n","    env.close()\n","    return dataset"],"metadata":{"id":"5EpQQNldJfwo"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["env = LunarLander(continuous=True, side_engine_power=SIDE_ENGINE_POWER)\n","agent = SACAgent(env, memory_size, batch_size)\n","dataset = collect_demonstrations(agent, 'checkpoints_2.0/checkpoint_600.pth', helipad_x=0.5)"],"metadata":{"id":"o5J-pE9LKCWx"},"execution_count":null,"outputs":[]},{"cell_type":"markdown","source":["Experiment 1: Error in training iteration.\n","--"],"metadata":{"id":"n2mydOZ1Qy6Q"}},{"cell_type":"code","source":["env = LunarLander(continuous=True, side_engine_power=SIDE_ENGINE_POWER)\n","optimal_agent = SACAgent(env, memory_size, batch_size)\n","optimal_agent.actor.load_state_dict(\n","        torch.load(\n","            os.path.join(SYM_PATH, 'checkpoints_2.0/checkpoint_600.pth'), map_location=optimal_agent.device))\n","helipad_xs = np.linspace(0.4, 1.0, num=32)\n","\n","true_helipad_xs = [0.5, 0.7, 0.9]\n","iters = [100, 200, 300, 400, 500, 600]\n","num_simulations = 100\n","policy_errs = np.empty((len(true_helipad_xs), len(iters), num_simulations))\n","\n","for i, helipad_x_star in enumerate(true_helipad_xs):\n","  for j, num_iter in enumerate(iters):\n","    for run in range(num_simulations):\n","      biased_agent = SACAgent(env, memory_size, batch_size)\n","      dataset = collect_demonstrations(\n","          biased_agent, 'checkpoints_2.0/checkpoint_{}.pth'.format(num_iter), helipad_x=helipad_x_star, dataset_size=10000)\n","      policy_err = 0\n","      for (state, _) in dataset:\n","        state_tensor = torch.FloatTensor(state).to(biased_agent.device)\n","        p = biased_agent.actor.distribution(state_tensor)\n","        q = optimal_agent.actor.distribution(state_tensor)\n","        kl_div = torch.distributions.kl_divergence(p, q).detach().cpu().numpy().mean()\n","        policy_err += kl_div / len(dataset)\n"," \n","      policy_errs[i, j, run] = policy_err\n","    print(helipad_x_star, num_iter, policy_errs[i, j].mean(axis=-1))"],"metadata":{"id":"BMit1EMqUipA"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["env = LunarLander(continuous=True, side_engine_power=SIDE_ENGINE_POWER)\n","optimal_agent = SACAgent(env, memory_size, batch_size)\n","optimal_agent.actor.load_state_dict(\n","        torch.load(os.path.join(SYM_PATH, 'checkpoints_2.0/checkpoint_600.pth'), map_location=agent.device))\n","helipad_xs = np.linspace(0.4, 1.0, num=32)\n","\n","true_helipad_xs = [0.5, 0.7, 0.9]\n","iters = [100, 200, 300, 400, 500, 600]\n","num_simulations = 10\n","reward_errs = np.empty((len(true_helipad_xs), len(iters), num_simulations))\n","\n","for i, helipad_x_star in enumerate(true_helipad_xs):\n","  for j, num_iter in enumerate(iters):\n","    for run in range(num_simulations):\n","      biased_agent = SACAgent(env, memory_size, batch_size)\n","      dataset = collect_demonstrations(\n","          biased_agent, 'checkpoints_2.0/checkpoint_{}.pth'.format(num_iter), helipad_x=helipad_x_star, dataset_size=10000)\n","\n","      lls = np.zeros_like(helipad_xs)\n","      for (state, action) in dataset:\n","        state_batch = np.tile(state[None, ...], (len(helipad_xs), 1))\n","        action_batch = np.tile(action[None, ...], (len(helipad_xs), 1))\n","        state_batch[:, 0] += (helipad_x_star - helipad_xs) / 2\n","        lls += optimal_agent.log_prob(state_batch, action_batch)[:, 0]\n","\n","      helipad_x_hat = helipad_xs[np.argmax(lls)]\n","      reward_errs[i, j, run] = (helipad_x_star - helipad_x_hat)**2\n","    print(helipad_x_star, num_iter, reward_errs[i, j].mean(axis=-1))"],"metadata":{"id":"NR6M--bKKfFb"},"execution_count":null,"outputs":[]},{"cell_type":"markdown","source":["Experiment 2: Error in transition bias\n","--"],"metadata":{"id":"MxdrtW21Q1w-"}},{"cell_type":"code","source":["env = LunarLander(continuous=True, side_engine_power=SIDE_ENGINE_POWER)\n","optimal_agent = SACAgent(env, memory_size, batch_size)\n","optimal_agent.actor.load_state_dict(\n","        torch.load(\n","            os.path.join(SYM_PATH, 'checkpoints_2.0/checkpoint_600.pth'), map_location=optimal_agent.device))\n","helipad_xs = np.linspace(0.4, 1.0, num=32)\n","\n","true_helipad_xs = [0.5, 0.7, 0.9]\n","side_engine_powers = [1.0, 1.2, 1.4, 1.6, 1.8, 2.0]\n","num_simulations = 100\n","policy_errs = np.empty((len(true_helipad_xs), len(side_engine_powers), num_simulations))\n","\n","for i, helipad_x_star in enumerate(true_helipad_xs):\n","  for j, side_engine_power in enumerate(side_engine_powers):\n","    for run in range(num_simulations):\n","      biased_agent = SACAgent(env, memory_size, batch_size)\n","      dataset = collect_demonstrations(\n","          biased_agent, 'checkpoints_{}/checkpoint_600.pth'.format(side_engine_power), helipad_x=helipad_x_star, dataset_size=10000)\n","      policy_err = 0\n","      for (state, _) in dataset:\n","        state_tensor = torch.FloatTensor(state).to(biased_agent.device)\n","        p = biased_agent.actor.distribution(state_tensor)\n","        q = optimal_agent.actor.distribution(state_tensor)\n","        kl_div = torch.distributions.kl_divergence(p, q).detach().cpu().numpy().mean()\n","        policy_err += kl_div / len(dataset)\n"," \n","      policy_errs[i, j, run] = policy_err\n","    print(helipad_x_star, side_engine_power, policy_errs[i, j].mean(axis=-1))"],"metadata":{"id":"lYk3llIM9uTO"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["env = LunarLander(continuous=True, side_engine_power=SIDE_ENGINE_POWER)\n","optimal_agent = SACAgent(env, memory_size, batch_size)\n","optimal_agent.actor.load_state_dict(\n","        torch.load(os.path.join(SYM_PATH, 'checkpoints_2.0/checkpoint_600.pth'), map_location=optimal_agent.device))\n","helipad_xs = np.linspace(0.4, 1.0, num=32)\n","\n","true_helipad_xs = [0.5, 0.7, 0.9]\n","side_engine_powers = [1.0, 1.2, 1.4, 1.6, 1.8, 2.0]\n","num_simulations = 10\n","reward_errs = np.empty((len(true_helipad_xs), len(side_engine_powers), num_simulations))\n","\n","for i, helipad_x_star in enumerate(true_helipad_xs):\n","  for j, side_engine_power in enumerate(side_engine_powers):\n","    for run in range(num_simulations):\n","      biased_agent = SACAgent(env, memory_size, batch_size)\n","      dataset = collect_demonstrations(\n","          biased_agent, 'checkpoints_{}/checkpoint_600.pth'.format(side_engine_power), helipad_x=helipad_x_star, dataset_size=10000)\n","\n","      lls = np.zeros_like(helipad_xs)\n","      for (state, action) in dataset:\n","        state_batch = np.tile(state[None, ...], (len(helipad_xs), 1))\n","        action_batch = np.tile(action[None, ...], (len(helipad_xs), 1))\n","        state_batch[:, 0] += (helipad_x_star - helipad_xs) / 2\n","        lls += optimal_agent.log_prob(state_batch, action_batch)[:, 0]\n","\n","      helipad_x_hat = helipad_xs[np.argmax(lls)]\n","      reward_errs[i, j, run] = (helipad_x_star - helipad_x_hat)**2\n","    print(helipad_x_star, side_engine_power, reward_errs[i, j].mean(axis=-1))"],"metadata":{"id":"Gh-QyB-F9uTV"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":[""],"metadata":{"id":"0F2YGLUFPtbP"},"execution_count":null,"outputs":[]}]}