import numpy as np
from ..utils import check_random_state


# Maze state is represented as a 2-element NumPy array: (Y, X). Increasing Y is South.

# Possible actions, expressed as (delta-y, delta-x).
maze_actions = {
    'N': np.array([-1, 0]),
    'S': np.array([1, 0]),
    'E': np.array([0, 1]),
    'W': np.array([0, -1]),
}

def parse_topology(topology):
    return np.array([list(row) for row in topology])


class Maze(object):
    """
    Simple wrapper around a NumPy 2D array to handle flattened indexing and staying in bounds.
    """
    def __init__(self, topology):
        self.topology = parse_topology(topology)
        self.flat_topology = self.topology.ravel()
        self.shape = self.topology.shape

    def in_bounds_flat(self, position):
        return 0 <= position < np.prod(self.shape)

    def in_bounds_unflat(self, position):
        return 0 <= position[0] < self.shape[0] and 0 <= position[1] < self.shape[1]

    def get_flat(self, position):
        if not self.in_bounds_flat(position):
            raise IndexError("Position out of bounds: {}".format(position))
        return self.flat_topology[position]

    def get_unflat(self, position):
        if not self.in_bounds_unflat(position):
            raise IndexError("Position out of bounds: {}".format(position))
        return self.topology[tuple(position)]

    def flatten_index(self, index_tuple):
        return np.ravel_multi_index(index_tuple, self.shape)

    def unflatten_index(self, flattened_index):
        return np.unravel_index(flattened_index, self.shape)

    def flat_positions_containing(self, x):
        return list(np.nonzero(self.flat_topology == x)[0])

    def flat_positions_not_containing(self, x):
        return list(np.nonzero(self.flat_topology != x)[0])

    def __str__(self):
        return '\n'.join(''.join(row) for row in self.topology.tolist())

    def __repr__(self):
        return 'Maze({})'.format(repr(self.topology.tolist()))


def move_avoiding_walls(maze, position, action):
    """
    Return the new position after moving, and the event that happened ('hit-wall' or 'moved').

    Works with the position and action as a (row, column) array.
    """
    # Compute new position
    new_position = position + action

    # Compute collisions with walls, including implicit walls at the ends of the world.
    if not maze.in_bounds_unflat(new_position) or maze.get_unflat(new_position) == '#':
        return position, 'hit-wall'

    return new_position, 'moved'



