#!/usr/bin/env python3
"""Copied from https://github.com/ripl/diffusion-for-shared-autonomy/blob/main/diffusha/data_collection/env/lunar_lander.py"""

__credits__ = ["Andrea PIERRÉ"]

import math
import sys
from typing import Optional
import random   # BM: needed to subsample goal space

import numpy as np

import numpy as np
import pygame
from pygame import gfxdraw


import Box2D
from Box2D.b2 import (
    edgeShape,
    circleShape,
    fixtureDef,
    polygonShape,
    revoluteJointDef,
    contactListener,
)

import gym
import gymnasium
from gym import error, spaces
from gym.utils import seeding, EzPickle

FPS = 50
SCALE = 30.0  # affects how fast-paced the game is, forces should be adjusted as well

MAIN_ENGINE_POWER = 13.0
SIDE_ENGINE_POWER = 0.6

INITIAL_RANDOM = 1000.0  # Set 1500 to make game harder

LANDER_POLY = [(-14, +17), (-17, 0), (-17, -10), (+17, -10), (+17, 0), (+14, +17)]
LEG_AWAY = 20
LEG_DOWN = 18
LEG_W, LEG_H = 2, 8
LEG_SPRING_TORQUE = 40
TARGET_R = 25

SIDE_ENGINE_HEIGHT = 14.0
SIDE_ENGINE_AWAY = 12.0

VIEWPORT_W = 600
VIEWPORT_H = 400


class ContactDetector(contactListener):
    def __init__(self, env):
        contactListener.__init__(self)
        self.env = env

    def BeginContact(self, contact):
        if (
            self.env.lander == contact.fixtureA.body
            or self.env.lander == contact.fixtureB.body
        ):
            # print('contact detected!! Game Over!!')
            self.env.game_over = True
        for i in range(2):
            if self.env.legs[i] in [contact.fixtureA.body, contact.fixtureB.body]:
                self.env.legs[i].ground_contact = True

    def EndContact(self, contact):
        for i in range(2):
            if self.env.legs[i] in [contact.fixtureA.body, contact.fixtureB.body]:
                self.env.legs[i].ground_contact = False


