import networkx as nx
import matplotlib.pyplot as plt
import localgraphclustering as lgc
from scipy.io import loadmat
import numpy as np
import time
import seaborn as sns


no_nodes=[500,700,1000,1100]
pt_time=[]
our_time=[]
pt_flr=[]
our_flr=[]

for ind,n in enumerate(no_nodes):
    prob_matrix=[[9/(n),0.8],[1-0.8,9/(n)]]
    G = nx.stochastic_block_model([n,n], prob_matrix, directed=True)
    edge_list_file = "directed_sbm_edgelist.edgelist"
    nx.write_edgelist(G, edge_list_file, data=False)  # data=False to ignore edge attributes
    directed_sbm_semi_DC = lgc.GraphLocal("directed_sbm_edgelist.edgelist", 'edgelist', separator=' ', semi_double_cover=True)
    best_L = None
    best_R = None
    best_fr = None
    start_time=time.time()
    for i in range(5):
        L, R, _= lgc.find_bipartite_clusters.evo_cut_directed(directed_sbm_semi_DC, [5], 0.1, T=2)
        flow_ratio= directed_sbm_semi_DC.compute_conductance(L + [v+(2*n)  for v in R])
        
        if best_fr is None or flow_ratio < best_fr:
            best_fr = flow_ratio
            best_L = L
            best_R = R
        #print(f"Peter's Cluster one: {(L)}")
        #print(" ")
        #print(f"Peter's Cluster two: {(R)}")
    print(f" Iteration: {ind + 1} ")
    print("--------------------------------------------------------------------------------------------------")
    print(f"ECD algo with {2*n} nodes")
    print(f"Flow Ratio: {best_fr:.3f}")
    pt_flr.append(best_fr)
    end_time=time.time()
    print(f"Time taken : {end_time - start_time : 4f} secs")
    pt_time.append(end_time - start_time )
    
    print(" ")
    best_L = None
    best_R = None
    best_fr = None
    start_time=time.time()
    print(f"ECD+Our algo with {2*n} nodes")

    for i in range(5):
        L, R, _= lgc.find_bipartite_clusters.new_evo_cut_directed(directed_sbm_semi_DC, [5], 0.1, T=2)
        flow_ratio= directed_sbm_semi_DC.compute_conductance(L + [v+(2*n)  for v in R])
        
        if best_fr is None or flow_ratio < best_fr:
            best_fr = flow_ratio
            best_L = L
            best_R = R
        #print(f"Peter's Cluster one: {(L)}")
        #print(" ")
        #print(f"Peter's Cluster two: {(R)}")
        #print(" ")
    print(f"Flow Ratio: {best_fr:.3f}")
    our_flr.append(best_fr)
    end_time=time.time()
    print(f"Time taken : {end_time - start_time : 4f} secs")
    our_time.append(end_time - start_time )
    
    print("---------------------------------------------------------------------------------------------------------------")
    print(" ")

sns.set(style="whitegrid", context="talk")

# Define colors for different plots (optional)
colors = sns.color_palette("husl", 8)  # Husl color palette

# First Plot: Run Time Comparison
plt.figure(figsize=(10, 6))  # Increase figure size for better visibility
plt.plot(no_nodes, our_time, label="ECD+Our", color=colors[0], marker='o', linestyle='-', linewidth=2.5, markersize=8)
plt.plot(no_nodes, pt_time, label="ECD", color=colors[1], marker='s', linestyle='--', linewidth=2.5, markersize=8)

# Add labels with enhanced formatting
plt.xlabel(r"Number of Vertices in Each Partition", fontsize=18, fontweight='bold')
plt.ylabel(r"Run Time (Seconds)", fontsize=18, fontweight='bold')

# Add legend and grid
plt.legend(fontsize=16, loc='best')  # Increase font size and set best location
plt.grid(True, linestyle='--', alpha=0.6)

# Save the figure with high DPI and tight layout
plt.tight_layout()
plt.savefig('Runtime_comparison_eta0.8.png', dpi=600)  # DPI set to 300 for high-quality print
plt.close()

# Second Plot: Flowratio Comparison
plt.figure(figsize=(10, 6))  # Increase figure size for better visibility
plt.plot(no_nodes, our_flr, label="ECD+Our", color=colors[2], marker='o', linestyle='-', linewidth=2.5, markersize=8)
plt.plot(no_nodes, pt_flr, label="ECD", color=colors[3], marker='s', linestyle='--', linewidth=2.5, markersize=8)

# Add labels with enhanced formatting
plt.xlabel(r"Number of Vertices in Each Partition", fontsize=18, fontweight='bold')
plt.ylabel(r"Flow Ratio", fontsize=18, fontweight='bold')

# Set y-axis limit
plt.ylim(0, 1)

# Add legend and grid
plt.legend(fontsize=16, loc='best')  # Increase font size and set best location
plt.grid(True, linestyle='--', alpha=0.6)

# Save the figure with high DPI and tight layout
plt.tight_layout()
plt.savefig('temp_flr_comparison_eta0.8.png', dpi=600)  # DPI set to 300 for high-quality print
plt.close()




'''
plt.figure()
plt.plot(no_nodes, our_time, label="ECD+Our")
plt.plot(no_nodes, pt_time, label="ECD")

# Adding labels and title
plt.xlabel("Number of Vertices in each partition")
plt.ylabel("Run time in Seconds")
#plt.title("Plot with Two Lines")

# Adding a legend
plt.legend()

# Display the graph
plt.savefig('Runtime_comparison_eta0.8.png', dpi = 2000)
plt.close()


plt.figure()
plt.plot(no_nodes, our_flr, label="ECD+Our")
plt.plot(no_nodes, pt_flr, label="ECD")

# Adding labels and title
plt.xlabel("Number of Vertices in each partition")
plt.ylabel("Flowratio")
#plt.title("Plot with Two Lines")
plt.ylim(0,1)
# Adding a legend
plt.legend()

# Display the graph
plt.savefig('flr_comparison_eta0.8.png', dpi = 2000)
plt.close()
'''