import networkx as nx
import matplotlib.pyplot as plt
import pickle
import os
import copy
import numpy as np
from matplotlib.patches import Polygon
from matplotlib.collections import PatchCollection
from simplicial import *
import platform
import itertools

# imports for making GIFs
# import glob
# import moviepy.editor as mpy # install 
# from natsort import natsorted # install

## load and plot a saved world graph constructed from worldbuilder.py

cwd = os.getcwd() # get current working directory

# list file-paths of files generated by woldbuilder.py
if platform.system() == "Windows":  
    graph_path = os.path.normpath(cwd+"\\result\\graph.gpickle")
    node_positions_path = os.path.normpath(cwd+"\\result\\node_positions")
    node_IDs_path = os.path.normpath(cwd+"\\result\\node_IDs")
    node_orders_path = os.path.normpath(cwd+"\\result\\node_orders")
else:
    graph_path = os.path.normpath(cwd+"/result/graph.gpickle")
    node_positions_path = os.path.normpath(cwd+"/result/node_positions")
    node_IDs_path = os.path.normpath(cwd+"/result/node_IDs")
    node_orders_path = os.path.normpath(cwd+"/result/node_orders")

# load variables
world_graph = nx.read_gpickle(graph_path)
infile = open(node_positions_path,'rb')
node_positions = pickle.load(infile)
infile.close()
infile = open(node_IDs_path,'rb')
node_IDs = pickle.load(infile)
infile.close()
infile = open(node_orders_path,'rb')
node_orders = pickle.load(infile)
infile.close()
WorldArea = world_graph.nodes()

# add colour information to the graph for special nodes
regular_colour = [0.5,0.5,0.5,0.5]
start_colour = [0,1,0,1]
goal_colour = [1,0,0,1]
object_colour = [0,0,1,1]
wall_colour = [1,1,1,0.5]

node_colours = []
for i in range(0,len(WorldArea)):
    node_colours.append(regular_colour)   

for k in node_orders[0]:
    node_colours[k] = start_colour
for k in node_orders[1]:
    node_colours[k] = goal_colour
for k in node_orders[2]:
    node_colours[k] = object_colour
for k in node_orders[3]:
    node_colours[k] = wall_colour

# plot the undirected graph
nx.draw_networkx(world_graph, pos=node_positions, node_color=node_colours)
plt.axis('off')
plt.show()


## start construction of state complex

# create state complex graph and set node counter
state_complex = nx.Graph()
state_complex_node_counter = 0

# set first state of state complex
state_complex.add_node(state_complex_node_counter, state=copy.deepcopy(node_IDs))

# set starting agent node
agent_node = node_IDs[0][0]

# GENERATOR 1
# agent (A) can move to adjacent node if it is empty (E), i.e. green -> grey
#   A-E
#   becomes
#   E-A

def gen_1_support(agent_node):
    agent_empty_neighbours = []
    print("## generator 1")
    for n in world_graph.neighbors(agent_node):
        if n in node_IDs[0]:
            print(n,'is an agent')
        elif n in node_IDs[1]:
            if n in node_IDs[2]:
                print(n,'is a goal but is blocked by an object')
            else:
                agent_empty_neighbours.append(n)
                print(n,'is a goal and is empty')
        elif n in node_IDs[2]:
            print(n,'is an object')
        elif n in node_IDs[3]:
            print(n,'is a wall')
        else:
            agent_empty_neighbours.append(n)
            print(n,'is empty')
    return agent_empty_neighbours

# GENERATOR 2
# agent (A) can push an object if there is empty space (E) behind the object (O), where *s represent any other node
# note: cannot pull!
#   A-O-E
#   *-*-*
#   becomes
#   E-A-O
#   *-*-*

def gen_2_support(agent_node):
    print("## generator 2")
    gen_2_possible = False
    
    object_loc = None
    adjacent_to_agent = []
    adjacent_to_object = []
    
    for n in world_graph.neighbors(agent_node):
        adjacent_to_agent.append(n)
        if n in node_IDs[2]:
            object_loc = n
            print("there is an adjacent object at node", object_loc)
    
    empty_adj_to_object = None
    object_movable_to = []
    agent_movable_to = []
    
    if object_loc:
        for m in world_graph.neighbors(object_loc):
            adjacent_to_object.append(m)
        
        for m in world_graph.neighbors(object_loc):
            not_agent = True
            if m == agent_node:
                not_agent = False
            
            set_of_objects = set(node_IDs[2])
            not_object = True
            if set([m]).issubset(set_of_objects):
                not_object = False
            
            if (not_agent and not_object):
                if [m] not in node_IDs and m not in node_IDs[3]:
                    empty_adj_to_object = m
                    adjacent_to_empty = []
                
                    for n in world_graph.neighbors(empty_adj_to_object):
                        adjacent_to_empty.append(n)
                    
                    if set(adjacent_to_empty).intersection(set(adjacent_to_agent)) == set([object_loc]):
                        gen_2_possible = True
                        object_movable_to = m
                        agent_movable_to = object_loc
                        print("gen 2 possible! move agent to", agent_movable_to, "and object to", object_movable_to)

    return gen_2_possible, object_movable_to, agent_movable_to

