import numpy as np
import matplotlib.pyplot as plt
import cv2
import gym


class GridWorldMDP(gym.Env):

    # up, right, down, left
    _direction_deltas = [
        (-1, 0),
        (0, 1),
        (1, 0),
        (0, -1),
    ]
    _num_actions = len(_direction_deltas)

    _episode_step = 0
    _current_state_index = (0, 0)

    def __init__(self,
                 start_state,
                 reward_grid,
                 terminal_mask,
                 obstacle_mask,
                 action_noise_probability=0.1,
                 no_action_probability=0.,
                 max_episode_steps=1000,
                 early_stop=False,
                 ):
        """
        GridWorld environment that supports the value iteration solver.
        However, do not use the associated solver in this class, which is only for testing.

        Args:
            start_state: start state coordinates.
            reward_grid: reward matrix (only dependent on state).
            terminal_mask: terminal matrix (only dependent on state).
            obstacle_mask: obstacle matrix (only dependent on state).
            action_noise_probability: action failure probability.
            no_action_probability: no op probability.
            max_episode_steps: max episode steps.
            early_stop: whether terminate when go to absorbing states.
        Example:
                shape = (4, 6)

                goal = (-1, -1)
                trap = (2, 3)
                start = (0, 0)
                default_reward = -0.1
                goal_reward = 1
                trap_reward = -1

                reward_grid = np.zeros(shape) + default_reward
                reward_grid[goal] = goal_reward
                reward_grid[obstacle] = 0
                reward_grid[trap] = trap_reward

                terminal_mask = np.zeros_like(reward_grid, dtype=np.bool)
                terminal_mask[goal] = True
                terminal_mask[trap] = True

                # no obstacle
                obstacle_mask = np.zeros_like(reward_grid, dtype=np.bool)

                gw = GridWorldMDP(reward_grid=reward_grid,
                                  obstacle_mask=obstacle_mask,
                                  terminal_mask=terminal_mask,
                                  action_noise_probability=0.1,
                                  no_action_probability=0.0,
                                  max_episode_steps=100,
                                  start_state=start)
        """

        self._start_state = start_state
        self._reward_grid = reward_grid
        self._terminal_mask = terminal_mask
        self._obstacle_mask = obstacle_mask
        self._max_episode_steps = max_episode_steps
        self._early_stop = early_stop

        action_probabilities = [
            (-1, action_noise_probability),
            (0, 1-2*action_noise_probability),
            (1, action_noise_probability)
        ]
        self._T = self._create_transition_matrix(
            action_probabilities,
            no_action_probability,
            obstacle_mask
        )

        self.observation_space = gym.spaces.Discrete(self.size)
        self.action_space = gym.spaces.Discrete(self._num_actions)

        self.reset()

    def reset(self):
        self._episode_step = 0
        self._current_state_index = self.grid_coordinates_to_indices(self._start_state)

    def step(self, action):
        self._episode_step += 1
        next_state, reward, terminal = self.generate_experience(self._current_state_index, action)
        self._current_state_index = next_state
        done = terminal if self._early_stop else False
        if self._episode_step >= self._max_episode_steps:
            done = True
        return next_state, reward, done, {'terminal': terminal}

    def render(self, mode='human'):
        raise NotImplementedError

    @property
    def shape(self):
        return self._reward_grid.shape

    @property
    def size(self):
        return self._reward_grid.size

    @property
    def reward_grid(self):
        return self._reward_grid

    @property
    def transition_probability(self):
        return self._T.copy()

    def run_value_iterations(self, discount=1.0,
                             iterations=10):
        utility_grids, policy_grids = self._init_utility_policy_storage(iterations)

        utility_grid = np.zeros_like(self._reward_grid)
        for i in range(iterations):
            utility_grid = self._value_iteration(utility_grid=utility_grid, discount=discount)
            policy_grids[:, :, i] = self.best_policy(utility_grid)
            utility_grids[:, :, i] = utility_grid
        return policy_grids, utility_grids

    def run_policy_iterations(self, discount=1.0,
                              iterations=10):
        utility_grids, policy_grids = self._init_utility_policy_storage(iterations)

        policy_grid = np.random.randint(0, self._num_actions,
                                        self.shape)
        utility_grid = self._reward_grid.copy()

        for i in range(iterations):
            policy_grid, utility_grid = self._policy_iteration(
                policy_grid=policy_grid,
                utility_grid=utility_grid
            )
            policy_grids[:, :, i] = policy_grid
            utility_grids[:, :, i] = utility_grid
        return policy_grids, utility_grids

    def generate_experience(self, current_state_idx, action_idx):
        sr, sc = self.grid_indices_to_coordinates(current_state_idx)
        next_state_probs = self._T[sr, sc, action_idx, :, :].flatten()

        next_state_idx = np.random.choice(np.arange(next_state_probs.size),
                                          p=next_state_probs)

        return (next_state_idx,
                self._reward_grid.flatten()[next_state_idx],
                self._terminal_mask.flatten()[next_state_idx])

    def grid_indices_to_coordinates(self, indices=None):
        if indices is None:
            indices = np.arange(self.size)
        return np.unravel_index(indices, self.shape)

    def grid_coordinates_to_indices(self, coordinates=None):
        # Annoyingly, this doesn't work for negative indices.
        # The mode='wrap' parameter only works on positive indices.
        if coordinates is None:
            return np.arange(self.size)
        return np.ravel_multi_index(coordinates, self.shape)

    def best_policy(self, utility_grid):
        M, N = self.shape
        return np.argmax(np.round(
            (utility_grid.reshape((1, 1, 1, M, N)) * self._T).sum(axis=-1).sum(axis=-1),
            decimals=4),
            axis=2)

    def _init_utility_policy_storage(self, depth):
        M, N = self.shape
        utility_grids = np.zeros((M, N, depth))
        policy_grids = np.zeros_like(utility_grids)
        return utility_grids, policy_grids

    def _create_transition_matrix(self,
                                  action_probabilities,
                                  no_action_probability,
                                  obstacle_mask):
        M, N = self.shape

        T = np.zeros((M, N, self._num_actions, M, N))

        r0, c0 = self.grid_indices_to_coordinates()

        T[r0, c0, :, r0, c0] += no_action_probability

        for action in range(self._num_actions):
            for offset, P in action_probabilities:
                direction = (action + offset) % self._num_actions

                dr, dc = self._direction_deltas[direction]
                r1 = np.clip(r0 + dr, 0, M - 1)
                c1 = np.clip(c0 + dc, 0, N - 1)

                temp_mask = obstacle_mask[r1, c1].flatten()
                r1[temp_mask] = r0[temp_mask]
                c1[temp_mask] = c0[temp_mask]

                T[r0, c0, action, r1, c1] += P

        terminal_locs = np.where(self._terminal_mask.flatten())[0]

        T[r0[terminal_locs], c0[terminal_locs], :, :, :] = 0
        T[r0[terminal_locs], c0[terminal_locs], :, r0[terminal_locs], c0[terminal_locs]] = 1.
        return T

    def _value_iteration(self, utility_grid, discount=1.0):
        out = np.zeros_like(utility_grid)
        M, N = self.shape
        for i in range(M):
            for j in range(N):
                out[i, j] = self._calculate_utility((i, j),
                                                    discount,
                                                    utility_grid)
        return out

    def _policy_iteration(self, *, utility_grid,
                          policy_grid, discount=1.0):
        r, c = self.grid_indices_to_coordinates()

        M, N = self.shape

        utility_grid = (
            self._reward_grid +
            discount * ((utility_grid.reshape((1, 1, 1, M, N)) * self._T)
                        .sum(axis=-1).sum(axis=-1))[r, c, policy_grid.flatten()]
            .reshape(self.shape)
        )

        utility_grid[self._terminal_mask] = self._reward_grid[self._terminal_mask]

        return self.best_policy(utility_grid), utility_grid

    def _calculate_utility(self, loc, discount, utility_grid):
        if self._terminal_mask[loc]:
            return self._reward_grid[loc]
        row, col = loc
        return np.max(
            discount * np.sum(
                np.sum(self._T[row, col, :, :, :] * utility_grid,
                       axis=-1),
                axis=-1)
        ) + self._reward_grid[loc]

    def plot_policy(self, utility_grid, policy_grid=None):
        if policy_grid is None:
            policy_grid = self.best_policy(utility_grid)
        markers = "^>v<"
        marker_size = 200 // np.max(policy_grid.shape)
        marker_edge_width = marker_size // 10
        marker_fill_color = 'w'

        no_action_mask = self._terminal_mask | self._obstacle_mask

        utility_normalized = (utility_grid - utility_grid.min()) / \
                             (utility_grid.max() - utility_grid.min())

        utility_normalized = (255*utility_normalized).astype(np.uint8)

        utility_rgb = cv2.applyColorMap(utility_normalized, cv2.COLORMAP_JET)
        for i in range(3):
            channel = utility_rgb[:, :, i]
            channel[self._obstacle_mask] = 0

        plt.imshow(utility_rgb[:, :, ::-1], interpolation='none')

        for i, marker in enumerate(markers):
            y, x = np.where((policy_grid == i) & np.logical_not(no_action_mask))
            plt.plot(x, y, marker, ms=marker_size, mew=marker_edge_width,
                     color=marker_fill_color)

        y, x = np.where(self._terminal_mask)
        plt.plot(x, y, 'o', ms=marker_size, mew=marker_edge_width,
                 color=marker_fill_color)

        tick_step_options = np.array([1, 2, 5, 10, 20, 50, 100])
        tick_step = np.max(policy_grid.shape)/8
        best_option = np.argmin(np.abs(np.log(tick_step) - np.log(tick_step_options)))
        tick_step = tick_step_options[best_option]
        plt.xticks(np.arange(0, policy_grid.shape[1] - 0.5, tick_step))
        plt.yticks(np.arange(0, policy_grid.shape[0] - 0.5, tick_step))
        plt.xlim([-0.5, policy_grid.shape[0]-0.5])
        plt.xlim([-0.5, policy_grid.shape[1]-0.5])

    def get_transition_matrix(self):
        """
        Transform transition probability into the standard form.
        Returns:
            transition_matrix with shape of (S, A, S).
        """
        transition_matrix = np.zeros([self.size, self._num_actions, self.size])
        for i in range(self.size):
            for a in range(self._num_actions):
                for j in range(self.size):
                    row0, col0 = self.grid_indices_to_coordinates(i)
                    row1, col1 = self.grid_indices_to_coordinates(j)
                    transition_matrix[i, a, j] = self._T[row0, col0, a, row1, col1]
        return transition_matrix

    def get_reward_matrix(self):
        """
        Transform reward func into the standard form.
        Returns:
            reward_matrix with shape of (S, A).
        """
        reward_matrix = np.zeros(self.size, dtype=self._reward_grid.dtype)
        for i in range(self.size):
            row, col = self.grid_indices_to_coordinates(i)
            reward_matrix[i] = self._reward_grid[row, col]
        return reward_matrix

    def get_terminal_matrix(self):
        """
        Transform terminal func into the standard form.
        Returns:
            terminal_matrix with shape of (S,).
        """
        terminal_matrix = np.zeros(self.size, dtype=self._terminal_mask.dtype)
        for i in range(self.size):
            row, col = self.grid_indices_to_coordinates(i)
            terminal_matrix[i] = self._terminal_mask[row, col]
        return terminal_matrix


