#%%
import numpy as np
from sklearn.metrics.pairwise import rbf_kernel
from sklearn import datasets
from networkx import draw_networkx_edges
from matplotlib.pyplot import quiver
from math import sqrt, log
from scipy.special import softmax
from numpy.random import default_rng
random = default_rng()
from scipy.linalg import pinv
import pickle


#%% "Graph construction"
def graph_Erdos_Renyi(n_vertices= 5, p= 0.5):
	Adj = np.zeros((n_vertices, n_vertices))
	for i in range(n_vertices):
		for j in range(i):
			Adj[i,j] = random.binomial(1, p)
		
	return (Adj + Adj.T)

def graph_rbf(n_nodes, dimension, gamma=None, sparsity = 0.9, clusters= 2, weighted= True): ##
	"""_summary_

	Parameters
	----------
	n_nodes : int
		number of nodes in the graph
	dimension : int
		dimension of the signal
	gamma : float, optional
		parameter of the Gaussian kernel, by default None
	sparsity : float, optional
		percentage of absent edges, by default 0.9
	clusters : int, optional
		number of clusters, by default 2
	weighted: bool, optional, by default True
		True if the graph has weights, False if it is binary

	Returns
	-------
	adj : array_like
		adjacency matrix
	node_f: array_like
		vectorial graph signal of size n_nodes x dimension
	"""
	# From the github of Kaige Yang
	if clusters==False or clusters == None:
		node_f= random.uniform(low=-0.5, high=0.5, size=(n_nodes, dimension))
	else:
		node_f, _=datasets.make_blobs(n_samples=n_nodes, n_features=dimension, centers= clusters, cluster_std=0.2, center_box=(-1,1),  shuffle=False)

	adj=rbf_kernel(node_f, gamma=gamma)

	# All of the weights lower than the sparsity defined quantile are set to 0
	if sparsity is not None : 
		tri_inds = np.triu_indices(n_nodes, k= 1) 
		thd = np.quantile(adj[tri_inds], q = sparsity)
		adj[adj<=thd] = 0.0

	np.fill_diagonal(adj, 0) #filling the diagonal to zero 
	if not weighted:
		adj[adj>0] = 1

	return adj, node_f

#%% Graph Laplacian

def random_walk_laplacian(Adj):
	deg = Adj.sum(axis= 1, keepdims= True)
	invDeg = np.where(deg>0, 1/deg, 0)
	return np.eye(len(Adj)) - Adj*invDeg

def Laplacian(Adj, which= "random_walk"):
	return Adj.sum(axis= 1) - Adj

def normalized_Laplacian(Adj):
	deg = Adj.sum(axis= 1)
	sqrtInvDeg = np.where(deg>0, 1/deg, 0)**0.5
	return np.eye(len(Adj)) - Adj*np.outer(sqrtInvDeg, sqrtInvDeg)

# def doubly_stochastic_Adjacency(Adj):
# 	# TODO replace by while some on rows or on colmns is far from ones vector
# 	for _ in range(1000):
# 		deg = Adj.sum(axis= 1, keepdims= True)
# 		invDeg = np.where(deg>0, 1/deg, 0)
# 		Adj *= invDeg
# 		deg = Adj.sum(axis= 0)
# 		invDeg = np.where(deg>0, 1/deg, 0)
# 		Adj *= invDeg
# 	return Adj

def quiver_graph(graph, pos, signal, alpha= 0.2):
	assert signal.shape[1] == 2; "The signal two plot must be two dimensional !"
	norms = np.linalg.norm(signal, axis= 1, keepdims= True)
	sig_norm = signal/norms
	quiver(pos[:,0], pos[:,1], sig_norm[:,0], sig_norm[:,1], norms)
	draw_networkx_edges(graph, pos = pos, alpha= alpha)
	return None

#%%
def create_cluster_indices(n_points, n_clusters, weights= None, imbalance= None, sorted= True):
	if weights is not None: 
		weights /= weights.sum()
	if imbalance is not None: 
		weights = softmax(imbalance*weights)
	res = random.choice(np.arange(n_clusters), size= n_points, p= weights)
	if sorted: res.sort()
	return res
#%% Sherman-Morrison formula
def rank1_update(invA, u):
    invA_u = invA @ u
    return np.outer(invA_u, invA_u)/(1 + u @ invA_u)

def rank1_update_decomposed(invA, u):
    """The Sherman Morrison formula updates the inverse with a rank one matrix
    that can be written as the outer product of a vector U by itself. This function outputs that
    U vector. Requires that invA to be symmetric.
    """
    invA_u = invA @ u
    return invA_u/sqrt(1 + u @ invA_u)

def rank1_update_vectorized_last(invA, u):
    "Computes the increment to subtract"
    invA_u = np.einsum("ijk, jk -> ik", invA , u)
    return np.einsum("ik,jk -> ijk", invA_u, invA_u)/(1 + np.einsum("ik,ik->k", u, invA_u)[None,None,:])

def rank1_update_vectorized_first(invA, u):
    "Computes the increment to subtract"
    invA_u = np.einsum("ijk, jk -> ik", invA , u)
    return np.einsum("ik,jk -> ijk", invA_u, invA_u)/(1 + np.einsum("ik,ik->k", u, invA_u)[None,None,:])



#%% network lasso related
def boundary_indicator(edges, cluster_inds):
    """_summary_

    Parameters
    ----------
    edges : edges of a networkx object

    cluster_inds : list(int)
        list of the nodes' cluster indices / classes

    Returns
    -------
    numpy array of bool
        returns whether a vector that indicates the edges on the boundary
    """
    cluster_edge_inds = cluster_inds[np.array(edges).T] # cluster indices of nodes in each edge, transposed
    return cluster_edge_inds[0] != cluster_edge_inds[1]

def alpha_theory_2(coef, XX, delta= 1e-3):
    norm_first_sqr = np.trace(np.sum(XX, axis= 0))
    norm_second = np.einsum("ijk,ijk", XX, XX)
    norm_third = np.max(np.linalg.norm(XX, axis= (1,2)))
    return coef *sqrt(norm_first_sqr + sqrt(norm_second*log(1/delta)) + norm_third*log(1/delta))

# %%
def sum_Hsu(v1, v2, v3, v_delta):
	return v1 + 2*np.sqrt(v2 * v_delta) + 2 * v3* v_delta

def fast_pinv(M):
	return pinv(M.T @ M) @ M.T

def postprocess(source):
	if isinstance(source, str):
		with open(source, 'rb') as file:
			res_dict = pickle.load(file)
	elif isinstance(source, dict):
		res_dict = source
	else:
		raise TypeError("source must have type either a string or a dict")
	res = res_dict["results"]
	rewards = res[:,:,0,:]
	rewards_oracle = res[:,:,1,:]
	return rewards, rewards_oracle, res_dict["agents_names"]

def postprocess_many(source_list):
	agents_names = []
	res_list = []
	for source in source_list:
		if isinstance(source, str):
			with open(source, 'rb') as file:
				res_dict = pickle.load(file)
		elif isinstance(source, dict):
			res_dict = source
		else:
			raise TypeError("source must have type either a string or a dict")
		res_list.append(res_dict["results"])
		agents_names.extend(res_dict["agents_names"])
	res = np.concat(res_list, axis= 0)
	rewards = res[:,:,0,:]
	rewards_oracle = res[:,:,1,:]
	return rewards, rewards_oracle, agents_names