import osmnx as ox
import numpy as np
import random
import networkx as nx
import pickle
from grakel import Graph

ox.settings.use_cache = False

def sample_RN(city, outer_radius, inner_radius, num_samples):

    # background road network
    G0 = ox.graph_from_point(city, dist=outer_radius, network_type='all', simplify=True)
    center_list = list(G0.nodes)
    # sample centers
    centers = np.random.choice(center_list, num_samples, replace=False)

    G_list = []
    for center in centers:
        center_coord = (G0.nodes[center]['y'], G0.nodes[center]['x'])
        G = ox.graph_from_point(center_coord, dist=inner_radius, network_type='all', simplify=True)
        G_list.append(G)

    return G_list

# define function to subdivide edges
def edge_subdivision(G, unit):
    '''
    Refine the combinatorial model by edge subdivision
    '''
    H = G.copy()  # Create a copy of the graph to avoid modifying the original

    # Create a list to hold new edges and nodes
    new_nodes = []
    new_edges = []
    remove_edges = []
    
    # Process each edge in the graph
    for edge in H.edges(data=True):  
        if edge[2]['length'] < unit:  # Check if the edge length is less than the unit length
            continue
        else:
            div = int(np.ceil(edge[2]['length'] / unit))
            u_i, v_i = edge[0], edge[1]
            u, v = (u_i,H.nodes[u_i]), (v_i,H.nodes[v_i])  # Get the nodes associated with the edge 
            
            # Create new vertices along the edge
            interior_vertices = [f"{u}-{v}-{i}" for i in range(1, div)]
            
            # Add the new vertices and edges
            previous_vertex = u
            for i in range(1, div):
                x = (1-i/div)*u[1]['x'] + i/div * v[1]['x']
                y = (1-i/div)*u[1]['y'] + i/div * v[1]['y']
                new_vertex = (interior_vertices[i-1],{'x':x, 'y':y})
                new_nodes.append(new_vertex)  # Add the new vertex
                new_edges.append((previous_vertex[0], new_vertex[0], unit))  # Add edge to new vertex
                previous_vertex = new_vertex
            
            # Add the final edge to v
            new_edges.append((previous_vertex[0], v[0], unit))
            
            # Remove the original edge
            remove_edges.append((u[0],v[0]))
    
    # Add new vertices and edges to the graph
    H.add_nodes_from(new_nodes)
    H.add_weighted_edges_from(new_edges, weight='length')
    H.remove_edges_from(remove_edges)

    return H

def save_grakel_graphs(class_list, path):
    '''
    save the graphs in grakel format
    
    class list: list of lists of graphs, each list corresponds to a class'''

    grakel_graphs = []
    for G_list in class_list:
        for G in G_list:
            # Convert the graph to a format suitable for grakel
            grakel_graph = Graph(nx.adjacency_matrix(G, weight='length').toarray())
            grakel_graphs.append(grakel_graph)

    # shuffle the networks and labels
    networks = grakel_graphs
    labels = []
    for i, G_list in enumerate(class_list):
        labels += len(G_list)*[i]

    # shuffle the dataset and labels together
    combined = list(zip(networks, labels))
    random.shuffle(combined)
    dataset, labels = zip(*combined)

    # save the networks
    with open(path+"networks.pkl", "wb") as f:
        pickle.dump(dataset, f)
    with open(path+"labels.pkl", "wb") as f:
        pickle.dump(labels, f)


######### to-use functions #########

def remove_interior_nodes(G, center, remove_ratio):
    '''Remove nodes of degree 2 and close to the center'''
    
    H = G.copy()
    deg2_nodes = {}
    for node, coord in H.nodes(data=True):
        if H.degree[node] == 2:
            dist = np.array([coord["y"] - center[0], coord["x"] - center[1]])
            deg2_nodes[node] = np.linalg.norm(dist)

    #sort the nodes by distance to the center
    sorted_nodes = sorted(deg2_nodes, key=deg2_nodes.get)
    remove_n = int(len(sorted_nodes) * remove_ratio)
    remove_nodes = sorted_nodes[:remove_n]

    for node in remove_nodes:
        # skip nodes with self-loops (special case)
        if H.has_edge(node, node):
            continue
            
        # get the neighbours of the node
        neighbours = list(H.neighbors(node))
        if len(neighbours) != 2:
            continue
            
         # extract the two neighbours and the weights of the edges
        u, v = neighbours
        length1 = H[u][node][0]['length']
        length2 = H[node][v][0]['length']
            
        if H.has_edge(u, v):
            continue
        else:
            # Remove the node and its edges
            H.remove_node(node)
                
            # Add a new edge between the two neighbours, summing the weights
            new_length = length1 + length2
            H.add_edge(u, v, length=new_length)
    return H

def max_neighbor_distance(G):
    max_dists = {}
    coords = {n: [d['y'], d['x']] for n, d in G.nodes(data=True)}
    
    for node in G.nodes:
        neighbors = list(G.neighbors(node))
        if not neighbors:
            max_dists[node] = 0  # isolated node
            continue
        dists = [
            np.linalg.norm(np.array(coords[node]) - np.array(coords[nb]))
            for nb in neighbors
        ]
        max_dists[node] = max(dists)
    
    return max_dists

def remove_dense_nodes(G, remove_ratio=0.1):
    H = G.copy()
    
    # Compute max distance to neighbors for each node
    max_dists = max_neighbor_distance(H)
    
    # Sort nodes by max distance to neighbors
    sorted_nodes = sorted(max_dists, key=max_dists.get)
    
    # Determine how many nodes to retain
    remove_n = int(len(sorted_nodes) * remove_ratio)
    keep_nodes = sorted_nodes[remove_n:]
    
    return H.subgraph(keep_nodes)

def remove_close_nodes(G, center, remove_ratio=0.1):
    H = G.copy()
    
    # Compute distance from each node to center
    distances = {
        node: np.linalg.norm(np.array([data['y'],data['x']])-center) # use Euclidean distance to approximate
        for node, data in H.nodes(data=True)
    }
    
    # Sort nodes by distance
    sorted_nodes = sorted(distances, key=distances.get)
    
    # Determine how many nodes to retain
    remove_n = int(len(sorted_nodes) * remove_ratio)
    keep_nodes = sorted_nodes[remove_n:]
    
    return H.subgraph(keep_nodes)

def prune_graph(G, retain_ratio=0.7):
    degrees = dict(G.degree())
    sorted_nodes = sorted(degrees, key=degrees.get, reverse=True)
    keep_n = int(len(sorted_nodes) * retain_ratio)
    keep_nodes = set(sorted_nodes[:keep_n])
    return G.subgraph(keep_nodes).copy()