import numpy as np
import networkx as nx
from scipy.linalg import eigh, inv
from random import seed
import os
os.system("pip install prv-accountant")
print("finish instal")
from rdp import rdp_to_approxdp, OneStepRDP, NewWightRDP
from gdp import MixtureGaussianMechanism
from IPython import embed
from tqdm import tqdm
import pickle


def vplambda(graph):
	"""Compute the spectral gap of a networkx graph
	"""
	W = nx.to_numpy_array(graph)
	eigen = eigh(np.eye(graph.number_of_nodes())- W, eigvals_only=True, subset_by_index=[0, 1])
	lambda_2 = eigen[1]
	assert(0 < lambda_2 < 1)
	return lambda_2

def gossip_matrix(graph, type_g="hamilton"):
    """Computes W with appropriate weight for a matrix gossip from a given graph
    Parameters
    ----------
    graph: networkx graph
    type_g: strategy to compute the weights
    - regular: just an normalization
    - hamilton Wuv = 1 / (max du, duv - 1)
    - else Wuv = 1 / (du + dv - 1) 
    Return 
    ------
    graph: networkx graph
        graph modified with correct weights and self loop
    """
    if type_g == "regular":
        A = np.array(nx.adjacency_matrix(graph).todense(), dtype=np.float)
        for i in range(A.shape[0]):
            A[i] = A[i] / A[i].sum()
        graph = nx.from_numpy_array(A)

    elif type_g == "hamilton":
        degree = nx.degree(graph)
        for u in nx.nodes(graph):
            graph.add_edge(u,u)
        for u in nx.nodes(graph):
            out_w = 0
            for v in nx.neighbors(graph, u):
                if v != u:
                    w = 1 / (max(degree[u],degree[v]) - 1)
                    out_w += w
                    graph[u][v]['weight'] = w
            graph[u][u]['weight'] = 1 - out_w
    else : 
        degree = nx.degree(graph)
        for u in nx.nodes(graph):
            graph.add_edge(u,u)
        for u in nx.nodes(graph):
            out_w = 0
            for v in nx.neighbors(graph, u):
                if v != u:
                    w = 1 / (degree[u] + degree[v] - 1)
                    out_w += w
                    graph[u][v]['weight'] = w
            graph[u][u]['weight'] = 1 - out_w


    return graph

def logW(W):
    """
    Compute for a given gossip matrix the graph specific loss, essentially compute log(1-W + 1/n * 1 * 1^T)
    """
    W = nx.to_numpy_array(W)
    eigenvalues, eigenvectors = eigh(W, eigvals_only=False)
    l_eig = -np.log(1-eigenvalues[:-1])  # remove the eigenvector (that corresponds to 1 or the largest eigenvalue)
    assert np.isclose(eigenvectors @ np.diag(eigenvalues) @eigenvectors.T , W).all()
    priv = eigenvectors[:,:-1] @ np.diag(l_eig) @ eigenvectors[:,:-1].T
    return priv


def communicability(W):
    """
    Compute communicability of the graph. We do not use the networkx implem, because it returns a dict of dict instead of a matrix
    """
    W = nx.to_numpy_array(W)
    eigenvalues, eigenvectors = eigh(W, eigvals_only=False)
    com = eigenvectors @ np.diag(np.exp(eigenvalues)) @ eigenvectors.T
    np.fill_diagonal(com, 0)
    return com

def computeTwalk(graph, sigma):
    """
    Compute the number of steps for convergence of the RW in theory for a given the level of precision that should be achieved
    """
    lambda_2 = vplambda(graph)
    return int(8*np.ceil(1/lambda_2 * np.log(graph.number_of_nodes()))*(.25+sigma**2)/sigma**2 )


def priv_global(logW, T, sigma):
    n = logW.shape[0]
    constant = T/(sigma**2 *n**2) * np.sum(1/np.arange(1,1+T)) 
    priv = constant + T*logW/(sigma**2 *n) # Compute the bound for eps according to Theorem 3 of https://arxiv.org/pdf/2402.07471
    for i in range(n):
        priv[i][i]=0
    return priv