# set-up and run the construction loop

todo = 0
todo_list = []
todo_list.append(todo)

todo_states = []
todo_states.append(copy.deepcopy(node_IDs))

step = 0

for x in todo_list:
    
    agent_nodes = copy.deepcopy(todo_states[x][0])
    if todo_states[x][2] == []:
        object_node = None
        node_IDs[2]= []
    else:
        object_node = copy.deepcopy(todo_states[x][2][0])
        node_IDs[2][0] = copy.deepcopy(object_node)
    
    print('### step', step,': agent at node', agent_nodes, 'and object at node', object_node)
    
    for agent in range(0,len(agent_nodes)):
        node_IDs[0] = copy.deepcopy(agent_nodes)
        agent_node = copy.deepcopy(todo_states[x][0][agent])
        node_IDs[0][agent] = copy.deepcopy(agent_node)

        # apply GENERATOR 1
        empty_neighbours = gen_1_support(agent_node)
        #last_agent = copy.deepcopy(agent_node)
        for e in empty_neighbours:
            print('test moving agent', agent_node, 'to empty neighbour', e)
            todo = len(todo_list)
            unique_node = True # assume this node will be unique
            node_IDs[0] = copy.deepcopy(agent_nodes)
            node_IDs[0] = [e if x==agent_node else x for x in node_IDs[0]]
            # node_IDs[0].remove(agent_node) # remove current agent location
            # node_IDs[0].append(e) # add empty neighbour to agent list
            # #last_agent = e
            
            print('current', node_IDs)

            # if this node_IDs state exists in todo states, draw an edge to it
            set_of_agents = set(node_IDs[0])
            set_of_objects = set(node_IDs[2])
            for n in range(0,len(todo_states)):
                set_of_agents_paststate = set(todo_states[n][0])
                set_of_objects_paststate = set(todo_states[n][2])
                if (set_of_agents == set_of_agents_paststate) and (set_of_objects == set_of_objects_paststate):
                #if set(node_IDs[0]) == set(todo_states[n][0]): # this is wrong because we care about the state of all things, not just the agent
                    print('not unique')
                    unique_node = False # say that this node is not unique
                    state_complex.add_edge(x, n, attr='Gen 1') # create an edge to this existing state complex node
            
            # if this node_IDs state doesn't exist, create it and draw an edge
            if unique_node:
                print('unique')
                state_complex.add_node(todo, state=copy.deepcopy(node_IDs)) # create the new state complex node node
                state_complex.add_edge(x, todo, attr='Gen 1') # create an edge from the current node to this new node
                todo_states.append(copy.deepcopy(node_IDs))
                todo_list.append(todo)
        
        node_IDs[0] = copy.deepcopy(agent_nodes)
        
        # apply GENERATOR 2
        if object_node:
            
            gen_2_possible, object_movable_to, agent_movable_to = gen_2_support(agent_node)
            
            if gen_2_possible:
                print('test moving object to', object_movable_to)
                todo = len(todo_list)
                unique_node = True # assume this node will be unique
                node_IDs[0][0] = agent_movable_to
                original_object_loc = copy.deepcopy(node_IDs[2][0])
                node_IDs[2][0] = object_movable_to
                print('current', node_IDs)
        
                # if this node_IDs state exists in todo states, draw an edge to it
                set_of_agents = set(node_IDs[0])
                set_of_objects = set(node_IDs[2])
                for n in range(0,len(todo_states)):
                    set_of_agents_paststate = set(todo_states[n][0])
                    set_of_objects_paststate = set(todo_states[n][2])
                    if (set_of_agents == set_of_agents_paststate) and (set_of_objects == set_of_objects_paststate):
                        print('not unique')
                        unique_node = False # say that this node is not unique
                        state_complex.add_edge(x, n, attr='Gen 2') # create an edge to this existing state complex node
                
                # if this node_IDs state doesn't exist, create it and draw an edge
                if unique_node:
                    print('unique')
                    state_complex.add_node(todo, state=copy.deepcopy(node_IDs)) # create the new state complex node node
                    state_complex.add_edge(x, todo, attr='Gen 2') # create an edge from the current node to this new node
                    todo_states.append(copy.deepcopy(node_IDs))
                    todo_list.append(todo)
                
                node_IDs[2][0] = original_object_loc

    step += 1

