import numpy as np
from deep_sprl.util.maze_env_utils import construct_maze, is_feasible, find_robot
from gym import utils
from gym.envs.mujoco import MujocoEnv
import xml.etree.ElementTree as ET
import tempfile
import os


class MazeEnv(MujocoEnv, utils.EzPickle):
    ORI_IND = None
    MAZE_SIZE_SCALING = None

    def __init__(self, maze_id=0, length=1, maze_height=0.5, maze_size_scaling=2, terminal_eps=0.3,
                 context=np.array([4., 4.]), *args, **kwargs):
        self._maze_id = maze_id
        self.length = length
        self.terminal_eps = terminal_eps
        self.MAZE_SIZE_SCALING = maze_size_scaling
        self.MAZE_STRUCTURE = structure = construct_maze(maze_id=self._maze_id, length=self.length)
        self.context = context

        torso_x, torso_y = find_robot(self.MAZE_STRUCTURE, self.MAZE_SIZE_SCALING)
        self._init_torso_x = torso_x
        self._init_torso_y = torso_y
        self._inited = False

        xml_dir = os.path.join(os.path.dirname(os.path.realpath(__file__)), "data")
        xml_path = os.path.join(xml_dir, "point.xml")

        tree = ET.parse(xml_path)
        worldbody = tree.find(".//worldbody")
        structure = construct_maze(maze_id=self._maze_id, length=self.length)
        for i in range(len(structure)):
            for j in range(len(structure[0])):
                if str(structure[i][j]) == '1':
                    # offset all coordinates so that robot starts at the origin
                    ET.SubElement(
                        worldbody, "geom",
                        name="block_%d_%d" % (i, j),
                        pos="%f %f %f" % (j * maze_size_scaling - torso_x,
                                          i * maze_size_scaling - torso_y,
                                          maze_height / 2 * maze_size_scaling),
                        size="%f %f %f" % (0.5 * maze_size_scaling,
                                           0.5 * maze_size_scaling,
                                           maze_height / 2 * maze_size_scaling),
                        type="box",
                        material="",
                        contype="1",
                        conaffinity="1",
                        rgba="0.4 0.4 0.4 1."
                    )

        torso = tree.find(".//body[@name='torso']")
        geoms = torso.findall(".//geom")
        for geom in geoms:
            if 'name' not in geom.attrib:
                raise Exception("Every geom of the torso must have a name "
                                "defined")

        contact = ET.SubElement(
            tree.find("."), "contact"
        )
        for i in range(len(structure)):
            for j in range(len(structure[0])):
                if str(structure[i][j]) == '1':
                    for geom in geoms:
                        ET.SubElement(
                            contact, "pair",
                            geom1=geom.attrib["name"],
                            geom2="block_%d_%d" % (i, j)
                        )

        _, temp_file_path = tempfile.mkstemp(text=True, suffix='.xml')
        tree.write(temp_file_path)

        MujocoEnv.__init__(self, temp_file_path, 5)
        utils.EzPickle.__init__(**locals())

    def step(self, action):

        self.do_simulation(action, self.frame_skip)
        reward = self._compute_dist_reward()

        ob = self._get_obs()
        done = False
        if self._is_goal_reached() and self._inited:
            done = True

        return ob, reward, done, dict(
            success=self._is_goal_reached(),
            distance=self._get_dist2goal())

    def _get_obs(self):
        # original goal gan used additional orientation and position of torso, should add them if not work
        return np.concatenate([
            self.sim.data.qpos.flat[:2],
            self.sim.data.qvel.flat[:2]
        ])

    def _get_dist2goal(self):
        dist = np.linalg.norm(
            self.get_body_com("torso")[:2] - self.get_body_com("target")[:2]
        )
        return dist

    def _compute_dist_reward(self):
        if self._is_goal_reached():
            if is_feasible(self.context, self.MAZE_STRUCTURE, self.MAZE_SIZE_SCALING, self._init_torso_x,
                           self._init_torso_y):
                return 1.0
            else:
                return 0.0
        else:
            return 0.0

    def _is_goal_reached(self):
        return self._get_dist2goal() < self.terminal_eps

    def reset_model(self):
        qpos = self.init_qpos
        qpos[2:] = np.array(self.context)
        qvel = self.init_qvel
        self.set_state(qpos, qvel)
        self._inited = True
        return self._get_obs()

    def viewer_setup(self):
        self.viewer.cam.distance = self.model.stat.extent
        self.viewer.cam.elevation = -90
