"""Visualization utilities for MiniGrid-based gridworld experiments.

Provides 2-D/3-D heatmaps, policy arrow plots, and helper routines to
highlight top-k cells or compare salient states across metrics.
"""

import numpy as np
import matplotlib.pyplot as plt
from matplotlib.colors import ListedColormap, hsv_to_rgb
import os

# Allow duplicate OpenMP runtimes when using some MKL builds.
os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE"
# One CPU thread is enough for plotting.
os.environ["OMP_NUM_THREADS"] = "1"
os.environ["MKL_NUM_THREADS"] = "1"


class BottleneckVisualization:
    def __init__(self, env):
        self.env = env
        self.height = env.height
        self.width = env.width

    # --------------------------------------------------------------------- #
    #                         helper: state ↔ matrix                        #
    # --------------------------------------------------------------------- #
    def decode_vector_to_matrix(self, vector: np.ndarray) -> np.ndarray:
        """
        Reshape a length-(env.size**2) vector back to a 2-D matrix,
        filling wall cells with NaN so they can be ignored by the colormap.
        """
        matrix = np.full((self.width, self.height), np.nan)  # walls → NaN
        for x in range(self.width):
            for y in range(self.height):
                idx = x + y * self.width
                if self.env.grid.get(x, y) is None:          # non-wall cell
                    matrix[x, y] = vector[idx]
        return matrix

    # --------------------------------------------------------------------- #
    #                   find top-k locations (ignore walls)                 #
    # --------------------------------------------------------------------- #
    def find_top_k_elements(self, data_matrix: np.ndarray, k: int):
        """
        Return the coordinates of the k largest values, ignoring wall cells.
        """
        valid = np.where(~np.isnan(data_matrix), data_matrix, -np.inf)
        flat_idx = np.argpartition(valid.flatten(), -k)[-k:]
        indices = np.unravel_index(flat_idx, data_matrix.shape)
        return list(zip(indices[0], indices[1]))

    # --------------------------------------------------------------------- #
    #                       generate vivid RGB colors                       #
    # --------------------------------------------------------------------- #
    @staticmethod
    def _vivid_colors(k: int, seed: int = 0):
        """
        Generate *k* highly saturated, bright colors (skip gray/black/white).
        If k == 0 → return [].
        """
        if k == 0:
            return []

        rng = np.random.default_rng(seed)
        hues = rng.choice(np.linspace(0, 1, 360, endpoint=False), size=k, replace=False)
        colors = [tuple(hsv_to_rgb([h, 0.45, 1.0])) for h in hues]
        return colors

    # --------------------------------------------------------------------- #
    #                            2-D heat-map                               #
    # --------------------------------------------------------------------- #
    def plot_2d_heatmap(
        self,
        data_vector: np.ndarray,
        topk: int = 32,
        title: str = "2D Heatmap",
        color_bar: bool = True,
        rep_states=None,
        cmap_name: str = "hot",  # "hot" = continuous; "auto" = discrete clusters
    ):
        data_matrix = self.decode_vector_to_matrix(data_vector).astype(float)
        extent = [0, self.width, self.height, 0]

        # -------- choose colormap --------
        if cmap_name == "auto":
            valid = data_matrix[data_matrix >= 0]
            K = int(np.nanmax(valid)) + 1 if valid.size else 0
            vivid = self._vivid_colors(K)
            color_list = [(0.0, 0.0, 0.0, 1.0)] + vivid        # black + vivid colors
            cmap = ListedColormap(color_list)
            cmap.set_bad("gray")                               # NaN → gray

            plot_mat = data_matrix + 1                         # -1→0, 0..K-1→1..K
            vmin, vmax = 0, max(K, 1)
            show_cbar = False
        else:
            cmap = plt.get_cmap(cmap_name)
            cmap.set_bad("gray")
            plot_mat = data_matrix
            vmin = vmax = None
            show_cbar = color_bar

        # -------- draw --------
        plt.figure(figsize=(8, 8))
        plt.imshow(
            plot_mat.T,
            cmap=cmap,
            origin="upper",
            interpolation="nearest",
            extent=extent,
            vmin=vmin,
            vmax=vmax,
        )
        if show_cbar:
            cbar = plt.colorbar(shrink=0.8)
            cbar.ax.tick_params(labelsize=18)

        plt.title(title, fontsize=20)
        plt.xticks(np.arange(0, self.width + 1))
        plt.yticks(np.arange(0, self.height + 1))
        plt.xlim(0, self.width)
        plt.ylim(self.height, 0)
        plt.grid(color="lightgray", linewidth=0.5)

        ax = plt.gca()
        ax.tick_params(axis="both", which="major", labelsize=18)

        # draw blue frames on top-k cells
        if topk > 0:
            for x, y in self.find_top_k_elements(data_matrix, topk):
                ax.add_patch(
                    plt.Rectangle(
                        (x, y), 1, 1, edgecolor="blue", linewidth=2, fill=False
                    )
                )

        # highlight representative states in red
        if rep_states is not None:
            for idx in rep_states:
                row, col = divmod(idx, self.width)
                ax.add_patch(
                    plt.Rectangle(
                        (col, row), 1, 1, edgecolor="red", linewidth=4, fill=False
                    )
                )

        ax.set_xticklabels([])  # hide axis numbers
        ax.set_yticklabels([])
        plt.tight_layout()
        plt.show()

    # --------------------------------------------------------------------- #
    #                     arrow map for optimal policy                      #
    # --------------------------------------------------------------------- #
    def plot_policy_arrows(
        self,
        action_vector: np.ndarray,
        option_q_table: np.ndarray,
        title: str = "Optimal Policy Arrows",
    ):
        """
        For every non-wall cell plot an arrow for its optimal action.
        Action code: 0=↑, 1=↓, 2=←, 3=→
        If max-Q ≤ 0 the state is considered an option-termination state and
        shown as a red dot instead of an arrow.
        """
        action_to_delta = {0: (0, -0.3), 1: (0, 0.3), 2: (-0.3, 0), 3: (0.3, 0)}

        plt.figure(figsize=(8, 8))
        ax = plt.gca()
        ax.set_title(title, fontsize=20)
        ax.set_xlim(0, self.width)
        ax.set_ylim(0, self.height)
        ax.set_xticks(np.arange(0, self.width + 1))
        ax.set_yticks(np.arange(0, self.height + 1))
        ax.set_xticklabels([])
        ax.set_yticklabels([])
        ax.grid(True, color="lightgray", linewidth=0.5)
        ax.set_aspect("equal")

        # draw walls first
        for x in range(self.width):
            for y in range(self.height):
                if self.env.grid.get(x, y) is not None:  # wall
                    ax.add_patch(plt.Rectangle((x, y), 1, 1, color="gray"))

        # draw arrows / terminal dots
        for x in range(self.width):
            for y in range(self.height):
                idx = x + y * self.width
                if self.env.grid.get(x, y) is None:      # non-wall
                    if np.max(option_q_table[idx]) <= 0:  # termination
                        ax.plot(x + 0.5, y + 0.5, "ro", markersize=6)
                    else:
                        a = action_vector[idx]
                        dx, dy = action_to_delta.get(a, (0, 0))
                        ax.arrow(
                            x + 0.5,
                            y + 0.5,
                            dx,
                            dy,
                            head_width=0.15,
                            head_length=0.15,
                            fc="black",
                            ec="black",
                        )

        plt.gca().invert_yaxis()
        plt.tight_layout()
        plt.show()

    # --------------------------------------------------------------------- #
    #                               3-D plot                                #
    # --------------------------------------------------------------------- #
    def display_3d_heatmap(self, data_vector, title="3D Heatmap"):
        """
        3-D surface plot; wall cells are ignored (set to zero).
        """
        data_matrix = self.decode_vector_to_matrix(data_vector)
        data_matrix = np.flipud(np.fliplr(data_matrix))  # flip for nicer view

        x = np.arange(self.width)
        y = np.arange(self.height)
        X, Y = np.meshgrid(x, y)
        Z = np.nan_to_num(data_matrix.T, nan=0)

        fig = plt.figure(figsize=(10, 8))
        ax = fig.add_subplot(111, projection="3d")
        ax.plot_surface(X, Y, Z, cmap="coolwarm", edgecolor="k")

        ax.set_title(title)
        ax.set_xlabel("X Position")
        ax.set_ylabel("Y Position")
        ax.set_zlabel("Value")

        plt.show()

    # --------------------------------------------------------------------- #
    #               overlap of top-k states between two vectors             #
    # --------------------------------------------------------------------- #
    def compare_top_k_positions(
        self, vector1: np.ndarray, vector2: np.ndarray, topk: int
    ) -> int:
        """
        Return the number of overlapping coordinates among the top-k
        highest entries of two vectors.
        """
        matrix1 = self.decode_vector_to_matrix(vector1)
        matrix2 = self.decode_vector_to_matrix(vector2)

        top_1 = set(self.find_top_k_elements(matrix1, topk))
        top_2 = set(self.find_top_k_elements(matrix2, topk))

        return len(top_1.intersection(top_2))


# ----------------------- example usage -----------------------
if __name__ == "__main__":
    from bottleneck_env import SimpleEnv

    env = SimpleEnv(render_mode=None)
    env.reset()

    # Load data produced by DQN_agent.py (example file names)
    state_total_increment = np.load("state_total_increment.npy")
    visit_count = np.load("visit_count.npy")

    visualizer = BottleneckVisualization(env)

    # 2-D and 3-D visualizations
    visualizer.plot_2d_heatmap(
        state_total_increment, title="Cumulative State Reward Heatmap"
    )
    visualizer.display_3d_heatmap(
        state_total_increment, title="3D Cumulative State Reward Heatmap"
    )

    # Overlap of the top 32 cells
    overlap_count = visualizer.compare_top_k_positions(
        state_total_increment, visit_count, topk=32
    )
    print(f"Number of overlapping top-32 positions: {overlap_count}")