class GridWorldSolver:
    def __init__(self,
                 reward_grid,
                 terminal_mask,
                 transition_probability,
                 only_state_reward,
                 gamma=1.0,
                 early_stop=False,
                 ):
        """
        Value Iteration Solver for GridWorld
        Args:
            reward_grid: reward function. Maybe the shape of [M, N] (state only) or [M, N, A] (state-action)
            terminal_mask: terminal function. Shape of [M, N].
            transition_probability: transition matrix. Shape of [M, N, A, M, N]
            only_state_reward: whether the reward only depends on state.
            gamma: discount factor.
            early_stop: whether assign 0 reward for terminal state.
        """
        self._terminal_mask = terminal_mask
        self._T = transition_probability
        self._discount = gamma
        self._shape = transition_probability.shape[:2]
        self._size = np.product(self._shape)
        self._num_actions = transition_probability.shape[2]
        self._early_stop = early_stop

        if only_state_reward:
            # [M, N] -> [M, N, A]
            self._reward_grid = np.repeat(reward_grid[..., None], axis=-1, repeats=self._num_actions)
        else:
            # shape of [M, N, A]
            self._reward_grid = reward_grid

        assert self._reward_grid.shape == (*self._shape, self._num_actions), "reward_grid.shape:{}".format(reward_grid.shape)
        assert self._terminal_mask.shape == self._shape

    def run_value_iterations(self, iterations, discount=None):
        if discount is None:
            discount = self._discount
        utility_grids, policy_grids = self._init_utility_policy_storage(iterations)

        # utility_grid = np.zeros_like(self._reward_grid)
        utility_grid = np.zeros(self._shape, dtype=utility_grids.dtype)
        for i in range(iterations):
            utility_grid = self._value_iteration(utility_grid=utility_grid, discount=discount)
            policy_grids[:, :, i] = self._get_best_policy(utility_grid, discount=discount)
            utility_grids[:, :, i] = utility_grid
        return policy_grids, utility_grids

    def _init_utility_policy_storage(self, depth):
        M, N = self._shape
        utility_grids = np.zeros((M, N, depth))
        policy_grids = np.zeros_like(utility_grids)
        return utility_grids, policy_grids

    def _value_iteration(self, utility_grid, discount=1.0):
        value_out = np.zeros_like(utility_grid)
        M, N = self._shape
        for i in range(M):
            for j in range(N):
                value_out[i, j] = self._calculate_utility(
                    (i, j), discount, utility_grid)
        return value_out

    def _calculate_utility(self, loc, discount, utility_grid):
        row, col = loc
        if self._early_stop and self._terminal_mask[loc]:
            assert np.all(self._reward_grid[row, col, 0] == self._reward_grid[row, col])
            return self._reward_grid[row, col, 0]
        assert utility_grid.shape == self._shape

        best_utility = np.max(
            self._reward_grid[row, col] + discount * np.sum(np.sum(
                self._T[row, col, :, :, :] * utility_grid,
                axis=-1), axis=-1)
        )
        return best_utility

    def _get_best_policy(self, utility_grid, discount):
        M, N = self._shape
        out = np.argmax(np.round(
            self._reward_grid + discount * (
                    utility_grid.reshape((1, 1, 1, M, N)) * self._T
            ).sum(axis=-1).sum(axis=-1), decimals=4),
            axis=2)
        return out

    def grid_indices_to_coordinates(self, indices=None):
        if indices is None:
            indices = np.arange(self._size)
        return np.unravel_index(indices, self._shape)

    def grid_coordinates_to_indices(self, coordinates=None):
        # Annoyingly, this doesn't work for negative indices.
        # The mode='wrap' parameter only works on positive indices.
        if coordinates is None:
            return np.arange(self._size)
        return np.ravel_multi_index(coordinates, self._shape)