class GridWorld(object):
    """
    A simple task in a maze: get to the goal.

    Parameters
    ----------

    maze : list of strings or lists
        maze topology (see below)

    rewards: dict of string to number. default: {'*': 10}.
        Rewards obtained by being in a maze grid with the specified contents,
        or experiencing the specified event (either 'hit-wall' or 'moved'). The
        contributions of content reward and event reward are summed. For
        example, you might specify a cost for moving by passing
        rewards={'*': 10, 'moved': -1}.

    terminal_markers: sequence of chars, default '*'
        A grid cell containing any of these markers will be considered a
        "terminal" state.

    action_error_prob: float
        With this probability, the requested action is ignored and a random
        action is chosen instead.

    random_state: None, int, or RandomState object
        For repeatable experiments, you can pass a random state here. See
        http://scikit-learn.org/stable/modules/generated/sklearn.utils.check_random_state.html

    Notes
    -----

    Maze topology is expressed textually. Key:
     '#': wall
     '.': open (really, anything that's not '#')
     '*': goal
     'o': origin
    """

    def __init__(self, maze, rewards={'*': 10}, terminal_markers='*', action_error_prob=0, random_state=None, directions="NSEW"):

        self.maze = Maze(maze) if not isinstance(maze, Maze) else maze
        self.rewards = rewards
        self.terminal_markers = terminal_markers
        self.action_error_prob = action_error_prob
        self.random_state = check_random_state(random_state)

        self.actions = [maze_actions[direction] for direction in directions]
        self.num_actions = len(self.actions)
        self.state = None
        self.reset()
        self.num_states = self.maze.shape[0] * self.maze.shape[1]
        self.max_reward = None
        self.goal_reward = None
        self.deterministic_goal_optimal = None

    def __repr__(self):
        return 'GridWorld(maze={maze!r}, rewards={rewards}, terminal_markers={terminal_markers}, action_error_prob={action_error_prob})'.format(**self.__dict__)

    def get_max_reward(self):
        if self.max_reward is None:
            rewards = [v[0] + 3 * v[1] if isinstance(v, tuple) else v for k, v in self.rewards.items()]
            self.max_reward = max(rewards)
        return self.max_reward

    @property
    def is_deterministic_goal_optimal(self):
        if self.deterministic_goal_optimal is None:
            stochastic = max([v[0] for k, v in self.rewards.items() if isinstance(v, tuple)])
            deterministic = max([v for k, v in self.rewards.items() if not isinstance(v, tuple)])
            self.deterministic_goal_optimal = deterministic >= stochastic

        return self.deterministic_goal_optimal

    def get_goal_reward(self):
        if self.goal_reward is None:
            rewards = [v[0] + 3 * v[1] if isinstance(v, tuple) else v for k, v in self.rewards.items()]
            self.goal_reward = np.unique(rewards)[-2]
            # self.goal_reward = np.unique(rewards)[-1]
        return self.goal_reward

    def reset(self):
        """
        Reset the position to a starting position (an 'o'), chosen at random.
        """
        options = self.maze.flat_positions_containing('o')
        self.state = options[self.random_state.choice(len(options))]
        return self.state

    def is_terminal(self, state):
        """Check if the given state is a terminal state."""
        return self.maze.get_flat(state) in self.terminal_markers

    def observe(self):
        """
        Return the current state as an integer.

        The state is the index into the flattened maze.
        """
        return self.state

    def step(self, action_idx):
        """Perform an action (specified by index), yielding a new state and reward."""
        # In the absorbing end state, nothing does anything.
        if self.is_terminal(self.state):
            return self.observe(), 0

        if self.action_error_prob and self.random_state.rand() < self.action_error_prob:
            action_idx = self.random_state.choice(self.num_actions)
        action = self.actions[action_idx]
        new_state_tuple, result = move_avoiding_walls(self.maze, self.maze.unflatten_index(self.state), action)
        self.state = self.maze.flatten_index(new_state_tuple)

        destination_reward = self.rewards.get(self.maze.get_flat(self.state), 0)
        if isinstance(destination_reward, tuple):
            destination_reward = np.random.normal(*destination_reward)

        reward = destination_reward + self.rewards.get(result, 0)
        done = self.is_terminal(self.state)
        info = {}
        return self.observe(), reward, done, info

    def as_mdp(self):
        transition_probabilities = np.zeros((self.num_states, self.num_actions, self.num_states))
        rewards = np.zeros((self.num_states, self.num_actions, self.num_states))
        action_rewards = np.zeros((self.num_states, self.num_actions))
        destination_rewards = np.zeros(self.num_states)

        for state in range(self.num_states):
            destination_rewards[state] = self.rewards.get(self.maze.get_flat(state), 0)

        is_terminal_state = np.zeros(self.num_states, dtype=np.bool)

        for state in range(self.num_states):
            if self.is_terminal(state):
                is_terminal_state[state] = True
                transition_probabilities[state, :, state] = 1.
            else:
                for action in range(self.num_actions):
                    new_state_tuple, result = move_avoiding_walls(self.maze, self.maze.unflatten_index(state), self.actions[action])
                    new_state = self.maze.flatten_index(new_state_tuple)
                    transition_probabilities[state, action, new_state] = 1.
                    action_rewards[state, action] = self.rewards.get(result, 0)

        # Now account for action noise.
        transitions_given_random_action = transition_probabilities.mean(axis=1, keepdims=True)
        transition_probabilities *= (1 - self.action_error_prob)
        transition_probabilities += self.action_error_prob * transitions_given_random_action

        rewards_given_random_action = action_rewards.mean(axis=1, keepdims=True)
        action_rewards = (1 - self.action_error_prob) * action_rewards + self.action_error_prob * rewards_given_random_action
        rewards = action_rewards[:, :, None] + destination_rewards[None, None, :]
        rewards[is_terminal_state] = 0

        return transition_probabilities, rewards

    ### Old API, where terminal states were None.

    def observe_old(self):
        return None if self.is_terminal(self.state) else self.state

    def step_old(self, action_idx):
        new_state, reward = self.step(action_idx)
        if self.is_terminal(new_state):
            return None, reward
        else:
            return new_state, reward


    samples = {
        'trivial': [
            '###',
            '#o#',
            '#.#',
            '#*#',
            '###'],

        'larger': [
            '#########',
            '#..#....#',
            '#..#..#.#',
            '#..#..#.#',
            '#..#.##.#',
            '#....*#.#',
            '#######.#',
            '#o......#',
            '#########'],

        'bipolar': [
            '#########',
            '#$......#',
            '#.......#',
            '#.......#',
            '#.......#',
            '#.......#',
            '#.......#',
            '#o.....*#',
            '#########']
    }


def construct_cliff_task(width, height, goal_reward=50, move_reward=-1, cliff_reward=-100, **kw):
    """
    Construct a 'cliff' task, a GridWorld with a "cliff" between the start and
    goal. Falling off the cliff gives a large negative reward and ends the
    episode.

    Any other parameters, like action_error_prob, are passed on to the
    GridWorld constructor.
    """

    maze = ['.' * width] * (height - 1)  # middle empty region
    maze.append('o' + 'X' * (width - 2) + '*') # bottom goal row

    rewards = {
        '*': goal_reward,
        'moved': move_reward,
        'hit-wall': move_reward,
        'X': cliff_reward
    }

    return GridWorld(maze, rewards=rewards, terminal_markers='*X', **kw)
