from typing import Dict
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.patches import Rectangle


def plot_fourrooms(mdp, show_state_idx=False, figsize=(6,6), ax=None):
    if ax is None:
        fig, ax = plt.subplots(figsize=figsize)
    
    ax.set_aspect('equal')
    ax.set_xlim(0, mdp.cols)
    ax.set_ylim(0, mdp.rows)
    ax.invert_yaxis()
    
    for s, (r, c) in enumerate(mdp.state_to_coord):
        ax.add_patch(Rectangle((c, r), 1, 1, fill=True, edgecolor='lightgray', facecolor='white'))
        
        if show_state_idx:
            ax.text(
                c + 0.5,
                r + 0.6,
                str(s),
                ha='center', va='center',
                fontsize=6,
                color='gray'
            )
    
    for (r, c) in mdp.walls:
        ax.add_patch(Rectangle((c, r), 1, 1, color='black'))
    
    if mdp.start_state is not None:
        sr, sc = mdp.start_state
        ax.add_patch(Rectangle((sc, sr), 1, 1, color='#4daf4a'))
        ax.text(sc+0.5, sr+0.5, "S", ha='center', va='center', color='white', fontsize=12, weight='bold')
    
    if mdp.goal_state is not None:
        gr, gc = mdp.goal_state
        ax.add_patch(Rectangle((gc, gr), 1, 1, color='#e41a1c'))
        ax.text(gc+0.5, gr+0.5, "G", ha='center', va='center', color='white', fontsize=12, weight='bold')
    
    ax.set_xticks([])
    ax.set_yticks([])
    
    for x in range(mdp.cols+1):
        ax.axvline(x, color='lightgray', linewidth=0.5)
    for y in range(mdp.rows+1):
        ax.axhline(y, color='lightgray', linewidth=0.5)
    
    return ax


def plot_eigenvectors_fourrooms(
    mdp, 
    eigenvector: np.ndarray, 
    cmap: str = "coolwarm", 
    ax=None, 
    colorbar: bool = True, 
    smooth_mode: str = "grid", 
    smooth_sigma: float = 0.0, 
    diffusion_steps: int = 0, 
    diffusion_alpha: float = 0.25, 
    vmin: float = None, 
    vmax: float = None,
    return_im: bool = False,
    normalise: bool = False, 
):
    if ax is None:
        fig, ax = plt.subplots(1, 1, figsize=(6, 6))
    
    if smooth_sigma > 0.0 and smooth_mode == "grid":
        eigenvector = _gaussian_smooth_fourrooms(mdp, eigenvector, sigma=smooth_sigma)
    
    if diffusion_steps > 0 and smooth_mode == "diffusion":
        eigenvector = _graph_diffuse_fourrooms(mdp, eigenvector, steps=diffusion_steps, alpha=diffusion_alpha)
    
    if normalise:
        max_abs = np.max(np.abs(eigenvector))
        eigenvector = eigenvector / (max_abs + 1e-8)
    grid = np.full((mdp.rows, mdp.cols), np.nan)
    
    for s, (r, c) in enumerate(mdp.state_to_coord):
        grid[r, c] = eigenvector[s]
    
    im = ax.imshow(grid, cmap=cmap, interpolation="nearest", vmin=vmin, vmax=vmax)
    ax.set_xticks([])
    ax.set_yticks([])

    for (r, c) in mdp.walls:
        ax.add_patch(Rectangle((c-0.5, r-0.5), 1, 1, color='black'))
    
    if colorbar:
        plt.colorbar(im, ax=ax)
    if return_im:
        return ax, im
    return ax


def _gaussian_smooth_fourrooms(mdp, values: np.ndarray, sigma: float) -> np.ndarray:
    from scipy.ndimage import gaussian_filter

    grid = np.full((mdp.rows, mdp.cols), np.nan)
    for s, (r, c) in enumerate(mdp.state_to_coord):
        grid[r, c] = values[s]

    mask = ~np.isnan(grid)
    grid_zero = np.where(mask, grid, 0.0)

    filtered_vals = gaussian_filter(grid_zero, sigma=sigma, mode='nearest')
    filtered_mask = gaussian_filter(mask.astype(float), sigma=sigma, mode='nearest')

    smoothed_grid = np.where(mask, filtered_vals / np.maximum(filtered_mask, 1e-8), np.nan)

    smoothed = np.zeros_like(values)
    for s, (r, c) in enumerate(mdp.state_to_coord):
        smoothed[s] = smoothed_grid[r, c]
    return smoothed


