"""Centrality measures for MiniGrid layouts via NetworkX.

Builds an undirected graph from free cells and computes different
centrality vectors used as baselines for comparison with VPS maps.
"""

import numpy as np
import networkx as nx
from bottleneck_env import SimpleEnv


def build_graph_from_env(env):
    """
    Convert the MiniGrid environment into a graph.
    Wall cells are treated as unreachable; all other cells are nodes.
    """
    G = nx.Graph()
    for x in range(env.size):
        for y in range(env.size):
            if env.grid.get(x, y) is None:  # Non-wall cell
                node = (x, y)
                # Add passable 4-neighborhood (left, right, up, down)
                neighbors = [
                    (x - 1, y), (x + 1, y),  # left / right
                    (x, y - 1), (x, y + 1)   # up / down
                ]
                for nx_pos in neighbors:
                    if 0 <= nx_pos[0] < env.size and 0 <= nx_pos[1] < env.size:
                        if env.grid.get(nx_pos[0], nx_pos[1]) is None:
                            G.add_edge(node, nx_pos)
    return G


def compute_betweenness_centrality(env):
    """
    Compute the betweenness centrality of every state
    and return it as a flat vector of length env.size**2.
    """
    G = build_graph_from_env(env)
    centrality_dict = nx.betweenness_centrality(G, normalized=True)  # Use Brandes algorithm

    # Match the storage format used in DQN_agent.py
    centrality_vector = np.zeros(env.size ** 2)

    for (x, y), value in centrality_dict.items():
        index = x + y * env.size  # Map 2-D coordinate to 1-D index
        centrality_vector[index] = value

    return centrality_vector


def compute_closeness_centrality(env):
    """
    Current-flow closeness centrality (also called “nearest” centrality).
    """
    G = build_graph_from_env(env)
    cen_dict = nx.current_flow_closeness_centrality(G)  # pass normalized=True if needed
    vec = np.zeros(env.size ** 2, dtype=float)
    for (x, y), v in cen_dict.items():
        vec[x + y * env.size] = v
    return vec


def compute_current_flow_centrality(env):
    """
    Compute current-flow betweenness centrality with NetworkX.
    Return a vector of length env.size**2; wall cells are set to 0.
    """
    G = build_graph_from_env(env)

    # Current-flow betweenness centrality
    centrality_dict = nx.current_flow_betweenness_centrality(G, normalized=True)

    # Result vector
    centrality_vector = np.zeros(env.size ** 2)

    for (x, y), value in centrality_dict.items():
        index = x + y * env.size
        centrality_vector[index] = value

    return centrality_vector


def main():
    env = SimpleEnv(render_mode=None)
    env.reset()

    # Compute betweenness-centrality vector
    _ = compute_betweenness_centrality(env)


if __name__ == "__main__":
    main()
