import hashlib
import numpy as np
import networkx as nx
import matplotlib.pyplot as plt


# def extract_states(minigrid_env, num_iters=1000):
#     # all_states = set()
#     all_states = dict()
#     for i in range(num_iters):
#         obs, _ = minigrid_env.reset()
#         hash_out = hash_state(unwrap_state(obs))
#         if hash_out not in all_states:
#             all_states[hash_out] = unwrap_state(obs)
#     return all_states

def hash_state(state):
    # print(state)
    image, carrying, agent = state
    sample_hash = hashlib.sha256()
    sample_hash.update((str(image.data.tobytes())+str(carrying)+str(agent)).encode("utf8"))
    return sample_hash.hexdigest()[:16]

# def unwrap_state(state):
#     return state

def shortest_path(state, start_loc, goal_loc, next_dir=True):
    # state is a numpy array of shape (5,5,3), goal_loc is a tuple (x, y)
    # Generate graph from state. All states where the first item of the third index is 1, 10, or 8 are valid.
    # I'm also going to add lava to this, the agent will walk into lava (9)
    # For each valid state, add edges to all other valid states that are adjacent.
    # Then run BFS on the graph.
    
    graph = nx.Graph()
    for i in range(state.shape[0]):
        for j in range(state.shape[1]):
            if state[i, j, 0] in (1, 10, 8, 9):
                graph.add_node((i, j))
            # or if the item is a door with id 4 and is open
            elif state[i, j, 0] == 4 and state[i, j, 2] == 0:
                graph.add_node((i, j))

    if goal_loc not in graph.nodes:
        graph.add_node(goal_loc)
        goal_not_walkable = True
    else:
        goal_not_walkable = False
    
    # print out graph
    # print(graph)
    
    for i, j in graph.nodes:
        # check all combinations of i+1, j+1, i-1, j-1 to see if they are in the graph
        if (i+1, j) in graph.nodes:
            graph.add_edge((i, j), (i+1, j))
        if (i-1, j) in graph.nodes:
            graph.add_edge((i, j), (i-1, j))
        if (i, j+1) in graph.nodes:
            graph.add_edge((i, j), (i, j+1))
        if (i, j-1) in graph.nodes:
            graph.add_edge((i, j), (i, j-1))

    # nx.draw(graph, with_labels=True)
    # plt.show()

    if nx.has_path(graph, start_loc, goal_loc):
        path = nx.shortest_path(graph, start_loc, goal_loc)
        # print(path)
        # print(len(path))
        if next_dir:
            if len(path) == 1:
                return None
            direction = (path[1][0] - path[0][0], path[1][1] - path[0][1])
            return direction, len(path) != 2
        else:
            return path
    else:
        return None