def test():
    shape = (5, 3)
    goal = (-1, -1)
    trap = (-1, -2)
    obstacle = (0, 1)
    start = (0, 0)
    default_reward = -0.1
    goal_reward = 1
    trap_reward = -1

    reward_grid = np.zeros(shape) + default_reward
    reward_grid[goal] = goal_reward
    reward_grid[obstacle] = 0
    reward_grid[trap] = trap_reward

    terminal_mask = np.zeros_like(reward_grid, dtype=np.bool)
    terminal_mask[goal] = True
    terminal_mask[trap] = True

    obstacle_mask = np.zeros_like(reward_grid, dtype=np.bool)
    obstacle_mask[obstacle] = True

    gw = GridWorldMDP(reward_grid=reward_grid,
                      obstacle_mask=obstacle_mask,
                      terminal_mask=terminal_mask,
                      action_noise_probability=0.1,
                      no_action_probability=0.0,
                      max_episode_steps=100,
                      start_state=start)
    solver = GridWorldSolver(
        reward_grid=reward_grid,
        terminal_mask=terminal_mask,
        transition_probability=gw.transition_probability,
        only_state_reward=True,
        gamma=1.,
        early_stop=True,
    )
    policy_real, value_real = solver.run_value_iterations(iterations=10)

    policy_ref, value_ref = gw.run_value_iterations(discount=1., iterations=10)

    np.testing.assert_allclose(value_real, value_ref)
    try:
        np.testing.assert_allclose(policy_real[..., -1], policy_ref[..., -1])
    except AssertionError as e:
        print(e)
        import ipdb
        ipdb.set_trace()

    # test env
    ac_space = gw.action_space
    for _ in range(3):
        rewards = 0
        gw.reset()
        for i in range(100):
            _, reward, done, _ = gw.step(ac_space.sample())
            rewards += reward
            if done:
                break
        print("episode return: %.4f" % rewards)


