import functools
import os
import warnings

import embodied
import numpy as np


class LocoNav(embodied.Env):

  DEFAULT_CAMERAS = dict(
      ant=4,
      quadruped=5,
  )

  def __init__(
      self, name, repeat=1, size=(64, 64), camera=-1, again=False,
      termination=False, weaker=1.0):
    # TODO: This env variable is meant for headless GPU machines but may fail
    # on CPU-only machines.
    if name.endswith('hz'):
      name, freq = name.rsplit('_', 1)
      freq = int(freq.strip('hz'))
    else:
      freq = 50
    if 'MUJOCO_GL' not in os.environ:
      os.environ['MUJOCO_GL'] = 'egl'
    from dm_control import composer
    from dm_control.locomotion.props import target_sphere
    from dm_control.locomotion.tasks import random_goal_maze
    walker, arena = name.split('_', 1)
    if camera == -1:
      camera = self.DEFAULT_CAMERAS.get(walker, 0)
    self._walker = self._make_walker(walker)
    arena = self._make_arena(arena)
    target = target_sphere.TargetSphere(radius=1.2, height_above_ground=0.0)
    task = random_goal_maze.RepeatSingleGoalMaze(
        walker=self._walker, maze_arena=arena, target=target,
        max_repeats=1000 if again else 1,
        randomize_spawn_rotation=True,
        target_reward_scale=1.0,
        aliveness_threshold=-0.5 if termination else -1.0,
        contact_termination=False,
        physics_timestep=min(1 / freq / 4, 0.02),
        control_timestep=1 / freq)
    if not again:
      def after_step(self, physics, random_state):
        super(random_goal_maze.RepeatSingleGoalMaze, self).after_step(
            physics, random_state)
        self._rewarded_this_step = self._target.activated
        self._targets_obtained = int(self._target.activated)
      task.after_step = functools.partial(after_step, task)
    env = composer.Environment(
        time_limit=60, task=task, random_state=None,
        strip_singleton_obs_buffer_dim=True)
    from . import dmc
    self._env = dmc.DMC(env, repeat, size=size, camera=camera)
    self._visited = None
    self._weaker = weaker

  @property
  def obs_space(self):
    return {
        **self._env.obs_space,
        'log_coverage': embodied.Space(np.int64, low=0),
    }

  @property
  def act_space(self):
    return self._env.act_space

  def step(self, action):
    with warnings.catch_warnings():
      warnings.filterwarnings('ignore', '.*is a deprecated alias for.*')
      action = action.copy()
      action['action'] *= self._weaker
      obs = self._env.step(action)
    if obs['is_first']:
      self._visited = set()
    global_pos = self._walker.get_pose(
        self._env._dmenv._physics)[0].reshape(-1)
    self._visited.add(tuple(np.round(global_pos[:2]).astype(int).tolist()))
    obs['log_coverage'] = len(self._visited)
    return obs

  def _make_walker(self, name):
    if name == 'ant':
      from dm_control.locomotion.walkers import ant
      return ant.Ant()
    elif name == 'quadruped':
      from . import loconav_quadruped
      return loconav_quadruped.Quadruped()
    else:
      raise NotImplementedError(name)

  def _make_arena(self, name):
    import labmaze
    from dm_control import mjcf
    from dm_control.locomotion.arenas import labmaze_textures
    from dm_control.locomotion.arenas import mazes
    import matplotlib.pyplot as plt
    class WallTexture(labmaze_textures.WallTextures):
      def _build(self, color=[0.8, 0.8, 0.8], model='labmaze_style_01'):
        self._mjcf_root = mjcf.RootElement(model=model)
        self._textures = [self._mjcf_root.asset.add(
            'texture', type='2d', name='wall', builtin='flat',
            rgb1=color, width=100, height=100)]
    wall_textures = {'*': WallTexture([0.8, 0.8, 0.8])}
    cmap = plt.get_cmap('tab10')
    for index in range(9):
      wall_textures[str(index + 1)] = WallTexture(cmap(index)[:3])
    layout = ''.join([
        line[::2].replace('.', ' ') + '\n' for line in MAPS[name]])
    maze = labmaze.FixedMazeWithRandomGoals(
        entity_layer=layout,
        num_spawns=1, num_objects=1, random_state=None)
    arena = mazes.MazeWithTargets(
        maze, xy_scale=1.2, z_height=2.0, aesthetic='default',
        wall_textures=wall_textures, name='maze')
    return arena


