import json
from typing import Tuple, Literal, List, Union


# defining a bunch of types to make the code more readable
State = Tuple[int, int]
Action = Literal["left", "right", "up", "down"]
actions: List[Action] = ["left", "right", "up", "down"]
Policy = List[List[Union[List[Action], Literal["TERM"]]]]
StateValue = List[List[float]]

# size of the gridworld; remains constant
SIZE = 8


def init_policy() -> Policy:
    """
    Initializes the policy for the gridworld problem.
    """
    cols: List[Union[List[Action], Literal["TERM"]]] = [actions] * SIZE
    rows = [cols] * SIZE
    # copy and reassign (hacky)
    copy = json.dumps(rows)
    rows = json.loads(copy)
    # set terminals
    rows[0][0] = "TERM"
    rows[SIZE-1][SIZE-1] = "TERM"

    return rows


def init_state_value() -> StateValue:
    """
    Initializes the state value for the gridworld problem.
    """
    cols: List[float] = [0.0] * SIZE
    rows = [cols] * SIZE
    # copy and reassign (hacky)
    copy = json.dumps(rows)
    rows = json.loads(copy)
    return rows


def next_state(s: State, a: Action) -> State:
    """
    Produces the next state from the current state and action.
    Takes account of the boundaries of the gridworld.
    """
    i, j = s
    i_next = i
    j_next = j
    if a == "left":
        j_next = max(0, j_next - 1)
    elif a == "right":
        j_next = min(SIZE-1, j_next + 1)
    elif a == "up":
        i_next = max(0, i_next - 1)
    elif a == "down":
        i_next = min(SIZE-1, i_next + 1)

    return (i_next, j_next)


def value_iteration(p: Policy, v: StateValue, theta: float):
    """
    Runs value iteration to find the optimal policy and state value.
    The policy and state value are updated in place. Theta controls the
    convergence of the algorithm, where the algorithm stops when the
    maximum change in the state value is less than theta.
    """
    while True:
        delta = 0

        for i, row in enumerate(p):
            for j, col in enumerate(row):
                s = (i, j)
                u = v[i][j]
                if col != "TERM":
                    max_a_val = float("-inf")
                    for a in actions:
                        s_next = next_state(s, a)
                        i_next, j_next = s_next
                        r = -1
                        scaled = r + v[i_next][j_next]
                        if scaled > max_a_val:
                            max_a_val = scaled

                    v[i][j] = max_a_val

                delta = max(delta, abs(u - v[i][j]))

        if delta < theta:
            break

    for i, row in enumerate(p):
        for j, col in enumerate(row):
            s = (i, j)
            if col != "TERM":
                max_a: List[Action] = []
                max_a_val = float("-inf")
                for a in actions:
                    s_next = next_state(s, a)
                    i_next, j_next = s_next
                    r = -1
                    scaled = r + v[i_next][j_next]
                    if scaled > max_a_val:
                        max_a_val = scaled
                        max_a = [a]
                    elif scaled == max_a_val:
                        max_a.append(a)
                p[i][j] = max_a


def policy_str(p: Policy):
    buf = ""
    for row in p:
        s_row = ""
        for col in row:
            shorted = ""
            if col == "TERM":
                shorted = str(col)
            else:
                for action in col:
                    shorted += action[0].upper()
            shorted += " " * max(6 - len(shorted), 0)
            s_row += shorted + " | "
        buf += s_row.rstrip("| ") + "\n"

    return buf


# value iteration
policy = init_policy()
state_value = init_state_value()
value_iteration(policy, state_value, 0.001)