def eps_global(priv, delta):
    eps = np.zeros_like(priv)
    total_iter = len(eps)**2
    for t in tqdm(range(total_iter)):
        i = t % len(eps)
        j = t // len(eps)
        if i == j:
            continue
        else:
            eps[i,j] = rdp_to_approxdp(priv[i,j],delta)
    return eps


def maxi_priv(graph, logW):
    # Initialize P with zeros
    P = np.zeros_like(logW)

    # Iterate over each node in the graph
    for u in graph.nodes():
        # Iterate over all other nodes v
        for v in range(len(logW)):
            # Initialize a variable to find the maximum
            max_value = - np.inf
            for w in graph.neighbors(v):
                # Ensure w' is not equal to u
                if w != v:
                    # Update the maximum value if necessary
                    max_value = max(max_value, logW[u][w])


            # Update P[u][v] with the maximum value found
            P[u][v] = max_value
    return P


def fdp_privacy_maxtrix(graph, c, T, sigma, delta, K=1, b=1, Delta=1):
    W = nx.to_numpy_array(graph)
    eps = np.zeros_like(W)
    total_iter = len(eps)**2
    for t in tqdm(range(total_iter)):
        i = t % len(eps)
        j = t // len(eps)
        if i == j:
            continue
        else:
            M = MixtureGaussianMechanism(W, c, T, i, j, K=K, Delta=Delta, b=b, sigma=sigma)
            eps[i,j] = M.approxdp(delta)
    return eps
        

import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt
plt.figure(figsize=[8, 6])
# plt.rcParams["font.family"] = "Times New Roman"
plt.rcParams.update({
    'font.size': 12,
    # 'text.usetex': True,
    # 'text.latex.preamble': r'\usepackage{amsfonts}'
})

