import numpy as np
import networkx as nx

from gym.envs.mujoco import mujoco_env
from d4rl import offline_env
from d4rl.locomotion import maze_env, goal_reaching_env, ant, wrappers


RESET = R = 'r'  # Reset position.
GOAL = G = 'g'
TERMINAL = T = 't'

# Maze specifications with a single target goal
U_MAZE_TEST = [[1, 1, 1, 1, 1],
              [1, R, 0, 0, 1],
              [1, T, 1, 0, 1],
              [1, G, 0, 0, 1],
              [1, 1, 1, 1, 1]]


class DangerousMazeEnv(maze_env.MazeEnv):
    
    def __init__(
            self,
            maze_map,
            maze_size_scaling,
            maze_height=0.5,
            manual_collision=False,
            non_zero_reset=False,
            reward_type='dense',
            *args,
            **kwargs):
        super().__init__(maze_map, maze_size_scaling, maze_height, manual_collision, non_zero_reset, reward_type, *args, **kwargs)
        
        self._np_maze_map = np.where(self._np_maze_map == 't', 1, self._np_maze_map).astype(np.int32)
    
    def get_terminal_locations(self):
        terminal_locations = []
        
        for i in range(len(self._maze_map)):
            for j in range(len(self._maze_map[0])):
                if self._maze_map[i][j] == TERMINAL:
                    terminal_locations.append(self._rowcol_to_xy((i, j)))
        
        return terminal_locations


class DangerousGoalReachingAntEnv(goal_reaching_env.GoalReachingEnv, ant.AntEnv):
    """Ant locomotion rewarded for goal-reaching."""
    BASE_ENV = ant.AntEnv
    
    def __init__(self, goal_sampler=goal_reaching_env.disk_goal_sampler,
                 terminal_locations_fnc=None,
                 file_path=None,
                 expose_all_qpos=False, non_zero_reset=False, eval=False, reward_type='dense', **kwargs):
        self.terminal_locations = None
        goal_reaching_env.GoalReachingEnv.__init__(self, goal_sampler, eval=eval, reward_type=reward_type)

        if terminal_locations_fnc is not None:
            self.terminal_locations = terminal_locations_fnc()
        
        if self.reward_type == 'dense':
            G = nx.Graph()
            for i in range(len(self._maze_map)):
                for j in range(len(self._maze_map[0])):
                    if self._maze_map[i][j] != 1 and self._maze_map[i][j] != 't':
                        G.add_node((i, j))
                        if i > 0 and self._maze_map[i - 1][j] != 1:
                            G.add_edge((i, j), (i - 1, j))
                        if j > 0 and self._maze_map[i][j - 1] != 1:
                            G.add_edge((i, j), (i, j - 1))
            self.graph = G
            self.path = [(0, 0)]
            self.path_index = 0
        
        ant.AntEnv.__init__(self,
                            file_path=file_path,
                            expose_all_qpos=expose_all_qpos,
                            expose_body_coms=None,
                            expose_body_comvels=None,
                            non_zero_reset=non_zero_reset)

    def step(self, a):
        obs, reward, done, info = super().step(a)
        
        if self.reward_type == 'dense':
            if np.linalg.norm(self.get_xy() - self._rowcol_to_xy(self.path[self.path_index])) <= 0.5:
                self.path_index += 1
                print('Path index: {:d}/{:d}'.format(self.path_index, len(self.path)))
            reward = -np.linalg.norm(self._rowcol_to_xy(self.path[self.path_index]) - self.get_xy())
            reward += -self._maze_size_scaling * (len(self.path) - self.path_index - 1)
        
        if self.terminal_locations is not None:
            for terminal_location in self.terminal_locations:
                if np.linalg.norm(self.get_xy() - terminal_location) <= 0.5:
                    done = True
                    if self.reward_type == 'sparse':
                        reward = -1.0
                    else:
                        reward = -self._maze_size_scaling * (len(self.path) - self.path_index - 1) * 200
                    break
        
        return obs, reward, done, info
    
    def reset(self):
        obs = super().reset()
        
        if self.reward_type == 'dense':
            self.path = nx.shortest_path(self.graph, self._xy_to_rowcol(self.get_xy()), self._xy_to_rowcol(self.target_goal))
            self.path_index = 1
        
        return obs


class DangerousAntMazeEnv(DangerousMazeEnv, DangerousGoalReachingAntEnv, offline_env.OfflineEnv):
    """Ant navigating a maze."""
    LOCOMOTION_ENV = DangerousGoalReachingAntEnv
    
    def __init__(self, goal_sampler=None, expose_all_qpos=True,
                 reward_type='dense', v2_resets=False,
                 *args, **kwargs):
        if goal_sampler is None:
            goal_sampler = lambda np_rand: maze_env.MazeEnv.goal_sampler(self, np_rand)
        DangerousMazeEnv.__init__(
            self, *args, manual_collision=False,
            goal_sampler=goal_sampler,
            terminal_locations_fnc=lambda : DangerousMazeEnv.get_terminal_locations(self),
            expose_all_qpos=expose_all_qpos,
            reward_type=reward_type,
            **kwargs)
        offline_env.OfflineEnv.__init__(self, **kwargs)
        
        ## We set the target foal here for evaluation
        self.set_target()
        self.v2_resets = v2_resets
    
    def reset(self):
        if self.v2_resets:
            """
            The target goal for evaluation in antmazes is randomized.
            antmazes-v0 and -v1 resulted in really high-variance evaluations
            because the target goal was set once at the seed level. This led to
            each run running evaluations with one particular goal. To accurately
            cover each goal, this requires about 50-100 seeds, which might be
            computationally infeasible. As an alternate fix, to reduce variance
            in result reporting, we are creating the v2 environments
            which use the same offline dataset as v0 environments, with the distinction
            that the randomization of goals during evaluation is performed at the level of
            each rollout. Thus running a few seeds, but performing the final evaluation
            over 100-200 episodes will give a valid estimate of an algorithm's performance.
            """
            self.set_target()
        return super().reset()
    
    def set_target(self, target_location=None):
        return self.set_target_goal(target_location)
    
    def seed(self, seed=0):
        mujoco_env.MujocoEnv.seed(self, seed)


def make_dangerous_ant_maze_env(**kwargs):
    env = DangerousAntMazeEnv(**kwargs)
    return wrappers.NormalizedBoxEnv(env)


def modify_dataset(dataset, env):
    if not isinstance(env.wrapped_env, DangerousMazeEnv):
        raise ValueError('This function only supports DangerousMazeEnv')
    
    terminal_locations = env.wrapped_env.get_terminal_locations()
    for i in range(len(dataset['observations'])):
        xy = dataset['observations'][i][:2]
        for terminal_location in terminal_locations:
            if np.linalg.norm(xy - terminal_location) <= 0.5:
                dataset['dones'][i] = True
                dataset['rewards'][i] = -1.0
                break