class LunarLander(gym.Env, EzPickle):
    """
    ### Description
    This environment is a classic rocket trajectory optimization problem.
    According to Pontryagin's maximum principle, it is optimal to fire the
    engine at full throttle or turn it off. This is the reason why this
    environment has discrete actions: engine on or off.

    There are two environment versions: discrete or continuous.
    The landing pad is always at coordinates (0,0). The coordinates are the
    first two numbers in the state vector.
    Landing outside of the landing pad is possible. Fuel is infinite, so an agent
    can learn to fly and then land on its first attempt.

    To see a heuristic landing, run:
    ```
    python gym/envs/box2d/lunar_lander.py
    ```
    <!-- To play yourself, run: -->
    <!-- python examples/agents/keyboard_agent.py LunarLander-v2 -->

    ### Action Space
    There are four discrete actions available: do nothing, fire left
    orientation engine, fire main engine, fire right orientation engine.

    ### Observation Space
    There are 8 states: the coordinates of the lander in `x` & `y`, its linear
    velocities in `x` & `y`, its angle, its angular velocity, and two booleans
    that represent whether each leg is in contact with the ground or not.

    ### Rewards
    Reward for moving from the top of the screen to the landing pad and coming
    to rest is about 100-140 points.
    If the lander moves away from the landing pad, it loses reward.
    If the lander crashes, it receives an additional -100 points. If it comes
    to rest, it receives an additional +100 points. Each leg with ground
    contact is +10 points.
    Firing the main engine is -0.3 points each frame. Firing the side engine
    is -0.03 points each frame. Solved is 200 points.

    ### Starting State
    The lander starts at the top center of the viewport with a random initial
    force applied to its center of mass.

    ### Episode Termination
    The episode finishes if:
    1) the lander crashes (the lander body gets in contact with the moon);
    2) the lander gets outside of the viewport (`x` coordinate is greater than 1);
    3) the lander is not awake. From the [Box2D docs](https://box2d.org/documentation/md__d_1__git_hub_box2d_docs_dynamics.html#autotoc_md61),
        a body which is not awake is a body which doesn't move and doesn't
        collide with any other body:
    > When Box2D determines that a body (or group of bodies) has come to rest,
    > the body enters a sleep state which has very little CPU overhead. If a
    > body is awake and collides with a sleeping body, then the sleeping body
    > wakes up. Bodies will also wake up if a joint or contact attached to
    > them is destroyed.

    ### Arguments
    To use to the _continuous_ environment, you need to specify the
    `continuous=True` argument like below:
    ```python
    import gym
    env = gym.make("LunarLander-v2", continuous=True)
    ```

    ### Version History
    - v2: Count energy spent
    - v1: Legs contact with ground added in state vector; contact with ground
        give +10 reward points, and -10 if then lose contact; reward
        renormalized to 200; harder initial random push.
    - v0: Initial version

    <!-- ### References -->

    ### Credits
    Created by Oleg Klimov
    """

    metadata = {"render.modes": ["human", "rgb_array"], "video.frames_per_second": FPS}

    def __init__(self, continuous: bool = False, fuel_penalty: bool = True, task: str = 'land', randomize_helipad: bool = False, spec = None, N=10, exploring_starts: bool = False):
        EzPickle.__init__(self)
        self.screen = None
        self.isopen = True
        self.world = Box2D.b2World()
        self.moon = None
        self.lander = None
        self.particles = []
        self.sky_polys = []  # BM: render_mode="human" causes crash w/o line
        self.drawlist = []   # BM: render_mode="human" causes crash w/o line
        self._exploring_starts = exploring_starts  # BM: added exploring starts for policy rollouts

        self.prev_reward = None

        self.continuous = continuous
        # BM: added to allow masking of goal info when collecting expert demonstrations
        self.goal_mask = np.array([True,
                                   True,
                                   True,
                                   True,
                                   True,
                                   True,
                                   True,
                                   True,
                                   False])
        # BM: used to determine if trial was succesful
        self.solved_reward = 200

        # BM: randomly constrains the goal space to N goals
        self._N = N
        self.initialize_goal_space()

        _state_size = 8
        if task == 'reach':
            _state_size += 2
        if randomize_helipad:
            _state_size += 1

        # useful range is -1 .. +1, but spikes can be higher
        self.observation_space = spaces.Box(
            -np.inf, np.inf, shape=(_state_size,), dtype=np.float32
        )

        if self.continuous:
            # Action is two floats [main engine, left-right engines].
            # Main engine: -1..0 off, 0..+1 throttle from 50% to 100% power. Engine can't work with less than 50% power.
            # Left-right:  -1.0..-0.5 fire left engine, +0.5..+1.0 fire right engine, -0.5..0.5 off
            self.action_space = spaces.Box(-1, +1, (2,), dtype=np.float32)
        else:
            # Nop, fire left engine, main engine, right engine
            self.action_space = spaces.Discrete(4)

        self.fuel_penalty = fuel_penalty

        print('LunarLander\ttask:', task)
        assert task in ['reach', 'land', 'float'], f'Unknown task: {task}'
        self.task = task
        self.randomize_helipad = randomize_helipad

        self.spec = spec

    def initialize_goal_space(self):
        CHUNKS = 11
        self._chunk_options = [_ for _ in range(1, CHUNKS-1)]
        if self._N < len(self._chunk_options):
            self._chunk_options = random.sample(self._chunk_options, self._N)

    def insert_goals(self, goal_agnostic_obs):
        """_summary_

        Args:
            goal_agnostic_obs (_type_): _description_
        """        
        GOAL_AGNOSTIC_OBS_DIM = 8
        OBS_DIM = 9
        CHUNKS = 11
        W = 20

        helipad_chunks = self._chunk_options


        chunk_x = [W / (CHUNKS - 1) * i for i in range(CHUNKS)]
        num_goals, goal_dim = len(helipad_chunks), 1
        helipad_x1 = np.array(chunk_x)[np.array(self._chunk_options)-1]
        helipad_x2 = np.array(chunk_x)[np.array(self._chunk_options)+1]
        helipad_x = (helipad_x1 + helipad_x2) / 2
        helipad_x = (helipad_x - VIEWPORT_W/SCALE/2) / (VIEWPORT_W/SCALE/2)

        goal_agnostic_obs = goal_agnostic_obs.reshape(1, GOAL_AGNOSTIC_OBS_DIM)
        goal_agnostic_obs = goal_agnostic_obs.repeat(num_goals, axis=0)

        observations = np.hstack([goal_agnostic_obs, helipad_x.reshape(num_goals, goal_dim)])
        return observations

    def _destroy(self):
        if not self.moon:
            return
        self.world.contactListener = None
        self._clean_particles(True)
        self.world.DestroyBody(self.moon)
        self.moon = None
        self.world.DestroyBody(self.lander)
        self.lander = None
        self.world.DestroyBody(self.legs[0])
        self.world.DestroyBody(self.legs[1])

    def reset(
        self,
        *,
        seed: Optional[int] = None,
        return_info: bool = False,
        options: Optional[dict] = None,
    ):
        super().reset(seed=seed)
        if seed is not None:
            self.np_random = np.random.default_rng(seed=seed)
        self._destroy()
        self.world.contactListener_keepref = ContactDetector(self)
        self.world.contactListener = self.world.contactListener_keepref
        self.game_over = False
        self.prev_shaping = None

        W = VIEWPORT_W / SCALE
        H = VIEWPORT_H / SCALE

        # terrain
        CHUNKS = 11
        height = self.np_random.uniform(0, H / 2, size=(CHUNKS + 1,))
        chunk_x = [W / (CHUNKS - 1) * i for i in range(CHUNKS)]

        # randomize helipad x coord
        if self.randomize_helipad:
            helipad_chunk = self.np_random.choice(self._chunk_options)
        else:
            helipad_chunk = CHUNKS // 2

        self.helipad_x1 = chunk_x[helipad_chunk - 1]
        self.helipad_x2 = chunk_x[helipad_chunk + 1]
        self.helipad_y = H / 4
        height[helipad_chunk - 2] = self.helipad_y
        height[helipad_chunk - 1] = self.helipad_y
        height[helipad_chunk + 0] = self.helipad_y
        height[helipad_chunk + 1] = self.helipad_y
        height[helipad_chunk + 2] = self.helipad_y
        smooth_y = [
            0.33 * (height[i - 1] + height[i + 0] + height[i + 1])
            for i in range(CHUNKS)
        ]

        if self.task == 'reach':
            # Target
            radius = TARGET_R / SCALE
            # exclude middle half of the screen
            self.target_cx = self.np_random.uniform(radius, W/2. - radius)
            if self.target_cx >= W / 4.:
                self.target_cx += W / 2.
            self.target_cy = self.np_random.uniform(H/2 + radius, H - radius)
            # print('target', self.target_cx, self.target_cy)

        self.moon = self.world.CreateStaticBody(
            shapes=edgeShape(vertices=[(0, 0), (W, 0)])
        )
        self.sky_polys = []
        for i in range(CHUNKS - 1):
            p1 = (chunk_x[i], smooth_y[i])
            p2 = (chunk_x[i + 1], smooth_y[i + 1])
            self.moon.CreateEdgeFixture(vertices=[p1, p2], density=0, friction=0.1)
            self.sky_polys.append([p1, p2, (p2[0], H), (p1[0], H)])

        self.moon.color1 = (0.0, 0.0, 0.0)
        self.moon.color2 = (0.0, 0.0, 0.0)


        # BM: added exploring starts
        initial_x = VIEWPORT_W / SCALE / 2 
        initial_y = VIEWPORT_H / SCALE
        angle=0.0  # radians

        if self._exploring_starts:
            initial_x *= 3*np.random.rand()-1.5
            initial_y *= 0.5*np.random.rand() + 0.5
            angle = 3.14*(np.random.rand()-0.5)


        self.lander = self.world.CreateDynamicBody(
            position=(initial_x, initial_y),
            angle=angle,
            fixtures=fixtureDef(
                shape=polygonShape(
                    vertices=[(x / SCALE, y / SCALE) for x, y in LANDER_POLY]
                ),
                density=5.0,
                friction=0.1,
                categoryBits=0x0010,
                maskBits=0x001,  # collide only with ground
                restitution=0.0,
            ),  # 0.99 bouncy
        )
        self.lander.color1 = (128, 102, 230)
        self.lander.color2 = (77, 77, 128)
        self.lander.ApplyForceToCenter(
            (
                self.np_random.uniform(-INITIAL_RANDOM, INITIAL_RANDOM),
                self.np_random.uniform(-INITIAL_RANDOM, INITIAL_RANDOM),
            ),
            True,
        )

        
        self.legs = []
        for i in [-1, +1]:
            leg = self.world.CreateDynamicBody(
                position=(initial_x - i * LEG_AWAY / SCALE, initial_y),
                angle=angle+(i * 0.05),
                fixtures=fixtureDef(
                    shape=polygonShape(box=(LEG_W / SCALE, LEG_H / SCALE)),
                    density=1.0,
                    restitution=0.0,
                    categoryBits=0x0020,
                    maskBits=0x001,
                ),
            )
            leg.ground_contact = False
            leg.color1 = (128, 102, 230)
            leg.color2 = (77, 77, 128)
            rjd = revoluteJointDef(
                bodyA=self.lander,
                bodyB=leg,
                localAnchorA=(0, 0),
                localAnchorB=(i * LEG_AWAY / SCALE, LEG_DOWN / SCALE),
                enableMotor=True,
                enableLimit=True,
                maxMotorTorque=LEG_SPRING_TORQUE,
                motorSpeed=+0.3 * i,  # low enough not to jump back into the sky
            )
            if i == -1:
                rjd.lowerAngle = (
                    +0.9 - 0.5
                )  # The most esoteric numbers here, angled legs have freedom to travel within
                rjd.upperAngle = +0.9
            else:
                rjd.lowerAngle = -0.9
                rjd.upperAngle = -0.9 + 0.5
            leg.joint = self.world.CreateJoint(rjd)
            self.legs.append(leg)

        self.drawlist = [self.lander] + self.legs

        if not return_info:
            return self.step(np.array([0, 0]) if self.continuous else 0)[0]
        else:
            return self.step(np.array([0, 0]) if self.continuous else 0)[0], {}

    def _create_particle(self, mass, x, y, ttl):
        p = self.world.CreateDynamicBody(
            position=(x, y),
            angle=0.0,
            fixtures=fixtureDef(
                shape=circleShape(radius=2 / SCALE, pos=(0, 0)),
                density=mass,
                friction=0.1,
                categoryBits=0x0100,
                maskBits=0x001,  # collide only with ground
                restitution=0.3,
            ),
        )
        p.ttl = ttl
        self.particles.append(p)
        self._clean_particles(False)
        return p

    def _clean_particles(self, all):
        while self.particles and (all or self.particles[0].ttl < 0):
            self.world.DestroyBody(self.particles.pop(0))

    def step(self, action):
        if self.continuous:
            action = np.clip(action, -1, +1).astype(np.float32)
        else:
            assert self.action_space.contains(
                action
            ), f"{action!r} ({type(action)}) invalid "

        # Engines
        tip = (math.sin(self.lander.angle), math.cos(self.lander.angle))
        side = (-tip[1], tip[0])
        dispersion = [self.np_random.uniform(-1.0, +1.0) / SCALE for _ in range(2)]

        m_power = 0.0
        if (self.continuous and action[0] > 0.0) or (
            not self.continuous and action == 2
        ):
            # Main engine
            if self.continuous:
                m_power = (np.clip(action[0], 0.0, 1.0) + 1.0) * 0.5  # 0.5..1.0
                assert m_power >= 0.5 and m_power <= 1.0
            else:
                m_power = 1.0
            ox = (
                tip[0] * (4 / SCALE + 2 * dispersion[0]) + side[0] * dispersion[1]
            )  # 4 is move a bit downwards, +-2 for randomness
            oy = -tip[1] * (4 / SCALE + 2 * dispersion[0]) - side[1] * dispersion[1]
            impulse_pos = (self.lander.position[0] + ox, self.lander.position[1] + oy)
            p = self._create_particle(
                3.5,  # 3.5 is here to make particle speed adequate
                impulse_pos[0],
                impulse_pos[1],
                m_power,
            )  # particles are just a decoration
            p.ApplyLinearImpulse(
                (ox * MAIN_ENGINE_POWER * m_power, oy * MAIN_ENGINE_POWER * m_power),
                impulse_pos,
                True,
            )
            self.lander.ApplyLinearImpulse(
                (-ox * MAIN_ENGINE_POWER * m_power, -oy * MAIN_ENGINE_POWER * m_power),
                impulse_pos,
                True,
            )

        s_power = 0.0
        if (self.continuous and np.abs(action[1]) > 0.5) or (
            not self.continuous and action in [1, 3]
        ):
            # Orientation engines
            if self.continuous:
                direction = np.sign(action[1])
                s_power = np.clip(np.abs(action[1]), 0.5, 1.0)
                assert s_power >= 0.5 and s_power <= 1.0
            else:
                direction = action - 2
                s_power = 1.0
            ox = tip[0] * dispersion[0] + side[0] * (
                3 * dispersion[1] + direction * SIDE_ENGINE_AWAY / SCALE
            )
            oy = -tip[1] * dispersion[0] - side[1] * (
                3 * dispersion[1] + direction * SIDE_ENGINE_AWAY / SCALE
            )
            impulse_pos = (
                self.lander.position[0] + ox - tip[0] * 17 / SCALE,
                self.lander.position[1] + oy + tip[1] * SIDE_ENGINE_HEIGHT / SCALE,
            )
            p = self._create_particle(0.7, impulse_pos[0], impulse_pos[1], s_power)
            p.ApplyLinearImpulse(
                (ox * SIDE_ENGINE_POWER * s_power, oy * SIDE_ENGINE_POWER * s_power),
                impulse_pos,
                True,
            )
            self.lander.ApplyLinearImpulse(
                (-ox * SIDE_ENGINE_POWER * s_power, -oy * SIDE_ENGINE_POWER * s_power),
                impulse_pos,
                True,
            )

        self.world.Step(1.0 / FPS, 6 * 30, 2 * 30)

        pos = self.lander.position
        vel = self.lander.linearVelocity
        helipad_x = (self.helipad_x1 + self.helipad_x2) / 2

        state = [
            (pos.x - VIEWPORT_W / SCALE / 2) / (VIEWPORT_W / SCALE / 2),
            (pos.y - (self.helipad_y + LEG_DOWN / SCALE)) / (VIEWPORT_H / SCALE / 2),
            vel.x * (VIEWPORT_W / SCALE / 2) / FPS,
            vel.y * (VIEWPORT_H / SCALE / 2) / FPS,
            self.lander.angle,
            20.0 * self.lander.angularVelocity / FPS,
            1.0 if self.legs[0].ground_contact else 0.0,
            1.0 if self.legs[1].ground_contact else 0.0,
            # (helipad_x - VIEWPORT_W/SCALE/2) / (VIEWPORT_W/SCALE/2)
        ]

        if self.task == 'reach':
            state += [
                (self.target_cx - VIEWPORT_W / SCALE / 2) / (VIEWPORT_W / SCALE / 2),
                (self.target_cy - (self.helipad_y + LEG_DOWN / SCALE)) / (VIEWPORT_H / SCALE / 2),
            ]

        if self.randomize_helipad:
            state += [
                (helipad_x - VIEWPORT_W/SCALE/2) / (VIEWPORT_W/SCALE/2)
            ]


        reward = 0

        if self.task == 'reach':
            _dx = (pos.x - self.target_cx) / (VIEWPORT_W / SCALE / 2)
            _dy = (pos.y - self.target_cy) / (VIEWPORT_H / SCALE / 2)
            shaping = (
                - 100*np.sqrt(_dx*_dx + _dy*_dy) * 2
                - 100*np.sqrt(state[2]*state[2] + state[3]*state[3])
                - 100*abs(state[4])
                + 10*state[6]
                + 10*state[7]   # And ten points for legs contact, the idea is if you
                # lose contact again after landing, you get negative reward
            )
        elif self.task == 'float':
            shaping = (
                -100 * np.sqrt(state[0] * state[0] + state[1] * state[1])
                - 100 * np.sqrt(state[2] * state[2] + state[3] * state[3])
                - 100 * abs(state[4])
            )
        else:
            dx = (pos.x - helipad_x) / (VIEWPORT_W/SCALE/2)
            shaping = (
                -100 * np.sqrt(dx * dx + state[1] * state[1])
                - 100 * np.sqrt(state[2] * state[2] + state[3] * state[3])
                - 100 * abs(state[4])
                + 10 * state[6]
                + 10 * state[7]
            )  # And ten points for legs contact, the idea is if you
            # lose contact again after landing, you get negative reward
        if self.prev_shaping is not None:
            reward = shaping - self.prev_shaping
        self.prev_shaping = shaping

        if self.fuel_penalty:
            reward -= (
                m_power * 0.30
            )  # less fuel spent is better, about -30 for heuristic landing
            reward -= s_power * 0.03


        # Copied from deepassist
        oob = abs(state[0]) >= 1.0
        not_awake = not self.lander.awake
        at_site = self.helipad_x1 <= pos.x and pos.x <= self.helipad_x2

        info = {'game_over_reason': ''}
        done = False

        # Task-specific termination
        if self.task == 'reach':
            _dx, _dy = pos.x - self.target_cx, pos.y - self.target_cy
            if np.sqrt(_dx*_dx + _dy*_dy) < TARGET_R / SCALE:
                # print('target is reached')
                # target is reached
                done = True
                reward = +1000
                self.lander.color1 = (0, 255, 0)
                info['game_over_reason'] = 'reach-goal-reached'
                info['goal'] = 'target-reached'

        elif self.task == 'land':
            if not_awake:
                # print('lander not awake')
                done = True
                info['game_over_reason'] = 'land-lander-not-awake'
                if at_site:
                    info['goal'] = 'landed'
                    reward = +1000

        if self.game_over or oob:
            # print('game over', state[0])
            done = True
            reward = -100
            self.lander.color1 = (255, 0, 0)
            if oob:
                info['game_over_reason'] = 'abs(state[0])>=1.0'
            elif self.game_over:
                info['game_over_reason'] = 'body-contact'
            self.game_over = True
            info['crashed'] = True

        return np.array(state, dtype=np.float32), reward, done, info

    def render(self, mode="human"):
        if self.screen is None:
            # pygame.init()
            pygame.display.init()
            pygame.font.init()
            self.screen = pygame.display.set_mode((VIEWPORT_W, VIEWPORT_H))

        self.surf = pygame.Surface(self.screen.get_size())

        pygame.transform.scale(self.surf, (SCALE, SCALE))
        pygame.draw.rect(self.surf, (255, 255, 255), self.surf.get_rect())

        for obj in self.particles:
            obj.ttl -= 0.15
            obj.color1 = (
                int(max(0.2, 0.15 + obj.ttl) * 255),
                int(max(0.2, 0.5 * obj.ttl) * 255),
                int(max(0.2, 0.5 * obj.ttl) * 255),
            )
            obj.color2 = (
                int(max(0.2, 0.15 + obj.ttl) * 255),
                int(max(0.2, 0.5 * obj.ttl) * 255),
                int(max(0.2, 0.5 * obj.ttl) * 255),
            )

        self._clean_particles(False)

        for p in self.sky_polys:
            scaled_poly = []
            for coord in p:
                scaled_poly.append((coord[0] * SCALE, coord[1] * SCALE))
            pygame.draw.polygon(self.surf, (0, 0, 0), scaled_poly)
            gfxdraw.aapolygon(self.surf, scaled_poly, (0, 0, 0))

        for obj in self.particles + self.drawlist:
            for f in obj.fixtures:
                trans = f.body.transform
                if type(f.shape) is circleShape:
                    pygame.draw.circle(
                        self.surf,
                        color=obj.color1,
                        center=trans * f.shape.pos * SCALE,
                        radius=f.shape.radius * SCALE,
                    )
                    pygame.draw.circle(
                        self.surf,
                        color=obj.color2,
                        center=trans * f.shape.pos * SCALE,
                        radius=f.shape.radius * SCALE,
                    )

                else:
                    path = [trans * v * SCALE for v in f.shape.vertices]
                    pygame.draw.polygon(self.surf, color=obj.color1, points=path)
                    gfxdraw.aapolygon(self.surf, path, obj.color1)
                    pygame.draw.aalines(
                        self.surf, color=obj.color2, points=path, closed=True
                    )

                for x in [self.helipad_x1, self.helipad_x2]:
                    x = x * SCALE
                    flagy1 = self.helipad_y * SCALE
                    flagy2 = flagy1 + 50
                    pygame.draw.line(
                        self.surf,
                        color=(255, 255, 255),
                        start_pos=(x, flagy1),
                        end_pos=(x, flagy2),
                        width=1,
                    )
                    pygame.draw.polygon(
                        self.surf,
                        color=(204, 204, 0),
                        points=[
                            (x, flagy2),
                            (x, flagy2 - 10),
                            (x + 25, flagy2 - 5),
                        ],
                    )
                    gfxdraw.aapolygon(
                        self.surf,
                        [(x, flagy2), (x, flagy2 - 10), (x + 25, flagy2 - 5)],
                        (204, 204, 0),
                    )

        if self.task == 'reach':
            # draw target
            # t = rendering.Transform(translation=np.array([self.target_cx, self.target_cy]))
            pygame.draw.circle(self.surf, color=(255, 0, 0), center=(self.target_cx * SCALE, self.target_cy * SCALE), radius=TARGET_R)

        self.surf = pygame.transform.flip(self.surf, False, True)
        self.screen.blit(self.surf, (0, 0))

        if mode == "human":
            pygame.display.flip()
            return np.transpose(
                np.array(pygame.surfarray.pixels3d(self.surf)), axes=(1, 0, 2)
            )

        if mode == "rgb_array":
            return np.transpose(
                np.array(pygame.surfarray.pixels3d(self.surf)), axes=(1, 0, 2)
            )
        else:
            return self.isopen

    def close(self):
        if self.screen is not None:
            pygame.quit()
            self.isopen = False


