import cv2
import glob
import gym
import math
import numpy as np
import random
import time
from gym import spaces
from matplotlib import pyplot as plt
from matplotlib import transforms as mtransforms
from PIL import Image, ImageDraw, ImageFont

from rlm import utils

class MowerEnv(gym.Env):
    """
    TODO: add combined map which is the sum of coverage and obstacles
    TODO: edge map as input feature, i.e. highlight pixels where there is change in coverage
    TODO: add detection map (for unknown obstacles)
    TODO: refactor, add as much as possible to individual functions, and add unit tests for all

    An environment where the goal is to cover as much of an area as possible.
    The environment includes known and unknown obstacles that should be avoided.
    The input to the agent is a multi-scale map of the coverage and known obstacles.
    The maps are translated and rotated so that the mid-point of the maps
    correspond to the mowers position, and up is forward.
    The input also consists of a lidar sensor with distance measurements, which
    can sense the known and unknown obstacles.
    Multiple observations can be stacked.
    The actions that the agent can make are continuous throttle and steering.

    :param input_size: Size of the input maps [pixels], same resolution for all scales
    :param num_maps: Number of maps, i.e. scales
    :param scale_factor: Scale factor between each subsequent scale
    :param meters_per_pixel: World map resolution in meters per pixel
    :param min_size: Minimum size of free space for random maps [m]
    :param max_size: Maximum size of free space for random maps [m]
    :param stacks: Number of subsequent observations to stack
    :param step_size: Time per step [s]
    :param constant_lin_vel: Whether to use a constant linear velocity, i.e. agent only predicts steering
    :param max_lin_vel: Maximum linear velocity [m/s]
    :param max_ang_vel: Maximum angular velocity [rad/s]
    :param steering_limits_lin_vel: Whether to suppress linear velocity based on steering angle
    :param mower_radius: Radius of cutting disc [m]
    :param lidar_rays: Number of lidar rays
    :param lidar_range: Range of the lidar sensor [m]
    :param lidar_fov: Lidar field of view [degrees]
    :param position_noise: Standard deviation of position noise [m]
    :param exploration: Whether coverage is defined by lidar detections (True), or robot extent (False)
    :param overlap_observation: Whether to use overlap (True) or binary coverage (False) in the map observation
    :param frontier_observation: Whether to use a frontier map as observation
    :param eval: Eval mode, fixed eval maps, no training map progression, overrides next 6 parameters + some more
    :param p_use_known_obstacles: Probability per episode to randomly scatter known obstacles
    :param p_use_unknown_obstacles: Probability per episode to randomly scatter unknown obstacles
    :param p_use_floor_plans: Probability per episode generate random floor plans
    :param max_known_obstacles: Maximum number of known obstacles
    :param max_unknown_obstacles: Maximum number of unknown obstacles
    :param obstacle_radius: Radius of (circular) obstacles [m]
    :param all_unknown: Whether all obstacles and map geometry should be unknown, overrides padding parameters
    :param max_episode_steps: Maximum episode length
    :param max_non_new_steps: Maximum steps in a row which did not cover new ground
    :param collision_ends_episode: Whether a collisoin ends the current episode
    :param flip_when_stuck: Whether to turn the agent 180 degrees when it becomes stuck
    :param max_stuck_steps: Consequtive steps being stuck before flipping the agent
    :param start_level: Starting level for progressive training
    :param use_goal_time_in_levels: Whether to use time as a criteria for completing a level
    :param goal_coverage: Target coverage considered for task completion
    :param goal_coverage_reward: Reward for reaching the target coverage
    :param wall_collision_reward: Reward for wall collision
    :param obstacle_collision_reward: Reward for obstacle collision
    :param newly_visited_reward_scale: Reward multiplier when covering new ground
    :param newly_visited_reward_max: Maximum reward from covering new ground
    :param overlap_reward_scale: Reward multiplier for overlapping a previously covered area
    :param overlap_reward_max: Maximum overlap reward
    :param overlap_reward_always: Whether to always use overlap reward, or only when the coverage reward is 0
    :param local_tv_reward_scale: Reward multiplier for local (incremental) total variation
    :param local_tv_reward_max: Maximum reward from local (incremental) total variation
    :param global_tv_reward_scale: Reward multiplier for global total variation
    :param global_tv_reward_max: Maximum reward from global total variation
    :param use_known_obstacles_in_tv: Whether to include known obstacles when computing total variation
    :param use_unknown_obstacles_in_tv: Whether to include unknown obstacles when computing total variation
    :param obstacle_dilation: Dilation kernel size around obstacles
    :param constant_reward: Constant reward at each time step
    :param truncation_reward_scale: Reward multiplier for episode truncation
    :param coverage_pad_value: Pad value for coverage maps outside environment borders (0 or 1)
    :param obstacle_pad_value: Pad value for obstacle maps outside environment borders (0 or 1)
    :param verbose: Whether to print reward/timing information
    """

    def __init__(
        self,
        input_size = 32,
        num_maps = 4,
        scale_factor = 4,
        meters_per_pixel = 0.0375,
        min_size = None,
        max_size = None,
        stacks = 1,
        step_size = 0.5,
        constant_lin_vel = True,
        max_lin_vel = 0.26,
        max_ang_vel = 1.0,
        steering_limits_lin_vel = True,
        mower_radius = 0.15,
        lidar_rays = 24,
        lidar_range = 3.5,
        lidar_fov = 180,
        position_noise = 0,
        exploration = False,
        overlap_observation = True,
        frontier_observation = True,
        eval = False,
        p_use_known_obstacles = 0.7,
        p_use_unknown_obstacles = 0.7,
        p_use_floor_plans = 0.7,
        max_known_obstacles = 100,
        max_unknown_obstacles = 100,
        obstacle_radius = 0.25,
        all_unknown = True,
        max_episode_steps = None,
        max_non_new_steps = 1000,
        collision_ends_episode = False,
        flip_when_stuck = False,
        max_stuck_steps = 5,
        start_level = 1,
        use_goal_time_in_levels = False,
        goal_coverage = 0.9,
        goal_coverage_reward = 0,
        wall_collision_reward = -10,
        obstacle_collision_reward = -10,
        newly_visited_reward_scale = 1,
        newly_visited_reward_max = 1,
        overlap_reward_scale = 0,
        overlap_reward_max = 5,
        overlap_reward_always = False,
        local_tv_reward_scale = 1,
        local_tv_reward_max = 5,
        global_tv_reward_scale = 0,
        global_tv_reward_max = 5,
        use_known_obstacles_in_tv = True,
        use_unknown_obstacles_in_tv = True,
        obstacle_dilation = 7,
        constant_reward = -0.1,
        truncation_reward_scale = 0,
        coverage_pad_value = 0,
        obstacle_pad_value = 1,
        verbose = False,
    ):
        super(MowerEnv, self).__init__()
        self.t1 = round(time.time() * 1000)

        # Environment parameters
        self.input_size = input_size
        self.num_maps = num_maps
        self.scale_factor = scale_factor
        self.meters_per_pixel = meters_per_pixel
        self.stacks = stacks
        self.step_size = step_size
        self.constant_lin_vel = constant_lin_vel
        self.max_lin_vel = max_lin_vel
        self.max_ang_vel = max_ang_vel
        self.steering_limits_lin_vel = steering_limits_lin_vel
        self.mower_radius = mower_radius
        self.lidar_rays = lidar_rays
        self.lidar_range = lidar_range
        self.lidar_fov = lidar_fov
        self.position_noise = position_noise
        self.exploration = exploration
        self.overlap_observation = overlap_observation
        self.frontier_observation = frontier_observation
        self.eval = eval
        self.p_use_known_obstacles = p_use_known_obstacles
        self.p_use_unknown_obstacles = p_use_unknown_obstacles
        self.p_use_floor_plans = p_use_floor_plans
        self.max_known_obstacles = max_known_obstacles
        self.max_unknown_obstacles = max_unknown_obstacles
        self.obstacle_radius = obstacle_radius
        self.all_unknown = all_unknown
        self.max_episode_steps = max_episode_steps
        self.max_non_new_steps = max_non_new_steps
        self.collision_ends_episode = collision_ends_episode
        self.flip_when_stuck = flip_when_stuck
        self.max_stuck_steps = max_stuck_steps
        self.start_level = start_level
        self.use_goal_time_in_levels = use_goal_time_in_levels
        self.goal_coverage = goal_coverage
        self.goal_coverage_reward = goal_coverage_reward
        self.wall_collision_reward = wall_collision_reward
        self.obstacle_collision_reward = obstacle_collision_reward
        self.newly_visited_reward_scale = newly_visited_reward_scale
        self.newly_visited_reward_max = newly_visited_reward_max
        self.overlap_reward_scale = overlap_reward_scale
        self.overlap_reward_max = overlap_reward_max
        self.overlap_reward_always = overlap_reward_always
        self.local_tv_reward_scale = local_tv_reward_scale
        self.local_tv_reward_max = local_tv_reward_max
        self.global_tv_reward_scale = global_tv_reward_scale
        self.global_tv_reward_max = global_tv_reward_max
        self.use_known_obstacles_in_tv = use_known_obstacles_in_tv
        self.use_unknown_obstacles_in_tv = use_unknown_obstacles_in_tv
        self.obstacle_dilation = obstacle_dilation
        self.constant_reward = constant_reward
        self.truncation_reward_scale = truncation_reward_scale
        self.coverage_pad_value = coverage_pad_value
        self.obstacle_pad_value = obstacle_pad_value
        self.verbose = verbose
        self.line_type = cv2.LINE_8

        # Derived parameters
        self.pixels_per_meter = 1 / meters_per_pixel
        self.ms_reach_p = self.input_size * self.scale_factor ** (self.num_maps - 1)
        self.ms_reach_m = self.ms_reach_p * self.meters_per_pixel
        if min_size is not None and max_size is not None:
            assert min_size <= max_size
        if min_size is not None:
            self.min_size_p = min_size * self.pixels_per_meter
        elif exploration:
            self.min_size_p = 256
        else:
            self.min_size_p = 64
        if max_size is not None:
            self.max_size_p = max_size * self.pixels_per_meter
            self.min_size_p = min(self.min_size_p, self.max_size_p)
        elif exploration:
            self.max_size_p = max(400, self.min_size_p)
        else:
            self.max_size_p = max(200, self.min_size_p)

        # Read map filenames
        if exploration:
            self.eval_maps = glob.glob('maps/eval_exploration*')
        else:
            self.eval_maps = glob.glob('maps/eval_mowing*')
        self.train_maps_0 = glob.glob('maps/train_0_*')
        self.train_maps_1 = glob.glob('maps/train_1_*')
        self.train_maps_2 = glob.glob('maps/train_2_*')
        self.train_maps_3 = glob.glob('maps/train_3_*')
        self.train_maps_4 = glob.glob('maps/train_4_*')
        self.train_maps_5 = glob.glob('maps/train_5_*')

        # Additional variables, set up level, reset maps etc.
        self.axes = None
        self.current_episode = 0
        self.level = start_level
        self.next_train_map = 0
        if not eval:
            self._set_level(self.level)
        self._reset()

        # Action and observation spaces
        if constant_lin_vel:
            self.action_space = spaces.Box(low=-1, high=+1, shape=(1,), dtype=np.float32)
        else:
            self.action_space = spaces.Box(low=-1, high=+1, shape=(2,), dtype=np.float32)
        obs_shape = (self.stacks, self.num_maps, self.input_size, self.input_size)
        self.observation_space = spaces.Dict(
            coverage = spaces.Box(low=0, high=1, shape=obs_shape, dtype=np.float32),
            obstacles = spaces.Box(low=0, high=1, shape=obs_shape, dtype=np.float32),
            lidar = spaces.Box(low=0, high=1, shape=(self.stacks, self.lidar_rays), dtype=np.float32)
        )
        if frontier_observation:
            self.observation_space['frontier'] = spaces.Box(low=0, high=1, shape=obs_shape, dtype=np.float32)

    def reset(self):
        self.current_episode += 1
        self._reset()
        return self._get_stacked_observation()

    def _reset(self):
        self.elapsed_steps = 0
        self.non_new_steps = 0
        self.elapsed_time = 0
        self.num_collisions = 0
        self.stuck_steps = 0
        self.filename = None

        # Create a training environment
        if not self.eval:
            self.current_map = None
            self.use_floor_plans = False
            self.use_known_obstacles = False
            self.use_unknown_obstacles = False

            # Load fixed training map
            if not self.use_randomized_envs or random.uniform(0, 1) < 0.5:
                self.current_map = self.next_train_map
                self.next_train_map = (self.next_train_map + 1) % len(self.train_maps)
                self.filename = self.train_maps[self.current_map]
                self._load_map(self.filename)

            # Generate random map
            else:
                self.size_p = random.randint(self.min_size_p, self.max_size_p)
                self.size_m = self.size_p / self.pixels_per_meter
                self.known_obstacle_map = np.zeros((self.size_p, self.size_p), dtype=float)
                self.unknown_obstacle_map = np.zeros((self.size_p, self.size_p), dtype=float)
                self.use_floor_plans = random.uniform(0, 1) < self.p_use_floor_plans
                if self.use_floor_plans:
                    self._randomize_floor_plan()
                self.use_known_obstacles = self.max_known_obstacles > 0 and random.uniform(0, 1) < self.p_use_known_obstacles
                self.use_unknown_obstacles = self.max_unknown_obstacles > 0 and random.uniform(0, 1) < self.p_use_unknown_obstacles
                if self.use_known_obstacles or self.use_unknown_obstacles:
                    self._randomize_circular_obstacles()

        # Create an evaluation environment
        if self.eval:
            self.filename = self.eval_maps[(self.current_episode - 1) % len(self.eval_maps)]
            self._load_map(self.filename)

        # All unknown environment
        if self.all_unknown:
            self.unknown_obstacle_map = np.maximum(self.known_obstacle_map, self.unknown_obstacle_map)
            self.unknown_obstacle_map[0, :] = 1; self.unknown_obstacle_map[-1, :] = 1
            self.unknown_obstacle_map[:, 0] = 1; self.unknown_obstacle_map[:, -1] = 1
            self.known_obstacle_map = np.zeros_like(self.known_obstacle_map)
            self.coverage_pad_value = 0
            self.obstacle_pad_value = 0

        # Initial pose
        self.heading = np.random.uniform(2 * math.pi)
        self.position_m = np.random.uniform(size=2) * self.size_m
        # make sure we don't spawn in the wall or in an obstacle
        tries = 0
        while self._is_wall_collision(self.position_m, self.mower_radius) or \
              self._is_obstacle_collision(self.position_m):
            self.position_m = np.random.uniform(size=2) * self.size_m
            tries += 1
            assert tries < 100, 'Could not initialize random position in 100 tries, try reducing obstacle amount'
        self.position_p = self.position_m * self.pixels_per_meter
        if self.position_noise > 0:
            self.noisy_position_m = np.random.normal(
                loc=self.position_m, scale=self.position_noise)
        else:
            self.noisy_position_m = self.position_m
        self.noisy_position_p = self.noisy_position_m * self.pixels_per_meter

        # Initial lidar detections
        self.lidar_pts, pts_info, self.lidar_pts_max = self._compute_lidar_pts(self.position_m, self.heading)
        lidar_obs = self._compute_lidar_observation(self.lidar_pts, self.position_m)

        # Update known obstacle map based on detected unknown obstacles
        for n in range(self.lidar_rays):
            if pts_info[n] == 2:
                i = self.lidar_pts[n, 0]
                j = self.lidar_pts[n, 1]
                self.known_obstacle_map[i, j] = 1

        # Initialize coverage/overlap/frontier maps
        self.coverage_map = np.zeros((self.size_p, self.size_p), dtype=float)
        if not self.exploration:
            cv2.circle(
                self.coverage_map,
                center=np.flip(self.position_p).astype(np.int32),
                radius=int(self.mower_radius * self.pixels_per_meter),
                color=1,
                thickness=cv2.FILLED,
                lineType=self.line_type)
        else:
            cv2.fillPoly(
                self.coverage_map,
                [np.fliplr(np.concatenate((self.lidar_pts, [self.position_p.astype(np.int32)])))],
                color=1)
            idxs = np.logical_or(self.known_obstacle_map, self.unknown_obstacle_map)
            self.coverage_map[idxs] = 0
        self.overlap_map = self.coverage_map.copy()
        self.frontier_map = self._compute_frontier_map(self.coverage_map, self.known_obstacle_map)

        # Compute coverage metrics
        all_obstacles = np.maximum(
            self.known_obstacle_map,
            self.unknown_obstacle_map)
        if self.obstacle_dilation > 1:
            kernel = np.ones((self.obstacle_dilation,)*2, dtype=float)
            all_obstacles[0, :] = 1; all_obstacles[-1, :] = 1
            all_obstacles[:, 0] = 1; all_obstacles[:, -1] = 1
            all_obstacles = cv2.dilate(all_obstacles, kernel, iterations=1)
        self.total_obstacle_pixels_dilated = all_obstacles.sum()
        self.free_space_dilated = self.size_p ** 2 - self.total_obstacle_pixels_dilated
        coverage = self.coverage_map.copy()
        coverage[all_obstacles > 0] = 0
        overlap = self.overlap_map.copy()
        overlap[all_obstacles > 0] = 0
        self.coverage_in_pixels = coverage.sum()
        self.coverage_in_m2 = self.coverage_in_pixels / (self.pixels_per_meter ** 2)
        self.coverage_in_percent = self.coverage_in_pixels / self.free_space_dilated
        self.overlap_in_pixels = (overlap - coverage).sum()
        self.overlap_in_m2 = self.overlap_in_pixels / (self.pixels_per_meter ** 2)
        if self.coverage_in_pixels != 0:
            self.overlap_in_percent = self.overlap_in_pixels / self.coverage_in_pixels
        else:
            self.overlap_in_percent = 0

        # Initial observation
        self.observation = {}
        self.observation['lidar'] = np.tile(lidar_obs, (self.stacks, 1))
        if self.overlap_observation:
            cov_map = np.tanh(0.2 * self.overlap_map)
        else:
            cov_map = self.coverage_map
        self.observation['coverage'] = np.tile(self._get_multi_scale_map(cov_map, pad_value=self.coverage_pad_value), (self.stacks, 1, 1, 1))
        self.observation['obstacles'] = np.tile(self._get_multi_scale_map(self.known_obstacle_map, pad_value=self.obstacle_pad_value), (self.stacks, 1, 1, 1))
        if self.frontier_observation:
            self.observation['frontier'] = np.tile(self._get_multi_scale_map(self.frontier_map, pad_value=0, ceil=True), (self.stacks, 1, 1, 1))

        # Initial global total variation
        radius_m = self.mower_radius
        if self.exploration:
            radius_m = max(self.mower_radius, self.lidar_range)
        local_coverage = self._get_local_neighborhood(self.coverage_map, self.position_m, self.position_m, radius_m)
        local_obstacles = None
        if self.use_known_obstacles_in_tv and self.use_unknown_obstacles_in_tv:
            local_known_obstacles = self._get_local_neighborhood(self.known_obstacle_map, self.position_m, self.position_m, radius_m)
            local_unknown_obstacles = self._get_local_neighborhood(self.unknown_obstacle_map, self.position_m, self.position_m, radius_m)
            local_obstacles = np.maximum(local_known_obstacles, local_unknown_obstacles)
        elif self.use_known_obstacles_in_tv:
            local_obstacles = self._get_local_neighborhood(self.known_obstacle_map, self.position_m, self.position_m, radius_m).copy()
        elif self.use_unknown_obstacles_in_tv:
            local_obstacles = self._get_local_neighborhood(self.unknown_obstacle_map, self.position_m, self.position_m, radius_m).copy()
        if local_obstacles is not None and self.obstacle_dilation > 1:
            kernel = np.ones((self.obstacle_dilation,)*2, dtype=float)
            local_obstacles[0, :] = 1; local_obstacles[-1, :] = 1
            local_obstacles[:, 0] = 1; local_obstacles[:, -1] = 1
            local_obstacles = cv2.dilate(local_obstacles, kernel, iterations=1)
        self.global_total_variation = utils.total_variation(local_coverage, local_obstacles)

        # Keep track of the path taken (only updated/used when rendering)
        self.path = [self.position_m]

    def step(self, action):
        self.t0 = round(time.time() * 1000)
        dt_other = str(self.t0 - self.t1)

        # Compute linear and angular velocities in this step (not per second)
        # (Update machine model here)
        if self.constant_lin_vel:
            throttle = 1
            steering = action[0]
        else:
            throttle = action[0]
            steering = action[1]
        lin_vel = throttle
        if self.steering_limits_lin_vel:
            lin_vel *= (1 - abs(steering))
        lin_vel *= self.max_lin_vel * self.step_size
        if lin_vel < 0:
            lin_vel = lin_vel / 2
        ang_vel = steering
        ang_vel *= self.max_ang_vel * self.step_size

        # Compute potential new pose
        new_heading = self.heading + ang_vel
        new_heading = new_heading % (2 * math.pi)
        pos_vec = lin_vel * np.array([math.cos(new_heading), math.sin(new_heading)])
        new_position_m = self.position_m + pos_vec

        # Check for collisions
        collided = False
        done = False
        reward_coll = 0
        if self._is_wall_collision(new_position_m, self.mower_radius):
            # out of bounds
            collided = True
            reward_coll = self.wall_collision_reward
            self.num_collisions += 1
            self.stuck_steps += 1
            if self.collision_ends_episode:
                done = True
        elif self._is_obstacle_collision(new_position_m):
            # collided with obstacle
            collided = True
            reward_coll = self.obstacle_collision_reward
            self.num_collisions += 1
            self.stuck_steps += 1
            if self.collision_ends_episode:
                done = True
        else:
            self.stuck_steps = 0

        # Flip the agent if stuck
        if self.flip_when_stuck and self.stuck_steps >= self.max_stuck_steps:
            self.stuck_steps = 0
            new_heading = (new_heading + math.pi) % (2 * math.pi)

        # Update pose
        old_heading = self.heading
        old_position_m = self.position_m.copy()
        self.heading = new_heading
        if not collided:
            self.position_m = new_position_m
        else:
            new_position_m = self.position_m
        self.position_p = self.position_m * self.pixels_per_meter
        old_position_p = old_position_m * self.pixels_per_meter
        new_position_p = self.position_p
        if self.position_noise > 0:
            self.noisy_position_m = np.random.normal(
                loc=self.position_m, scale=self.position_noise)
        else:
            self.noisy_position_m = self.position_m
        self.noisy_position_p = self.noisy_position_m * self.pixels_per_meter

        # Save local maps of the previous time step
        radius_m = self.mower_radius
        if self.exploration:
            radius_m = max(self.mower_radius, self.lidar_range)
        i1, i2, j1, j2 = self._get_local_neighborhood_indices(old_position_m, new_position_m, radius_m)
        local_overlap_old = self.overlap_map[i1:i2, j1:j2].copy()
        local_coverage_old = self.coverage_map[i1:i2, j1:j2].copy()
        local_known_obstacles_old = self.known_obstacle_map[i1:i2, j1:j2].copy()
        local_unknown_obstacles_old = self.unknown_obstacle_map[i1:i2, j1:j2].copy()

        # Compute lidar point cloud and observation
        old_lidar_pts = self.lidar_pts.copy()
        old_lidar_pts_max = self.lidar_pts_max.copy()
        self.lidar_pts, pts_info, self.lidar_pts_max = self._compute_lidar_pts(new_position_m, new_heading)
        lidar_obs = self._compute_lidar_observation(self.lidar_pts, new_position_m)

        # Update known obstacle map based on detected unknown obstacles
        for n in range(self.lidar_rays):
            if pts_info[n] == 2:
                i = self.lidar_pts[n, 0]
                j = self.lidar_pts[n, 1]
                self.known_obstacle_map[i, j] = 1

        # Update the coverage/overlap maps
        local_coverage_diff = np.zeros_like(local_coverage_old)
        # compute coverage diff based on mower position
        if not self.exploration:
            orth_vec_new = np.array([-math.sin(new_heading), math.cos(new_heading)])
            orth_vec_old = np.array([-math.sin(old_heading), math.cos(old_heading)])
            head_vec_old = np.array([math.cos(old_heading), math.sin(old_heading)])
            orth_vec_new *= self.mower_radius * self.pixels_per_meter
            orth_vec_old *= self.mower_radius * self.pixels_per_meter
            head_vec_old *= 1 + self.max_lin_vel * self.step_size * self.pixels_per_meter
            cv2.circle(
                local_coverage_diff,
                center=np.flip(new_position_p - [i1, j1]).astype(np.int32),
                radius=int(self.mower_radius * self.pixels_per_meter),
                color=1,
                thickness=cv2.FILLED,
                lineType=self.line_type)
            cv2.fillConvexPoly(
                local_coverage_diff,
                points=np.array(
                    [np.flip(old_position_p - [i1, j1] + orth_vec_new),
                     np.flip(new_position_p - [i1, j1] + orth_vec_new),
                     np.flip(new_position_p - [i1, j1] - orth_vec_new),
                     np.flip(old_position_p - [i1, j1] - orth_vec_new)]).astype(np.int32),
                color=1,
                lineType=self.line_type)
            cv2.circle(
                local_coverage_diff,
                center=np.flip(old_position_p - [i1, j1]).astype(np.int32),
                radius=int(self.mower_radius * self.pixels_per_meter),
                color=0,
                thickness=cv2.FILLED,
                lineType=self.line_type)
            cv2.fillConvexPoly(
                local_coverage_diff,
                points=np.array(
                    [np.flip(old_position_p - [i1, j1] + orth_vec_old),
                     np.flip(old_position_p - [i1, j1] + orth_vec_old - head_vec_old),
                     np.flip(old_position_p - [i1, j1] - orth_vec_old - head_vec_old),
                     np.flip(old_position_p - [i1, j1] - orth_vec_old)]).astype(np.int32),
                color=0,
                lineType=self.line_type)
        # compute coverage diff based on explored area
        if self.exploration:
            cv2.fillPoly(
                local_coverage_diff,
                [np.fliplr(np.concatenate((self.lidar_pts, [self.position_p.astype(np.int32)])) - [i1, j1])],
                color=1)
            cv2.fillPoly(
                local_coverage_diff,
                [np.fliplr(np.concatenate((old_lidar_pts, [old_position_p.astype(np.int32)])) - [i1, j1])],
                color=0)
            overlappable_pixels = np.ones_like(local_coverage_diff)
            cv2.fillPoly(
                overlappable_pixels,
                [np.fliplr(old_lidar_pts_max - [i1, j1])],
                color=0)
            overlappable_pixels[local_coverage_old == 0] = 1
            local_coverage_diff *= overlappable_pixels
        # update maps
        idxs = np.logical_or(
            self.known_obstacle_map[i1:i2, j1:j2],
            self.unknown_obstacle_map[i1:i2, j1:j2])
        local_coverage_diff[idxs] = 0
        self.overlap_map[i1:i2, j1:j2] += local_coverage_diff
        self.coverage_map[i1:i2, j1:j2] = self.overlap_map[i1:i2, j1:j2].clip(max=1)

        # Update the frontier map
        d = max(1, 2 + self.obstacle_dilation // 2)
        ii1 = max(0, i1 - d); ii2 = min(self.size_p, i2 + d)
        jj1 = max(0, j1 - d); jj2 = min(self.size_p, j2 + d)
        local_frontier = self._compute_frontier_map(
            self.coverage_map[ii1:ii2, jj1:jj2], self.known_obstacle_map[ii1:ii2, jj1:jj2])
        self.frontier_map[i1:i2, j1:j2] = local_frontier[i1-ii1:i2-ii1, j1-jj1:j2-jj1]

        # Compute newly visited positions in this time step
        local_coverage_new = self.coverage_map[i1:i2, j1:j2]
        newly_visited = local_coverage_new > local_coverage_old
        newly_visited_sum = newly_visited.sum()

        # Update current coverage metrics
        all_obstacles = np.maximum(
            self.known_obstacle_map[i1:i2, j1:j2],
            self.unknown_obstacle_map[i1:i2, j1:j2])
        if self.obstacle_dilation > 1:
            kernel = np.ones((self.obstacle_dilation,)*2, dtype=float)
            all_obstacles[0, :] = 1; all_obstacles[-1, :] = 1
            all_obstacles[:, 0] = 1; all_obstacles[:, -1] = 1
            all_obstacles = cv2.dilate(all_obstacles, kernel, iterations=1)
        newly_visited_dilated = newly_visited.copy()
        newly_visited_dilated[all_obstacles > 0] = 0
        newly_visited_dilated_sum = newly_visited_dilated.sum()
        local_coverage_diff_dilated = local_coverage_diff.copy()
        local_coverage_diff_dilated[all_obstacles > 0] = 0
        local_coverage_diff_dilated_sum = local_coverage_diff_dilated.sum()
        self.coverage_in_pixels += newly_visited_dilated_sum
        self.coverage_in_m2 = self.coverage_in_pixels / (self.pixels_per_meter ** 2)
        self.coverage_in_percent = self.coverage_in_pixels / self.free_space_dilated
        self.overlap_in_pixels += local_coverage_diff_dilated_sum - newly_visited_dilated_sum
        self.overlap_in_m2 = self.overlap_in_pixels / (self.pixels_per_meter ** 2)
        if self.coverage_in_pixels != 0:
            self.overlap_in_percent = self.overlap_in_pixels / self.coverage_in_pixels
        else:
            self.overlap_in_percent = 0

        # Update current global total variation metric
        local_obstacles_old = None
        local_obstacles_new = None
        if self.use_known_obstacles_in_tv and self.use_unknown_obstacles_in_tv:
            local_obstacles_old = np.maximum(local_known_obstacles_old, local_unknown_obstacles_old)
            local_obstacles_new = np.maximum(self.known_obstacle_map[i1:i2, j1:j2], self.unknown_obstacle_map[i1:i2, j1:j2])
        elif self.use_known_obstacles_in_tv:
            local_obstacles_old = local_known_obstacles_old
            local_obstacles_new = self.known_obstacle_map[i1:i2, j1:j2].copy()
        elif self.use_unknown_obstacles_in_tv:
            local_obstacles_old = local_unknown_obstacles_old
            local_obstacles_new = self.unknown_obstacle_map[i1:i2, j1:j2].copy()
        if local_obstacles_new is not None and self.obstacle_dilation > 1:
            kernel = np.ones((self.obstacle_dilation,)*2, dtype=float)
            local_obstacles_old[0, :] = 1; local_obstacles_old[-1, :] = 1
            local_obstacles_old[:, 0] = 1; local_obstacles_old[:, -1] = 1
            local_obstacles_new[0, :] = 1; local_obstacles_new[-1, :] = 1
            local_obstacles_new[:, 0] = 1; local_obstacles_new[:, -1] = 1
            local_obstacles_old = cv2.dilate(local_obstacles_old, kernel, iterations=1)
            local_obstacles_new = cv2.dilate(local_obstacles_new, kernel, iterations=1)
        total_variation_old = utils.total_variation(local_coverage_old, local_obstacles_old)
        total_variation_new = utils.total_variation(local_coverage_new, local_obstacles_new)
        total_variation_diff = total_variation_new - total_variation_old
        self.global_total_variation += total_variation_diff

        # Reset non-new steps if we visited any new position
        if newly_visited_sum > 0:
            self.non_new_steps = 0

        # Compute coverage-based rewards
        reward_area = 0
        reward_ovrlp = 0
        if not collided:

            # Normalization constant: maximum possible area covered per step
            # The maximum area corresponds to a rectangle based on the radius and maximum velocity
            width_m = 2 * self.mower_radius
            if self.exploration:
                width_m = 2 * self.lidar_range
            length_m = self.max_lin_vel * self.step_size
            max_newly_visited_sum = width_m * length_m * self.pixels_per_meter ** 2

            # Reward based on newly covered area
            reward_area = self.newly_visited_reward_scale * newly_visited_dilated_sum
            reward_area = reward_area / max_newly_visited_sum
            reward_area = min(reward_area, self.newly_visited_reward_max)

            # Reward based on overlap
            if self.overlap_reward_always or reward_area == 0:
                reward_ovrlp = (local_coverage_diff * local_overlap_old).sum()
                reward_ovrlp *= self.overlap_reward_scale / max_newly_visited_sum
                reward_ovrlp = min(reward_ovrlp, self.overlap_reward_max)
                reward_ovrlp = -reward_ovrlp

        # Reward based on total variation
        reward_tv = 0
        if not collided:

            # Local (incremental) total variation reward
            reward_itv = -total_variation_diff
            # normalize by the speed
            # to make the TV independent of pixel resolution and max speed etc.
            # (it is already independent of the radius)
            reward_itv *= self.meters_per_pixel / self.step_size / self.max_lin_vel
            # normalize such that TV of going straight forward at max speed = 1
            # the correct constant is probably 2 (2 sides of the mower = TV diff)
            # but 2.5 seemed more accurate due to the grid discretization
            reward_itv /= 2.5
            reward_itv *= self.local_tv_reward_scale
            reward_itv = np.sign(reward_itv) * min(abs(reward_itv), self.local_tv_reward_max)

            # Global total variation reward
            reward_gtv = -self.global_total_variation
            # normalize by the area
            reward_gtv /= math.sqrt(self.coverage_in_pixels)
            # normalize so that TV of a disk = 1
            # the correct constant is probably 2*sqrt(pi) (circumference over sqrt(area))
            # but 4 is more accurate due to the grid discretization
            reward_gtv /= 4
            reward_gtv *= self.global_tv_reward_scale
            reward_gtv = np.sign(reward_gtv) * min(abs(reward_gtv), self.global_tv_reward_max)

            # Final TV reward
            reward_tv = reward_itv + reward_gtv

        # Check if we reached the goal coverage
        reward_goal = 0
        if not done and self.coverage_in_percent >= self.goal_coverage:
            reward_goal = self.goal_coverage_reward
            done = True

            # Set the current map as completed
            if not self.eval:
                time_goal_reached = not self.use_goal_time_in_levels or self.elapsed_steps <= self.goal_steps
                if time_goal_reached:
                    if self.current_map is not None:
                        self.completed_maps[self.current_map] = True
                    else:
                        if self.use_floor_plans:
                            self.completed_floor_plan = True
                        if self.use_known_obstacles or self.use_unknown_obstacles:
                            self.completed_obstacles = True

                if np.all(self.completed_maps) and self.completed_floor_plan and self.completed_obstacles:
                    self.level += 1
                    self.next_train_map = 0
                    self._set_level(self.level)

        # Update the observation
        n = self.elapsed_steps % self.stacks
        self.observation['lidar'][n] = lidar_obs
        if self.overlap_observation:
            cov_map = np.tanh(0.2 * self.overlap_map)
        else:
            cov_map = self.coverage_map
        self.observation['coverage'][n] = self._get_multi_scale_map(cov_map, pad_value=self.coverage_pad_value)
        self.observation['obstacles'][n] = self._get_multi_scale_map(self.known_obstacle_map, pad_value=self.obstacle_pad_value)
        if self.frontier_observation:
            self.observation['frontier'][n] = self._get_multi_scale_map(self.frontier_map, pad_value=0, ceil=True)

        # Truncate episode if needed
        self.elapsed_steps += 1
        self.non_new_steps += 1
        self.elapsed_time = self.elapsed_steps * self.step_size
        info = {}
        reward_trunc = 0
        if self.max_episode_steps is not None:
            if not done and self.elapsed_steps >= self.max_episode_steps:
                reward_trunc = -self.truncation_reward_scale * (1 - self.coverage_in_percent)
                done = True
                info['TimeLimit.truncated'] = not done
        if self.max_non_new_steps is not None:
            if not done and self.non_new_steps >= self.max_non_new_steps:
                reward_trunc = -self.truncation_reward_scale * (1 - self.coverage_in_percent)
                done = True
                info['TimeLimit.non_new'] = True
        info['level'] = self.level

        # Compute observation and reward
        obs = self._get_stacked_observation()
        reward = reward_area + reward_ovrlp + reward_tv + reward_coll + reward_goal + reward_trunc + self.constant_reward

        # Print stuff
        if self.verbose:
            self.t1 = round(time.time() * 1000)
            dt_step = str(self.t1 - self.t0)
            print('Step:', str(self.elapsed_steps) + ',',
                  'Reward:' + str(round(reward, 2)).rjust(5),
                  '(area:' + str(round(reward_area, 2)).rjust(4) + ',',
                  'ovrlp:' + str(round(reward_ovrlp, 2)).rjust(5) + ',',
                  'TV:' + str(round(reward_tv, 2)).rjust(5) + ',',
                  'coll:' + str(reward_coll) + ',',
                  'const:' + str(self.constant_reward) + ',',
                  'goal:' + str(reward_goal) + ',',
                  'trunc:' + str(reward_trunc) + '),',
                  'Time:', dt_step + '/' + dt_other, 'ms step/other')
        return obs, reward, done, info

    def render(self, mode='human'):
        assert mode in ['human', 'rgb_array']

        # Construct pyplot figure
        if self.axes is None:
            if mode == 'human':
                self.fig, self.axes = plt.subplot_mosaic('AABC;AAD.', constrained_layout=True)
                self.fig.set_size_inches(10, 5)
            else:
                self.fig, self.axes = plt.subplot_mosaic('A', constrained_layout=True)
                self.fig.set_size_inches(8, 8)
            for ax in self.axes:
                self.axes[ax].get_xaxis().set_visible(False)
                self.axes[ax].get_yaxis().set_visible(False)
        for ax in self.axes:
            self.axes[ax].clear()

        # Draw environment image, with coverage/obstacles/frontier/agent
        img = np.ones((self.size_p, self.size_p, 3), dtype=float)
        overlap_tanh = np.tanh(0.5 * self.overlap_map)
        img[:, :, 0] = np.clip(1 - 1.5*overlap_tanh, 0, 1)
        img[:, :, 1] = np.clip(2 - 1.5*overlap_tanh, 0, 1)
        img[:, :, 2] = np.clip(1 - 1.5*overlap_tanh, 0, 1)
        img[self.unknown_obstacle_map > 0] = 0.5
        img[self.known_obstacle_map > 0] = 0
        img[self.frontier_map > 0, 0] = 1
        img[self.frontier_map > 0, 1] = 0
        img[self.frontier_map > 0, 2] = 1
        cv2.circle(
            img,
            center=np.flip(self.position_p).astype(np.int32),
            radius=int(self.mower_radius * self.pixels_per_meter),
            color=[1, 0, 0],
            thickness=cv2.FILLED,
            lineType=self.line_type)
        self.axes['A'].imshow(np.flip(img.transpose(1, 0, 2), axis=0))

        # Draw lidar detections
        for n in range(self.lidar_rays):
            self.axes['A'].plot(
                [self.pixels_per_meter * self.position_m[0] - 0.5, self.lidar_pts[n, 0]],
                [self.size_p - self.pixels_per_meter * self.position_m[1] - 0.5,
                self.size_p - 1 - self.lidar_pts[n, 1]], '-b', linewidth=0.5)

        # Draw an arrow for the heading
        self.axes['A'].arrow(
            self.pixels_per_meter * self.position_m[0] - 0.5,
            self.pixels_per_meter * (self.size_m - self.position_m[1]) - 0.5,
            self.pixels_per_meter * 2 * self.mower_radius * math.cos(self.heading),
            self.pixels_per_meter * 2 * self.mower_radius * -math.sin(self.heading),
            head_width = 1, head_length = 1, zorder=10)

        # Draw the path
        self.path.append(self.position_m)
        path = self.pixels_per_meter * np.array(self.path)
        self.axes['A'].plot(path[:, 0] - 0.5, self.size_p - path[:, 1] - 0.5, '-', color='yellow')
        self.axes['A'].set_xlim([-0.5, self.size_p - 0.5])
        self.axes['A'].set_ylim([self.size_p - 0.5, -0.5])

        # Return image if mode is rgb_array
        if mode == 'rgb_array':
            coverage = round(100 * self.coverage_in_percent)
            overlap = round(100 * self.overlap_in_percent)
            self.fig.canvas.draw()
            rgb_img = np.frombuffer(self.fig.canvas.tostring_rgb(), dtype=np.uint8)
            rgb_img = rgb_img.reshape(self.fig.canvas.get_width_height()[::-1] + (3,)).copy()
            rgb_img[:81, :100] = 255
            rgb_img[:81, 100] = 0; rgb_img[:81, 0] = 0
            rgb_img[81, :100] = 0; rgb_img[0, :100] = 0
            rgb_img = Image.fromarray(rgb_img)
            draw = ImageDraw.Draw(rgb_img)
            font = ImageFont.truetype('rlm/fonts/FreeMono.ttf', 16)
            color = (0, 0, 0)
            draw.text((2, 1), f'{self.elapsed_steps} steps', color, font=font)
            draw.text((2, 1+16), f'{round(self.elapsed_time, 1)} s', color, font=font)
            draw.text((2, 1+16*2), f'{coverage}% cover', color, font=font)
            draw.text((2, 1+16*3), f'{overlap}% over', color, font=font)
            draw.text((2, 1+16*4), f'{self.num_collisions} coll', color, font=font)
            rgb_img = np.array(rgb_img)
            return rgb_img

        # Create images of observed coverage/obstacle/frontier maps
        ob = self._get_latest_observation()
        coverage_ms_img = self._get_image_from_multi_scale_map(ob['coverage'])
        obstacle_ms_img = self._get_image_from_multi_scale_map(ob['obstacles'])
        coverage_img = np.ones((self.ms_reach_p, self.ms_reach_p, 3), dtype=float)
        obstacle_img = np.ones((self.ms_reach_p, self.ms_reach_p, 3), dtype=float)
        coverage_img[:, :, 0] = np.clip(1 - 1.5*coverage_ms_img, 0, 1)
        coverage_img[:, :, 1] = np.clip(2 - 1.5*coverage_ms_img, 0, 1)
        coverage_img[:, :, 2] = np.clip(1 - 1.5*coverage_ms_img, 0, 1)
        obstacle_img[:, :, 0] = np.clip(1 - obstacle_ms_img, 0, 1)
        obstacle_img[:, :, 1] = np.clip(1 - obstacle_ms_img, 0, 1)
        obstacle_img[:, :, 2] = np.clip(1 - obstacle_ms_img, 0, 1)
        if self.frontier_observation:
            frontier_ms_img = self._get_image_from_multi_scale_map(ob['frontier'])
            frontier_img = np.ones((self.ms_reach_p, self.ms_reach_p, 3), dtype=float)
            frontier_img[:, :, 1] = np.clip(1 - frontier_ms_img, 0, 1)

        # Draw observed maps
        self.axes['B'].imshow(np.flip(coverage_img.transpose(1, 0, 2), axis=0))
        self.axes['C'].imshow(np.flip(obstacle_img.transpose(1, 0, 2), axis=0))
        if self.frontier_observation:
            self.axes['D'].imshow(np.flip(frontier_img.transpose(1, 0, 2), axis=0))

        # Add some labels
        coverage = round(100 * self.coverage_in_percent)
        overlap = round(100 * self.overlap_in_percent)
        self.fig.suptitle(
            f'{self.elapsed_steps} steps' +
            f', {round(self.elapsed_time, 1)} s'
            f', {self.num_collisions} collisions' +
            f', {coverage}% coverage' +
            f', {overlap}% overlap', fontsize=16)
        labels = ['full map', 'coverage observation', 'obstacle observation']
        letters = ['A', 'B', 'C']
        if self.frontier_observation:
            labels += ['frontier observation']
            letters += ['D']
        trans = mtransforms.ScaledTranslation(10/72, -5/72, self.fig.dpi_scale_trans)
        for label, letter in zip(labels, letters):
            self.axes[letter].text(0.0, 1.0, label, transform=self.axes[letter].transAxes + trans,
            fontsize='medium', verticalalignment='top', fontfamily='serif',
            bbox=dict(facecolor='0.7', edgecolor='none', pad=3.0))
        plt.show(block=False)
        plt.pause(0.001)

    def close(self):
        plt.close('all')

    def _is_wall_collision(self, position_m, radius_m):
        return position_m[0] <= radius_m or \
               position_m[1] <= radius_m or \
               position_m[0] >= self.size_m - radius_m or \
               position_m[1] >= self.size_m - radius_m

    def _is_obstacle_collision(self, position_m):
        # TODO: account for collisions in-between old and new positions
        i1, i2, j1, j2 = self._get_local_neighborhood_indices(position_m, position_m, self.mower_radius)
        local_known_obstacle_map = self.known_obstacle_map[i1:i2, j1:j2]
        local_unknown_obstacle_map = self.unknown_obstacle_map[i1:i2, j1:j2]
        local_all_obstacle_map = np.logical_or(local_known_obstacle_map, local_unknown_obstacle_map)
        local_position_map = np.zeros_like(local_known_obstacle_map)
        position_p = position_m * self.pixels_per_meter - [i1, j1]
        cv2.circle(
            local_position_map,
            center=(np.flip(position_p)).astype(np.int32),
            radius=int(self.mower_radius * self.pixels_per_meter),
            color=1,
            thickness=cv2.FILLED,
            lineType=self.line_type)
        return np.logical_and(local_position_map, local_all_obstacle_map).any()

    def _compute_frontier_map(self, coverage_map, obstacle_map):
        """
        Computes frontier points, i.e. points that are on the border between
        covered and free space.
        """
        coverage_map = coverage_map.copy()
        obstacle_map = obstacle_map.copy()
        if self.obstacle_dilation > 1:
            kernel = np.ones((self.obstacle_dilation,)*2, dtype=float)
            obstacle_map[0, :] = 1; obstacle_map[-1, :] = 1
            obstacle_map[:, 0] = 1; obstacle_map[:, -1] = 1
            obstacle_map = cv2.dilate(obstacle_map, kernel, iterations=1)
        coverage_map[obstacle_map > 0] = 0
        free_map = np.logical_not(coverage_map + obstacle_map)
        kernel = np.ones((3, 3), dtype=float)
        coverage_map = cv2.dilate(coverage_map, kernel, iterations=1)
        return np.logical_and(coverage_map, free_map).astype(float)

    def _compute_lidar_pts(self, position_m, heading):
        """
        Computes a lidar point cloud from a given pose.

        Parameters
        ----------
        position_m : position of lidar sensor [m]
        heading : heading of lidar sensor [rad]

        Returns
        -------
        lidar_pts : array of detected 2D points in absolute pixel coordinates
        pts_info : detection info
            0 = max range (no detection)
            1 = known obstacle
            2 = unknown obstacle
            3 = out of bounds
        lidar_pts_max : array of 2D points at max range for each ray
        """
        lidar_pts = np.zeros((self.lidar_rays, 2), dtype=np.int32)
        lidar_pts_max = np.zeros((self.lidar_rays, 2), dtype=np.int32)
        pts_info = np.zeros(self.lidar_rays, dtype=np.int32)
        samples = int(self.lidar_range * self.pixels_per_meter) # number of samples per ray
        position_p = position_m * self.pixels_per_meter
        # TODO: parallelize this for loop
        for n, angle in enumerate(np.linspace(-self.lidar_fov/2, self.lidar_fov/2, num=self.lidar_rays)):
            ang = heading + angle * math.pi / 180
            search_vec = np.array([math.cos(ang), math.sin(ang)])
            for s in range(samples):
                offset_p = (s + 1) * search_vec
                pos_p = position_p + offset_p
                i = int(pos_p[0])
                j = int(pos_p[1])
                if i < 0 or i >= self.size_p or j < 0 or j >= self.size_p:
                    # lidar ray reaches beyond the area
                    pts_info[n] = 3
                    break
                if self.known_obstacle_map[i, j] > 0:
                    # lidar ray hits a known obstacle
                    pts_info[n] = 1
                    break
                if self.unknown_obstacle_map[i, j] > 0:
                    # lidar ray hits an unknown obstacle
                    pts_info[n] = 2
                    break
            # Store detection point
            lidar_pts[n] = [i, j]
            # Store max range point
            offset_p = samples * search_vec
            pos_p = position_p + offset_p
            lidar_pts_max[n] = [int(pos_p[0]), int(pos_p[1])]
        # Add max range points for the full 360 fov
        angle_diff = self.lidar_fov / (self.lidar_rays - 1)
        angle = -self.lidar_fov / 2 - angle_diff
        while angle > -180:
            ang = heading + angle * math.pi / 180
            search_vec = np.array([math.cos(ang), math.sin(ang)])
            offset_p = samples * search_vec
            pos_p = position_p + offset_p
            lidar_pts_max = np.concatenate(
                ([[int(pos_p[0]), int(pos_p[1])]], lidar_pts_max), axis=0)
            angle -= angle_diff
        angle = self.lidar_fov / 2 + angle_diff
        while angle < 180:
            ang = heading + angle * math.pi / 180
            search_vec = np.array([math.cos(ang), math.sin(ang)])
            offset_p = samples * search_vec
            pos_p = position_p + offset_p
            lidar_pts_max = np.concatenate(
                (lidar_pts_max, [[int(pos_p[0]), int(pos_p[1])]]), axis=0)
            angle += angle_diff
        return lidar_pts, pts_info, lidar_pts_max

    def _compute_lidar_observation(self, lidar_pts, position_m):
        """
        Computes normalized lidar distances from a given point cloud.

        Parameters
        ----------
        lidar_pts : array of 2D points in absolute pixel coordinates
        position_m : position of lidar sensor [m]

        Returns
        -------
        lidar_obs : array of lidar distances, normalized to [0, 1]
        """
        lidar_obs = np.ones(self.lidar_rays)
        # TODO: parallelize this for loop?
        for n in range(self.lidar_rays):
            offset_m = lidar_pts[n] / self.pixels_per_meter - position_m
            dist = math.sqrt(offset_m[0] ** 2 + offset_m[1] ** 2)
            lidar_obs[n] = dist / self.lidar_range
        return lidar_obs

    def _get_transform_matrix(self, scale):
        heading_degrees = self.heading * 180 / math.pi
        translation_matrix_1 = np.eye(3)
        translation_matrix_1[0, 2] = -self.noisy_position_p[1] / scale
        translation_matrix_1[1, 2] = -self.noisy_position_p[0] / scale
        rotation_matrix = np.eye(3)
        rotation_matrix[:2] = cv2.getRotationMatrix2D(
            center = (0, 0),
            angle = 90 - heading_degrees,
            scale = 1)
        translation_matrix_2 = np.eye(3)
        translation_matrix_2[0, 2] = self.input_size / 2
        translation_matrix_2[1, 2] = self.input_size / 2
        return translation_matrix_2 @ rotation_matrix @ translation_matrix_1

    def _get_relative_map(self, map, pad_value, scale=1, ceil=False, floor=False):
        assert not (ceil and floor), "_get_relative_map cannot do both ceil and floor"
        # note: since cv2.warAffine scales poorly with the input size, first
        # downsample to the correct resolution and then perform the warp
        # also: warpAffine ignores cv2.INTER_AREA, so using it to downsample by
        # using scaling in the transformation matrix does not work well, see:
        # https://stackoverflow.com/questions/57477478/opencv-warpaffine-ignores-flags-cv2-inter-area
        sc = min(scale, self.size_p)
        matrix = self._get_transform_matrix(sc)
        relative_map = cv2.resize(map, (int(0.5 + self.size_p / sc),)*2, interpolation=cv2.INTER_AREA)
        relative_map = cv2.warpAffine(relative_map, M=matrix[:2], dsize=(self.input_size,)*2, borderValue=pad_value, flags=cv2.INTER_AREA)
        if ceil:
            relative_map = np.ceil(relative_map)
        if floor:
            relative_map = np.floor(relative_map)
        return relative_map

    def _get_multi_scale_map(self, map, pad_value, ceil=False, floor=False):
        ms_map = np.zeros((self.num_maps, self.input_size, self.input_size))
        for n in range(self.num_maps):
            ms_map[n] = self._get_relative_map(map, pad_value, self.scale_factor**n, ceil, floor)
        return ms_map

    def _get_image_from_multi_scale_map(self, ms_map):
        img = np.zeros((self.ms_reach_p, self.ms_reach_p), dtype=float)
        for n in range(self.num_maps):
            size = self.ms_reach_p // (self.scale_factor ** n)
            map = cv2.resize(ms_map[self.num_maps-n-1], (size,)*2, interpolation=cv2.INTER_NEAREST)
            i1 = self.ms_reach_p // 2 - size // 2
            i2 = self.ms_reach_p // 2 + size // 2
            img[i1:i2, i1:i2] = map
        return img

    def _get_local_neighborhood_indices(self, pos1_m, pos2_m, radius_m):
        """
        Returns the local neighborhood indices to be used for cropping.
        TODO: determine additional pixels based on dilation size etc.
        """
        i1 = min(pos1_m[0], pos2_m[0]) - radius_m
        i2 = max(pos1_m[0], pos2_m[0]) + radius_m
        j1 = min(pos1_m[1], pos2_m[1]) - radius_m
        j2 = max(pos1_m[1], pos2_m[1]) + radius_m
        i1 = max(0, min(self.size_p, int(i1 * self.pixels_per_meter - 10)))
        i2 = max(0, min(self.size_p, int(i2 * self.pixels_per_meter + 10)))
        j1 = max(0, min(self.size_p, int(j1 * self.pixels_per_meter - 10)))
        j2 = max(0, min(self.size_p, int(j2 * self.pixels_per_meter + 10)))
        return i1, i2, j1, j2

    def _get_local_neighborhood(self, map, pos1_m, pos2_m, radius_m):
        """
        Returns the local neighborhood of a map as a crop.
        """
        i1, i2, j1, j2 = self._get_local_neighborhood_indices(pos1_m, pos2_m, radius_m)
        return map[i1:i2, j1:j2]

    def _get_stacked_observation(self):
        """
        Returns observations of the latest consecutive time steps, oldest first
        """
        ob = self.observation
        n = self.elapsed_steps % self.stacks
        observation = \
            {'coverage': np.concatenate((ob['coverage'][n:], ob['coverage'][:n]), axis=0),
             'obstacles': np.concatenate((ob['obstacles'][n:], ob['obstacles'][:n]), axis=0),
             'lidar': np.concatenate((ob['lidar'][n:], ob['lidar'][:n]), axis=0)}
        if self.frontier_observation:
            observation['frontier'] = np.concatenate((ob['frontier'][n:], ob['frontier'][:n]), axis=0)
        return observation

    def _get_latest_observation(self):
        ob = self.observation
        n = (self.elapsed_steps - 1) % self.stacks
        observation = \
            {'coverage': ob['coverage'][n],
             'obstacles': ob['obstacles'][n],
             'lidar': ob['lidar'][n]}
        if self.frontier_observation:
            observation['frontier'] = ob['frontier'][n]
        return observation

    def _load_map(self, filename):
        img = cv2.imread(filename, flags=cv2.IMREAD_GRAYSCALE)
        img = np.fliplr(img.transpose(1, 0))
        assert len(img.shape) == 2
        assert img.shape[0] == img.shape[1]
        assert set(img.flatten()).issubset([0, 128, 255]), 'Invalid map image'
        self.size_p = img.shape[0]
        self.size_m = self.size_p / self.pixels_per_meter
        self.known_obstacle_map = np.zeros((self.size_p, self.size_p), dtype=float)
        self.unknown_obstacle_map = np.zeros((self.size_p, self.size_p), dtype=float)
        self.known_obstacle_map[img == 0] = 1
        self.unknown_obstacle_map[img == 128] = 1

    def _randomize_floor_plan(self):
        min_room_size_p = int(10 * self.mower_radius * self.pixels_per_meter)
        max_room_size_p = int(32 * self.mower_radius * self.pixels_per_meter)
        min_wall_thickness_p = 2
        max_wall_thickness_p = int(2 * self.mower_radius * self.pixels_per_meter)
        min_gap = int(4 * self.mower_radius * self.pixels_per_meter)
        max_gap = int(8 * self.mower_radius * self.pixels_per_meter)
        if self.size_p > 2 * min_room_size_p:
            room_size_p = random.randint(min_room_size_p, max_room_size_p)
            num_walls = max(1, int(self.size_p / room_size_p) - 1)
            room_size_p = int(self.size_p / (num_walls + 1))
            wall_thickness_p = random.randint(min_wall_thickness_p, max_wall_thickness_p)
            vertical_stop = random.uniform(0, 1) < 0.5
            for n in range(num_walls):
                i1 = room_size_p * (n + 1) - wall_thickness_p // 2
                i2 = room_size_p * (n + 1) + wall_thickness_p
                if random.uniform(0, 1) < 0.9:
                    # place vertical wall
                    self.known_obstacle_map[i1:i2, :] = 1
                if random.uniform(0, 1) < 0.9:
                    # place horizontal wall
                    self.known_obstacle_map[:, i1:i2] = 1
                stop_placed = False
                for m in range(num_walls + 1):
                    # open gaps in walls
                    g_min = min_gap
                    g_max = min(max_gap, room_size_p - 2 * wall_thickness_p)
                    j_min = room_size_p * m + wall_thickness_p
                    j_max = room_size_p * (m + 1) - wall_thickness_p
                    p_stop = 1 / (num_walls + 1 - m)
                    place_stop = False
                    if not stop_placed and random.uniform(0, 1) < p_stop:
                        place_stop = True
                        stop_placed = True
                    if not vertical_stop or not place_stop:
                        # open gap in vertical wall
                        gap = random.randint(g_min, g_max)
                        j1 = random.randint(j_min, j_max - gap)
                        j2 = j1 + gap
                        self.known_obstacle_map[i1:i2, j1:j2] = 0
                    if vertical_stop or not place_stop:
                        # open gap in horizontal wall
                        gap = random.randint(g_min, g_max)
                        j1 = random.randint(j_min, j_max - gap)
                        j2 = j1 + gap
                        self.known_obstacle_map[j1:j2, i1:i2] = 0

    def _randomize_circular_obstacles(self):
        known_pos = np.random.uniform(size=(self.max_known_obstacles, 2))
        unknown_pos = np.random.uniform(size=(self.max_unknown_obstacles, 2))
        radius = 2 * self.mower_radius + self.obstacle_radius
        assert self.size_m > 4 * radius
        for n in range(max(self.max_known_obstacles, self.max_unknown_obstacles)):
            # alternate placing known/unknown obstacles
            if self.use_known_obstacles and n < self.max_known_obstacles:
                pos_m = 2 * radius + known_pos[n] * (self.size_m - 4 * radius)
                local_known = self._get_local_neighborhood(
                    self.known_obstacle_map, pos_m, pos_m, radius)
                local_unknown = self._get_local_neighborhood(
                    self.unknown_obstacle_map, pos_m, pos_m, radius)
                if local_known.sum() == 0 and local_unknown.sum() == 0:
                    cv2.circle(
                        self.known_obstacle_map,
                        center=(np.flip(pos_m * self.pixels_per_meter)).astype(np.int32),
                        radius=int(self.obstacle_radius * self.pixels_per_meter),
                        color=1,
                        thickness=cv2.FILLED,
                        lineType=self.line_type)
            if self.use_unknown_obstacles and n < self.max_unknown_obstacles:
                pos_m = 2 * radius + unknown_pos[n] * (self.size_m - 4 * radius)
                local_known = self._get_local_neighborhood(
                    self.known_obstacle_map, pos_m, pos_m, radius)
                local_unknown = self._get_local_neighborhood(
                    self.unknown_obstacle_map, pos_m, pos_m, radius)
                if local_known.sum() == 0 and local_unknown.sum() == 0:
                    cv2.circle(
                        self.unknown_obstacle_map,
                        center=(np.flip(pos_m * self.pixels_per_meter)).astype(np.int32),
                        radius=int(self.obstacle_radius * self.pixels_per_meter),
                        color=1,
                        thickness=cv2.FILLED,
                        lineType=self.line_type)

    def _set_level(self, level):
        """
        Sets the followng variables based on the specified level:
          - self.train_maps
          - self.goal_coverage
          - self.goal_steps
          - self.use_randomized_envs
          - self.min_size_p
          - self.max_size_p
          - self.completed_maps
          - self.completed_floor_plan
          - self.completed_obstacles
        """
        self.goal_steps = 2000
        if self.exploration:
            if level == 1:
                self.goal_coverage = 0.9
                self.use_randomized_envs = False
                self.train_maps = \
                    self.train_maps_1 + \
                    self.train_maps_2
            elif level == 2:
                self.goal_coverage = 0.9
                self.use_randomized_envs = False
                self.train_maps = \
                    self.train_maps_1 + \
                    self.train_maps_2 + \
                    self.train_maps_4
            elif level == 3:
                self.goal_coverage = 0.95
                self.use_randomized_envs = False
                self.train_maps = \
                    self.train_maps_1 + \
                    self.train_maps_2 + \
                    self.train_maps_4
            elif level == 4:
                self.goal_coverage = 0.97
                self.use_randomized_envs = False
                self.train_maps = \
                    self.train_maps_1 + \
                    self.train_maps_2 + \
                    self.train_maps_4
            elif level == 5:
                self.goal_coverage = 0.99
                self.use_randomized_envs = False
                self.train_maps = \
                    self.train_maps_1 + \
                    self.train_maps_2 + \
                    self.train_maps_4
            elif level == 6:
                self.goal_coverage = 0.99
                self.use_randomized_envs = False
                self.train_maps = \
                    self.train_maps_1 + \
                    self.train_maps_2 + \
                    self.train_maps_3 + \
                    self.train_maps_4
            elif level == 7:
                self.goal_coverage = 0.99
                self.use_randomized_envs = True
                self.min_size_p = 256
                self.max_size_p = 320
                self.train_maps = \
                    self.train_maps_1 + \
                    self.train_maps_2 + \
                    self.train_maps_3 + \
                    self.train_maps_4
            elif level == 8:
                self.goal_coverage = 0.99
                self.use_randomized_envs = True
                self.min_size_p = 256
                self.max_size_p = 400
                self.train_maps = \
                    self.train_maps_1 + \
                    self.train_maps_2 + \
                    self.train_maps_3 + \
                    self.train_maps_4
            else:
                self.goal_coverage = 0.99
                self.use_randomized_envs = True
                self.min_size_p = 256
                self.max_size_p = 400
                self.train_maps = \
                    self.train_maps_1 + \
                    self.train_maps_2 + \
                    self.train_maps_3 + \
                    self.train_maps_4 + \
                    self.train_maps_5
        else:
            if level == 1:
                self.goal_coverage = 0.9
                self.use_randomized_envs = False
                self.train_maps = self.train_maps_0
            elif level == 2:
                self.goal_coverage = 0.9
                self.use_randomized_envs = False
                self.train_maps = \
                    self.train_maps_0 + \
                    self.train_maps_1
            elif level == 3:
                self.goal_coverage = 0.95
                self.use_randomized_envs = False
                self.train_maps = \
                    self.train_maps_0 + \
                    self.train_maps_1
            elif level == 4:
                self.goal_coverage = 0.95
                self.use_randomized_envs = False
                self.train_maps = \
                    self.train_maps_0 + \
                    self.train_maps_1 + \
                    self.train_maps_2
            elif level == 5:
                self.goal_coverage = 0.97
                self.use_randomized_envs = False
                self.train_maps = \
                    self.train_maps_0 + \
                    self.train_maps_1 + \
                    self.train_maps_2
            elif level == 6:
                self.goal_coverage = 0.99
                self.use_randomized_envs = False
                self.train_maps = \
                    self.train_maps_0 + \
                    self.train_maps_1 + \
                    self.train_maps_2
            elif level == 7:
                self.goal_coverage = 0.99
                self.use_randomized_envs = False
                self.train_maps = \
                    self.train_maps_0 + \
                    self.train_maps_1 + \
                    self.train_maps_2 + \
                    self.train_maps_3
            else:
                self.goal_coverage = 0.99
                self.use_randomized_envs = True
                self.train_maps = \
                    self.train_maps_0 + \
                    self.train_maps_1 + \
                    self.train_maps_2 + \
                    self.train_maps_3

        # Keep track of which maps have been completed
        self.completed_maps = [False]*len(self.train_maps)
        self.completed_floor_plan = True
        self.completed_obstacles = True
        if self.use_randomized_envs:
            if self.p_use_floor_plans > 0:
                self.completed_floor_plan = False
            else:
                self.completed_floor_plan = True
            use_known_obstacles = self.max_known_obstacles > 0 and self.p_use_known_obstacles > 0
            use_unknown_obstacles = self.max_unknown_obstacles > 0 and self.p_use_unknown_obstacles > 0
            if use_known_obstacles or use_unknown_obstacles:
                self.completed_obstacles = False
            else:
                self.completed_obstacles = True
