"""
Grid-World Markov Decision Processes (MDPs).

The MDPs in this module are actually not complete MDPs, but rather the
sub-part of an MDP containing states, actions, and transitions (including
their probabilistic character). Reward-function and terminal-states are
supplied separately.

Some general remarks:
    - Edges act as barriers, i.e. if an agent takes an action that would cross
    an edge, the state will not change.

    - Actions are not restricted to specific states. Any action can be taken
    in any state and have a unique inteded outcome. The result of an action
    can be stochastic, but there is always exactly one that can be described
    as the intended result of the action.
"""

from matplotlib import pyplot as plt
import numpy as np
from itertools import product
from mpl_toolkits.axes_grid1 import make_axes_locatable


class SimpleGridWorld:
    def is_terminal(self, s):
        return s == self.num_states - 1

    def __init__(self, size, debug=False) -> None:
        self.size = size
        self.num_states = size**2
        self.num_actions = 4
        self.rewards = np.zeros((self.num_states, 4))
        self.transitions = np.zeros((self.num_states, 4, self.num_states))

        for state in range(self.num_states):
            for action in range(4):
                if self.is_terminal(state):
                    self.rewards[state, action] = 1
                    self.transitions[state, action, state] = 1
                else:
                    next_state = self.step(state, action)
                    # if is_terminal(next_state):
                    #     self.rewards[state, action] = 0
                    # else:
                    self.rewards[state, action] = -1

                    self.transitions[state, action, next_state] = 1

        if debug:
            for state in range(self.num_states):
                print(f"State {self.index_to_point(state)}:")

                print("Transitions:")
                print(["Up", "Right", "Down", "Left"])
                print(
                    [
                        self.index_to_point(i)
                        for i in np.argmax(self.transitions[state, :, :], axis=1)
                    ]
                )
                print(self.rewards[state, :])
                print()

    def step(self, state, action):
        y, x = self.index_to_point(state)
        # print(x, y, action)
        if action == 0:
            y = max(0, y - 1)
        elif action == 1:
            x = min(self.size - 1, x + 1)
        elif action == 2:
            y = min(self.size - 1, y + 1)
        elif action == 3:
            x = max(0, x - 1)

        # print(x, y)
        # print()
        return self.point_to_index((x, y))

    def index_to_point(self, index):
        """
        Convert a state index to the coordinate representing it.

        Args:
            state: Integer representing the state."""
        return (index // self.size, index % self.size)

    def point_to_index(self, point):
        """
        Convert a state coordinate to the index representing it.

        Note:
            Does not check if coordinates lie outside of the world.

        Args:
            state: Tuple of integers representing the state.

        Returns:
            The index as integer representing the same state as the given
            coordinate.
        """
        return point[0] + point[1] * self.size


def plot_state_action_function(
    world, sa_function, title="Untitled", border=None, **kwargs
):
    """
    Plot a stochastic sa_function.

    Args:
        ax: The matplotlib Axes instance used for plotting.
        world: The GridWorld for which the sa_function should be plotted.
        sa_function: The stochastic sa_function to be plotted as table
            `[state: Index, action: Index] -> probability: Float`
            representing the probability p(action | state) of an action
            given a state.
        border: A map containing styling information regarding the
            state-action borders. All key-value pairs are directly forwarded
            to `pyplot.triplot`.

        All further key-value arguments will be forwarded to
        `pyplot.tripcolor`.

    """
    plt.clf()
    fig = plt.figure()
    ax = fig.add_subplot(121)
    ax.title.set_text(title)
    divider = make_axes_locatable(ax)
    cax = divider.append_axes("right", size="5%", pad=0.05)

    xy = [
        (x - 0.5, y - 0.5)
        for y, x in product(range(world.size + 1), range(world.size + 1))
    ]
    xy += [(x, y) for y, x in product(range(world.size), range(world.size))]

    t, v = [], []
    for sy, sx in product(range(world.size), range(world.size)):
        # sy = 4-sy
        state = world.point_to_index((sx, sy))

        # compute cell points
        bl, br = sy * (world.size + 1) + sx, sy * (world.size + 1) + sx + 1
        tl, tr = (sy + 1) * (world.size + 1) + sx, (sy + 1) * (world.size + 1) + sx + 1
        cc = (world.size + 1) ** 2 + sy * world.size + sx

        # compute triangles
        t += [(tr, cc, br)]  # action = (1, 0)
        t += [(tl, bl, cc)]  # action = (-1, 0)
        t += [(tl, cc, tr)]  # action = (0, 1)
        t += [(bl, br, cc)]  # action = (0, -1)

        # stack triangle values
        # v += [sa_function[state, 1]]  # action = (1, 0)
        # v += [sa_function[state, 2]]  # action = (-1, 0)
        # v += [sa_function[state, 0]]  # action = (0, 1)
        # v += [sa_function[state, 3]]  # action = (0, -1)

        v += [sa_function[state, 1]]  # action = (1, 0)
        v += [sa_function[state, 0]]  # action = (-1, 0) # since the orientation is flipped
        v += [sa_function[state, 2]]# action = (0, 1)
        v += [sa_function[state, 3]]  # action = (0, -1)

    x, y = zip(*xy)
    x, y = np.array(x), np.array(y)
    t, v = np.array(t), np.array(v)

    ax.set_aspect("equal")
    ax.set_xticks(range(world.size))
    ax.set_yticks(range(world.size))
    ax.set_xlim(-0.5, world.size - 0.5)
    ax.set_ylim(-0.5, world.size - 0.5)
    ax.invert_yaxis()

    p = ax.tripcolor(x, y, t, facecolors=v, **kwargs)

    if border is not None:
        ax.triplot(x, y, t, **border)

    fig.colorbar(p, cax=cax)
    return p