def _graph_diffuse_fourrooms(mdp, values: np.ndarray, steps: int = 5, alpha: float = 0.25) -> np.ndarray:

    coord_to_state = {(r, c): s for s, (r, c) in enumerate(mdp.state_to_coord)}
    walls = set(mdp.walls)
    vals = values.astype(float).copy()
    for _ in range(steps):
        new = vals.copy()
        for s, (r, c) in enumerate(mdp.state_to_coord):
            neighbors = []
            for dr, dc in [(-1,0),(1,0),(0,-1),(0,1)]:
                nr, nc = r+dr, c+dc
                if (nr, nc) in walls:
                    continue
                if 0 <= nr < mdp.rows and 0 <= nc < mdp.cols and (nr, nc) in coord_to_state:
                    neighbors.append(coord_to_state[(nr, nc)])
            if neighbors:
                mean_n = np.mean(vals[neighbors])
                new[s] += alpha * (mean_n - vals[s])
        vals = new
    return vals


def plot_eigenoption_policy_fourrooms(
    mdp, 
    option: Dict[str, np.ndarray], 
    arrow_scale: float = 0.35,
    ax=None, 
    arrow_color: str = 'blue',
    termination_color: str = 'red',
):
    pi = option["pi"]
    beta = option["beta"]
    
    rows, cols = mdp.rows, mdp.cols
    
    if ax is None:
        fig, ax = plt.subplots(1, 1, figsize=(6, 6))
    
    for (r, c) in mdp.walls:
        ax.add_patch(Rectangle((c, r), 1, 1, facecolor='black'))
    
    for x in range(cols+1):
        ax.plot([x, x], [0, rows], color='lightgray', linewidth=0.4)
    for y in range(rows+1):
        ax.plot([0, cols], [y, y], color='lightgray', linewidth=0.4)
    
    for s in range(mdp.num_states):
        if beta[s]:
            continue
        (r, c) = mdp.state_to_coord[s]
        a_list = np.where(pi[s] == np.max(pi[s]))[0]
        for a in a_list:
        
            if a == 0:   # up
                dx, dy = 0, -arrow_scale
            elif a == 1: # down
                dx, dy = 0, arrow_scale
            elif a == 2: # left
                dx, dy = -arrow_scale, 0
            elif a == 3: # right
                dx, dy = arrow_scale, 0
                
            ax.arrow(
                c+0.5, r+0.5, 
                dx, dy, 
                head_width=0.2, 
                head_length=0.2, 
                linewidth=1., 
                color=arrow_color, 
                length_includes_head=True, 
            )
    
    terminal_states = [mdp.state_to_coord[s] for s in range(mdp.num_states) if beta[s]]
    if len(terminal_states) > 0:
        terminal_r, terminal_c = zip(*terminal_states)
        ax.scatter(
            [c+0.5 for c in terminal_c], 
            [r+0.5 for r in terminal_r],
            s=50, 
            c=termination_color,
            edgecolors='k',
            linewidths=0.8, 
            zorder=3, 
        )
    
    ax.set_xlim(0, cols)
    ax.set_ylim(rows, 0)
    ax.set_aspect('equal')
    
    ax.set_xticks([])
    ax.set_yticks([])
    
    return ax


def standard_post_processing_plot(ax, fontsize: int = 20):
    ax.spines['left'].set_linewidth(2)
    ax.spines['left'].set_color('black')
    ax.spines['bottom'].set_linewidth(2)
    ax.spines['bottom'].set_color('black')
    ax.spines['top'].set_visible(False)
    ax.spines['right'].set_visible(False)

    # Make ticks thicker and black
    ax.tick_params(width=2, color='black', direction='out')
    ax.xaxis.label.set_color('black')
    ax.yaxis.label.set_color('black')
    ax.title.set_color('black')

    ax.tick_params(axis='x', colors='black', labelsize=fontsize)
    ax.tick_params(axis='y', colors='black', labelsize=fontsize)