# Copyright 2018 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Adapted from efficient-hrl maze_env.py."""

import math
import os
import random
import tempfile
import xml.etree.ElementTree as ET
from copy import deepcopy
from collections import defaultdict

import gym
import numpy as np

RESET = R = 'r'  # Reset position.
GOAL = G = 'g'

# yapf: disable
# Maze specifications for dataset generation
TEST_MAP = [[1, 1, 1, 1, 1],
            [1, 1, 0, 1, 1],
            [1, 0, R, 0, 1],
            [1, 1, 0, 1, 1],
            [1, 1, 1, 1, 1]]

U_MAZE = [[1, 1, 1, 1, 1],
          [1, R, 0, 0, 1],
          [1, 1, 1, 0, 1],
          [1, G, 0, 0, 1],
          [1, 1, 1, 1, 1]]

BIG_MAZE = [[1, 1, 1, 1, 1, 1, 1, 1],
            [1, R, 0, 1, 1, 0, 0, 1],
            [1, 0, 0, 1, 0, 0, G, 1],
            [1, 1, 0, 0, 0, 1, 1, 1],
            [1, 0, 0, 1, 0, 0, 0, 1],
            [1, G, 1, 0, 0, 1, 0, 1],
            [1, 0, 0, 0, 1, G, 0, 1],
            [1, 1, 1, 1, 1, 1, 1, 1]]

