"""Centrality measures for MiniGrid layouts via NetworkX.

Builds an undirected graph from free cells and computes different
centrality vectors used as baselines for comparison with VPS maps.
"""

import numpy as np
import networkx as nx
from .bottleneck_env import SimpleEnv


def build_graph_from_env(env):
    """
    Convert the MiniGrid environment into a graph.
    Wall cells are treated as unreachable; all other cells are nodes.
    """
    G = nx.Graph()
    for x in range(env.size):
        for y in range(env.size):
            if env.grid.get(x, y) is None:  # Non-wall cell
                node = (x, y)
                # Add passable 4-neighborhood (left, right, up, down)
                neighbors = [(x - 1, y), (x + 1, y), (x, y - 1), (x, y + 1)]
                for nx_pos in neighbors:
                    if 0 <= nx_pos[0] < env.size and 0 <= nx_pos[1] < env.size:
                        if env.grid.get(nx_pos[0], nx_pos[1]) is None:
                            G.add_edge(node, nx_pos)
    return G


def compute_betweenness_centrality(env):
    """
    Compute the betweenness centrality of every state
    and return it as a flat vector of length env.size**2.
    """
    G = build_graph_from_env(env)
    centrality_dict = nx.betweenness_centrality(G, normalized=True)  # Brandes algorithm

    centrality_vector = np.zeros(env.size**2)
    for (x, y), value in centrality_dict.items():
        index = x + y * env.size
        centrality_vector[index] = value

    return centrality_vector


def compute_closeness_centrality(env):
    """Current-flow closeness centrality (also called “nearest” centrality)."""
    G = build_graph_from_env(env)
    cen_dict = nx.current_flow_closeness_centrality(G)
    vec = np.zeros(env.size**2, dtype=float)
    for (x, y), v in cen_dict.items():
        vec[x + y * env.size] = v
    return vec


def compute_current_flow_centrality(env):
    """
    Compute current-flow betweenness centrality with NetworkX.
    Return a vector of length env.size**2; wall cells are set to 0.
    """
    G = build_graph_from_env(env)
    centrality_dict = nx.current_flow_betweenness_centrality(G, normalized=True)

    centrality_vector = np.zeros(env.size**2)
    for (x, y), value in centrality_dict.items():
        index = x + y * env.size
        centrality_vector[index] = value

    return centrality_vector


def main():
    env = SimpleEnv(render_mode=None)
    env.reset()
    _ = compute_betweenness_centrality(env)


if __name__ == "__main__":
    main()