gen_attributes = list(nx.get_edge_attributes(state_complex,'attr').values())
gen_attributes = list(map(lambda x: str.replace(x, "Gen 1", "blue"), gen_attributes))
gen_attributes = list(map(lambda x: str.replace(x, "Gen 2", "red"), gen_attributes))

nx.draw_kamada_kawai(state_complex, width=1.0, alpha=0.5, node_size=25, with_labels=True, edge_color=gen_attributes, font_size=5)
plt.axis('off')
plt.savefig("state_complex.png", dpi=300)
plt.show()

is_planar, certificate = nx.check_planarity(state_complex,counterexample=False)
print("State complex planar?", is_planar)

# fill in the squares 'dancing by myself' & commuting moves algorithm

candidate_paths = []
dancing_with_yourself_candidate_paths = []
commuting_moves_candidate_paths = []
test_counter = []
test_test_counter = []

for step_0 in state_complex:
    for step_1 in state_complex.neighbors(step_0):
        for step_2 in state_complex.neighbors(step_1):
            for step_3 in state_complex.neighbors(step_2):
                for step_4 in state_complex.neighbors(step_3):  # walk all possible paths of length 4
                    if step_4 == step_0:                        # which start and end at the same vertex
                        candidate_path = [step_0, step_1, step_2, step_3, step_4]
                        if len(list(set(candidate_path))) == 4: # and have no other vertices in common
                            state_0 = state_complex.nodes(data=True)[step_0]
                            state_0_object_pos = state_0['state'][2]
                            state_0_agent_pos = state_0['state'][0]
                            state_1 = state_complex.nodes(data=True)[step_1]
                            state_1_object_pos = state_1['state'][2]
                            state_1_agent_pos = state_1['state'][0]
                            state_2 = state_complex.nodes(data=True)[step_2]
                            state_2_object_pos = state_2['state'][2]
                            state_2_agent_pos = state_2['state'][0]
                            state_3 = state_complex.nodes(data=True)[step_3]
                            state_3_object_pos = state_3['state'][2]
                            state_3_agent_pos = state_3['state'][0]
                            if state_0_object_pos == state_1_object_pos == state_2_object_pos == state_3_object_pos: # the object/s is/are unmoving
                                agent_locations_list = state_0_agent_pos + state_1_agent_pos + state_2_agent_pos + state_3_agent_pos
                                repeating_agents_location_counts = {i:agent_locations_list.count(i) for i in agent_locations_list}
                                loc_visits_by_index = list(repeating_agents_location_counts.values())
                                loc_unchanging_idx = [i for i, x in enumerate(loc_visits_by_index) if x == 4] # remove any location label that appears in every state
                                entries_to_remove = []
                                #loc_changing_idx = list(set(range(0,5))-set(loc_unchanging_idx))
                                for loc in loc_unchanging_idx:
                                    entries_to_remove.append(list(repeating_agents_location_counts.keys())[loc])
                                for k in entries_to_remove:
                                    repeating_agents_location_counts.pop(k, None)
                                number_of_moving_agents = sum(repeating_agents_location_counts.values())/len(repeating_agents_location_counts)
                                commuting_agents = all(elem == 2 for elem in list(repeating_agents_location_counts.values()))
                                if number_of_moving_agents == 1: # if 1 agent moves, all location labels appear exactly once
                                    candidate_paths.append(candidate_path)
                                    dancing_with_yourself_candidate_paths.append(candidate_path)
                                elif number_of_moving_agents == 2 and commuting_agents: # if 2 agents move, all location labels appear exactly twice
                                    candidate_paths.append(candidate_path)
                                    commuting_moves_candidate_paths.append(candidate_path)

unique_square_paths = [list(x) for x in set([frozenset(path) for path in candidate_paths])]
unique_square_paths_set = [frozenset(x) for x in unique_square_paths]

unique_dwy_square_paths = [list(x) for x in set([frozenset(path) for path in dancing_with_yourself_candidate_paths])]
unique_dwy_square_paths_set = [frozenset(x) for x in unique_dwy_square_paths]

unique_cm_square_paths = [list(x) for x in set([frozenset(path) for path in commuting_moves_candidate_paths])]
unique_cm_square_paths_set = [frozenset(x) for x in unique_cm_square_paths]

nodePos = nx.kamada_kawai_layout(state_complex)

gen_attributes = list(nx.get_edge_attributes(state_complex,'attr').values())
gen_attributes = list(map(lambda x: str.replace(x, "Gen 1", "blue"), gen_attributes))
gen_attributes = list(map(lambda x: str.replace(x, "Gen 2", "red"), gen_attributes))