if __name__ == "__main__":

    # Various constants
    seed(0)
    np.random.seed(0)
    results_dict = dict()
    
	## Different graphs
    hypercube = nx.hypercube_graph(5) #exponential graph
    regular = nx.random_regular_graph(3, 24) 
    d_cliques = nx.ring_of_cliques(3, 8) 
    women = nx.davis_southern_women_graph()
    sizes = [10, 10, 5]

    sigma = 1
    delta = 1e-5

    sizes = [10, 10, 5]
    probs = [[0.25, 0.05, 0.02], [0.05, 0.35, 0.07], [0.02, 0.07, 0.40]]
    g = nx.stochastic_block_model(sizes, probs, seed=0) # communities with intra and inter link defines by prob list
    star = nx.star_graph(30)    

    print("sigma:", sigma)
    print("delta:", delta)
    results_dict["sigma"] = sigma
    results_dict["delta"] = delta
    
    graph_name = "star"
    results_dict["graph"] = graph_name
    if graph_name == "hypercube":
        choose_graph = hypercube
    elif graph_name == "regular":
        choose_graph = regular
    elif graph_name == "cliques":
        choose_graph = d_cliques
    elif graph_name == "block":
        choose_graph = g
    elif graph_name == "star":
        choose_graph = star
    elif graph_name == "women":
        choose_graph = women
    else:
        raise ValueError(f"No such graph: {graph_name} !!!")
    print("The graph is:", graph_name)
    
	## Compute parameters for the chosen graph
    graph = gossip_matrix(choose_graph)
    n = graph.number_of_nodes()
    print("With", n, "nodes.")
    T = computeTwalk(graph, sigma)
    # T=110
    print("T:", T, n*np.log(n))
    c = 1
    save_name = f"{graph_name}-sigma{sigma}-delta{delta}-T{T}-n{n}-c{c}"
    results_dict = pickle.load(open(save_name+".pkl", "rb"))
    rdp_eps = results_dict["RDP_eps"]
    fdp_eps = results_dict["fdp_eps"]

    fig, axes = plt.subplots(nrows=1, ncols=2, figsize = (6, 3))
    max_value = max(np.max(fdp_eps), np.max(rdp_eps))
    im0 = axes[0].imshow(fdp_eps,vmin=0, vmax=max_value)
    axes[0].set_title("Ours")
    im1 = axes[1].imshow(rdp_eps,vmin=0, vmax=max_value)
    axes[1].set_title("RDP")
    plt.tight_layout()


    fig.subplots_adjust(right=0.8)
    cbar_ax = fig.add_axes([0.85, 0.15, 0.05, 0.7])
    fig.colorbar(im0, cax=cbar_ax)
    # plt.tight_layout()
    plt.savefig(f"final/{save_name}-final-epsilon.pdf", dpi=300)


	## For connection
    com = communicability(graph)
    plt.figure( figsize = (4, 3))
    plt.imshow(com)
    plt.colorbar()
    plt.title("Communicability")
    plt.tight_layout() 
    plt.savefig(f"final/{save_name}-final-com.pdf", dpi=300)

    save = True
    W = nx.to_numpy_array(graph)
    print(W.shape) 
    deltas_exp = np.linspace(-7, -2, 100)
    deltas = 10**deltas_exp
    if save:

        prv_RDP = OneStepRDP(transition_matrix=W, T=T, start=0, end=n-1, K=1, Delta=1, b=1, sigma=sigma)
        prv_RDP_new = NewWightRDP(transition_matrix=W, T=T, start=0, end=n-1, K=1, Delta=1, b=1, sigma=sigma)
        prv_mixture_gaussian = MixtureGaussianMechanism(transition_matrix=W, c=c, T=T, start=0, end=n-1, K=1, Delta=1, b=1, sigma=sigma)

        RDP_result = []
        RDP_result_new = [] 
        mixture_Guassian = []

        for delta in tqdm(deltas):
            RDP_result.append(prv_RDP.approxdp(delta))
            RDP_result_new.append(prv_RDP_new.approxdp(delta))
            mixture_Guassian.append(prv_mixture_gaussian.approxdp(delta))

        RDP_result = np.array(RDP_result)
        RDP_result_new = np.array(RDP_result_new)
        mixture_Guassian = np.array(mixture_Guassian)

        results_dict["x"] = deltas.tolist()
        results_dict["rdp"] = RDP_result.tolist()
        results_dict["rdp_new"] = RDP_result_new.tolist()
        results_dict["fdp"] = mixture_Guassian.tolist()

        pickle.dump(results_dict, open("final/"+save_name+".pkl", "wb"))
    else:
        results_dict = pickle.load(open("final/"+save_name+".pkl", "rb"))
        x = results_dict["x"]
        RDP_result = results_dict["rdp"]
        RDP_result_new = results_dict["rdp_new"]
        mixture_Guassian = results_dict["fdp"]

    # new = NewWightRDP(transition_matrix=W, T=T, start=0, end=n-1, K=1, Delta=1, b=1, sigma=sigma)
    # RDP_result_new = [] 
    # deltas_exp = np.linspace(-7, -2, 100)
    # deltas = 10**deltas_exp
    # for delta in tqdm(deltas):
    #     RDP_result_new.append(new.approxdp(delta))
    # RDP_result_new = np.array(RDP_result_new)
    # results_dict["rdp_new"] = RDP_result_new.tolist()
    # pickle.dump(results_dict, open("final/"+save_name+".pkl", "wb"))


    fig, axs = plt.subplots( figsize = (4, 3))
    plt.ylabel("varepsilon")
    plt.title("varepsilon-delta plot")
    # plt.xlabel(r"$\delta$")
    plt.plot(deltas, mixture_Guassian, label="Ours", color="red")
    plt.plot(deltas, RDP_result, label="RDP", color="blue")
    plt.plot(deltas, RDP_result_new, label="RDP (hitting time)", color="orange")
    plt.xscale("log")
    plt.legend()
    plt.tight_layout()
    plt.savefig(f"final/{save_name}-final-DP.pdf", dpi=300)


    
    # plt.figure()
    # graph.remove_edges_from(nx.selfloop_edges(graph))
    # nx.draw(graph, node_color=priv,node_size=30, alpha=1, edge_color='xkcd:silver', width=0.8, cmap=plt.cm.cividis)
    # # plt.tight_layout()
    # plt.savefig(f"fig/{save_name}-figi.pdf", dpi=300)