import networkx as nx
import matplotlib.pyplot as plt
import scipy
import localgraphclustering as lgc
import stag.graph
import stag.random
import stag.graphio
import stag.cluster
import numpy as np
import time
import numpy.random as npr
import seaborn as sns
import math

# Function to calculate sparsification probability for node v
def pr_for_sparsification(v,n, nx_graph, Pr_Matrix):
    pr = []
    for u in nx_graph.neighbors(v):
        if Pr_Matrix[u, v] == 0:  # Only compute if not already calculated
            # Corrected degree access for node u and v in NetworkX
            temp_pr = 1 / nx_graph.degree(u) + 1 / nx_graph.degree(v) - 1 / (nx_graph.degree(u) * nx_graph.degree(v))*((math.log(n))**3) *3
            Pr_Matrix[u, v] = temp_pr
            Pr_Matrix[v, u] = temp_pr
            pr.append(temp_pr)
        else:
            pr.append(Pr_Matrix[u, v])  # Use the precomputed probability

    # Use np.random.choice to pick one value based on probabilities
    return np.random.choice(pr)



def simplify(num_g_vertices: int, sparse_vector):
    """
    Given a sparse vector (presumably from an approximate pagerank calculation on the double cover),
    and the number of vertices in the original graph, compute the 'simplified' approximate pagerank vector.
    """
    # Initialise the new sparse vector
    new_vector = scipy.sparse.lil_matrix((2 * num_g_vertices, 1))

    # Iterate through the entries in the matrix
    for i in range(min(num_g_vertices, sparse_vector.shape[0] - num_g_vertices)):
        if sparse_vector[i, 0] > sparse_vector[i + num_g_vertices, 0]:
            new_vector[i, 0] = sparse_vector[i, 0] - sparse_vector[i + num_g_vertices, 0]
        elif sparse_vector[i + num_g_vertices, 0] > sparse_vector[i, 0]:
            new_vector[i + num_g_vertices, 0] = sparse_vector[i + num_g_vertices, 0] - sparse_vector[i, 0]

    return new_vector.tocsc()


def local_bipart_dc(g: stag.graph.Graph, start_vertex: int, alpha: float, eps: float):
    """
    An implementation of the local_bipart_dc algorithm using the STAG library.
    """
    # Now, we construct the double cover graph of g
    adj_mat = g.adjacency().to_scipy()
    identity = scipy.sparse.csc_matrix((g.number_of_vertices(), g.number_of_vertices()))
    double_cover_adj = scipy.sparse.bmat([[identity, adj_mat], [adj_mat, identity]])
    h = stag.graph.Graph(double_cover_adj)

    # Run the approximate pagerank on the double cover graph
    seed_vector = scipy.sparse.lil_matrix((h.number_of_vertices(), 1))
    seed_vector[start_vertex, 0] = 1
    p, r = stag.cluster.approximate_pagerank(h, seed_vector.tocsc(), alpha, eps)

    # Compute the simplified pagerank vector
    p_simplified = simplify(g.number_of_vertices(), p.to_scipy())

    # Compute the sweep set in the double cover
    sweep_set = stag.cluster.sweep_set_conductance(h, p_simplified)
    bipartiteness = stag.cluster.conductance(h, sweep_set)

    # Split the returned vertices into those in the same cluster as the seed, and others.
    this_cluster = [i for i in sweep_set if i < g.number_of_vertices()]
    that_cluster = [i - g.number_of_vertices() for i in sweep_set if i >= g.number_of_vertices()]
    return this_cluster, that_cluster, bipartiteness


nodes=[2000,2500,3000,3500,4000,4500,5000]
pt_time=[]
our_time=[]
pt_flr=[]
our_flr=[]

for ind,n in enumerate(nodes):
    G = stag.random.sbm(n, 2, 0.03, 0.3)
    # Run the STAG implementation of the bipartite clusters algorithm.
    nx_graph = G.to_networkx()
    adj_matrix1 = nx.adjacency_matrix(nx_graph).todense()
    Pr_Matrix = np.zeros(adj_matrix1.shape)
    edges_to_remove = []
    for v in range(n):  
        for u in nx_graph.neighbors(v):  
            if Pr_Matrix[u, v] < pr_for_sparsification(v,n, nx_graph, Pr_Matrix):
                edges_to_remove.append((u, v))
    starting_vertex = 1
    start_time=time.time()
    L, R, bipartiteness = local_bipart_dc(G, starting_vertex, 0.5, 4e-7)
    end_time=time.time()
    print(f" Iteration: {ind + 1} ")
    print("--------------------------------------------------------------------------------------------------")
    print(f"LocBipartDC with {n} nodes")
    print(f"bipartiteness Ratio: {bipartiteness:.3f}")
    pt_flr.append(bipartiteness)
    end_time=time.time()
    print(f"Time taken : {end_time - start_time : .4f} secs")
    pt_time.append(end_time - start_time )
    print(" ")
    #print(f"Cluster One: {sorted(L)}")
    #print(f"Cluster Two: {sorted(R)}")
    #print(f"Bipartiteness: {bipartiteness:.3f}")

    nx_graph.remove_edges_from(edges_to_remove)
    H = stag.graph.from_networkx(nx_graph)
    starting_vertex = 1

    start_time=time.time()
    L, R, bipartiteness = local_bipart_dc(H, starting_vertex, 0.5, 4e-7)
    #print(f"Cluster One: {sorted(L)}")
    #print(f"Cluster Two: {sorted(R)}")
    #print(f"Bipartiteness: {bipartiteness:.3f}")
    print(f"bipartiteness Ratio: {bipartiteness:.3f}")
    our_flr.append(bipartiteness)
    end_time=time.time()
    print(f"Time taken : {end_time - start_time : .4f} secs")
    our_time.append(end_time - start_time )
    
    print("---------------------------------------------------------------------------------------------------------------")
    print(" ")