fig, ax = plt.subplots()
nx.draw_kamada_kawai(state_complex, width=1.0, alpha=0.5, node_size=25, with_labels=True, edge_color=gen_attributes, ax=ax, font_size=5)
patches = []
for path in unique_dwy_square_paths:
    corner_0 = tuple(nodePos[path[0]])
    corner_1 = tuple(nodePos[path[1]])
    corner_2 = tuple(nodePos[path[2]])
    corner_3 = tuple(nodePos[path[3]])
    corners_arrays = [np.array([corner_0]), np.array([corner_1]), np.array([corner_2]), np.array([corner_3])]
    corners_array_stacked = np.stack(corners_arrays, axis=0)
    center_point = np.mean(corners_array_stacked, axis=0)
    corner_rel_angles = [np.arctan2(corner[0][1] - center_point[0][1], corner[0][0] - center_point[0][0]) for corner in corners_arrays]
    corners_plotting_order = np.argsort(corner_rel_angles)
    arb_corners_order = [corner_0, corner_1, corner_2, corner_3]
    fix_corners_order = [arb_corners_order[i] for i in corners_plotting_order]
    patches.append(Polygon(fix_corners_order))
p = PatchCollection(patches, alpha=0.1, color='blue')
ax.add_collection(p)
plt.axis('off')
plt.savefig("result/state_complex_dances.png", dpi=300)
plt.show()

fig, ax = plt.subplots()
nx.draw_kamada_kawai(state_complex, width=1.0, alpha=0.5, node_size=50, node_color="red", with_labels=False, edge_color="black", ax=ax)
nx.draw_kamada_kawai(state_complex, nodelist=[15], width=1.0, node_size=50, alpha=1.0, node_color="red")
patches = []
for path in unique_cm_square_paths:
    corner_0 = tuple(nodePos[path[0]])
    corner_1 = tuple(nodePos[path[1]])
    corner_2 = tuple(nodePos[path[2]])
    corner_3 = tuple(nodePos[path[3]])
    corners_arrays = [np.array([corner_0]), np.array([corner_1]), np.array([corner_2]), np.array([corner_3])]
    corners_array_stacked = np.stack(corners_arrays, axis=0)
    center_point = np.mean(corners_array_stacked, axis=0)
    corner_rel_angles = [np.arctan2(corner[0][1] - center_point[0][1], corner[0][0] - center_point[0][0]) for corner in corners_arrays]
    corners_plotting_order = np.argsort(corner_rel_angles)
    arb_corners_order = [corner_0, corner_1, corner_2, corner_3]
    fix_corners_order = [arb_corners_order[i] for i in corners_plotting_order]
    patches.append(Polygon(fix_corners_order))
p = PatchCollection(patches, alpha=0.1, color='red')
ax.add_collection(p)
plt.axis('off')
plt.savefig("result/state_complex_comm-moves.png", dpi=300)
plt.show()

## for each vertex, see if it has missing cubical faces (check for local CAT(0) geometry)

