import numpy as np
import networkx as nx
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import torch
import random


def compute_ave__shortest_path_len(adj_nx):
    average_shortest_path_list = []
    for C in (adj_nx.subgraph(c).copy() for c in nx.connected_components(adj_nx)):
        average_shortest_path_list.append((nx.average_shortest_path_length(C)))
    ave__shortest_path_len = torch.mean(torch.tensor(average_shortest_path_list))
    return ave__shortest_path_len

def compute_properties(samples, property_name):
    result = []

    for adj in samples:
        if type(adj) == np.ndarray:
            adj = torch.from_numpy(adj)
        
        adj_nx = nx.from_numpy_matrix(adj.numpy())
      
        if(property_name == "average_degree"):
            ave_deg = torch.sum(adj)/len(adj)
            result.append(ave_deg.item())
        elif(property_name == "num_triads"):
            num_triads = torch.sum(torch.diag(torch.linalg.matrix_power(adj, 3)))/6
            result.append(num_triads.item())
        elif(property_name == "average_shortest_path_len"):
            ave_shortest_path_len = compute_ave__shortest_path_len(adj_nx)
            result.append(ave_shortest_path_len.item())
        elif(property_name == "ave_clustering_coeff"):
            ave_clustering_coeff = nx.average_clustering(adj_nx)
            result.append(ave_clustering_coeff)
        elif(property_name == "max_cliques"):
            result.append((len(max(nx.find_cliques(adj_nx), key=len))))
        else:
            logging.info(f"{property_name} un implemented property")
            print(f"{property_name} is unimplemnted")
    return np.array(result)


def compute_all_properties(samples,properties):
    n_samples = len(samples)
    n_prop = len(properties)
    Z = np.ones((n_samples,n_prop))
    for prop_index,prop_name in enumerate(properties):
        Z[:,prop_index] =compute_properties(samples,prop_name)
    return Z
# this is the function we used for graph generarion, we needed for the oracle model
def generate_random_graphs(num_graphs_to_generate,seed=1, plot = False,n_nodes = 10, p_edge = 0.5):
    graphs = []
    for i in range(0,num_graphs_to_generate):
        G = nx.erdos_renyi_graph(n_nodes,p_edge, seed = i*seed)
        graphs.append(torch.from_numpy(nx.to_numpy_array(G)))
        if(plot==True and i<5):
            pos = nx.spring_layout(G, seed=i)  # Seed for reproducible layout
            nx.draw_networkx(G, pos=pos)
            plt.show()
    #print(f"generated {len(graphs)} graphs")
    return graphs
    
    
  
def add_random_edges(adj_matrix, num_edges_to_add):
    # Convert the adjacency matrix to a NetworkX graph
    if type(adj_matrix) != np.ndarray:
        graph = nx.from_numpy_array(adj_matrix.numpy())
    else:
        graph = nx.from_numpy_array(adj_matrix)

    # Get the number of nodes in the graph
    num_nodes = len(graph.nodes)

    for _ in range(num_edges_to_add):
        edges_to_add = []
        # Generate random edges to add
        
        cur_node_list = list(graph.nodes)
        source_node = random.choice(cur_node_list)
        cur_node_list.remove(source_node)
        target_node = random.choice(cur_node_list)

        # Avoid self-loops and duplicate edges
        while source_node == target_node or graph.has_edge(source_node, target_node):
            cur_node_list = list(graph.nodes)
            source_node = random.choice(cur_node_list)
            cur_node_list.remove(source_node)
            target_node = random.choice(cur_node_list)

        edges_to_add.append((source_node, target_node))

        # Add the random edges to the graph
        graph.add_edges_from(edges_to_add)

    # Convert the modified graph back to an adjacency matrix
    modified_adj_matrix = nx.to_numpy_array(graph, dtype=int)

    return modified_adj_matrix
    
    
def remove_random_edges(adj_matrix, num_edges_to_remove):
    # Convert the adjacency matrix to a NetworkX graph
    if type(adj_matrix) != np.ndarray:
        graph = nx.from_numpy_array(adj_matrix.numpy())
    else:
        graph = nx.from_numpy_array(adj_matrix)

    # Get the edges in the graph
    edges = list(graph.edges)

    # Randomly select edges to remove
    edges_to_remove = np.random.choice(len(edges), num_edges_to_remove, replace=False)

    # Remove the selected edges from the graph
    for edge_idx in edges_to_remove:
        edge = edges[edge_idx]
        graph.remove_edge(edge[0], edge[1])

    # Convert the modified graph back to an adjacency matrix
    modified_adj_matrix = nx.to_numpy_array(graph, dtype=int)

    return modified_adj_matrix
    
    
 # Assuming you have a list of adjacency matrices called 'dataset'