HARDEST_MAZE = [[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
                [1, R, 0, 0, 0, 1, G, 0, 0, 0, 0, 1],
                [1, 0, 1, 1, 0, 1, 0, 1, 0, 1, 0, 1],
                [1, 0, 0, 0, 0, G, 0, 1, 0, 0, G, 1],
                [1, 0, 1, 1, 1, 1, 0, 1, 1, 1, 0, 1],
                [1, 0, G, 1, 0, 1, 0, 0, 0, 0, 0, 1],
                [1, 1, 0, 1, 0, 1, 0, 1, 0, 1, 1, 1],
                [1, 0, 0, 1, G, 0, G, 1, 0, G, 0, 1],
                [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]]

# Maze specifications with a single target goal
U_MAZE_TEST = [[1, 1, 1, 1, 1],
               [1, R, 0, 0, 1],
               [1, 1, 1, 0, 1],
               [1, G, 0, 0, 1],
               [1, 1, 1, 1, 1]]

BIG_MAZE_TEST = [[1, 1, 1, 1, 1, 1, 1, 1],
                 [1, R, 0, 1, 1, 0, 0, 1],
                 [1, 0, 0, 1, 0, 0, 0, 1],
                 [1, 1, 0, 0, 0, 1, 1, 1],
                 [1, 0, 0, 1, 0, 0, 0, 1],
                 [1, 0, 1, 0, 0, 1, 0, 1],
                 [1, 0, 0, 0, 1, 0, G, 1],
                 [1, 1, 1, 1, 1, 1, 1, 1]]

HARDEST_MAZE_TEST = [[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
                     [1, R, 0, 0, 0, 1, 0, 0, 0, 0, 0, 1],
                     [1, 0, 1, 1, 0, 1, 0, 1, 0, 1, 0, 1],
                     [1, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 1],
                     [1, 0, 1, 1, 1, 1, 0, 1, 1, 1, 0, 1],
                     [1, 0, 0, 1, 0, 1, 0, 0, 0, 0, 0, 1],
                     [1, 1, 0, 1, 0, 1, 0, 1, 0, 1, 1, 1],
                     [1, 0, 0, 1, 0, 0, 0, 1, 0, G, 0, 1],
                     [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]]

# Maze specifications for evaluation
U_MAZE_EVAL = [[1, 1, 1, 1, 1],
               [1, 0, 0, R, 1],
               [1, 0, 1, 1, 1],
               [1, 0, 0, G, 1],
               [1, 1, 1, 1, 1]]

BIG_MAZE_EVAL = [[1, 1, 1, 1, 1, 1, 1, 1],
                 [1, R, 0, 0, 0, 0, G, 1],
                 [1, 0, 1, 0, 1, 1, 0, 1],
                 [1, 0, 0, 0, 0, 1, 0, 1],
                 [1, 1, 1, 0, 0, 1, 1, 1],
                 [1, G, 0, 0, 0, 0, 0, 1],
                 [1, 0, 0, 1, 1, G, 0, 1],
                 [1, 1, 1, 1, 1, 1, 1, 1]]

HARDEST_MAZE_EVAL = [[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
                     [1, R, 0, 1, G, 0, 0, 1, 0, G, 0, 1],
                     [1, 1, 0, 1, 1, 1, 0, 1, 0, 1, 0, 1],
                     [1, 0, 0, 1, 0, 1, G, 0, 0, 0, 0, 1],
                     [1, 0, 1, 1, 0, 1, 0, 0, 1, 1, 0, 1],
                     [1, G, 0, 0, 0, 0, 0, 1, 0, 0, 0, 1],
                     [1, 0, 1, 1, 0, 1, 0, 1, 0, 1, 1, 1],
                     [1, 0, 0, 0, G, 1, G, 0, 0, 0, G, 1],
                     [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]]

U_MAZE_EVAL_TEST = [[1, 1, 1, 1, 1],
                    [1, 0, 0, R, 1],
                    [1, 0, 1, 1, 1],
                    [1, 0, 0, G, 1],
                    [1, 1, 1, 1, 1]]

BIG_MAZE_EVAL_TEST = [[1, 1, 1, 1, 1, 1, 1, 1],
                      [1, R, 0, 0, 0, 0, G, 1],
                      [1, 0, 1, 0, 1, 1, 0, 1],
                      [1, 0, 0, 0, 0, 1, 0, 1],
                      [1, 1, 1, 0, 0, 1, 1, 1],
                      [1, 0, 0, 0, 0, 0, 0, 1],
                      [1, 0, 0, 1, 1, 0, 0, 1],
                      [1, 1, 1, 1, 1, 1, 1, 1]]

HARDEST_MAZE_EVAL_TEST = [[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
                          [1, R, 0, 1, 0, 0, 0, 1, 0, G, 0, 1],
                          [1, 1, 0, 1, 1, 1, 0, 1, 0, 1, 0, 1],
                          [1, 0, 0, 1, 0, 1, 0, 0, 0, 0, 0, 1],
                          [1, 0, 1, 1, 0, 1, 0, 0, 1, 1, 0, 1],
                          [1, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 1],
                          [1, 0, 1, 1, 0, 1, 0, 1, 0, 1, 1, 1],
                          [1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 1],
                          [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]]

# yapf: enable


class MazeEnv(gym.Env):
    LOCOMOTION_ENV = None  # Must be specified by child class.

    def __init__(self,
                 maze_map,
                 maze_size_scaling,
                 maze_height=0.5,
                 manual_collision=False,
                 non_zero_reset=False,
                 reward_type='dense',
                 return_direction=False,
                 *args,
                 **kwargs):
        if self.LOCOMOTION_ENV is None:
            raise ValueError('LOCOMOTION_ENV is unspecified.')

        xml_path = self.LOCOMOTION_ENV.FILE
        tree = ET.parse(xml_path)
        worldbody = tree.find(".//worldbody")

        self._maze_map = maze_map

        self._maze_height = maze_height
        self._maze_size_scaling = maze_size_scaling
        self._manual_collision = manual_collision
        self._return_direction = return_direction

        self.id_to_xy = {}
        idx = 0
        for i in range(len(maze_map)):
            for j in range(len(maze_map[0])):
                if self._maze_map[i][j] != 1:  # if not wall
                    self.id_to_xy[idx] = tuple(
                        (np.array([j - 1, i - 1]) *
                         self._maze_size_scaling).astype("int"))
                    idx += 1

        self.xy_to_id = {v: k for k, v in self.id_to_xy.items()}

        if maze_map == U_MAZE_TEST:
            self.n_start_groups = 3

            def _pos_to_start_id(xy):
                if xy[1] < 4 and xy[0] < 6:
                    return 0
                elif xy[1] >= 4 and xy[0] < 6:
                    return 1
                else:
                    return 2

        elif maze_map == BIG_MAZE_TEST:
            self.n_start_groups = 4

            def _pos_to_start_id(xy):
                if xy[1] < 10 and xy[0] < 10:
                    return 0
                elif xy[1] >= 10 and xy[0] < 10:
                    return 1
                elif xy[1] < 10 and xy[0] >= 10:
                    return 2
                elif xy[1] >= 10 and xy[0] >= 10:
                    return 3

        else:
            raise ValueError('Invalid maze shape.')

        self._pos_to_start_id = _pos_to_start_id
        self.set_start_id_to_pos_group_dict()
        self.n_goals = len(self.id_to_xy)

        # Obtain a numpy array form for a maze map in case we want to reset
        # to multiple starting states
        temp_maze_map = deepcopy(self._maze_map)
        for i in range(len(maze_map)):
            for j in range(len(maze_map[0])):
                if temp_maze_map[i][j] == RESET:
                    temp_maze_map[i][j] = 0
                elif temp_maze_map[i][j] == GOAL:
                    # temp_maze_map[i][j] = 1
                    temp_maze_map[i][j] = 0

        self._np_maze_map = np.array(temp_maze_map)

        torso_x, torso_y = self._find_robot()
        self._init_torso_x = torso_x
        self._init_torso_y = torso_y
        self._init_torso = np.array([torso_x, torso_y])

        for i in range(len(self._maze_map)):
            for j in range(len(self._maze_map[0])):
                struct = self._maze_map[i][j]
                if struct == 1:  # Unmovable block.
                    # Offset all coordinates so that robot starts at the origin.
                    ET.SubElement(
                        worldbody,
                        "geom",
                        name="block_%d_%d" % (i, j),
                        pos="%f %f %f" %
                        (j * self._maze_size_scaling - torso_x,
                         i * self._maze_size_scaling - torso_y,
                         self._maze_height / 2 * self._maze_size_scaling),
                        size="%f %f %f" %
                        (0.5 * self._maze_size_scaling,
                         0.5 * self._maze_size_scaling,
                         self._maze_height / 2 * self._maze_size_scaling),
                        type="box",
                        material="",
                        contype="255",
                        conaffinity="255",
                        rgba="0.7 0.5 0.3 1.0",
                    )

        torso = tree.find(".//body[@name='torso']")
        geoms = torso.findall(".//geom")

        _, file_path = tempfile.mkstemp(text=True, suffix='.xml')
        tree.write(file_path)

        self.LOCOMOTION_ENV.__init__(self,
                                     *args,
                                     file_path=file_path,
                                     non_zero_reset=non_zero_reset,
                                     reward_type=reward_type,
                                     **kwargs)

        self.target_goal = None

    def pos_to_start_id(self, xy: np.ndarray):
        return self._pos_to_start_id(xy)

    def set_start_id_to_pos_group_dict(self):
        self.start_id_to_pos_group_dict = defaultdict(list)
        for pos in self.xy_to_id.keys():
            start_id = self.pos_to_start_id(pos)
            self.start_id_to_pos_group_dict[start_id].append(pos)

    def goal_id_and_start_id_to_task_id(self, goal_id: int, start_id: int):
        return self.n_start_groups * goal_id + start_id

    def task_id_to_goal_id(self, task_id: int):
        return task_id // self.n_start_groups

    def task_id_to_start_id(self, task_id: int):
        return task_id % self.n_start_groups

    def goal_id_to_task_id_list(self, goal_id: int):
        start = goal_id * self.n_start_groups
        return list(range(start, start + self.n_start_groups))

    def setup_task(self, start_id: int, goal_id: int):
        """Initialize the robot position and the goal.

      Returns: goal position.
      """
        self.sim.reset()
        self.LOCOMOTION_ENV.reset_model(self)
        goal = self.id_to_xy[goal_id]
        self.set_target_goal(goal_input=goal)

        reset_location_list = self.start_id_to_pos_group_dict[start_id]
        reset_location = np.array(random.sample(reset_location_list, k=1)[0],
                                  dtype=self.observation_space.dtype)

        while np.all(np.abs(reset_location - np.array(goal)) < 1.):
            reset_location = np.array(random.sample(reset_location_list,
                                                    k=1)[0],
                                      dtype=self.observation_space.dtype)

        self.set_init_xy(init_xy=reset_location)

        self.set_marker()
        return goal

    def _xy_to_rowcol(self, xy):
        size_scaling = self._maze_size_scaling
        xy = np.array(xy)
        rowcol = np.floor((xy + self._init_torso) / size_scaling + 0.5)[::-1]
        return rowcol.astype(np.int8)

    def _get_reset_location(self, ):
        prob = (1.0 - self._np_maze_map) / np.sum(1.0 - self._np_maze_map)
        prob_row = np.sum(prob, 1)
        row_sample = np.random.choice(np.arange(self._np_maze_map.shape[0]),
                                      p=prob_row)
        col_sample = np.random.choice(np.arange(self._np_maze_map.shape[1]),
                                      p=prob[row_sample] * 1.0 /
                                      prob_row[row_sample])
        target_rowcol = self._xy_to_rowcol(self.target_goal)
        while np.abs(np.array([row_sample, col_sample]) -
                     target_rowcol).sum() < 1e-4:
            row_sample = np.random.choice(np.arange(
                self._np_maze_map.shape[0]),
                                          p=prob_row)
            col_sample = np.random.choice(
                np.arange(self._np_maze_map.shape[1]),
                p=prob[row_sample] * 1.0 / prob_row[row_sample])
        reset_location = self._rowcol_to_xy((row_sample, col_sample))

        # Add some random noise
        random_x = np.random.uniform(low=-0.5,
                                     high=0.5) * 0.25 * self._maze_size_scaling
        random_y = np.random.uniform(low=-0.5,
                                     high=0.5) * 0.25 * self._maze_size_scaling

        return np.array(
            [reset_location[0] + random_x, reset_location[1] + random_y])

    def _rowcol_to_xy(self, rowcol, add_random_noise=False):
        row, col = rowcol
        x = col * self._maze_size_scaling - self._init_torso_x
        y = row * self._maze_size_scaling - self._init_torso_y
        if add_random_noise:
            x = x + np.random.uniform(
                low=-0.5, high=0.5) * self._maze_size_scaling * 0.25
            y = y + np.random.uniform(
                low=-0.5, high=0.5) * self._maze_size_scaling * 0.25
        return (x, y)

    def goal_sampler(self, np_random, only_free_cells=True, interpolate=True):
        valid_cells = []
        goal_cells = []

        for i in range(len(self._maze_map)):
            for j in range(len(self._maze_map[0])):
                if self._maze_map[i][j] in [0, RESET, GOAL
                                            ] or not only_free_cells:
                    valid_cells.append((i, j))
                # if self._maze_map[i][j] == GOAL:
                #   goal_cells.append((i, j))

        # If there is a 'goal' designated, use that. Otherwise, any valid cell can
        # be a goal.
        sample_choices = goal_cells if goal_cells else valid_cells
        cell = sample_choices[np_random.choice(len(sample_choices))]
        xy = self._rowcol_to_xy(cell, add_random_noise=False)

        random_x = np.random.uniform(low=-0.5,
                                     high=0.5) * 0.25 * self._maze_size_scaling
        random_y = np.random.uniform(low=-0.5,
                                     high=0.5) * 0.25 * self._maze_size_scaling

        xy = np.array([xy[0] + random_x, xy[1] + random_y])

        return xy

    def set_target_goal(self, goal_input=None):
        init_rowcol = self._xy_to_rowcol(self.LOCOMOTION_ENV.get_xy(self))
        if goal_input is None:
            self.target_goal = self.goal_sampler(np.random)
            target_rowcol = self._xy_to_rowcol(self.target_goal)
            while np.abs(init_rowcol - target_rowcol).sum() < 1e-4:
                self.target_goal = self.goal_sampler(np.random)
                target_rowcol = self._xy_to_rowcol(self.target_goal)
        else:
            self.target_goal = goal_input

        # print ('Target Goal: ', self.target_goal)

        ## Make sure that the goal used in self._goal is also reset:
        self._goal = self.target_goal

    def set_init_xy(self, init_xy=None):
        if init_xy is None:
            init_xy = self._get_reset_location()
        self.set_xy(init_xy)

        init_rowcol = self._xy_to_rowcol(init_xy)
        target_rowcol = self._xy_to_rowcol(self.target_goal)
        if np.abs(init_rowcol - target_rowcol).sum() < 1e-4:
            print(init_xy, init_rowcol)
            print(self.target_goal, target_rowcol)
            raise ValueError

    def _find_robot(self):
        structure = self._maze_map
        size_scaling = self._maze_size_scaling
        for i in range(len(structure)):
            for j in range(len(structure[0])):
                if structure[i][j] == RESET:
                    return j * size_scaling, i * size_scaling
        raise ValueError('No robot in maze specification.')

    def _is_in_collision(self, pos):
        x, y = pos
        structure = self._maze_map
        size_scaling = self._maze_size_scaling
        for i in range(len(structure)):
            for j in range(len(structure[0])):
                if structure[i][j] == 1:
                    minx = j * size_scaling - size_scaling * 0.5 - self._init_torso_x
                    maxx = j * size_scaling + size_scaling * 0.5 - self._init_torso_x
                    miny = i * size_scaling - size_scaling * 0.5 - self._init_torso_y
                    maxy = i * size_scaling + size_scaling * 0.5 - self._init_torso_y
                    if minx <= x <= maxx and miny <= y <= maxy:
                        return True
        return False

    def step(self, action):
        if self._manual_collision:
            old_pos = self.get_xy()
            inner_next_obs, inner_reward, done, info = self.LOCOMOTION_ENV.step(
                self, action)
            new_pos = self.get_xy()
            if self._is_in_collision(new_pos):
                self.set_xy(old_pos)
        else:
            inner_next_obs, inner_reward, done, info = self.LOCOMOTION_ENV.step(
                self, action)
        self.set_marker()
        next_obs = self._get_obs()
        return next_obs, inner_reward, done, info

    def get_obs(self):
        return self._get_obs()

    def reset(self, **kwargs):
        obs = self.LOCOMOTION_ENV.reset(self, **kwargs)
        self.set_target_goal()

        # start = np.array([0, 4]) + 0.5 * np.random.uniform(-1, 1, (2, ))
        # goal = np.array([0, 0]) + 0.5 * np.random.uniform(-1, 1, (2, ))
        # self.set_target_goal(goal)
        # self.set_init_xy(start)
        obs = self._get_obs()

        return obs

    def set_marker(self):
        self.data.site_xpos[self.model.site_name2id("target_site")] = np.array(
            [self.target_goal[0], self.target_goal[1], 0.0])

    def _get_best_next_rowcol(self, current_rowcol, target_rowcol):
        """Runs BFS to find shortest path to target and returns best next rowcol. 
       Add obstacle avoidance"""
        current_rowcol = tuple(current_rowcol)
        target_rowcol = tuple(target_rowcol)
        if target_rowcol == current_rowcol:
            return target_rowcol

        visited = {}
        to_visit = [target_rowcol]
        while to_visit:
            next_visit = []
            for rowcol in to_visit:
                visited[rowcol] = True
                row, col = rowcol
                left = (row, col - 1)
                right = (row, col + 1)
                down = (row + 1, col)
                up = (row - 1, col)
                for next_rowcol in [left, right, down, up]:
                    if next_rowcol == current_rowcol:  # Found a shortest path.
                        return rowcol
                    next_row, next_col = next_rowcol
                    if next_row < 0 or next_row >= len(self._maze_map):
                        continue
                    if next_col < 0 or next_col >= len(self._maze_map[0]):
                        continue
                    if self._maze_map[next_row][next_col] not in [
                            0, RESET, GOAL
                    ]:
                        continue
                    if next_rowcol in visited:
                        continue
                    next_visit.append(next_rowcol)
            to_visit = next_visit

        print(current_rowcol, target_rowcol)
        raise ValueError('No path found to target.')

    def create_navigation_policy(self,
                                 goal_reaching_policy_fn,
                                 obs_to_robot=lambda obs: obs[:2],
                                 obs_to_target=lambda obs: obs[-2:],
                                 relative=False):
        """Creates a navigation policy by guiding a sub-policy to waypoints."""

        def policy_fn(obs):
            # import ipdb; ipdb.set_trace()
            robot_x, robot_y = obs_to_robot(obs)
            robot_row, robot_col = self._xy_to_rowcol([robot_x, robot_y])
            target_x, target_y = self.target_goal
            if relative:
                target_x += robot_x  # Target is given in relative coordinates.
                target_y += robot_y
            target_row, target_col = self._xy_to_rowcol([target_x, target_y])
            # print ('Target: ', target_row, target_col, target_x, target_y)
            # print ('Robot: ', robot_row, robot_col, robot_x, robot_y)

            waypoint_row, waypoint_col = self._get_best_next_rowcol(
                [robot_row, robot_col], [target_row, target_col])

            if waypoint_row == target_row and waypoint_col == target_col:
                waypoint_x = target_x
                waypoint_y = target_y
            else:
                waypoint_x, waypoint_y = self._rowcol_to_xy(
                    [waypoint_row, waypoint_col], add_random_noise=True)

            goal_x = waypoint_x - robot_x
            goal_y = waypoint_y - robot_y

            # print ('Waypoint: ', waypoint_row, waypoint_col, waypoint_x, waypoint_y)

            if self._return_direction:
                target = (goal_x, goal_y)
            else:
                target = (waypoint_x, waypoint_y)

            return goal_reaching_policy_fn(obs, target)

        return policy_fn

    def viewer_setup(self):
        height, width = self._np_maze_map.shape
        init_x, init_y = self._init_torso
        scale = self._maze_size_scaling
        self.viewer.cam.lookat[0] = (width -
                                     1) * scale / 2 - init_x
        self.viewer.cam.lookat[1] = (height -
                                     1) * scale / 2 - init_y
        self.viewer.cam.lookat[2] = 8.
        self.viewer.cam.elevation = -90

    def get_distance(self, start_xy, target_xy):
        """Runs BFS to find shortest path to target and returns distance. 
        """
        current_rowcol = tuple(self._xy_to_rowcol(start_xy))
        target_rowcol = tuple(self._xy_to_rowcol(target_xy))
        if target_rowcol == current_rowcol:
            return 0

        visited = {}
        to_visit = [target_rowcol]
        distance = 1
        while to_visit:
            next_visit = []
            for rowcol in to_visit:
                visited[rowcol] = True
                row, col = rowcol
                left = (row, col - 1)
                right = (row, col + 1)
                down = (row + 1, col)
                up = (row - 1, col)
                for next_rowcol in [left, right, down, up]:
                    if next_rowcol == current_rowcol:  # Found a shortest path.
                        return distance
                    next_row, next_col = next_rowcol
                    if next_row < 0 or next_row >= len(self._maze_map):
                        continue
                    if next_col < 0 or next_col >= len(self._maze_map[0]):
                        continue
                    if self._maze_map[next_row][next_col] not in [
                            0, RESET, GOAL
                    ]:
                        continue
                    if next_rowcol in visited:
                        continue
                    next_visit.append(next_rowcol)
            to_visit = next_visit
            distance += 1

        print(current_rowcol, target_rowcol)
        raise ValueError('No path found to target.')