def positive_curvature_test(vertex_to_test):
    instances_of_positive_curvature_count = 0
    # construct the local subgraph of a vertex    
    induced_vertices = [vertex_to_test]
    first_neighbours = []
    for first_neighbour in state_complex.neighbors(vertex_to_test):
        induced_vertices.append(first_neighbour)
        first_neighbours.append(first_neighbour)
        for second_neighbour in state_complex.neighbors(first_neighbour):
            induced_vertices.append(second_neighbour)
            
    induced_vertices = list(set(induced_vertices))
    local_subgraph_of_vertex = nx.Graph.copy(state_complex.subgraph(induced_vertices))
    
    # remove single-degree vertices
    v_to_remove = [node for node,degree in dict(local_subgraph_of_vertex.degree()).items() if degree == 1]
    local_subgraph_of_vertex.remove_nodes_from(v_to_remove)
    
    # construct simplicial complex using 0-simplexes (vertices) and 1-simplexes (edges) from the local subgraph of the vertex being tested
    
    c = SimplicialComplex()
    
    zero_simplexes = {}
    for v in first_neighbours:
        zero_simplexes[v] = c.addSimplex(id = v)
    
    one_simplexes = {}
    one_simplexes_square_supports_list = []
    for zero_simplex_one in zero_simplexes:
        for zero_simplex_two in zero_simplexes:
            if zero_simplex_one == zero_simplex_two:
                continue
            else:
                candidate_square_members = set([zero_simplex_one, zero_simplex_two, vertex_to_test])
                for square in unique_square_paths_set:
                    if candidate_square_members.issubset(square):
                        if (c.containsSimplex(tuple((zero_simplex_one, zero_simplex_two)))) or (c.containsSimplex(tuple((zero_simplex_two, zero_simplex_one)))):
                            continue
                        else:
                            one_simplexes[tuple(square)] = c.addSimplex(fs = [zero_simplex_one, zero_simplex_two], id = tuple((zero_simplex_one, zero_simplex_two)))
                            one_simplexes_square_supports_list.append(square)
    # for x in one_simplexes_square_supports_list:
    #     for y in one_simplexes_square_supports_list:
    #         if x != y:
    #             if x.issubset(y):
    #                 print("found repeat")
    
    # check Gromov link condition (here we only need to check condition 2, since condition 1 never occurs in our graph/set-up)
    one_simplexes_square_supports_set = set(one_simplexes_square_supports_list)
    if len(one_simplexes_square_supports_set) > len(one_simplexes):
        print("1-simplices test: Gromov link condition fails") # since we have more squares supporting the same 1-simplexes than there are unique 1-simplexes added, therefore a multi-edge, breaking condition 2
        instances_of_positive_curvature_count += len(one_simplexes_square_supports_set) - len(one_simplexes)
    else:
        print("1-simplices test: Gromov link condition passes")
    
    # from here, the original stragey was to build flag complex and test that it meets the Gromov link condition at all higher levels (see below for commented partial implemented)
    # however, we then thought that for the condition to hold, in our system, it just means that all pairs of simplexes transitions/movements in the state graph have disjoint supports in the gridworld
    # thus, we use this simpler method of checking for all pairs whether they have disjoint traces of supports -- but not directly, we just check the 'active' subgraph of the support, i.e. what changes
    # however, we then realised this doesn't quite work because of dancing with yourself squares and the 'fog of war' effect, where the move ahead of the current move can block commutitivity
    # therefore, we revert to our original strategy, but realise we only need to check up to 3-simplices
    
    # check if there are 2-simplices indicated in the simplicial complex
    flag_complex = copy.copy(c)
    nss = dict()
    nss[1] = set(range(len(flag_complex.simplicesOfOrder(1))))
    
    k = 2
    k2s_exist = True
    if ((k - 1) not in nss.keys()) or (len(nss[k - 1]) == 0):
        # no new simplices to form faces of any new simplices at this order
        k2s_exist = False
        
    if k2s_exist:
        if k not in nss.keys():
            # create a new set into which to add created simplex indices
            nss[k] = set()
        
        # grab the boundary matrix of the faces
        boundary = flag_complex.boundaryOperator(k - 1)
                
        # test all collections of (k + 1) (k - 1)-simplices that include
        # at least one of the new simplicies to see whether they close
        # a new simplex at the higher order
        ks = len(flag_complex._indices[k - 1])
        for fs in [ set(fs) for fs in itertools.combinations(range(ks), k + 1) ]:
            if not nss[k - 1].isdisjoint(fs):
                if flag_complex._isClosed(boundary, list(fs)):
                    # simplices form a boundary, add to the
                    # flag complex (if it doesn't already exist)
                    # sd: this could be a lot more optimised
                    cfs = [ flag_complex._indices[k - 1][i] for i in fs ]
                    if flag_complex.simplexWithFaces(cfs) is None:
                        s = flag_complex.addSimplex(fs = cfs)
                        (_, i) = flag_complex._simplices[s]
                        nss[k].add(i)
    
    k = 3
    k3s_exist = True
    if ((k - 1) not in nss.keys()) or (len(nss[k - 1]) == 0):
        # no new simplices to form faces of any new simplices at this order
        k3s_exist = False
    if k3s_exist:
        if k not in nss.keys():
            # create a new set into which to add created simplex indices
            nss[k] = set()
        
        # grab the boundary matrix of the faces
        boundary = flag_complex.boundaryOperator(k - 1)
        
        # test all collections of (k + 1) (k - 1)-simplices that include
        # at least one of the new simplicies to see whether they close
        # a new simplex at the higher order
        ks = len(flag_complex._indices[k - 1])
        for fs in [ set(fs) for fs in itertools.combinations(range(ks), k + 1) ]:
            if not nss[k - 1].isdisjoint(fs):
                if flag_complex._isClosed(boundary, list(fs)):
                    # simplices form a boundary, add to the
                    # flag complex (if it doesn't already exist)
                    # sd: this could be a lot more optimised
                    cfs = [ flag_complex._indices[k - 1][i] for i in fs ]
                    if flag_complex.simplexWithFaces(cfs) is None:
                        s = flag_complex.addSimplex(fs = cfs)
                        (_, i) = flag_complex._simplices[s]
                        nss[k].add(i)

    # for some reason in this example the first time it created two 2-simplexes, each with the same faces (BUT IT NEVER RE-OCCURRED! MAYBE USER ERROR)
    # e.g. `flag_complex.faces('2d1')` outputs `{(0, 7), (0, 8), (7, 8)}`
    # and  `flag_complex.faces('2d2')` outputs `{(0, 7), (0, 8), (7, 8)}`
    # for now, let's just delete one, since they are the same (maybe flagComplex() makes a unique higher order simplex for num of faces - 1?)
    # without deleting the redundant simplex, `flag_complex.bettiNumbers()` finds a 2nd-order hole `{0: 1, 1: 0, 2: 1}`
    # deleting the redundant 2-simplex can be done using `flag_complex.deleteSimplex('2d2')`
    # with deleting the redundant simplex, `flag_complex.bettiNumbers()` finds no 2nd-order hole `{0: 1, 1: 0, 2: 0}`
    # delete redundant 2-simplexes (this might be unncessary but just for sanity, in case the bug detailed above was not user error)
    
    implied_two_simplexes = list(flag_complex.simplicesOfOrder(2))
    redundant_two_simplexes = set()
    for implied_two_complex in implied_two_simplexes:
        for other_implied_two_complex in implied_two_simplexes:
            if implied_two_complex != other_implied_two_complex:
                if flag_complex.faces(implied_two_complex) == flag_complex.faces(other_implied_two_complex):
                    redundant_two_simplexes.add(frozenset([implied_two_complex, other_implied_two_complex]))
    print("Number of redundant 2-simplexes:", len(redundant_two_simplexes))
    redundant_two_simplex_list = [list(x) for x in list(redundant_two_simplexes)]
    two_simplexes_to_delete = [x[0] for x in redundant_two_simplex_list]
    flag_complex.deleteSimplices(two_simplexes_to_delete)
    
    # print some info
    # print("Betti numbers of semi-link (from local subgraph):",c.bettiNumbers())
    # print("Betti numbers of flag complex of semi-link:",flag_complex.bettiNumbers())
    if flag_complex.maxOrder() >= 2:
        print("Flag complex has 2-simplices -- we need to check for higher-D cubes")
        check_2nd_order = True
    else:
        print("Flag complex has 2-simplices -- we don't need to check for higher-D cubes")
        check_2nd_order = False
    
    cubes = []
    hypercubes = []
    
    # check 2-simplexes condition
    if check_2nd_order:
        flag_complex_2nd_order_simplices = list(flag_complex.simplicesOfOrder(2))
        faces_to_check = len(flag_complex_2nd_order_simplices)
        faces_checked = 0
        for s in flag_complex_2nd_order_simplices:
            one_simplex_faces = list(flag_complex.faces(s))
            one_simplex_faces = [list(flag_complex.faces(x)) for x in one_simplex_faces]
            face_squares = [list() for x in range(0,len(one_simplex_faces))]
            missing_corners = [list() for x in range(0,len(one_simplex_faces))]
            for f in range(0,len(one_simplex_faces)):
                one_simplex_face = one_simplex_faces[f]
                one_simplex_face.append(vertex_to_test)
                corners_to_test = frozenset(one_simplex_face)
                for square in unique_square_paths_set:
                    if corners_to_test.issubset(square):
                        face_squares[f].append(square)
                        missing_corners[f].append(list(set(square).difference(corners_to_test)))
            
            possible_cube_corners = list(itertools.product(*missing_corners))
            possible_cube_corners = [ [item for sublist in list(x) for item in sublist] for x in possible_cube_corners]
            for possible_cube_corner in possible_cube_corners:
                common_neighbours = set(list(range(len(state_complex))))
                for corner in possible_cube_corner:
                    common_neighbours = common_neighbours.intersection(set([n for n in state_complex.neighbors(corner)]))
                if len(common_neighbours) == 1:
                    faces_checked += 1
                    cube = frozenset(list(common_neighbours)+possible_cube_corner+list(set([item for sublist in one_simplex_faces for item in sublist])))
                    print("Found cube corner for this 2-simplex:", list(cube))
                    cubes.append(cube)
                    local_subgraph_of_vertex.add_node(list(common_neighbours)[0], order="2")
                    for u in possible_cube_corner:
                        local_subgraph_of_vertex.add_edge(u,list(common_neighbours)[0])
                if len(common_neighbours) > 1:
                    print("Found multiple cube corners, thus failing second condition of Gromov's test")
                    
        if faces_to_check > faces_checked:
            print("2-simplices test: Gromov link condition fails")
            instances_of_positive_curvature_count += faces_to_check - faces_checked
        else:
            print("2-simplices test: Gromov link condition passes")
        
        if flag_complex.maxOrder() >= 3:
            print("Flag complex has 3-simplices -- we need to check for higher-D cubes")
            check_3rd_order = True
        else:
            print("Flag complex has no 3-simplices -- we don't need to check for higher-D cubes")
            check_3rd_order = False
    
    
        if check_3rd_order:
            flag_complex_3rd_order_simplices = list(flag_complex.simplicesOfOrder(3))
            faces_to_check = len(flag_complex_3rd_order_simplices)
            faces_checked = 0
            for s in flag_complex_3rd_order_simplices:
                two_simplex_faces = list(flag_complex.faces(s))
                two_simplex_faces = [list(flag_complex.faces(x)) for x in two_simplex_faces]
                
                one_simplex_subfaces = [ frozenset([item for t in x for item in t]) for x in two_simplex_faces]
                three_simplex_corners = 0 # we require 4
                hypercube = []
                for cube in cubes:
                    for subface in one_simplex_subfaces:
                        if subface.issubset(cube):
                            three_simplex_corners += 1
                            hypercube.append(cube)
                
                if three_simplex_corners == 4:
                    faces_checked += 1
                    hypercubes.append(hypercube)
                    print("Found hypercube corner for this 3-simplex:", list(hypercube))
                if three_simplex_corners > 4:
                    print("Found multiple hypercube corners, thus failing second condition of Gromov's test")
    
            if faces_to_check > faces_checked:
                print("3-simplices test: Gromov link condition fails")
                instances_of_positive_curvature_count += faces_to_check - faces_checked
            else:
                print("3-simplices test: Gromov link condition passes")

    nx.draw_kamada_kawai(local_subgraph_of_vertex, width=1.0, alpha=0.5, node_size=125, with_labels=False, edge_color=gen_attributes, font_size=25)
    plt.axis('off')
    plt.savefig(str("result/subgraph_of_link_of_vertex_"+str(vertex_to_test)+".png"), dpi=300)
    plt.show()
    
    return instances_of_positive_curvature_count

