# Copyright 2018 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

"""Adapted from rllab maze_env.py."""

import os
import random
import tempfile
import time
import xml.etree.ElementTree as ET
import math
import numpy as np
import gym

import copy
import mujoco_py
from mujoco_py.generated import const
import glfw
import cv2
from ant import AntEnv
import maze_env_utils



# Directory that contains mujoco xml files.
MODEL_DIR = 'assets'


class MazeEnv(gym.Env):
  MODEL_CLASS = AntEnv

  MAZE_HEIGHT = None
  MAZE_SIZE_SCALING = None


  def __init__(
      self,
      maze_id=None,
      maze_height=0.5,
      maze_size_scaling=3,
      seed=0,
      resolution = (128,128),
      window_size = [1500, 1500],
      subprocess_num = None,
      random_xy_pos = True,
      image_obs = True,
      view_flag = False,
      robot = 'Ant',
      *args,
      **kwargs):
    self._maze_id = maze_id
    self.robot = robot
    if robot == 'Ant':
        FILE_NAME = 'ant_my.xml'
        self.qpos_length = 15
        self.qvel_length = 14
    elif robot == 'Point' or robot == 'Point-v2':
        FILE_NAME = 'point_my.xml'
        self.qpos_length = 3
        self.qvel_length = 3

    elif robot == 'Swimmer' or  robot == 'Swimmer-v2':
        FILE_NAME = 'swimmer_my.xml'
        self.qpos_length = 3
        self.qvel_length = 5
    elif robot == 'Humanoid':
        FILE_NAME = 'humanoid_my.xml'
        self.qpos_length = 24
        self.qvel_length = 23

    else:
        print('unknown robot class, choose default = Ant')
        FILE_NAME = 'ant_my.xml'
        self.qpos_length = 15
        self.qvel_length = 14

    if robot == 'Swimmer' or robot == 'Swimmer-v2':
        model_cls = SwimmerEnv
    else:
        model_cls = self.__class__.MODEL_CLASS

    if model_cls is None:
      raise "MODEL_CLASS unspecified!"
    # xml_path = os.path.join(model_cls.FILE)
    xml_path = os.path.join(FILE_NAME)

    tree = ET.parse(xml_path)
    worldbody = tree.find(".//worldbody")

    self.MAZE_HEIGHT = height = maze_height
    self.MAZE_SIZE_SCALING = size_scaling = maze_size_scaling
    self.window_size = window_size
    self.resolution = resolution

    self.MAZE_STRUCTURE = structure = maze_env_utils.construct_maze(maze_id=self._maze_id)
    self.elevated = any(-1 in row for row in structure)  # Elevate the maze to allow for falling.
    self.blocks = any(
        any(maze_env_utils.can_move(r) for r in row)
        for row in structure)  # Are there any movable blocks?

    torso_x, torso_y = self._find_robot()
    self._init_torso_x = torso_x
    self._init_torso_y = torso_y
    self._init_positions = [
        (x - torso_x, y - torso_y)
        for x, y in self._find_all_robots()]

    height_offset = 0.
    # if self.elevated:
    #   # Increase initial z-pos of ant.
    #   height_offset = height * size_scaling
    #   torso = tree.find(".//body[@name='torso']")
    #   torso.set('pos', '0 0 %.2f' % (0.75 + height_offset))
    # if self.blocks:
    #   # If there are movable blocks, change simulation settings to perform
    #   # better contact detection.
    #   default = tree.find(".//default")
    #   print(default)
    #   default.find('.//geom').set('solimp', '.995 .995 .01')

    self.rng = np.random.RandomState(seed)

    for i in range(len(structure)):
      for j in range(len(structure[0])):
        if self.elevated and structure[i][j] not in [-1]:
          # Create elevated platform.
          ET.SubElement(
              worldbody, "geom",
              name="elevated_%d_%d" % (i, j),
              pos="%f %f %f" % (j * size_scaling - torso_x,
                                -i * size_scaling + torso_y,
                                height / 2 * size_scaling),
              size="%f %f %f" % (0.5 * size_scaling,
                                 0.5 * size_scaling,
                                 height / 2 * size_scaling),
              type="box",
              material="",
              contype="1",
              conaffinity="1",
              rgba="0.9 0.9 0.9 1",
          )
        if structure[i][j] == 1:  # Unmovable block.
          # Offset all coordinates so that robot starts at the origin.
          ET.SubElement(
              worldbody, "geom",
              name="block_%d_%d" % (i, j),
              pos="%f %f %f" % (j * size_scaling - torso_x,
                                -i * size_scaling + torso_y,
                                height_offset +
                                height / 2 * size_scaling),
              size="%f %f %f" % (0.5 * size_scaling,
                                 0.5 * size_scaling,
                                 height / 2 * size_scaling),
              type="box",
              material="",
              contype="1",
              conaffinity="1",
              rgba="0.4 0.4 0.4 1",
          )

        elif maze_env_utils.can_move(structure[i][j]):  # Movable block.
          # The "falling" blocks are shrunk slightly and increased in mass to
          # ensure that it can fall easily through a gap in the platform blocks.
          falling = maze_env_utils.can_move_z(structure[i][j])
          shrink = 0.99 if falling else 1.0
          moveable_body = ET.SubElement(
              worldbody, "body",
              name="moveable_%d_%d" % (i, j),
              pos="%f %f %f" % (j * size_scaling - torso_x,
                                 - i * size_scaling + torso_y,
                                height_offset +
                                height / 2 * size_scaling),
          )
          ET.SubElement(
              moveable_body, "geom",
              name="block_%d_%d" % (i, j),
              pos="0 0 0",
              size="%f %f %f" % (0.5 * size_scaling * shrink,
                                 0.5 * size_scaling * shrink,
                                 height / 2 * size_scaling),
              type="box",
              material="",
              mass="0.001" if falling else "0.0002",
              contype="1",
              conaffinity="1",
              rgba="0.9 0.1 0.1 1"
          )
          if maze_env_utils.can_move_x(structure[i][j]):
            ET.SubElement(
                moveable_body, "joint",
                armature="0",
                axis="1 0 0",
                damping="0.0",
                limited="true" if falling else "false",
                range="%f %f" % (-size_scaling, size_scaling),
                margin="0.01",
                name="moveable_x_%d_%d" % (i, j),
                pos="0 0 0",
                type="slide"
            )
          if maze_env_utils.can_move_y(structure[i][j]):
            ET.SubElement(
                moveable_body, "joint",
                armature="0",
                axis="0 1 0",
                damping="0.0",
                limited="true" if falling else "false",
                range="%f %f" % (-size_scaling, size_scaling),
                margin="0.01",
                name="moveable_y_%d_%d" % (i, j),
                pos="0 0 0",
                type="slide"
            )
          if maze_env_utils.can_move_z(structure[i][j]):
            ET.SubElement(
                moveable_body, "joint",
                armature="0",
                axis="0 0 1",
                damping="0.0",
                limited="true",
                range="%f 0" % (-height_offset),
                margin="0.01",
                name="moveable_z_%d_%d" % (i, j),
                pos="0 0 0",
                type="slide"
            )

    self.obj_init_pos = {}

    self.init_set_all_object(worldbody, structure, size_scaling, torso_x, torso_y, height_offset, height)


    torso = tree.find(".//body[@name='torso']")
    geoms = torso.findall(".//geom")
    for geom in geoms:
      if 'name' not in geom.attrib:
        raise Exception("Every geom of the torso must have a name "
                        "defined")

    file_path = 'tmpfile.xml'
    file_attr = os.getcwd()

    if subprocess_num is not None:
        if os.name == 'nt': #windows
            file_path_abs =  '/temp_mujoco_xml/' + 'tmpfile' + str(subprocess_num) + '.xml'
            if not os.path.isdir('/temp_mujoco_xml/'):
                os.mkdir('/temp_mujoco_xml/')
        elif os.name == 'posix': # ubuntu
            file_path_abs = file_attr + '/' + 'tmpfile' + str(subprocess_num) + '.xml'
    else:
        if os.name == 'nt': #windows
            file_path_abs =  '/temp_mujoco_xml/' + file_path
            if not os.path.isdir('/temp_mujoco_xml/'):
                os.mkdir('/temp_mujoco_xml/')
        elif os.name == 'posix': # ubuntu
            file_path_abs = file_attr + '/' + file_path

    # print(file_attr)
    with open(file_path_abs, 'w') as f:

        pass
    # print(os.path.exists(file_path_abs))
    tree.write(file_path_abs)


    self.wrapped_env = model_cls(*args, file_path=file_path_abs, seed=seed, **kwargs)
    self.sim = self.wrapped_env.physics

    curr_i = int(torso_y/ size_scaling)  #coidx_y
    curr_j = int(torso_x/ size_scaling)  #coidx_x
    room_idx_x = curr_j // 6 # for room num, by mod
    room_idx_y = curr_i // 6

    self.see_render = image_obs


    if self.see_render:
        self.wrapped_env._viewers['human'] = mujoco_py.MjViewer(self.sim)
        self.curr_viewer = self.wrapped_env._viewers['human']
        self.curr_viewer._hide_overlay = True
        glfw.set_window_size(self.curr_viewer.window, self.window_size[0], self.window_size[1])
        glfw.set_window_pos(self.curr_viewer.window, 1000, 400)
        self.curr_viewer.cam.distance = 26
        self.curr_viewer.cam.elevation = -90
        self.curr_viewer.cam.lookat[0] = (room_idx_x * 6 + 3 - curr_j) * size_scaling
        self.curr_viewer.cam.lookat[1] = (curr_i - room_idx_y * 6 - 3 ) * size_scaling

        img_obs_next = self.sim.render(1500, 1500)

        if self.sim._render_context_offscreen is not None:
            self.sim._render_context_offscreen.cam.distance = 26
            self.sim._render_context_offscreen.cam.elevation = -90
            self.sim._render_context_offscreen.cam.lookat[0] = (room_idx_x * 6 + 3 - curr_j) * size_scaling
            self.sim._render_context_offscreen.cam.lookat[1] = (curr_i - room_idx_y * 6 - 3) * size_scaling

    self.robot_int_pos = []
    self.get_init_pos_all()
    self.torso_init_pos = self.sim.get_state().qpos
    self.alpha = 0.002 # init control loss coeff
    self.curr_room_id = [room_idx_x, room_idx_y]





    # print('+++++++++++++++++++++++++++++++++++')
    self.offline_pos = np.load('pos_save.npy', allow_pickle=True)
    # print('load {} offline pos for tab step'.format(len(self.offline_pos)))
    self.random_xy_pos = random_xy_pos
    self.max_action = 1.0
    self.view_flag = view_flag



  def updata_camera(self, view_flag = False):
      curr_agent_x, curr_agent_y = self.wrapped_env.get_xy()
      bias_r_x, bias_r_y = self.robot_int_pos
      normal_r_x = curr_agent_x + bias_r_x * self.MAZE_SIZE_SCALING
      normal_r_y = bias_r_y * self.MAZE_SIZE_SCALING - curr_agent_y

      curr_i = int(normal_r_y / self.MAZE_SIZE_SCALING)  # coidx_y
      curr_j = int(normal_r_x / self.MAZE_SIZE_SCALING)  # coidx_x
      room_idx_x = curr_j // 6  # for room num, by mod
      room_idx_y = curr_i // 6

      init_i = bias_r_y
      init_j = bias_r_x
      if view_flag:
          # self.sim._render_context_offscreen.cam.elevation = -90 + random.random() * 30
          self.sim._render_context_offscreen.cam.elevation = -40 - random.random() * 50  ##  up and down  -90 ~ -30
          self.sim._render_context_offscreen.cam.azimuth = (random.random() * 30 - 15)     ##  left and right
          self.sim._render_context_offscreen.cam.distance = 30    ##  zoom in and out

          ##  [ele, azi, dis]    [-90, 90, 26]



      if [room_idx_x,room_idx_y] != self.curr_room_id:
          self.sim._render_context_offscreen.cam.lookat[0] = (room_idx_x * 6 + 3 - init_j) * self.MAZE_SIZE_SCALING
          self.sim._render_context_offscreen.cam.lookat[1] = (init_i - room_idx_y * 6 - 3) * self.MAZE_SIZE_SCALING
          self.curr_room_id = [room_idx_x, room_idx_y]
          self.curr_viewer.cam.lookat[0] = (room_idx_x * 6 + 3 - init_j) * self.MAZE_SIZE_SCALING
          self.curr_viewer.cam.lookat[1] = (init_i - room_idx_y * 6 - 3) * self.MAZE_SIZE_SCALING
          time.sleep(1)


  def get_init_pos_all(self):
      init_cdx = np.where(self.structure_room_obj == 'r')
      init_x = init_cdx[1][0]
      init_y = init_cdx[0][0]
      self.robot_int_pos = [init_x, init_y]



  def init_set_all_object(self, worldbody, structure, size_scaling, torso_x, torso_y, height_offset, height,):
      self.structure_room_obj = np.array(structure)
      self.storage_bias = 200
      self.curr_door_list = []
      self.objects = []
      self.curr_obj_list = []



      door_str = """
              <body name="door" pos="0 0 0">
                  <geom size="0.05 0.15 0.15" type="box" rgba="1 1 0 1"/>
                  <joint axis="1 0 0" name="door:x" type="slide"/>
                  <joint axis="0 1 0" name="door:y" type="slide"/>
              </body>
          """

      color = 'yel'


      door_idx = np.where(self.structure_room_obj == 'd')
      for i in range(len(door_idx[0])):
          door_coidx_1 = door_idx[1][i]
          door_coidx_2 = door_idx[0][i]
          body_name = 'door' + '_' + color + '_' + str(i)
          self.structure_room_obj[door_coidx_2][door_coidx_1] = body_name
          self.curr_door_list.append(body_name)
          self.obj_init_pos[body_name] = [door_coidx_1, door_coidx_2]

          temp_body = ET.fromstring(door_str)
          temp_body.set('name', '{}'.format(body_name))
          temp_body.set('pos', '{} {} {}'.format(door_coidx_1 * size_scaling - torso_x,
                                                 -door_coidx_2 * size_scaling + torso_y,
                                                 height_offset + height / 2 * size_scaling))

          temp_geom = temp_body.findall('geom')
          for geom_ in temp_geom:
              geom_.set('size', '{} {} {}'.format(0.47 * size_scaling,
                                                  0.47 * size_scaling,
                                                  height / 2 * size_scaling))

          temp_joint = temp_body.findall('joint')
          temp_joint[0].set('name', '{}'.format(body_name + ':x'))
          temp_joint[0].set('range', '{:f} {:f}'.format(-size_scaling * 200, size_scaling * 200))

          temp_joint[1].set('name', '{}'.format(body_name + ':y'))
          temp_joint[1].set('range', '{:f} {:f}'.format(-size_scaling * 200, size_scaling * 200))

          # self.obj_list.append('{}'.format(body_name))
          worldbody.append(temp_body)


      #########  set object  #########

      key_str = """
              <body name="key" pos="0 0 0">
                  <geom size="0.05 0.15 0.15" quat = "1 0 1 0" type="capsule" rgba="0 0 1 1" />
                  <geom size="0.05 0.15 0.15" quat = "1 1 0 0" type="capsule" rgba="0 0 1 1" />
                  <joint axis="1 0 0" name="key:x" type="slide"/>
                  <joint axis="0 1 0" name="key:y" type="slide"/>
              </body>
          """


      ball_str = """
              <body name="ball" pos="0 0 0">
                  <geom size="0.05 0.15 0.15" type="sphere" rgba="1 0 0 1"/>
                  <joint axis="1 0 0" name="ball:x" type="slide"/>
                  <joint axis="0 1 0" name="ball:y" type="slide"/>
              </body>
          """

      obj_type_all = ['key', 'ball']

      blank_cdx = np.where(self.structure_room_obj == '0')
      blank_cdx_2_x = np.concatenate((blank_cdx[0], blank_cdx[0] + 6), axis=-1)  # 2 room to store objects
      blank_cdx_2_y = np.concatenate((blank_cdx[1], blank_cdx[1]), axis=-1)
      shuf_idx = np.array(range(len(blank_cdx_2_x)), dtype=np.int32)
      np.random.shuffle(shuf_idx)

      blank_x_all = blank_cdx_2_x[shuf_idx]
      blank_y_all = blank_cdx_2_y[shuf_idx]

      i = 0

      # obj_num = [0,0,2]
      for type in obj_type_all:
          i = i + 1
          body_name = type + '_' + str(i)

          self.obj_init_pos[body_name] = [blank_x_all[i] + self.storage_bias, blank_y_all[i]]
          temp_body = ET.fromstring(eval("{}_str".format(type)))
          temp_body.set('name', '{}'.format(body_name))
          temp_body.set('pos', '{} {} {}'.format((blank_x_all[i] + self.storage_bias) * size_scaling - torso_x,
                                                 -blank_y_all[i] * size_scaling + torso_y,
                                                 height_offset + height / 2 * size_scaling))
          temp_geom = temp_body.findall('geom')
          for geom_ in temp_geom:
              geom_.set('size', '{} {} {}'.format(0.15 * size_scaling if type == 'key' else 0.2 * size_scaling,
                                                  0.2 * size_scaling,
                                                  0.2 * size_scaling, ))

          temp_joint = temp_body.findall('joint')
          temp_joint[0].set('name', '{}'.format(body_name + ':x'))
          temp_joint[0].set('range', '{:f} {:f}'.format(-size_scaling * 200, size_scaling * 200))

          temp_joint[1].set('name', '{}'.format(body_name + ':y'))
          temp_joint[1].set('range', '{:f} {:f}'.format(-size_scaling * 200, size_scaling * 200))

          self.objects.append('{}'.format(body_name))
          worldbody.append(temp_body)














      ###############    set maze shape
      u_shape_maze = {
          'obj_list':['goal_0','goal_1'],
          'obj_list_goal':{'goal_0':lambda  x, y, a, b: y < 10,
                           'goal_1':lambda  x, y, a, b: x < 4 and y < 10 ,
                           },
          'goal_xy':{'goal_0':[12, 9],
                    'goal_1':[3,6],
          }

      }

      square_random_maze = {
          'obj_list':['door_yel_0','door_yel_1','door_yel_2','door_yel_3'],
          'obj_list_goal':{'door_yel_0':lambda  x, y, a, b:y<b+0.5,
                           'door_yel_1':lambda  x, y, a, b:x<a+0.5,
                           'door_yel_2':lambda  x, y, a, b:x>a-0.5,
                           'door_yel_3': lambda x, y, a, b:y>b-0.5,
                           }
      }

      square_blocked_maze = {
          'obj_list':['door_yel_0','door_yel_1','door_yel_2','door_yel_3'],
          'obj_list_goal':{'door_yel_0':lambda  x, y, a, b:y<b+0.5,
                           'door_yel_1':lambda  x, y, a, b:x<a+0.5,
                           'door_yel_2':lambda  x, y, a, b:x>a-0.5,
                           'door_yel_3': lambda x, y, a, b:y>b-0.5,
                           }
      }


      s_shape_maze = {
          'obj_list':['door_yel_2','door_yel_0','door_yel_1'],
          'obj_list_goal':{'door_yel_2':lambda x,y,a,b:x>a,
                           'door_yel_0':lambda x,y,a,b:x>a,
                           'door_yel_1':lambda x,y,a,b:x>=a-6 and y<b-1,}
      }

      spiral_maze = {
          'obj_list': ['door_yel_0', 'door_yel_1', 'door_yel_5', 'door_yel_4', 'door_yel_3'],
          'obj_list_goal': {'door_yel_0': lambda x, y, a, b: y < b - 1,
                            'door_yel_1': lambda x, y, a, b: x > a - 6 and y > b + 1,
                            'door_yel_5': lambda x, y, a, b: x < a - 1 and y > b - 6,
                            'door_yel_4': lambda x, y, a, b: x < a - 1 and y < b + 3,
                            'door_yel_3': lambda x, y, a, b: x < a + 3 and y < b + 1,
                            }
      }

      general_maze_1 = {
          'obj_list':['door_yel_0',],
          'obj_list_goal':{
                           'door_yel_0':lambda  x, y, a, b:x>a,
                           }
      }
      general_maze_2 = {
          'obj_list':['door_yel_0',],
          'obj_list_goal':{
                           'door_yel_0':lambda  x, y, a, b:y<b,
                           }
      }
      general_maze_3 = {
          'obj_list':['door_yel_0',],
          'obj_list_goal':{
                           'door_yel_0':lambda  x, y, a, b:x>a,
                           }
      }


      self.obj_list ={'Maze_U_shape':u_shape_maze,
                      'Maze_square_random':square_random_maze,
                      'Maze_square_blocked': square_blocked_maze,
                      'Maze_S_shape':s_shape_maze,
                      'Maze_spiral_shape':spiral_maze,
                      'Maze_general_1':general_maze_1,
                      'Maze_general_2': general_maze_2,
                      'Maze_general_3': general_maze_3,
                      }
      self.max_step_list = {'Maze_U_shape':[25],
                        'Maze_square_random':[30],
                            'Maze_square_blocked':[30],
                        'Maze_S_shape':[50],
                        'Maze_spiral_shape':[200],
                        'Maze_general_1': [30],
                        'Maze_general_2': [30],
                        'Maze_general_3': [30],
                            }


      self.max_test_step_list = {'Maze_U_shape':[15,10],
                        'Maze_square_random':[30],
                        'Maze_square_blocked': [30],
                        'Maze_S_shape':[15,15,25],
                        'Maze_spiral_shape':[40,40,40,40,40],
                         'Maze_general_1': [30],
                         'Maze_general_2': [30],
                         'Maze_general_3': [30],
                                 }


      self.max_step = self.max_step_list[self._maze_id][0]


      self.max_step_continuous_list = {'Maze_U_shape':[600],
                        'Maze_square_random':[600],
                        'Maze_square_blocked': [600],
                        'Maze_S_shape':[1000],
                       'Maze_spiral_shape':[5000],
                       'Maze_general_1': [600],
                       'Maze_general_2': [600],
                       'Maze_general_3': [600],
                                       }
      self.max_step_continuous = self.max_step_continuous_list[self._maze_id][0]




  def get_max_step_add(self):
      if self._maze_id == 'Maze_square_random' or self._maze_id == 'Maze_square_blocked':
          return self.max_test_step_list[self._maze_id][0]
      else:
          if self.curr_goal_idx >= len(self.obj_list[self._maze_id]['obj_list']):
              return self.max_test_step_list[self._maze_id][-1]
          else:
              return self.max_test_step_list[self._maze_id][self.curr_goal_idx]





  # def random_obj_random_reset(self, blank_y, blank_x, obj_num):
  #
  #     idx = 0
  #     # print(self.structure_room_obj)
  #     # print(self.curr_obj_list)
  #
  #     # time.sleep(10)
  #     # cdx direction = ^|y   ----->x
  #     if len(self.curr_obj_list) > 0:
  #       for exist_obj in self.curr_obj_list:
  #           change_id_x =  self.sim.model.get_joint_qpos_addr("{}:x".format(exist_obj))
  #           change_id_y =  self.sim.model.get_joint_qpos_addr("{}:y".format(exist_obj))
  #           # next_x = self.obj_init_pos[exist_obj][0]
  #           # next_y = self.obj_init_pos[exist_obj][1]
  #           curr_cdx = np.where(self.structure_room_obj == exist_obj)
  #           curr_x = curr_cdx[1][0]
  #           curr_y = curr_cdx[0][0]
  #           self.structure_room_obj[curr_y][curr_x] = '0'
  #           # self.sim.data.qpos[change_id_x] += self.MAZE_SIZE_SCALING * (next_x - curr_x)
  #           # self.sim.data.qpos[change_id_y] += self.MAZE_SIZE_SCALING * (curr_y - next_y)
  #           # print(exist_obj, self.sim.data.qpos[change_id_x], self.sim.data.qpos[change_id_y])
  #           self.sim.data.qpos[change_id_x] = 0.
  #           self.sim.data.qpos[change_id_y] = 0.
  #
  #     next_obj_list = np.random.choice(self.objects, obj_num, replace=False) # from obj_list, not door
  #
  #     for obj_name in next_obj_list:
  #           change_id_x =  self.sim.model.get_joint_qpos_addr("{}:x".format(obj_name))
  #           change_id_y =  self.sim.model.get_joint_qpos_addr("{}:y".format(obj_name))
  #
  #           # curr_cdx = np.where(self.structure_room_obj == obj_name)
  #           curr_x = self.obj_init_pos[obj_name][0]
  #           curr_y = self.obj_init_pos[obj_name][1]
  #
  #           next_x = blank_x[idx]
  #           next_y = blank_y[idx]
  #           self.sim.data.qpos[change_id_x] += self.MAZE_SIZE_SCALING * (next_x - curr_x)
  #           self.sim.data.qpos[change_id_y] += self.MAZE_SIZE_SCALING * (curr_y - next_y)
  #           self.structure_room_obj[next_y][next_x] = obj_name
  #           idx += 1
  #
  #
  #
  #
  #     self.curr_obj_list = next_obj_list










  def door_random_reset(self, blank_x, blank_y, reset_robot_xy = True):

      idx = 0



      curr_door_pos_list = []
      reset_obj = True
      obj_name = 'door_yel_0'
      if reset_obj:
          change_id_x = self.sim.model.get_joint_qpos_addr("{}:x".format(obj_name))
          change_id_y = self.sim.model.get_joint_qpos_addr("{}:y".format(obj_name))

          curr_cdx = np.where(self.structure_room_obj == obj_name)
          curr_x = curr_cdx[1][0]
          curr_y = curr_cdx[0][0]


          door_cdx = random.randint(1, 5)
          if self._maze_id == 'Maze_square_blocked':
              if random.random() < 0.5:
                  direction = 0
              else:
                  direction = 2
          else:
                direction = random.randint(0, 3)

          if direction == 0:
              next_x = 6  # right
              next_y = door_cdx
              self.curr_goal = 'door_yel_2'
          elif direction == 1:
              next_x = door_cdx
              next_y = 6  # down
              self.curr_goal = 'door_yel_3'

          elif direction == 2:
              next_x = 0  # left
              next_y = door_cdx
              self.curr_goal = 'door_yel_1'

          else:
              next_x = door_cdx  # up
              next_y = 0
              self.curr_goal = 'door_yel_0'

          if abs(int(curr_x) - int(next_x)) < 1e-6 and abs(int(curr_y) - int(next_y)) < 1e-6:
              # print('door not change')
              # print(curr_y,curr_x,next_y,next_x)
              pass
          else:
              self.structure_room_obj[curr_y][curr_x] = '1'
              ##  exchange geom and door  ##
              geom_id = self.sim.model.geom_name2id("block_{}_{}".format(next_y, next_x))
              self.sim.model.geom_pos[geom_id][0] += self.MAZE_SIZE_SCALING * (curr_x - next_x)
              self.sim.model.geom_pos[geom_id][1] += self.MAZE_SIZE_SCALING * (next_y - curr_y)
              self.sim.model._geom_name2id["block_{}_{}".format(curr_y, curr_x)] = \
                  self.sim.model._geom_name2id.pop("block_{}_{}".format(next_y, next_x))

              self.sim.model._geom_id2name[geom_id] = "block_{}_{}".format(curr_y, curr_x)


          self.sim.data.qpos[change_id_x] += self.MAZE_SIZE_SCALING * (next_x - curr_x)
          self.sim.data.qpos[change_id_y] += self.MAZE_SIZE_SCALING * (curr_y - next_y)
          # print('===============================================')
          # print(self.structure_room_obj[next_y][next_x])
          self.structure_room_obj[next_y][next_x] = obj_name
          idx += 1


          ###  reset objects  ###
          blank_cdx = np.where(self.structure_room_obj == '0')
          shuf_idx = np.array(range(len(blank_cdx[0])), dtype=np.int32)
          np.random.shuffle(shuf_idx)

          blank_x_all = blank_cdx[0][shuf_idx]
          blank_y_all = blank_cdx[1][shuf_idx]

          # self.obj_random_reset(blank_x_all,blank_y_all)

          obj_num = random.randint(1, 2)
              # obj_num = 3

          # self.random_obj_random_reset(blank_x_all, blank_y_all, obj_num)











      ###   reset robot  ###
      curr_cdx = np.where(self.structure_room_obj == 'r')
      curr_x = curr_cdx[1][0]
      curr_y = curr_cdx[0][0]

      next_x = blank_x[idx]
      next_y = blank_y[idx]
      while self.structure_room_obj[next_y][next_x] != '0':
            idx += 1
            next_x = blank_x[idx]
            next_y = blank_y[idx]

      self.structure_room_obj[curr_y][curr_x] = '0'
      self.structure_room_obj[next_y][next_x] = 'r'
      bais_x, bais_y = self.robot_int_pos


      now_r_x, now_r_y = self.wrapped_env.get_xy()
      now_normal_x = now_r_x + bais_x * self.MAZE_SIZE_SCALING #cdx from structure
      now_normal_y = bais_y * self.MAZE_SIZE_SCALING - now_r_y


      next_set_x = now_r_x + (self.MAZE_SIZE_SCALING * next_x - now_normal_x)
      next_set_y = now_r_y + (now_normal_y - self.MAZE_SIZE_SCALING * next_y)

      self.wrapped_env.set_xy([next_set_x, next_set_y])




  def reset(self, obj_num = None):
    # print('==============================')
    self.last_dis = 0
    self.t = 0
    # self.wrapped_env.reset()
    # self.wrapped_env.reset_model()

    # Set everything other than ant to original position and 0 velocity.
    if self.robot == 'Ant':
        old_qpos = self.torso_init_pos + np.random.RandomState().uniform(
            size=self.sim.model.nq, low=-.1, high=.1)
        old_qvel = self.sim.get_state().qvel

        old_qpos[self.qpos_length:] = self.sim.get_state().qpos[self.qpos_length:]
        old_qvel[self.qvel_length:] = 0.

        self.wrapped_env.set_state(old_qpos, old_qvel)

        if self._maze_id == 'Maze_square_random'  or self._maze_id == 'Maze_square_blocked':
            blank_cdx = np.where(self.structure_room_obj == '0')
            shuf_idx = np.array(range(len(blank_cdx[0])), dtype=np.int32)
            np.random.shuffle(shuf_idx)
            blank_x_all = blank_cdx[0][shuf_idx]
            blank_y_all = blank_cdx[1][shuf_idx]
            self.door_random_reset(blank_x_all,blank_y_all)
        else:
            self.wrapped_env.set_xy([0, 0])
        self.sim.forward()

    self.updata_camera(view_flag=self.view_flag)
    if self.see_render:
        init_img_obs = self.sim.render(1500,1500)
        init_img_obs = init_img_obs[::-1, :, :]
        init_img_obs = cv2.resize(init_img_obs, None, fx = 0.1, fy = 0.1, interpolation = cv2.INTER_LINEAR)
        # cv2.imshow('img', init_img_obs)
        # cv2.waitKey(500)
    else:
        init_img_obs = None


    # print(self.structure_room_obj)
    # print(curr_goal)
    # print(curr_goal_1hot)
    if self.robot == 'Ant':
        obs_pos = self.wrapped_env._get_obs()
    else:
        temp_pos = self.wrapped_env.sim.data.qpos
        temp_vel = self.wrapped_env.sim.data.qvel
        # print(temp_pos.shape)
        # print(temp_pos.shape)
        obs_pos = np.concatenate([temp_pos.flat[2:], temp_vel.flat])


    self.ep_init_qpos = self.sim.get_state().qpos
    now_r_x, now_r_y = self.wrapped_env.get_xy()
    self.ori_xy = [now_r_x, now_r_y]
    self.curr_goal_idx = 0
    if self._maze_id != 'Maze_square_random' and self._maze_id != 'Maze_square_blocked':
        self.curr_goal = self.obj_list[self._maze_id]['obj_list'][self.curr_goal_idx]
    # print(self.structure_room_obj)
    return init_img_obs, obs_pos





  def calculate_distance(self, curr_goal_name,):
      #########    total x_,y_; obj x,y; obj_0_0 in total = x0_, y0_;  x_ = x + x0_, y_ = y0_ - y

      curr_agent_x, curr_agent_y = self.wrapped_env.get_xy()
      bias_r_x, bias_r_y = self.robot_int_pos
      normal_r_x = curr_agent_x + bias_r_x * self.MAZE_SIZE_SCALING
      normal_r_y = bias_r_y * self.MAZE_SIZE_SCALING - curr_agent_y

      if self._maze_id == 'Maze_square_random' or self._maze_id == 'Maze_square_blocked':
          goal_name = 'door_yel_0'
      else:
          goal_name = curr_goal_name

      if self._maze_id != 'Maze_U_shape':
          goal_x_id = self.sim.model.get_joint_qpos_addr("{}:x".format(goal_name))
          goal_y_id = self.sim.model.get_joint_qpos_addr("{}:y".format(goal_name))

          bias_x, bias_y = self.obj_init_pos[goal_name]

          goal_x = self.sim.data.qpos[goal_x_id]
          goal_y = self.sim.data.qpos[goal_y_id]

          normal_g_x = goal_x + bias_x * self.MAZE_SIZE_SCALING
          normal_g_y = bias_y * self.MAZE_SIZE_SCALING - goal_y

      else:

          [normal_g_x, normal_g_y] = self.obj_list[self._maze_id]['goal_xy'][self.curr_goal]



      distance__2 = (normal_g_x - normal_r_x)**2 + (normal_g_y - normal_r_y)**2
      distance = ( (normal_g_x - normal_r_x)**2 + (normal_g_y - normal_r_y)**2 )**0.5

      # reward = 1/(1 + distance__2/self.MAZE_SIZE_SCALING**2)
      # print(normal_r_x)
      # print(normal_g_x)

      # 0: x+, 1: x-, 2: y+, 3: y-


      pass_flag = self.obj_list[self._maze_id]['obj_list_goal'][curr_goal_name](normal_r_x, normal_r_y, normal_g_x, normal_g_y)
      # print(normal_r_x,normal_r_y)
      # print(normal_g_x,normal_g_y)

      return distance, pass_flag



  def get_final_goal_xy(self, target = 'final'):

      if self._maze_id == 'Maze_square_random' or self._maze_id == 'Maze_square_blocked':
          target = 'curr'

      if target == 'final':
          desired_goal = self.obj_list[self._maze_id]['obj_list'][-1]
      elif target == 'curr':
          desired_goal = self.curr_goal

      if self._maze_id == 'Maze_square_random' or self._maze_id == 'Maze_square_blocked':
          desired_goal = 'door_yel_0'

      if self._maze_id != 'Maze_U_shape':
          goal_x_id = self.sim.model.get_joint_qpos_addr("{}:x".format(desired_goal))
          goal_y_id = self.sim.model.get_joint_qpos_addr("{}:y".format(desired_goal))

          bias_x, bias_y = self.obj_init_pos[desired_goal]

          goal_x = self.sim.data.qpos[goal_x_id]
          goal_y = self.sim.data.qpos[goal_y_id]

          normal_g_x = goal_x + bias_x * self.MAZE_SIZE_SCALING
          normal_g_y = bias_y * self.MAZE_SIZE_SCALING - goal_y

      else:
          [normal_g_x, normal_g_y] = self.obj_list[self._maze_id]['goal_xy'][self.curr_goal]

      return  [normal_g_x,  normal_g_y]

  def get_curr_robot_xy(self):
      curr_agent_x, curr_agent_y = self.wrapped_env.get_xy()
      # print(curr_agent_x,curr_agent_y)
      bias_r_x, bias_r_y = self.robot_int_pos
      normal_r_x = curr_agent_x + bias_r_x * self.MAZE_SIZE_SCALING
      normal_r_y = bias_r_y * self.MAZE_SIZE_SCALING - curr_agent_y
      return  [normal_r_x,  normal_r_y]



  @property
  def viewer(self):
    return self.wrapped_env.viewer

  def render(self, *args, **kwargs):
    return self.wrapped_env.render(*args, **kwargs)

  @property
  def observation_space(self):
    shape = self.wrapped_env._get_obs().shape
    high = np.inf * np.ones(shape)
    low = -high
    return gym.spaces.Box(low, high)

  @property
  def action_space(self):
    return self.wrapped_env.action_space

  def _find_robot(self):
    structure = self.MAZE_STRUCTURE
    size_scaling = self.MAZE_SIZE_SCALING
    for i in range(len(structure)):
      for j in range(len(structure[0])):
        if structure[i][j] == 'r':
          return j * size_scaling, i * size_scaling
    assert False, 'No robot in maze specification.'

  def _find_all_robots(self):
    structure = self.MAZE_STRUCTURE
    size_scaling = self.MAZE_SIZE_SCALING
    coords = []
    for i in range(len(structure)):
      for j in range(len(structure[0])):
        if structure[i][j] == 'r':
          coords.append((j * size_scaling, i * size_scaling))
    return coords


  def step(self, action):
    self.t += 1


    # print(action)
    next_obs_qpos, inner_reward, done, info = self.wrapped_env.step(action)
    done = False
    # print('step',next_obs_qpos.shape)
    if self.robot == 'Ant':
        next_obs_qpos = np.concatenate((next_obs_qpos[2:],np.zeros(2)),axis=-1)
        old_qpos = self.sim.get_state().qpos
        old_qpos[self.qpos_length:] = self.ep_init_qpos[self.qpos_length:]
        old_qvel = self.sim.get_state().qvel
        old_qvel[self.qvel_length:] = 0.
        self.wrapped_env.set_state(old_qpos, old_qvel)

    else:
        temp_pos = self.wrapped_env.sim.data.qpos
        temp_vel = self.wrapped_env.sim.data.qvel
        # print(temp_pos.shape)
        # print(temp_pos.shape)
        next_obs_qpos = np.concatenate([temp_pos.flat[2:], temp_vel.flat])



    if self.see_render:
        self.updata_camera()

        img_obs_next = self.sim.render(1500, 1500)
        img_obs_next = img_obs_next[::-1, :, :]
        img_obs_next = cv2.resize(img_obs_next, None, fx=0.1, fy=0.1, interpolation=cv2.INTER_LINEAR)
        # cv2.imshow('img',img_obs_next)
        # cv2.waitKey(500)
    else:
        img_obs_next = None


    self.last_dis, pass_flag = self.calculate_distance(self.curr_goal)

    dense_reward = 1 / (1 + self.last_dis)
    sparse_reward = 0

    if pass_flag:
        # print('pass')
        if self._maze_id == 'Maze_square_random' or self._maze_id == 'Maze_square_blocked':
            sparse_reward = 1
            done = True
        else:
            self.curr_goal_idx += 1
            # print(self.curr_goal_idx)
            sparse_reward = 1

            if self.curr_goal_idx >= len(self.obj_list[self._maze_id]['obj_list']):
                done = True
                # print(done)
            else:
                self.curr_goal = self.obj_list[self._maze_id]['obj_list'][self.curr_goal_idx]

    return img_obs_next, dense_reward, done, next_obs_qpos, sparse_reward


  def step_for_tab(self, action):
    self.t += 1
    done = False
    self.tab_pos_step(action)

    old_qpos = self.sim.get_state().qpos
    if self.robot == 'Ant':
        old_qpos[15:] = self.ep_init_qpos[15:]


        random_robot_pos_idx = np.random.choice(len(self.offline_pos))
        old_qpos[2:15] = self.offline_pos[random_robot_pos_idx][:13] #offline data = agent_qpos[2:] + agent_qvel[] + np.zero(2)



        old_qvel = self.sim.get_state().qvel
        old_qvel[14:] = 0.
        old_qvel[:14] = self.offline_pos[random_robot_pos_idx][13:-2]
        self.wrapped_env.set_state(old_qpos, old_qvel)
        next_obs_qpos = np.concatenate((self.offline_pos[random_robot_pos_idx][:13], old_qvel[:14], np.zeros(2)), axis=-1)
    else:
        next_obs_qpos = None

    if self.see_render:
        self.updata_camera()

        img_obs_next = self.sim.render(1500, 1500)
        img_obs_next = img_obs_next[::-1, :, :]
        img_obs_next = cv2.resize(img_obs_next, None, fx=0.1, fy=0.1, interpolation=cv2.INTER_LINEAR)
        # cv2.imshow('img',img_obs_next)
        # cv2.waitKey(500)
    else:
        img_obs_next = None

    self.last_dis, pass_flag = self.calculate_distance(self.curr_goal)

    dense_reward = 1/(1 + self.last_dis)
    sparse_reward = 0

    if pass_flag:
        # print('pass')
        if self._maze_id == 'Maze_square_random' or self._maze_id == 'Maze_square_blocked':
            sparse_reward = 1
            done = True
        else:
            self.curr_goal_idx += 1
            # print(self.curr_goal_idx)
            sparse_reward = 1

            if self.curr_goal_idx >= len(self.obj_list[self._maze_id]['obj_list']):
                done = True
                # print(done)
            else:
                self.curr_goal = self.obj_list[self._maze_id]['obj_list'][self.curr_goal_idx]


    return img_obs_next, dense_reward, done, next_obs_qpos, sparse_reward



  def tab_pos_step(self, action):

      if action >=4:
          return

      bais_x, bais_y = self.robot_int_pos

      now_r_x, now_r_y = self.ori_xy # expected xy
      # print(now_r_x, now_r_y)

      now_tab_x = int(now_r_x/self.MAZE_SIZE_SCALING) + bais_x   # expexted structure xy
      now_tab_y = bais_y - int(now_r_y/self.MAZE_SIZE_SCALING)
      # print(now_tab_x)
      # print(now_tab_y)

      # print('tab_X act', action)


      if action == 0:  # x+ for structure, x->, y |v
          # print(len(self.structure_room_obj[:]))
          # print(self.structure_room_obj)
          if now_tab_x >= len(self.structure_room_obj[0][:])-1 or self.structure_room_obj[now_tab_y][now_tab_x + 1] == '1':
              next_x = now_tab_x
              next_y = now_tab_y
          else:
              next_x = now_tab_x + 1
              next_y = now_tab_y


      elif action == 1:  # x-
          if self.structure_room_obj[now_tab_y][now_tab_x - 1] == '1':
              next_x = now_tab_x
              next_y = now_tab_y
          else:
              next_x = now_tab_x - 1
              next_y = now_tab_y


      elif action == 2:  # y-
          if self.structure_room_obj[now_tab_y - 1][now_tab_x] == '1':
              next_x = now_tab_x
              next_y = now_tab_y
          else:
              next_x = now_tab_x
              next_y = now_tab_y - 1


      elif action == 3:  # y+
          if now_tab_y >= len(self.structure_room_obj)-1 or self.structure_room_obj[now_tab_y + 1][now_tab_x] == '1':
              next_x = now_tab_x
              next_y = now_tab_y
          else:
              next_x = now_tab_x
              next_y = now_tab_y + 1
      else:
          print('action error')


      now_real_x, now_real_y = self.wrapped_env.get_xy()
      now_normal_x = now_real_x + bais_x * self.MAZE_SIZE_SCALING # real cdx in structure
      now_normal_y = bais_y * self.MAZE_SIZE_SCALING - now_real_y

      # print(next_x)
      # next_x = next expected tab x


      next_expect_x = now_real_x + (self.MAZE_SIZE_SCALING * next_x - now_normal_x)
      next_expect_y = now_real_y + (now_normal_y - self.MAZE_SIZE_SCALING * next_y)

      self.ori_xy = [round(next_expect_x)*1.0, round(next_expect_y)*1.0]
      if self.random_xy_pos:
          random_x = np.random.normal(0, self.MAZE_SIZE_SCALING/9)
          random_y = np.random.normal(0, self.MAZE_SIZE_SCALING/9)
          if abs(random_x) > self.MAZE_SIZE_SCALING/2:
              random_x = random_x/abs(random_x)*self.MAZE_SIZE_SCALING/2
          if abs(random_y) > self.MAZE_SIZE_SCALING/2:
              random_y = random_y/abs(random_y)*self.MAZE_SIZE_SCALING/2



          if (next_x <= 1 and random_x < -1) or (next_x >= 5 and random_x >1):
              next_set_x = next_expect_x
          else:
              next_set_x = next_expect_x + random_x

          if (next_y <= 1 and random_y < -1) or (next_y >= 5 and random_y > 1):
              next_set_y = next_expect_y
          else:
              next_set_y = next_expect_y + random_y
      else:
          next_set_x = next_expect_x
          next_set_y = next_expect_y


      # print(next_expect_x)
      # print(next_expect_y)
      # print(next_set_x)
      # print(next_set_y)

      self.wrapped_env.set_xy([next_set_x, next_set_y])