def corrupt_graphs(dataset, n_samples, prev_dataset=None):
    if prev_dataset== None:
        modified_dataset = []
    else:
        modified_dataset = prev_dataset

    for i,adj_matrix in enumerate(dataset):
        action_indx = np.random.randint(2)
        if (action_indx == 0):
        # Randomly add 1 edges to each graph
            modified_adj_matrix = add_random_edges(adj_matrix, 1)
        
        elif (action_indx == 1):
            # Alternatively, randomly remove 1 edge from each graph
            modified_adj_matrix = remove_random_edges(adj_matrix, 1)
        else:
            modified_adj_matrix = adj_matrix
        modified_dataset.append(torch.from_numpy(modified_adj_matrix))
    
    if n_samples <= len(modified_dataset):
        sampled_modified_dataset = random.sample(modified_dataset,n_samples)
    else:
        sampled_modified_dataset = modified_dataset
        n_sample_more = n_samples - len(modified_dataset)
        dataset = np.array(dataset)
        n_train = dataset.shape[0]
        #n_train = len(dataset)
        while len(sampled_modified_dataset) < n_samples:
            chosen_graph_idx = np.random.randint(n_train, size=1)
            adj_matrix = dataset[chosen_graph_idx][0]
            
            action_indx = np.random.randint(2)
            if (action_indx == 0):
            # Randomly add 1 edges to each graph
                modified_adj_matrix = add_random_edges(adj_matrix, 1)

            elif (action_indx == 1):
                # Alternatively, randomly remove 1 edge from each graph
                modified_adj_matrix = remove_random_edges(adj_matrix, 1)
            else:
                modified_adj_matrix = modified_adj_matrix
            
            sampled_modified_dataset.append(torch.from_numpy(modified_adj_matrix))

    return sampled_modified_dataset


def n_community(p,n,num_communities, max_nodes, p_inter=0.05):
    assert num_communities > 1
    
    one_community_size = max_nodes // num_communities
    c_sizes = [one_community_size] * num_communities
    total_nodes = one_community_size * num_communities
    
    """ 
    Community graph construction from https://github.com/ermongroup/GraphScoreMatching/blob/master/utils/data_generators.py#L10

    here we calculate `p_make_a_bridge` so that `p_inter = \mathbb{E}(Number_of_bridge_edges) / Total_number_of_nodes `
    
    To make it more clear: 
    let `M = num_communities` and `N = one_community_size`, then
    
    ```
    p_inter
    = \mathbb{E}(Number_of_bridge_edges) / Total_number_of_nodes
    = (p_make_a_bridge * C_M^2 * N^2) / (MN)  # see the code below for this derivation
    = p_make_a_bridge * (M-1) * N / 2
    ```
    
    so we have:
    """
    p_make_a_bridge = p_inter * 2 / ((num_communities - 1) * one_community_size)
    #print(f"c_sizes {c_sizes} n {n} len {len(c_sizes)}")
    #print(num_communities, total_nodes, end=' ')
    seeds = random.sample(range(0, len(c_sizes)*n), len(c_sizes))
  
    graphs = [nx.gnp_random_graph(c_sizes[i], p, seed=seeds[i]) for i in range(len(c_sizes))]

    G = nx.disjoint_union_all(graphs)
    communities = [G.subgraph(c).copy() for c in nx.connected_components(G)]
    add_edge = 0
    for i in range(len(communities)):
        subG1 = communities[i]
        nodes1 = list(subG1.nodes())
        for j in range(i + 1, len(communities)):  # loop for C_M^2 times
            subG2 = communities[j]
            nodes2 = list(subG2.nodes())
            has_inter_edge = False
            for n1 in nodes1:  # loop for N times
                for n2 in nodes2:  # loop for N times
                    if np.random.rand() < p_make_a_bridge:
                        G.add_edge(n1, n2)
                        has_inter_edge = True
                        add_edge += 1
            if not has_inter_edge:
                G.add_edge(nodes1[0], nodes2[0])
                add_edge += 1
    #print('connected comp: ', len([G.subgraph(c).copy() for c in nx.connected_components(G)]),
    #      'add edges: ', add_edge)
    #print(G.number_of_edges())
    return G

def generate_n_comm_graphs(n_graphs,p):
    n_start = 12
    n_end = 20
    graph_list = []
    for i in range(0,n_graphs):

        n_max = np.random.choice(np.arange(n_start, n_end).tolist())
        G = n_community(p,n_graphs,num_communities=2, max_nodes=n_max, p_inter=0.05)
        graph_list.append(G)
    return graph_list