pos_curvature = []
for v in state_complex:
    pos_curvature.append(positive_curvature_test(v))

relabelling_map = dict(zip(list(range(0,len(state_complex))), pos_curvature))

nx.draw_kamada_kawai(state_complex, width=1.0, alpha=0.5, node_size=125, labels=relabelling_map, with_labels=True, edge_color=gen_attributes, font_size=15)
# p = PatchCollection(patches, alpha=0.1)
# ax.add_collection(p)
plt.axis('off')
plt.savefig(str("result/state_complex_with_G-test_failures.png"), dpi=300)
plt.show()

num_of_states = len(state_complex.nodes())
total_failures = sum(pos_curvature)
avg_failure = total_failures / num_of_states
max_failure = max(pos_curvature)
num_of_zero_failures = sum((x == 0) for x in pos_curvature)
num_of_dwy_squares = len(unique_dwy_square_paths_set)
num_of_cm_squares = len(unique_cm_square_paths_set)

results_file = open("result/_stats.txt","w")
results_file.write("Stats \n")
results_file.writelines(str("num_of_states: "+str(num_of_states)+"\n"))
results_file.writelines(str("total_failures: "+str(total_failures)+"\n"))
results_file.writelines(str("avg_failure: "+str(avg_failure)+"\n"))
results_file.writelines(str("max_failure: "+str(max_failure)+"\n"))
results_file.writelines(str("num_of_zero_failures: "+str(num_of_zero_failures)+"\n"))
results_file.writelines(str("num_of_dwy_squares: "+str(num_of_dwy_squares)+"\n"))
results_file.writelines(str("num_of_cm_squares: "+str(num_of_cm_squares)+"\n"))
results_file.close()

