# %%
import networkx as nx
import numpy as np
import joblib
from itertools import product
from matplotlib import pyplot as plt
import pygsp as gsp
from graph_tool import generation as gt_gen
from graph_tool import spectral as gt_spec
from numpy.random import default_rng
random = default_rng()
from tqdm import tqdm

plt.rcParams.update({"font.family":"serif", "font.size": 14,
    "pgf.texsystem": "pdflatex",
    # "pgf.preamble": [
    #      r"\usepackage[utf8x]{inputenc}",
    #      r"\usepackage[T1]{fontenc}",
    #      r"\usepackage{cmbright}",
    #      r"\usepackage{amsmath, amsfonts, amssymb, amstext, amsthm, bbm, mathtools}",
    #      ]
})
plt.rc('text', usetex=True)

# %%
def compute_connectivities(L, connected= True):
    if connected:
        L_pinv = np.linalg.inv(L+1/L.shape[0])-1/L.shape[0] # Faster than directly taking the pseudo-inverse, but only for connected !
    else:
        L_pinv = np.linalg.pinv(L)
    return 1/np.linalg.norm(L_pinv,ord= 2), 1/np.linalg.norm(np.diag(L_pinv), ord= np.inf)


import scipy
from utils import rank1_update

def add_edge(L, i, j):
    L[i, i] += 1
    L[j, j] += 1
    L[i, j] -= 1
    L[j, i] -= 1
    return L

N = 200
I = np.eye(N)
N_comp = N//2
L_comp = N_comp*np.eye(N_comp)-np.ones((N_comp,N_comp))
L = scipy.linalg.block_diag(L_comp, L_comp)

frac = 1.0
linking_edges = [(i,j) for i in range(N//2) for j in range(N//2,N)]
n_chosen_edges = int(frac*len(linking_edges))
chosen_edges_inds = random.choice(len(linking_edges), size= n_chosen_edges, replace= False)
alg_con, inf_con = compute_connectivities(L, connected= False)
alg_connectivities = [alg_con]
inf_connectivities = [inf_con]


n_points = 200
period = n_chosen_edges // n_points

for k in tqdm(range(0,n_chosen_edges)):
    (i,j) = linking_edges[k]
    L = add_edge(L, i, j)

    if k % period == 0:
        alg_con, inf_con = compute_connectivities(L)
        
        alg_connectivities.append(alg_con)
        inf_connectivities.append(inf_con)
        
alg_connectivities = np.array(alg_connectivities)
inf_connectivities = np.array(inf_connectivities)


# %%
frac_range = np.arange(1,n_chosen_edges+period, period)/len(linking_edges)
plt.figure()
plt.plot(frac_range, alg_connectivities, label = "algebraic connectivity")
plt.plot(frac_range, inf_connectivities, label= "minimum topological centrality index")
plt.legend()
plt.xlabel("fraction of all edges linking the two complete graphs")
plt.savefig("./alg_connectivity_vs_centrality.pdf", format= "pdf")

# %%
plt.figure()
plt.plot(frac_range[1:], inf_connectivities[1:]/alg_connectivities[1:])
plt.legend()
plt.xlabel("fraction of all edges linking the two complete graphs")
plt.savefig("./alg_connectivity_over_centrality_ratio.pdf", format= "pdf")