def heuristic(env, s):
    """
    The heuristic for
    1. Testing
    2. Demonstration rollout.

    Args:
        env: The environment
        s (list): The state. Attributes:
                  s[0] is the horizontal coordinate
                  s[1] is the vertical coordinate
                  s[2] is the horizontal speed
                  s[3] is the vertical speed
                  s[4] is the angle
                  s[5] is the angular speed
                  s[6] 1 if first leg has contact, else 0
                  s[7] 1 if second leg has contact, else 0
    returns:
         a: The heuristic to be fed into the step function defined above to determine the next step and reward.
    """

    angle_targ = s[0] * 0.5 + s[2] * 1.0  # angle should point towards center
    if angle_targ > 0.4:
        angle_targ = 0.4  # more than 0.4 radians (22 degrees) is bad
    if angle_targ < -0.4:
        angle_targ = -0.4
    hover_targ = 0.55 * np.abs(
        s[0]
    )  # target y should be proportional to horizontal offset

    angle_todo = (angle_targ - s[4]) * 0.5 - (s[5]) * 1.0
    hover_todo = (hover_targ - s[1]) * 0.5 - (s[3]) * 0.5

    if s[6] or s[7]:  # legs have contact
        angle_todo = 0
        hover_todo = (
            -(s[3]) * 0.5
        )  # override to reduce fall speed, that's all we need after contact

    if env.continuous:
        a = np.array([hover_todo * 20 - 1, -angle_todo * 20])
        a = np.clip(a, -1, +1)
    else:
        a = 0
        if hover_todo > np.abs(angle_todo) and hover_todo > 0.05:
            a = 2
        elif angle_todo < -0.05:
            a = 3
        elif angle_todo > +0.05:
            a = 1
    return a


def demo_heuristic_lander(env, seed=None, render=False):
    total_reward = 0
    steps = 0
    s = env.reset(seed=seed)
    while True:
        a = heuristic(env, s)
        s, r, done, info = env.step(a)
        total_reward += r

        if render:
            still_open = env.render()
            if still_open == False:
                break

        if steps % 20 == 0 or done:
            print("observations:", " ".join([f"{x:+0.2f}" for x in s]))
            print(f"step {steps} total_reward {total_reward:+0.2f}")
        steps += 1
        if done:
            break
    if render:
        env.close()
    return total_reward


class LunarLanderContinuous:
    def __init__(self):
        raise error.Error(
            "Error initializing LunarLanderContinuous Environment.\n"
            "Currently, we do not support initializing this mode of environment by calling the class directly.\n"
            "To use this environment, instead create it by specifying the continuous keyword in gym.make, i.e.\n"
            'gym.make("LunarLander-v2", continuous=True)'
        )