# # plot states in a given walk through state complex

# for commuting_walk in unique_cm_square_paths:
    
#     node_order = list(world_graph)
#     start_order = []
#     goal_order = []
#     object_order = []
#     wall_order = []
#     node_colours = []
#     for i in range(0,len(WorldArea)):
#         node_colours.append(regular_colour)
#     visit_nums = np.zeros([len(WorldArea)])
    
#     for n in commuting_walk:
#         state = state_complex.nodes(data=True)[n]
#         state_labels = state['state']
        
#         for j in range(0,len(state_labels[0])):
#             start_order.append(node_order.index(state_labels[0][j]))
#             for s_node in state_labels[0]:
#                 visit_nums[s_node] += 1
#         for j in range(0,len(state_labels[1])):
#             goal_order.append(node_order.index(state_labels[1][j]))
#         for j in range(0,len(state_labels[2])):
#             object_order.append(node_order.index(state_labels[2][j]))
#         for j in range(0,len(state_labels[3])):
#             wall_order.append(node_order.index(state_labels[3][j]))
         
#         for k in start_order:
#             node_colours[k] = start_colour
#         for k in goal_order:
#             node_colours[k] = goal_colour
#         for k in object_order:
#             node_colours[k] = object_colour
#         for k in wall_order:
#             node_colours[k] = wall_colour
    