def generate_cliff_chain(length=5):
    """
    Cliff chain example.

    state:  1 -> 2 -> 3 -> 4 ->5
            |    |    |    |   |
            6    7    8    9   10

    where {1, 2, 3, 4, 5} earn a reward of +1, while {6, 7, 8, 9, 10} are absorbing state earn a reward of -1.
    Args:
        length: Length.
    Returns:
    """
    shape = (2, length)
    start = (0, 0)

    reward_grid = np.zeros(shape)
    reward_grid[0, -1] = 1.
    print(reward_grid.size)

    terminal_mask = np.zeros_like(reward_grid, dtype=np.bool)
    terminal_mask[0, -1] = True
    terminal_mask[1, :] = True

    obstacle_mask = np.zeros_like(reward_grid, dtype=np.bool)

    gw = GridWorldMDP(reward_grid=reward_grid,
                      obstacle_mask=obstacle_mask,
                      terminal_mask=terminal_mask,
                      action_noise_probability=0.0,
                      no_action_probability=0.0,
                      max_episode_steps=100,
                      start_state=start)
    solver = GridWorldSolver(
        reward_grid=reward_grid,
        terminal_mask=terminal_mask,
        transition_probability=gw.transition_probability,
        only_state_reward=True,
        gamma=0.99,
        early_stop=False,
    )

    policy, value = solver.run_value_iterations(iterations=100)
    print("=============[Policy]=============")
    print(policy[..., -1])
    print("=============[Value]=============")
    print(value[..., -1])


if __name__ == "__main__":
    test()