MAPS = {

    'maze_s': (
        '            6 6 6 6 6',
        '            6 . . . 6',
        '            6 . G . 6',
        '            6 . . . 6',
        '            5 . . . 4',
        '            5 . . . 4',
        '1 1 1 1 5 5 5 . . . 4',
        '1 . . . . . . . . . 3',
        '1 . P . . . . . . . 3',
        '1 . . . . . . . . . 3',
        '1 1 1 1 2 2 2 3 3 3 3',
    ),

    'maze_m': (
        '6 6 6 6 8 8 8 7 7 7 7',
        '6 . . . . . . . . . 7',
        '6 . G . . . . . . . 7',
        '6 . . . . . . . . . 7',
        '6 6 6 5 5 5 5 . . . 4',
        '            5 . . . 4',
        '1 1 1 1 5 5 5 . . . 4',
        '1 . . . . . . . . . 3',
        '1 . P . . . . . . . 3',
        '1 . . . . . . . . . 3',
        '1 1 1 1 2 2 2 3 3 3 3',
    ),

    'maze_l': (
        '8 8 8 8 7 7 7 6 6 6 6 . . .',
        '8 . . . . . . . . . 6 . . .',
        '8 . G . . . . . . . 6 . . .',
        '8 . . . . . . . . . 6 5 5 5',
        '8 8 8 8 7 7 7 . . . . . . 5',
        '. . . . . . 7 . . . . . . 5',
        '1 1 1 1 1 . 7 . . . . . . 5',
        '1 . . . 1 . 7 9 9 9 . . . 5',
        '1 . . . 1 . . . . 9 . . . 5',
        '1 . . . 1 1 1 9 9 9 . . . 5',
        '2 . . . . . . . . . . . . 4',
        '2 . . . . P . . . . . . . 4',
        '2 . . . . . . . . . . . . 4',
        '2 2 2 2 3 3 3 3 3 3 4 4 4 4',
    ),

    'maze_xl': (
        '9 9 9 9 9 9 9 8 8 8 8 . 4 4 4 4 4',
        '9 . . . . . . . . . 8 . 4 . . . 4',
        '9 . . . . . . . G . 8 . 4 . . . 4',
        '9 . . . . . . . . . 8 . 4 . . . 4',
        '6 . . . 7 7 7 8 8 8 8 . 5 . . . 3',
        '6 . . . 7 . . . . . . . 5 . . . 3',
        '6 . . . 7 7 7 5 5 5 5 5 5 . . . 3',
        '5 . . . . . . . . . . . . . . . 3',
        '5 . . . . . . . . . . . . . . . 3',
        '5 . . . . . . . . . . . . . . . 3',
        '5 5 5 5 4 4 4 . . . 6 6 6 . . . 3',
        '. . . . . . 4 . . . 6 . 6 . . . 3',
        '1 1 1 1 4 4 4 . . . 6 . 6 . . . 3',
        '1 . . . . . . . . . 2 . 1 . . . 1',
        '1 . P . . . . . . . 2 . 1 . . . 1',
        '1 . . . . . . . . . 2 . 1 . . . 1',
        '1 1 1 1 1 1 1 2 2 2 2 . 1 1 1 1 1',
    ),

    'maze_xxl': (
        '7 7 7 7 * * * 6 6 6 * * * 9 9 9 9',
        '7 . . . . . . . . . . . . . . . 9',
        '7 . . . . . . . . . . . . . G . 9',
        '7 . . . . . . . . . . . . . . . 9',
        '* . . . 5 5 5 * * * * * * 9 9 9 9',
        '* . . . 5 . . . . . . . . . . . .',
        '* . . . 5 5 5 * * * * * * 3 3 3 3',
        '8 . . . . . . . . . . . . . . . 3',
        '8 . . . . . . . . . . . . . . . 3',
        '8 . . . . . . . . . . . . . . . 3',
        '8 8 8 8 * * * * * * 4 4 4 . . . *',
        '. . . . . . . . . . . . 4 . . . *',
        '1 1 1 1 * * * * * * 4 4 4 . . . *',
        '1 . . . . . . . . . . . . . . . 2',
        '1 . P . . . . . . . . . . . . . 2',
        '1 . . . . . . . . . . . . . . . 2',
        '1 1 1 1 * * * 6 6 6 * * * 2 2 2 2',
    ),

    'empty': (
        '. . . . . . . . . . . . . . . . .',
        '. . . . . . . . . . . . . . . . .',
        '. . . . . . . . . . . . . . . . .',
        '. . . . . . . . . . . . . . . . .',
        '. . . . . . . . . . . . . . . . .',
        '. . . . . . . . . . . . . . . . .',
        '. . . . . . . . . . . . . . . . .',
        '. . . . . . . . . . . . . . . . .',
        '. . . . . . . . . . . . . . . . .',
        '. . . . . . . . . . . . . . . . .',
        '. . . . . . . . . . . . . . . . .',
        '. . . . . . . . . . . . . . . . .',
        '. . . . . . . . . . . . . . . . .',
        '. . . . . . . . . . . . . . . . .',
        '. . . . . . . . . . . . . . . . .',
        '. . . . . . . . . . . . . . . . .',
        '. . . . . . . . . . . . . . . . .',
    ),

}