#     relabelling_map = dict(zip(list(range(0,len(world_graph))), visit_nums))

#     plt.figure(1, figsize=(15, 7), dpi=300)
#     plt.subplot(121)
#     nx.draw_kamada_kawai(state_complex, width=1.0, alpha=0.3)
#     nx.draw_kamada_kawai(state_complex, nodelist=commuting_walk, width=1.0, alpha=1.0)
#     plt.subplot(122)
#     nx.draw_networkx(world_graph, pos=node_positions, node_color=node_colours, labels=relabelling_map, with_labels=True)
#     plt.axis('off')
#     plt.savefig(str("walks\walk"+str(commuting_walk)+"s.png"), dpi=300)
#     plt.show()

# for commuting_walk in unique_cm_square_paths:
    
#     for n in commuting_walk:
#         state = state_complex.nodes(data=True)[n]
#         state_labels = state['state']
    
#         node_order = list(world_graph)
#         start_order = []
#         for j in range(0,len(state_labels[0])):
#             start_order.append(node_order.index(state_labels[0][j]))
#         goal_order = []
#         for j in range(0,len(state_labels[1])):
#             goal_order.append(node_order.index(state_labels[1][j]))
#         object_order = []
#         for j in range(0,len(state_labels[2])):
#             object_order.append(node_order.index(state_labels[2][j]))
#         wall_order = []
#         for j in range(0,len(state_labels[3])):
#             wall_order.append(node_order.index(state_labels[3][j]))
    
#         node_colours = []
#         for i in range(0,len(WorldArea)):
#             node_colours.append(regular_colour)   
        
#         for k in start_order:
#             node_colours[k] = start_colour
#         for k in goal_order:
#             node_colours[k] = goal_colour
#         for k in object_order:
#             node_colours[k] = object_colour
#         for k in wall_order:
#             node_colours[k] = wall_colour
            
    
#         plt.figure(1, figsize=(15, 7), dpi=300)
#         plt.subplot(121)
#         nx.draw_kamada_kawai(state_complex, width=1.0, alpha=0.3)
#         nx.draw_kamada_kawai(state_complex, nodelist=[n], width=1.0, alpha=1.0)
#         plt.subplot(122)
#         nx.draw_networkx(world_graph, pos=node_positions, node_color=node_colours, with_labels=False)
#         plt.axis('off')
#         plt.savefig(str("walks\walk"+str(commuting_walk)+"_state"+str(n)+"s.png"), dpi=300)
#         plt.show()


# plot states in random walk through state complex

# state_num = 42
# np.random.seed(2015)

# for n in range(0,100):
#     state = state_complex.nodes(data=True)[state_num]
#     state_labels = state['state']

#     node_order = list(world_graph)
#     start_order = []
#     for j in range(0,len(state_labels[0])):
#         start_order.append(node_order.index(state_labels[0][j]))
#     goal_order = []
#     for j in range(0,len(state_labels[1])):
#         goal_order.append(node_order.index(state_labels[1][j]))
#     object_order = []
#     for j in range(0,len(state_labels[2])):
#         object_order.append(node_order.index(state_labels[2][j]))
#     wall_order = []
#     for j in range(0,len(state_labels[3])):
#         wall_order.append(node_order.index(state_labels[3][j]))

#     node_colours = []
#     for i in range(0,len(WorldArea)):
#         node_colours.append(regular_colour)   
    
#     for k in start_order:
#         node_colours[k] = start_colour
#     for k in goal_order:
#         node_colours[k] = goal_colour
#     for k in object_order:
#         node_colours[k] = object_colour
#     for k in wall_order:
#         node_colours[k] = wall_colour

#     plt.figure(1, figsize=(15, 7), dpi=300)
#     plt.subplot(121)
#     nx.draw_kamada_kawai(state_complex, width=1.0, alpha=0.3)
#     nx.draw_kamada_kawai(state_complex, nodelist=[state_num], width=1.0, alpha=1.0)
#     plt.subplot(122)
#     nx.draw_networkx(world_graph, pos=node_positions, node_color=node_colours, with_labels=False)
#     plt.axis('off')
#     plt.savefig('gif\state%s.png'% n, dpi=300)
#     plt.show()
    
#     neighbours = []
#     for k in state_complex.neighbors(state_num):
#         neighbours.append(k)
        
#     state_num = neighbours[np.random.randint(0,len(neighbours))]


# create GIF of random walk
    
# gif_name = 'random_walk'
# fps = 1
# file_list = glob.glob('gif\*.png') # Get all the pngs in the gif directory
# file_list = natsorted(file_list) # Sort the images by number
# clip = mpy.ImageSequenceClip(file_list, fps=fps)
# clip.write_gif('{}.gif'.format(gif_name), fps=fps)