sns.set(style="whitegrid", context="paper")

# Define a color palette for different plots
colors = sns.color_palette("deep", 8)

# Increase figure size for better readability and conference-quality visuals
fig_width, fig_height = 12, 8

# 1. First Plot: Runtime Comparison
plt.figure(figsize=(fig_width, fig_height))  # Set figure size
plt.plot(nodes, our_time, label="MS+Our", color=colors[0], marker='o', linestyle='-', linewidth=3, markersize=10)
plt.plot(nodes, pt_time, label="MS", color=colors[1], marker='s', linestyle='--', linewidth=3, markersize=10)

# Set axis labels with larger fonts
plt.xlabel("Number of Vertices in Each Partition", fontsize=20, fontweight='bold')
plt.ylabel("Time (Seconds)", fontsize=20, fontweight='bold')

# Set title with a larger font
plt.title("Runtime Comparison", fontsize=22, fontweight='bold')

# Set tick parameters for better readability
plt.xticks(fontsize=16)
plt.yticks(fontsize=16)

# Add grid for better readability
plt.grid(True, linestyle='--', alpha=0.7)

# Add legend with larger fonts and enhanced style
plt.legend(fontsize=18, loc='best', frameon=True, fancybox=True, shadow=True)

# Save the figure with high DPI and tight layout for clarity
plt.tight_layout()
plt.savefig('runtime_comparison_undirected_increasing_nodes.png', dpi=600)  # DPI set to 300 for publication quality
plt.close()

# 2. Second Plot: Bipartiteness Ratio Comparison
plt.figure(figsize=(fig_width, fig_height))  # Set figure size
plt.plot(nodes, our_flr, label="MS+Our", color=colors[2], marker='o', linestyle='-', linewidth=3, markersize=10)
plt.plot(nodes, pt_flr, label="MS", color=colors[3], marker='s', linestyle='--', linewidth=3, markersize=10)

# Set axis labels with larger fonts
plt.xlabel("Number of Vertices in Each Partition", fontsize=20, fontweight='bold')
plt.ylabel("Bipartiteness Ratio", fontsize=20, fontweight='bold')

# Set title with a larger font
plt.title("Bipartiteness Ratio Comparison", fontsize=22, fontweight='bold')

# Set y-axis limit for better visualization
plt.ylim(0, 1.5)

# Set tick parameters for better readability
plt.xticks(fontsize=16)
plt.yticks(fontsize=16)

# Add grid for better readability
plt.grid(True, linestyle='--', alpha=0.7)

# Add legend with larger fonts and enhanced style
plt.legend(fontsize=18, loc='best', frameon=True, fancybox=True, shadow=True)

# Save the figure with high DPI and tight layout for clarity
plt.tight_layout()
plt.savefig('Bpr_comparison_undirected_increasing_nodes.png', dpi=300)  # DPI set to 300 for publication quality
plt.close()
'''
plt.plot(nodes, our_time, label="MS+Our")
plt.plot(nodes, pt_time, label="MS")

# Adding labels and title
plt.xlabel("Number of Vertices in each partition")
plt.ylabel("Time in Seconds")
plt.title("Runtime Comparison")

# Adding a legend
plt.legend()

# Display the graph
plt.savefig('runtime_comparison_undirected.png', dpi=2000)
plt.close()

plt.plot(nodes, our_flr, label="MS+Our")
plt.plot(nodes, pt_flr, label="MS")

# Adding labels and title
plt.xlabel("Number of Vertices in each partition")
plt.ylabel("Bipartiteness Ratio")
plt.ylim(0,1.5)
plt.title("Bipartiteness Ratio Comparison")

# Adding a legend
plt.legend()

# Display the graph
plt.savefig('Bpr_comparison_undirected.png', dpi=2000)

plt.close()

'''