import math
import stag.graph
import stag.random
import stag.graphio
import stag.cluster
import numpy as np
import scipy.sparse as sp
import ClusterCountsUsingEigenGap
import BruteForce
import time
import matplotlib.pyplot as plt
import seaborn as sns
from statistics import mode,mean
import networkx as nx

def edge_sampling_algorithm(G, C, lambda_k_plus_1):
    """
    Sample edges from the input graph G based on the specified probability function and return a new graph H.

    Parameters:
    - G: NetworkX graph (weighted) where the edge weights are stored as `weight` attributes.
    - C: Positive constant used in the probability function.
    - lambda_k_plus_1: The (k+1)-th smallest eigenvalue of the normalized adjacency matrix of G.

    Returns:
    - H: A new weighted NetworkX graph containing the sampled edges.
    """
    H = nx.Graph()  # Create an empty graph to store the sampled edges
    
    n = G.number_of_nodes()
    log_n = np.log(n)  # Compute log(n)

    # Compute the probability for each edge and sample
    for u, v, data in G.edges(data=True):
        w_uv = data.get('weight', 1)  # Default weight is 1 if not specified
        degree_u = G.degree(u, weight='weight')  # Weighted degree of u
        degree_v = G.degree(v, weight='weight')  # Weighted degree of v

        # Calculate p_u(v) and p_v(u) based on the formula given
        p_u_v = min(C * (log_n / ( lambda_k_plus_1)) * (w_uv / degree_u), 1)
        p_v_u = min(C * (log_n / (lambda_k_plus_1)) * (w_uv / degree_v), 1)

        # Compute the probability of retaining the edge
        p_e = p_u_v + p_v_u - (p_u_v * p_v_u)

        # Sample the edge with probability p_e
        if np.random.rand() <= p_e:
            # Add the edge to H with the adjusted weight
            adjusted_weight = w_uv / p_e
            H.add_edge(u, v, weight=adjusted_weight)

    return H



def fast_normalized_adjacency(graph):
    """
    Compute the normalized weighted adjacency matrix from a given NetworkX graph, with optimizations.

    Parameters:
    - graph (networkx.Graph): The input NetworkX graph object. It can be weighted or unweighted.

    Returns:
    - normalized_adj (scipy.sparse.csr_matrix): Normalized weighted adjacency matrix in sparse format.
    """
    # Step 1: Create the sparse weighted adjacency matrix directly in CSR format for efficient operations
    try:
        # NetworkX 2.6 or newer
        adj_matrix = nx.to_scipy_sparse_array(graph, format='csr', weight='weight')
    except AttributeError:
        # NetworkX 2.5 or older
        adj_matrix = nx.to_scipy_sparse_matrix(graph, format='csr', weight='weight')
    # Ensure the matrix is symmetric if the graph is undirected
    if not nx.is_directed(graph):
        # Use only the upper triangular part to ensure symmetry without double counting
        adj_matrix = sp.triu(adj_matrix) + sp.triu(adj_matrix, k=1).T

    # Step 2: Calculate the degree of each node
    degrees = np.array(adj_matrix.sum(axis=1)).flatten()

    # Step 3: Compute the inverse square root of the degree matrix
    with np.errstate(divide='ignore'):
        inv_sqrt_deg = np.power(degrees, -0.5, where=degrees > 0)  # Avoid division by zero
    inv_sqrt_deg[np.isinf(inv_sqrt_deg)] = 0  # Set inf values to 0 for isolated nodes

    # Create the diagonal matrix with inverse square root of degrees
    D_inv_sqrt = sp.diags(inv_sqrt_deg)

    # Step 4: Compute the normalized adjacency matrix: D^(-1/2) * A * D^(-1/2)
    normalized_adj = D_inv_sqrt @ adj_matrix @ D_inv_sqrt

    return normalized_adj




