from itertools import product

import minigrid
import numpy as np
from gymnasium.vector import SyncVectorEnv
from matplotlib import pyplot as plt
from matplotlib.patches import Arrow, Rectangle
from minigrid.minigrid_env import MiniGridEnv
from tqdm import tqdm, trange


def plot_minigrid_with_values(envs: SyncVectorEnv, V: np.ndarray, policy: np.ndarray) -> tuple[plt.Figure, plt.Axes]:
    """Custom rendering of the MiniGrid environment with an overlay of the value function."""

    env = envs.envs[0].unwrapped

    fig, ax = plt.subplots(figsize=(8, 8))  # (env.width, env.height))
    ax.set_xlim(0, env.width)
    ax.set_ylim(0, env.height)
    ax.set_xticks([])
    ax.set_yticks([])
    ax.set_frame_on(False)

    # minigrid: 0 -> right, 1 -> down, 2 -> left, 3 -> up
    action_arrows = {3: (0, 0.3), 0: (0.3, 0), 1: (0, -0.3), 2: (-0.3, 0)}  # Up, Right, Down, Left

    # env.reset(seed=seed)
    for x in range(env.width):
        for y in range(env.height):
            tile = env.grid.get(x, y)
            if isinstance(tile, minigrid.core.world_object.Wall):
                color = "gray"
                ax.add_patch(Rectangle((x, env.height - y - 1), 1, 1, color=color))
            elif isinstance(tile, minigrid.core.world_object.Lava):
                color = "red"
                ax.add_patch(Rectangle((x, env.height - y - 1), 1, 1, color=color))
            elif isinstance(tile, minigrid.core.world_object.Goal):
                color = "green"
                ax.add_patch(Rectangle((x, env.height - y - 1), 1, 1, color=color))
            elif tile is None:
                pass
            else:
                raise ValueError("Tile type: ", str(tile))

            state_idx = (x * env.height + y) * 4  # Only considering one orientation for visualization
            value = np.mean(V[state_idx : state_idx + 4])  # Averaging over orientations
            ax.text(
                x + 0.5,
                env.height - y - 0.5,
                f"{value:.2f}",
                color="black",
                fontsize=8,
                ha="center",
                va="center",
                bbox=dict(facecolor="white", alpha=0.7, edgecolor="none"),
            )

            # select direction with highest value
            dir = np.argmax(V[state_idx : state_idx + 4])
            best_action = dir
            # best_action = policy[state_idx]  # Optimal action for this state
            if best_action in action_arrows:
                dx, dy = action_arrows[best_action]
                ax.add_patch(Arrow(x + 0.5, env.height - y - 0.5, dx, dy, width=0.1, color="blue"))

    # ax.add_patch(Rectangle((env.agent_pos[0], env.height - env.agent_pos[1] - 1), 1, 1, color='blue', alpha=0.5))
    return fig, ax


def state_to_idx(env: MiniGridEnv, x: int, y: int, orientation: int) -> int:
    return (x * env.unwrapped.height + y) * 4 + orientation


def get_values_and_policy(envs: SyncVectorEnv, gamma: float, epsilon: float, max_steps: int) -> tuple[np.ndarray, np.ndarray]:
    env = envs.envs[0]
    # Assumes an Empty-like environment to upper-bound the number of states
    env.unwrapped.max_steps = env.unwrapped.width * env.unwrapped.height * 4 * env.action_space.n + 1
    T, R = get_transition_reward_matrices(env)
    V, pi = value_iteration(T, R, gamma, epsilon, max_steps)

    return V, pi


def get_transition_reward_matrices(env: MiniGridEnv) -> tuple[np.ndarray, np.ndarray]:
    """
    Computes the transition matrix and reward vector for a minigrid environment
    """
    state_space_size = env.unwrapped.width * env.unwrapped.height * 4 + 1  # 4 orientations; extra state for terminal state
    action_space_size = env.action_space.n
    terminal_state_idx = state_space_size - 1

    transition_matrix = np.zeros((state_space_size, state_space_size, action_space_size))
    reward_vector = np.zeros((state_space_size, action_space_size))

    # Enumerate the states with 4 possible orientations
    state_list = list(product(range(env.unwrapped.width), range(env.unwrapped.height), range(4)))
    for x, y, orientation in tqdm(state_list, "Building transition matrix", leave=False):
        state_idx = state_to_idx(env, x, y, orientation)

        for action in range(action_space_size):
            env.step_count = 0

            tile = env.unwrapped.grid.get(x, y)
            if isinstance(tile, minigrid.core.world_object.Wall):
                # if we are in the wall we go to a terminal state
                new_state_idx = terminal_state_idx
                transition_matrix[state_idx, new_state_idx, action] = 1
                reward_vector[state_idx, action] = 0
                continue
            elif isinstance(tile, minigrid.core.world_object.Lava) and "LavaNotDeadWrapper" not in str(env):
                # Scenario where the agent is in lava and lava kills the agent
                # go to a terminal state and assign a reward depending on the wrapper
                new_state_idx = terminal_state_idx
                transition_matrix[state_idx, new_state_idx, action] = 1
                reward = 0
                if "LavaNegativeRewardWrapper" in str(env):
                    reward = env.get_wrapper_attr("lava_penalty")
                reward_vector[state_idx, action] = reward
                continue

            env.unwrapped.agent_pos = (x, y)
            env.unwrapped.agent_dir = orientation
            env.unwrapped.step_count = 0
            _, reward, done, *_ = env.step(action)

            new_x, new_y = env.unwrapped.agent_pos
            new_orientation = env.unwrapped.agent_dir
            if done:
                new_state_idx = terminal_state_idx
            else:
                new_state_idx = (new_x * env.unwrapped.height + new_y) * 4 + new_orientation  # * 2 + new_has_key

            transition_matrix[state_idx, new_state_idx, action] = 1.0
            reward_vector[state_idx, action] = reward

    return transition_matrix, reward_vector


def value_iteration(T: np.ndarray, R: np.ndarray, gamma: float, epsilon: float, max_iter: int):
    """
    Performs value iteration to find the optimal value function.

    Parameters:
    - T: Transition matrix of shape (S, S*A), where each row contains transition probabilities.
    - R: Reward vector of shape (S, A).
    - gamma: Discount factor.
    - theta: Convergence threshold.
    - max_iter: Maximum number of iterations.

    Returns:
    - V: Optimal value function.
    - policy: Optimal policy mapping states to actions.
    """
    num_states, num_actions = R.shape
    V = np.zeros(num_states)  # Initialize value function
    policy = np.zeros(num_states, dtype=int)  # Initialize policy

    for _ in trange(0, max_iter, desc="Calculating optimal V", leave=True):
        delta = 0
        for s in range(num_states):
            action_values = np.zeros(num_actions)
            for a in range(num_actions):
                next_state_probs = T[s, :, a]  # Transition probabilities
                expected_value = next_state_probs @ V  # Expected future value
                action_values[a] = R[s, a] + gamma * expected_value

            new_value = np.max(action_values)
            delta = max(delta, abs(V[s] - new_value))
            V[s] = new_value
            policy[s] = np.argmax(action_values)

        if delta < epsilon:  # Convergence check
            tqdm.write(f"Converged after {_ + 1} iterations with delta {delta:.6f}")
            break

    return V, policy