def count_eigenvalues(M, a, b, i):
    """
    Count the number of eigenvalues of matrix M within the interval [a, b].

    Parameters:
    M (scipy.sparse.csr_matrix): Sparse matrix of size n x n.
    a (float): Lower bound of the interval.
    b (float): Upper bound of the interval.
    epsilon (float): Error parameter. Default is 0.01.np.zeros((l, N + 1))

    Returns:
    float: Approximate number of eigenvalues in [a, b].
    """
    # Ensure M is in CSR sparse matrix format for efficient operations
    if not sp.isspmatrix_csr(M):
        M = sp.csr_matrix(M)

    # Dimension of the matrix
    n = M.shape[0]
    delta = 1 / n

    # Calculate the number of iterations needed
    
    if i==0:
        # Iterate for l times as specified in the algorithm
        for i in range(l):
            x = np.random.choice([-1, 1], size=n)
            x_0 = x
            T[i, 0] = np.dot(x_0, x_0)
            x_1 = M.dot(x_0)
            T[i, 1] = np.dot(x_0, x_1)
            x_k_minus_2 = x_0
            x_k_minus_1 = x_1
            for k in range(2, N + 1):
                x_k = 2 * M.dot(x_k_minus_1) - x_k_minus_2
                T[i, k] = np.dot(x_0, x_k)
                x_k_minus_2 = x_k_minus_1
                x_k_minus_1 = x_k
            
        alpha = compute_alpha(N, a, b)
        count = np.sum(np.dot(T, alpha)) / l
            
        return math.ceil(count) if (count % 1) >= 0.5 else math.floor(count)
        
    else:
        alpha = compute_alpha(N, a, b)
        count = np.sum(np.dot(T, alpha)) / l
        
        return math.ceil(count) if (count % 1) >= 0.5 else math.floor(count)



def compute_alpha(N, a, b):
    """
    Compute the coefficients alpha_i for i = 0, 1, ..., N based on the given formula.
    
    Parameters:
    N (int): Maximum number of iterations or degree of Chebyshev polynomials.
    a (float): Lower bound of the interval.
    b (float): Upper bound of the interval.

    Returns:
    np.array: Array of coefficients alpha_i for i = 0, 1, ..., N.
    """
    # Initialize the array for alpha coefficients
    alpha = np.zeros(N + 1)

    # Compute alpha_0 separately
    alpha[0] = (1 / np.pi) * (np.arcsin(b) - np.arcsin(a))
    
    # Compute alpha_i for i > 0 using the given formula
    for i in range(1, N + 1):
        alpha[i] = (2 / np.pi) * (np.sin(i * np.arccos(a)) - np.sin(i * np.arccos(b))) / i
    
    return alpha

times=[]
q_arr=[]
i_for_q=1
q=0.006
while(1):
    total_time=0
    candidates=[]
    start=time.time()
    for j in range(5):
        G = stag.random.sbm(2000, 4, 0.06, (i_for_q**(1/2))*q)
        G = G.to_networkx()
        H=edge_sampling_algorithm(G,1,0.1)
        print("Sparsification done.")
        epsilon=0.01
        N = math.floor(3 /epsilon)
        l = int((2*math.log(2000)/ (epsilon)))
        T = np.zeros((l, N + 1))
        A = fast_normalized_adjacency(H)
        i=0
        k=1
        while(1):
            if (count_eigenvalues(A,(1-(2**i)/(0.5*2000)),1,i)>1):
                k=count_eigenvalues(A,(1-(2**i)/(0.5*2000)),1,i) 
                break
            i=i+1
        k_now=k
        while(1):
            k=count_eigenvalues(A,(1-(((2**i)/(0.5*2000)))),1,i)
            if(k == k_now ):
                k_now=k
                break
            i=i+1
        print(f"For graph with {2000} nodes k: {k} (run {j+1})")
        candidates.append(k)
    k_out=mode(candidates)
    end=time.time()
    #print(f"k: {k}")
    print(f"Iteration {i_for_q} complete.")
    if (k==4):
        c= (((i_for_q**(1/2))*q)/0.06)
        print(f"c: {c}")
        q_arr.append(c)
        times.append((end-start)/5)
        i_for_q+=1
    else:
        break

sns.set_theme(style="whitegrid")  # Options include "darkgrid", "whitegrid", "dark", "white", "ticks"

# Set up the figure with larger size for better visibility
plt.figure(figsize=(12, 6))

# Plot the graph with enhanced settings using Seaborn's color palette
sns.lineplot(x=q_arr, y=times, label='Our Algorithm', 
             color='navy', linewidth=2, marker='o', markersize=6)

# Add labels with larger font sizes
plt.xlabel('c=q/p', fontsize=14)
plt.ylabel('Execution Time (s)', fontsize=14)

# Optional: Add title if appropriate
# plt.title('Algorithm Runtime vs. Number of Edges', fontsize=16)

# Customize tick parameters
plt.xticks(fontsize=12)
plt.yticks(fontsize=12)

# Show grid with lighter color
plt.grid(True, linestyle='--', linewidth=0.7, color='gray', alpha=0.6)

# Add legend with enhanced font size and placement
#plt.legend(fontsize=12, loc='upper left')

# Save the figure with high resolution
plt.savefig('Experiment_3.png', format='png', dpi=1000, bbox_inches='tight')