# TODO : replace networkx by graphtool when possible ##SEGFAULTS



from sklearn.pipeline import Pipeline
from sklearn.model_selection import GridSearchCV
from sklearn.base import BaseEstimator, TransformerMixin
import gudhi as gd
from os.path import expanduser, exists
import networkx as nx
import pickle
from joblib import Parallel, delayed, cpu_count
import numpy as np
from tqdm import tqdm
from warnings import warn
from random import choice
from sklearn.neighbors import KernelDensity
from typing import Callable, Iterable
from os import walk
from pandas import read_csv
from sklearn.preprocessing import LabelEncoder
import matplotlib.pyplot as plt
from gudhi.representations import Landscape
from gudhi.representations.vector_methods import PersistenceImage
from gudhi.representations.kernel_methods import SlicedWassersteinKernel ## TODO distance instead, to store the matrix
from gudhi.representations.kernel_methods import SlicedWassersteinDistance
from torch_geometric.data.data import Data
from sklearn.neighbors import kneighbors_graph
from torch_geometric.transforms import FaceToEdge
from torch_geometric.datasets import ModelNet
from torch_geometric.datasets import ShapeNet
import functools

import torch_geometric.transforms as transforms
import graph_tool as gt # networkx is too slow to compute dijkstra (python) but segfaults....
from graph_tool.topology import shortest_distance



# DATASET_PATH="../mperslay_expes/data/"
DATASET_PATH = expanduser("~/Datasets/")
# PATH OF GRAPHS, .e.g., imdb-binary : $HOME/Datasets/IMDB-BINARY/mat/nodes_*_edges_*_gid_*_lb_*_index_*_adj.mat
# UCR, eg, Coffee : $HOME/Datasets/UCR/Coffee/Coffee_TRAIN.tsv
# 3d Shapes : $HOME/Datasets/3dshapes/Airplane/*.off

################################################################################################# DATASET GET / SET
def get_graph_dataset(dataset, label= "lb", **kwargs):
	if dataset == "SBM":	return get_sbm_dataset(kwargs.get("N", 100))
	if dataset == "modelnet10":	return modelnet2pts2gs("10", **kwargs)
	if dataset.startswith("3dshapes/"):	return get_3dshape(dataset[9:], **kwargs)
	from os import walk
	from scipy.io import loadmat
	from warnings import warn
	path = DATASET_PATH + dataset  +"/mat/"
	labels:list[int] = []
	gs:list[nx.Graph] = []
	for root, dir, files in walk(path):
		for file in files:
			file_ppties  = file.split("_")
			i=0
			while i+1 < len(file_ppties) and file_ppties[i] != label :
				i+=1
			if i+1 >= len(file_ppties):
				warn(f"Cannot find label {label} on file {file}.")
			else:
				labels += [file_ppties[i+1]]
			adj_mat = np.array(loadmat(path + file)['A'], dtype=np.float32)
			gs.append(nx.Graph(adj_mat))
	return gs, labels

############################# 3D SHAPES
def get_3dshape(dataset:str, dataset_num:int|None=None, num_sample:int=0):
	from torch_geometric.io import read_off
	if dataset_num is None and "/" in dataset:
		position = dataset.rfind("/")
		dataset_num = int(dataset[position+1:-4]) # cuts the "<dataset>/" and the ".off"
		dataset = dataset[:position]

	if dataset_num is None: # gets a random (available) number for this dataset
		from os import listdir
		from random import choice
		files = listdir(DATASET_PATH+f"3dshapes/{dataset}")
		if num_sample <= 0:
			files = [file for file in files if "label" not in file]
		else:
			files = np.random.choice([file for file in files if "label" not in file], replace=False, size=num_sample)
		dataset_nums = np.sort([int("".join([char for  char in file  if char.isnumeric()])) for file in files])
		
		print("Dataset nums : ", *dataset_nums)
		out = [get_3dshape(dataset, dataset_num=num) for num in dataset_nums]
		return out

	path = DATASET_PATH+f"3dshapes/{dataset}/{dataset_num}.off"
	data = read_off(path)
	faces = data.face.numpy().T
	# data = FaceToEdge(remove_faces=remove_faces)(data)
	#labels 
	label_path = path.split(".")[0] + "_labels.txt"
	f = open(label_path, "r")
	labels = np.zeros(len(data.pos), dtype="<U10") # Assumes labels are of size at most 10 chars
	current_label=""
	for i, line in enumerate(f.readlines()):
		if i %  2 == 0:
			current_label = line.strip()
			continue
		faces_of_label = np.array(line.strip().split(" "), dtype=int) -1 # this starts at 1, python starts at 0
		# print(faces_of_label.min())
		nodes_of_label = np.unique(faces[faces_of_label].flatten())
		labels[nodes_of_label] = current_label  # les labels sont sur les faces
	return data, labels

class TorchData2NeighborGraph(BaseEstimator,TransformerMixin):
	def __init__(self, num_neighbors:int=8, exp_flag:bool=True, n_jobs:int=1):
		super().__init__()
		self.num_neighbors=num_neighbors 
		self.exp_flag = exp_flag 
		self.n_jobs = n_jobs
	def fit(self, X, y=None):
		return self
	def transform(self,X:list[Data]):
		def data2graph(data:Data):
			pos = data.pos.numpy()
			adj = kneighbors_graph(pos, self.num_neighbors, mode='distance', n_jobs=self.n_jobs)
			g = nx.from_scipy_sparse_array(adj, edge_attribute= 'weight')
			if self.exp_flag:
				for u, v in g.edges():
					g[u][v]['weight'] = np.exp(-g[u][v]['weight'])
			return g
			#TODO : nx.set_edge_attributes()
		return [data2graph(data) for data in X]

### Objects segmentation
def dijkstraSimplexTree(data_node:tuple[Data,int], backend="graph_tool") -> gd.SimplexTree:
	data,node = data_node
	distances_to_node = get_dijkstra(data=data, node=node, backend=backend)
	num_nodes = len(distances_to_node)
	st = gd.SimplexTree()
	if not hasattr(data, "edge_index"):	FaceToEdge(remove_faces=False)(data)
	st.insert_batch(np.array([range(num_nodes)], dtype=int), np.zeros(num_nodes))
	st.insert_batch(data.edge_index.numpy(), np.zeros(len(data.edge_index.T))) # the copy is a gudhi workaround : https://github.com/GUDHI/gudhi-devel/issues/802
	for node, d in enumerate(distances_to_node):
		st.assign_filtration([node], -d)
	# for st_node,_ in st_copy.get_skeleton(0): # TODO : optimize
	# 	st_copy.assign_filtration(st_node, distances_to_node[st_node[0]])
	st.make_filtration_non_decreasing() #TODO : If doing extended persistence, that's an unnecessary overhead, as gudhi only uses the filtration of 0-simplices
	return st
def dataGeodesicSimplexTree(data_node:tuple[Data,int]):
	data, node = data_node
	if not hasattr(data, 'edge'):
		FaceToEdge(remove_faces=False)(data)
	edges = data.edge_index.numpy()
	distances_to_node = get_dijkstra(data, node,backend="torch_geometric")
	num_nodes = len(distances_to_node)
	st = gd.SimplexTree()
	st.insert_batch(np.array([range(num_nodes)], dtype=int), np.zeros(num_nodes))
	st.insert_batch(edges, np.zeros(len(edges.T)))
	for node, d in enumerate(distances_to_node):
		st.assign_filtration([node], -d)
	st.make_filtration_non_decreasing()
	return st
class TorchData2DijkstraSimplexTree(BaseEstimator,TransformerMixin):
	def __init__(self, n_jobs:int=-1, dtype=gd.SimplexTree, progress=False, true_geodesic:bool=True):
		super().__init__()
		self.n_jobs = n_jobs
		self.dtype = dtype
		self.progress = progress
		self.true_geodesic=true_geodesic
	def fit(self, X, y=None):
		return self
	def transform(self,X:list[tuple[Data|None, int]]): #Datas and nodes (otherwise, we don't keep the labels). To avoid copies, repeated data can be ommited (previous data is used)
		args:list[tuple[Data, int]] = []
		current:Data|None= None
		for data, node in X:
			if not data is None: # data is only updated when not None
				current = data # if self.true_geodesic else FaceToEdge(remove_faces=False)(data).edge_index.numpy()
			# assert len(current_edges) > 0
			assert not current is None
			data_tuple =tuple((current,node))
			args.append(data_tuple)
		to_st = dataGeodesicSimplexTree if self.true_geodesic else dijkstraSimplexTree
		if self.dtype is None:
			return to_st, args
		else:
			return Parallel(n_jobs=self.n_jobs, prefer="threads")(delayed(to_st)(d) for d in tqdm(args, disable = not self.progress, desc="Generating SimplexTrees."))








### Graphs
def get_sbm(n1,n2,p,q)->nx.Graph: # Stochastic Block Model
	rand = np.random.uniform
	edges = [[u,v] for u in range(n1) for v in range(n1)  if (rand()<p) and u < v] # edges of block 1
	edges += [[u+n1,v+n1] for u in range(n2) for v in range(n2)  if (rand()<p) and (u < v)] # edges of block 2
	edges += [[u,v+n1] for u in range(n1) for v in range(n2) if (rand()<q)] # interblock edges
	edges = np.array(edges)
	g = nx.Graph()
	for i in range(n1+n2):
		g.add_node(i)
	for e in edges:
		g.add_edge(*e)
	return g

def get_sbm_dataset(N=100, progress=False)->tuple[list[nx.Graph], list[int]]:
	graphs:list[nx.Graph] = Parallel(n_jobs=5)(delayed(get_sbm)(n1=100,n2=50,p=0.5,q= 0.1) for _ in tqdm(range(N), disable = not progress))
	labels = [0] * len(graphs)
	graphs += Parallel(n_jobs=-1)(delayed(get_sbm)(n1=75,n2=75, p=0.4, q=0.2) for _ in tqdm(range(N), disable = not progress))
	labels += [1]*(len(graphs) - len(labels))
	return graphs, labels

#### graphs tools
def get_graphs(dataset:str, N:int|str="")->tuple[list[nx.Graph], list[int]]:
	graphs_path = f"{DATASET_PATH}{dataset}/graphs{N}.pkl"
	labels_path = f"{DATASET_PATH}{dataset}/labels{N}.pkl"
	if not exists(graphs_path) or not exists(labels_path):
		graphs, labels = get_graph_dataset(dataset,)
		if dataset.startswith("3dshapes/"):
			return graphs, labels
		print("Saving graphs at :", graphs_path)
		set_graphs(graphs = graphs, labels = labels, dataset = dataset)
	else:
		graphs = pickle.load(open(graphs_path, "rb"))
		labels = pickle.load(open(labels_path, "rb"))
	return graphs, labels


def set_graphs(graphs:list[nx.Graph], labels:list, dataset:str, N:int|str=""): # saves graphs (and filtration values) into a file
	graphs_path = f"{DATASET_PATH}{dataset}/graphs{N}.pkl"
	labels_path = f"{DATASET_PATH}{dataset}/labels{N}.pkl"
	pickle.dump(graphs, open(graphs_path, "wb"))
	pickle.dump(labels, open(labels_path, "wb"))
	return

def reset_graphs(dataset:str, N=100): # Resets filtrations values on graphs
	graphs, labels = get_graph_dataset(dataset, N=N)
	set_graphs(graphs,labels, dataset)
	return


############################################## Immuno 1.5mm
def get_immuno_regions():
	X, labels = [],[]
	path = DATASET_PATH+"1.5mmRegions/"
	for label in ["FoxP3", "CD8", "CD68"]:
	#     for label in ["FoxP3", "CD8"]:
		for root, dirs, files in walk(path + label+"/"):
			for name in files:
				X.append(np.array(read_csv(path+label+"/"+name))/1500)
				labels.append(label)
	labels = np.array(LabelEncoder().fit_transform(labels))
	p = np.random.permutation(len(labels))
	return [X[i] for i in p], np.array(labels)[p]


########################################## LARGE IMMUNO
def get_immuno(i=1):
	immu_dataset = read_csv(DATASET_PATH+f"LargeHypoxicRegion{i}.csv")
	X = np.array(immu_dataset['x'])
	X /= np.max(X)
	Y = np.array(immu_dataset['y'])
	Y /= np.max(Y)
	labels = LabelEncoder().fit_transform(immu_dataset['Celltype'])
	return X,Y, labels


############################################## UCR
def get_UCR_dataset(dataset = "Coffee", test = False):
	dataset_path = DATASET_PATH +"UCR/"+ dataset + "/" + dataset
	dataset_path +=  "_TEST.tsv" if test else "_TRAIN.tsv"
	data = np.array(read_csv(dataset_path, delimiter='\t', header=None, index_col=None))
	return data[:,1:-1], LabelEncoder().fit_transform(data[:,0])

################################################# Synthetic
#def circle_pole(n_circle:int=500,n_noise:int=100, low:float= 1, high:float=1.1, k:int=3, sigma:float=0.5, pinch:bool=False)->np.ndarray:
#	def pt()->np.ndarray:
#		n = np.random.normal(loc=0,scale=sigma)
#		r = np.sqrt(np.random.uniform(low = low, high = high**2)) - pinch*0.1/sigma*(1-np.abs(n))
#		θ = np.random.choice(range(k)) * 2*np.pi / k + n
#		return np.array([r*np.cos(θ), r* np.sin(θ)])
#	out = np.vstack([np.array([pt() for _ in range(n_circle)], dtype=float),np.random.uniform(low=-1.2, high=1.2, size=(n_noise,2))])
#	np.random.shuffle(out)
#	return out

################################################################################################# GRAPHS FILTRATION
def compute_ricci(graphs:list[nx.Graph], alpha=0.5, progress = 1):
	from GraphRicciCurvature.OllivierRicci import OllivierRicci
	def ricci(graph, alpha=alpha):
		return OllivierRicci(graph,alpha=alpha).compute_ricci_curvature()
	graphs = Parallel(n_jobs=-1)(delayed(ricci)(g) for g in tqdm(graphs, disable = not progress, desc="Computing ricci"))
	return graphs

def compute_cc(graphs:list[nx.Graph], progress = 1):
	def _cc(g):
		cc = nx.closeness_centrality(g)
		nx.set_node_attributes(g,cc,"cc")
		edges_cc = {(u,v):max(cc[u], cc[v]) for u,v in g.edges}
		nx.set_edge_attributes(g,edges_cc, "cc")
		return g
	graphs = Parallel(n_jobs=-1,)(delayed(_cc)(g) for g in tqdm(graphs, disable = not progress, desc="Computing cc"))
	return graphs
	# for g in tqdm(graphs, desc="Computing cc"):
	# 	_cc(g)
	# return graphs

def compute_degree(graphs:list[nx.Graph], progress=1):
	def _degree(g):
		degrees = dict(g.degree)
		nx.set_node_attributes(g,degrees,"degree")
		edges_dg = {(u,v):max(degrees[u], degrees[v]) for u,v in g.edges}
		nx.set_edge_attributes(g,edges_dg, "degree")
		return g
	graphs = Parallel(n_jobs=-1)(delayed(_degree)(g) for g in tqdm(graphs, disable = not progress, desc="Computing degree"))
	return graphs
	# for g in tqdm(graphs, desc="Computing degree"):
	# 	_degree(g)
	# return graphs

def compute_fiedler(graphs:list[nx.Graph], progress = 1): # TODO : make it compatible with non-connexe graphs
	def _fiedler(g):
		fiedler = nx.fiedler_vector(g)**2
		fielder_dict = {i:f for i,f in enumerate(fiedler)}
		nx.set_node_attributes(g,fielder_dict,"fiedler")
		edges_fiedler = {(u,v):max(fiedler[u], fiedler[v]) for u,v in g.edges}
		nx.set_edge_attributes(g,edges_fiedler, "fiedler")
		return g
	graphs = Parallel(n_jobs=-1)(delayed(_fiedler)(g) for g in tqdm(graphs, disable = not progress, desc="Computing fiedler"))
	return graphs
	# for g in tqdm(graphs, desc="Computing fiedler"):
	# 	_fiedler(g)
	# return graphs


def compute_filtration(dataset:str, filtration:str, **kwargs):
	if filtration == "ALL":
		reset_graphs(dataset)
		graphs,labels = get_graphs(dataset, **kwargs)
		graphs = compute_cc(graphs)
		graphs = compute_degree(graphs)
		graphs = compute_ricci(graphs)
		graphs = compute_fiedler(graphs)
		set_graphs(graphs=graphs, labels=labels, dataset=dataset)
		return
	graphs,labels = get_graphs(dataset, **kwargs)
	if filtration == "dijkstra":
		return
	elif filtration == "cc":
		graphs = compute_cc(graphs)
	elif filtration == "degree":
		graphs = compute_degree(graphs)
	elif filtration == "ricciCurvature":
		graphs = compute_ricci(graphs)
	elif filtration == "fiedler":
		graphs = compute_fiedler(graphs)
	else:
		warn(f"Filtration {filtration} not implemented !")
		return
	set_graphs(graphs=graphs, labels=labels, dataset=dataset)
	return

##############################################################################################


def get_dijkstra(data:Data, node:int, backend="graph_tool")->np.ndarray:
	if not hasattr(data, 'edges'):
		FaceToEdge(remove_faces=False)(data)
	edges = data.edge_index.numpy().T
	if backend == "graph_tool": #TENDS TO SEGFAULT AAAAAAAAAA
		from graph_tool.topology import shortest_distance
				
		g = gt.Graph()
		num_nodes = len(data.pos)
		g.add_vertex(n=num_nodes)
		g.add_edge_list(edges)
		assert node <= num_nodes
		distances = shortest_distance(g=g, source=node).get_array()
		return distances
	# if backend == "igraph":
	# 	g = ig.Graph(edges = data.T)
	# 	return g.distances(source=node)
	if backend == "networkx":
		g = nx.from_edgelist(edges)
		temp = nx.shortest_path_length(g, node)
		distances = np.empty(len(g.nodes))
		for key, value in temp.items():	distances[key]=value
		return distances
	
	if backend == "torch_geometric":
		from torch import Tensor
		from torch_geometric.utils import geodesic_distance
		return geodesic_distance(pos=data.pos, face=data.face,src=Tensor([node])).numpy()[0]
	raise Exception("Cannot find bakend")



################################################################################################# WRAPPERS
from types import FunctionType
def get_simplextree(x)->gd.SimplexTree:
	if isinstance(x, gd.SimplexTree):
		return x
	if isinstance(x, FunctionType):
		return x()
	if len(x) == 3 and isinstance(x[0],FunctionType):
		f,args, kwargs = x
		return f(*args,**kwargs)
	raise TypeError("Not a valid SimplexTree")
def get_simplextrees(X)->Iterable[gd.SimplexTree]:
	if len(X) == 2 and isinstance(X[0], FunctionType):
		f,data = X
		return (f(x) for x in data)
	if len(X) == 0: return []
	if not isinstance(X[0], gd.SimplexTree):
		raise TypeError
	return X
	



############## INTERVALS (for sliced wasserstein)
class Graph2SimplexTree(BaseEstimator,TransformerMixin):
	def __init__(self, f:str="ricciCurvature",dtype=gd.SimplexTree, reverse_filtration:bool=False):
		super().__init__()
		self.f=f # filtration to search in graph
		self.dtype = dtype # If None, will delay the computation in the pipe (for parallelism)
		self.reverse_filtration = reverse_filtration # reverses the filtration #TODO
	def fit(self, X, y=None):
		return self
	def transform(self,X:list[nx.Graph]):
		def todo(graph, f=self.f) -> gd.SimplexTree: # TODO : use batch insert
			st = gd.SimplexTree()
			for i in graph.nodes:	st.insert([i], graph.nodes[i][f])
			for u,v in graph.edges:	st.insert([u,v], graph[u][v][f])
			return st
		return [todo, X] if self.dtype is None else Parallel(n_jobs=-1, prefer="threads")(delayed(todo)(graph) for graph in X)


class PointCloud2SimplexTree(BaseEstimator,TransformerMixin):
	def __init__(self, delayed:bool = False, threshold = np.inf):
		super().__init__()
		self.delayed = delayed
		self.threshold=threshold
	@staticmethod
	def _get_point_cloud_diameter(x):
		from scipy.spatial import distance_matrix
		return np.max(distance_matrix(x,x))
	def fit(self, X, y=None):
		if self.threshold < 0:
			self.threshold = max(self._get_point_cloud_diameter(x) for x in X)
		return self
	def transform(self,X:list[nx.Graph]):
		def todo(point_cloud) -> gd.SimplexTree: # TODO : use batch insert
			st = gd.AlphaComplex(points=point_cloud).create_simplex_tree(max_alpha_square = self.threshold**2)
			return st
		return [todo, X] if self.delayed is None else Parallel(n_jobs=-1, prefer="threads")(delayed(todo)(point_cloud) for point_cloud in X)



#################### FILVEC
def get_filtration_values(g:nx.Graph, f:str)->np.ndarray:
	filtrations_values = [
		g.nodes[node][f] for node in g.nodes
	]+[
		g[u][v][f] for u,v in g.edges
	]
	return np.array(filtrations_values)
def graph2filvec(g:nx.Graph, f:str, range:tuple, bins:int)->np.ndarray:
    fs = get_filtration_values(g, f)
    return np.histogram(fs, bins=bins,range=range)[0]
class FilvecGetter(BaseEstimator, TransformerMixin):
	def __init__(self, f:str="ricciCurvature",quantile:float=0., bins:int=100, n_jobs:int=1):
		super().__init__()
		self.f=f
		self.quantile=quantile
		self.bins=bins
		self.range:tuple[float]|None=None
		self.n_jobs=n_jobs
	def fit(self, X, y=None):
		filtration_values = np.concatenate(Parallel(n_jobs=self.n_jobs)(delayed(get_filtration_values)(g,f=self.f) for g in X))
		self.range= tuple(np.quantile(filtration_values, [self.quantile, 1-self.quantile]))
		return self
	def transform(self,X):
		if self.range == None:
			print("Fit first")
			return
		return Parallel(n_jobs=self.n_jobs)(delayed(graph2filvec)(g,f=self.f, range=self.range, bins=self.bins) for g in X)




############# Filvec from SimplexTree
# Input list of [list of diagrams], outputs histogram of persitence values (x and y coord mixed) 
def simplextree2hist(simplextree, range:tuple[float, float], bins:int, density:bool)->np.ndarray: #TODO : Anything to histogram
	filtration_values = np.array([f for s,f in simplextree.get_simplices()])
	return np.histogram(filtration_values, bins=bins,range=range, density=density)[0]
class SimplexTree2Histogram(BaseEstimator, TransformerMixin):
	def __init__(self, quantile:float=0., bins:int=100, n_jobs:int=1, progress:bool=False, density:bool=True):
		super().__init__()
		self.range:np.ndarray | None=None
		self.quantile:float=quantile
		self.bins:int=bins
		self.n_jobs=n_jobs
		self.density=density
		self.progress = progress
		# self.max_dimension=None # TODO: maybe use it
	def fit(self, X, y=None): # X:list[diagrams]
		if len(X) == 0:	return self
		if type(X[0]) is gd.SimplexTree: # If X contains simplextree : nothing to do
			data = X
			to_st = lambda x : x
		else: # otherwise we assume that we retrieve simplextrees using f,data = X; simplextrees = (f(x) for x in data)
			# assert len(X) == 2
			to_st, data = X
		persistence_values = np.array([f for st in data for s,f in to_st(st).get_simplices()])
		persistence_values = persistence_values[persistence_values<np.inf]
		self.range = np.quantile(persistence_values, [self.quantile, 1-self.quantile])
		return self
	def transform(self,X):
		if len(X) == 0:	return self
		if type(X[0]) is gd.SimplexTree: # If X contains simplextree : nothing to do
			if self.n_jobs > 1:
				warn("Cannot pickle simplextrees, reducing to 1 thread to compute the simplextrees")
			return [simplextree2hist(g,range=self.range, bins=self.bins, density=self.density) for g in tqdm(X, desc="Computing diagrams", disable=not self.progress)]
		else: # otherwise we assume that we retrieve simplextrees using f,data = X; simplextrees = (f(x) for x in data)
			to_st, data = X # asserts len(X) == 2
			def pickle_able_todo(x, **kwargs):
				simplextree = to_st(x)
				return simplextree2hist(simplextree=simplextree, **kwargs)
		return Parallel(n_jobs=self.n_jobs)(delayed(pickle_able_todo)(g,range=self.range, bins=self.bins, density=self.density) for g in tqdm(data, desc="Computing simplextrees and their diagrams", disable=not self.progress))




############# PERVEC
# Input list of [list of diagrams], outputs histogram of persitence values (x and y coord mixed) 
def dgm2pervec(dgms, range:tuple[float, float], bins:int)->np.ndarray: #TODO : Anything to histogram
	dgm_union = np.concatenate([dgm.flatten() for dgm in dgms]).flatten()
	return np.histogram(dgm_union, bins=bins,range=range)[0]
class Dgm2Histogram(BaseEstimator, TransformerMixin):
	def __init__(self, quantile:float=0., bins:int=100, n_jobs:int=1):
		super().__init__()
		self.range:np.ndarray | None=None
		self.quantile:float=quantile
		self.bins:int=bins
		self.n_jobs=n_jobs
	def fit(self, X, y=None): # X:list[diagrams]
		persistence_values = np.concatenate([dgm.flatten() for dgms in X for dgm in dgms], axis=0).flatten()
		persistence_values = persistence_values[persistence_values<np.inf]
		self.range = np.quantile(persistence_values, [self.quantile, 1-self.quantile])
		return self
	def transform(self,X):
		return Parallel(n_jobs=self.n_jobs)(delayed(dgm2pervec)(g,range=self.range, bins=self.bins) for g in X)







################# SignedMeasureImage
class Dgms2SignedMeasureImage(BaseEstimator, TransformerMixin):
	def __init__(self, ranges:None|Iterable[Iterable[float]]=None, resolution:int=100, quantile:float=0, bandwidth:float=1, kernel:str="gaussian") -> None:
		super().__init__()
		self.ranges=ranges
		self.resolution=resolution
		self.quantile = quantile
		self.bandwidth = bandwidth
		self.kernel = kernel
	def fit(self, X, y=None): # X:list[diagrams]
		num_degrees = len(X[0])
		persistence_values = [np.concatenate([dgms[i].flatten() for dgms in X], axis=0) for i in range(num_degrees)] # values per degree
		persistence_values = [degrees_values[(-np.inf<degrees_values) * (degrees_values<np.inf)] for degrees_values in persistence_values] # non-trivial values
		quantiles = [np.quantile(degree_values, [self.quantile, 1-self.quantile]) for degree_values in persistence_values] # quantiles 
		self.ranges = np.array([np.linspace(start=[a], stop=[b], num=self.resolution) for a,b in quantiles])
		return self

	def _dgm2smi(self, dgms:Iterable[np.ndarray]):
		smi = np.concatenate(
				[
					KernelDensity(bandwidth=self.bandwidth, kernel=self.kernel).fit(dgm[:,[0]]).score_samples(range)
					- KernelDensity(bandwidth=self.bandwidth).fit(dgm[:,[1]]).score_samples(range)
					for dgm, range in zip(dgms, self.ranges)
				],
			axis=0)
		return smi
		
	def transform(self,X): # X is a list (data) of list of diagrams
		assert self.ranges is not None
		out = Parallel(n_jobs=1, prefer="threads")(
			delayed(Dgms2SignedMeasureImage._dgm2smi)(self=self, dgms=dgms)
			for dgms in X
			)

		return out



################# SignedMeasureHistogram
class Dgms2SignedMeasureHistogram(BaseEstimator, TransformerMixin):
	def __init__(self, ranges:None|list[tuple[float,float]]=None, bins:int=100, quantile:float=0) -> None:
		super().__init__()
		self.ranges=ranges
		self.bins=bins
		self.quantile = quantile
	def fit(self, X, y=None): # X:list[diagrams]
		num_degrees = len(X[0])
		persistence_values = [np.concatenate([dgms[i].flatten() for dgms in X], axis=0) for i in range(num_degrees)] # values per degree
		persistence_values = [degrees_values[(-np.inf<degrees_values) * (degrees_values<np.inf)] for degrees_values in persistence_values] # non-trivial values
		self.ranges = [np.quantile(degree_values, [self.quantile, 1-self.quantile]) for degree_values in persistence_values] # quantiles 
		return self
	def transform(self,X): # X is a list (data) of list of diagrams
		assert self.ranges is not None
		out = [
			np.concatenate(
				[np.histogram(dgm[:,0], bins=self.bins,range=range)[0] - np.histogram(dgm[:,1], bins=self.bins,range=range)[0]
				for dgm, range in zip(dgms, self.ranges)]
			)
		for dgms in X]
		return out








################## Signed Measure Kernel 1D
# input : list of [list of diagrams], outputs: the kernel to feed to an svm

# TODO : optimize ?
## TODO : np.triu
class Dgms2SignedMeasureDistance(BaseEstimator, TransformerMixin):
	def __init__(self, n_jobs:int=1, distance_matrix_path:str|None=None, progress:bool = False) -> None:
		super().__init__()
		self.degrees:list[int]|None=None
		self.X:None|list[np.ndarray] = None
		self.n_jobs=n_jobs
		self.distance_matrix_path = distance_matrix_path
		self.progress=progress
	def fit(self, X:list[np.ndarray], y=None):
		if len(X) <= 0:
			warn("Fit a nontrivial vector")
			return
		self.X = X
		self.degrees = list(range(len(X[0]))) # Assumes that all x \in X have the same number of diagrams
		return self
	
	@staticmethod
	def wasserstein_1(a:np.ndarray,b:np.ndarray)->float:
		return np.abs(np.sort(a) - np.sort(b)).mean() # norm 1
	@staticmethod
	def OSWdistance(mu:list[np.ndarray], nu:list[np.ndarray], dim:int)->float:
		return Dgms2SignedMeasureDistance.wasserstein_1(np.hstack([mu[dim][:,0], nu[dim][:,1]]), np.hstack([nu[dim][:,0], mu[dim][:,1]])) # TODO : check: do we want to sum the kernels or the distances ? add weights ?
	@staticmethod
	def _ds(mu:list[np.ndarray], nus:list[list[np.ndarray]], dim:int): # mu and nu are lists of diagrams seen as signed measures (birth = +, death = -)
		return [Dgms2SignedMeasureDistance.OSWdistance(mu,nu, dim) for nu in nus]
	
	def transform(self,X): # X is a list (data) of list of diagrams
		if self.X is None or self.degrees is None:
			warn("Fit first !")
			return np.array([[]])
		# Cannot use sklearn / scipy, measures don't have the same size, -> no numpy array
		# from sklearn.metrics import pairwise_distances
		# distances = pairwise_distances(X, self.X, metric = OSWdistance, n_jobs=self.n_jobs)
		# from scipy.spatial.distance import cdist
		# distances = cdist(X, self.X, metric=self.OSWdistance)
		distances_matrices = []
		if not self.distance_matrix_path is None:
			for degree in self.degrees:
				with tqdm(X, desc=f"Computing distance matrix of degree {degree}") as diagrams_iterator:
					matrix_path = f"{self.distance_matrix_path}_{degree}"
					if exists(matrix_path):
						distance_matrix = np.load(open(matrix_path, "rb"))
					else:
						distance_matrix = np.array(Parallel(n_jobs=self.n_jobs)(delayed(self._ds)(mu, self.X, degree) for mu in diagrams_iterator))
						np.save(open(matrix_path, "wb"), distance_matrix)
					distances_matrices.append(distance_matrix)
		else:
			for degree in self.degrees:
				with tqdm(X, desc=f"Computing distance matrix of degree {degree}") as diagrams_iterator:
					distances_matrices.append(np.array(Parallel(n_jobs=self.n_jobs, prefer="threads")(delayed(self._ds)(mu, self.X, degree) for mu in diagrams_iterator)))
		return np.asarray(distances_matrices)
		# kernels = [np.exp(-distance_matrix / (2*self.sigma**2)) for distance_matrix in distances_matrices]
		# return np.sum(kernels, axis=0)




## To do k folds with a distance matrix, we need to slice it into list of distances. 
# k-fold usually shuffles the lists, so we need to add an identifier to each entry, 
# 
class DistanceMatrix2DistanceList(BaseEstimator, TransformerMixin):
	def __init__(self) -> None:
		super().__init__()
	def fit(self, X, y=None):
		return self
	def transform(self,X):
		X = np.asarray(X)
		assert X.ndim == 2 ## Its a matrix
		return np.asarray([[i, *distance_to_pt] for i,distance_to_pt in enumerate(X)])


class DistanceList2DistanceMatrix(BaseEstimator, TransformerMixin):
	def __init__(self) -> None:
		super().__init__()
	def fit(self, X, y=None):
		return self
	def transform(self,X):
		index_list = np.asarray(X[:,0], dtype=int) + 1 # shift of 1, because the first index is for indexing the pts 
		return X[:, index_list] ## The distance matrix of the index_list
	

class DistanceMatrices2DistancesList(BaseEstimator, TransformerMixin):
	"""
	Input (degree) x (distance matrix) or (axis) x (degree) x (distance matrix D)
	Output _ (D1) x opt (axis) x (degree) x (D2, , with indices first)
	"""
	def __init__(self) -> None:
		super().__init__()
		self._axes=None
	def fit(self, X, y=None):
		X = np.asarray(X)
		self._axes = X.ndim ==4
		assert self._axes or X.ndim == 3, " Bad input shape. Input is either (degree) x (distance matrix) or (axis) x (degree) x (distance matrix) "
		
		return self
	def transform(self, X):
		X = np.asarray(X)
		assert (X.ndim == 3 and not self._axes) or (X.ndim == 4 and self._axes), f"X shape ({X.shape}) is not valid"
		if self._axes:
			out = np.asarray([[DistanceMatrix2DistanceList().fit_transform(M) for M in matrices_in_axes] for matrices_in_axes in X])
			return np.moveaxis(out, [2,0,1,3], [0,1,2,3])
		else:
			out = np.array([DistanceMatrix2DistanceList().fit_transform(M) for M in X]) ## indices are at [:,0,Any_coord]
			# return np.moveaxis(out, 0, -1) ## indices are at [:,0,any_coord], degree axis is the last
			return np.moveaxis(out, [1,0,2], [0,1,2])

		
	def predict(self,X):
		return self.transform(X)

class DistancesLists2DistanceMatrices(BaseEstimator, TransformerMixin):
	"""
	Input (D1) x opt (axis) x (degree) x (D2 with indices first)
	Output opt (axis) x (degree) x (distance matrix (D1,D2))
	"""
	def __init__(self) -> None:
		super().__init__()
		self.train_indices = None
		self._axes = None
	def fit(self, X:np.ndarray, y=None):
		X = np.asarray(X)
		assert X.ndim in [3,4]
		self._axes = X.ndim == 4
		if self._axes:
			self.train_indices = np.asarray(X[:,0,0,0], dtype=int)
		else:
			self.train_indices = np.asarray(X[:,0,0], dtype=int)
		return self
	def transform(self,X):
		X = np.asarray(X)
		assert X.ndim in [3,4]
		# test_indices = np.asarray(X[:,0,0], dtype=int) 
		# print(X.shape, self.train_indices, test_indices, flush=1) 
		# First coord of X is test indices by design, train indices have to be selected in the second coord, last one is the degree
		if self._axes:
			Y=X[:,:,:,self.train_indices+1]
			return np.moveaxis(Y, [0,1,2,3], [2,0,1,3])
		else:
			Y = X[:,:,self.train_indices+1] ## we only keep the good indices # shift of 1, because the first index is for indexing the pts
			return np.moveaxis(Y, [0,1,2], [1,0,2]) ## we put back the degree axis first
		
		# # out = np.moveaxis(Y,-1,0) ## we put back the degree axis first
		# return out
	


class DistanceMatrix2Kernel(BaseEstimator, TransformerMixin):
	"""
	Input : (degree) x (distance matrix) or (axis) x (degree) x (distance matrix) in the second case, axis HAS to be specified (meant for cross validation)
	Output : kernel of the same shape of distance matrix
	"""
	def __init__(self, sigma:float|Iterable[float]=1, axis:int|None=None, weights:Iterable[float]|float=1) -> None:
		super().__init__()
		self.sigma = sigma
		self.axis=axis
		self.weights = weights
		# self._num_axes=None
		self._num_degrees = None
	def fit(self, X, y=None):
		if len(X) == 0: return self
		assert X.ndim in [3,4], "Bad input."
		if self.axis is None:
			assert X.ndim ==3 or X.shape[0] == 1, "Set an axis for data with axis !"
			if X.shape[0] == 1 and X.ndim == 4:	
				self.axis=0
				self._num_degrees = len(X[0])
			else:
				self._num_degrees = len(X)
		else:
			assert X.ndim ==4, "Cannot choose axis from data with no axis !"
			self._num_degrees = len(X[self.axis])
		if isinstance(self.weights,float) or isinstance(self.weights,int):	self.weights = [self.weights]*self._num_degrees
		assert len(self.weights) == self._num_degrees, f"Number of weights ({len(self.weights)}) has to be the same as the number of degrees ({self._num_degrees})"
		return self
	def transform(self,X)->np.ndarray:
		if self.axis is not None:
			X=X[self.axis]
		kernels = np.asarray([np.exp(-distance_matrix / (2*self.sigma**2))*weight for distance_matrix, weight in zip(X, self.weights)])
		out = np.mean(kernels, axis=0)

		return out


## Wrapper for SW, in order to take as an input a list of (list of diagrams)
class Dgms2SWK(BaseEstimator, TransformerMixin):
	def __init__(self, num_directions:int=10, bandwidth:float=1.0, n_jobs:int=1, distance_matrix_path:str|None = None, progress:bool = False) -> None:
		super().__init__()
		self.num_directions:int=num_directions
		self.bandwidth:float = bandwidth
		self.n_jobs=n_jobs
		self.SW_:list = []
		self.distance_matrix_path = distance_matrix_path
		self.progress = progress
	def fit(self, X:list[list[np.ndarray]], y=None):
		# Assumes that all x \in X have the same size
		self.SW_ = [
			SlicedWassersteinDistance(num_directions=self.num_directions, n_jobs = self.n_jobs) for _ in range(len(X[0]))
		]
		for i, sw in enumerate(self.SW_):
			self.SW_[i]=sw.fit([dgms[i] for dgms in X]) # TODO : check : Not sure copy is necessary here
		return self
	def transform(self,X)->np.ndarray:
		if not self.distance_matrix_path is None:
			distance_matrices = []
			for i in range(len(self.SW_)):
				SW_i_path = f"{self.distance_matrix_path}_{i}"
				if exists(SW_i_path):
					distance_matrices.append(np.load(open(SW_i_path, "rb"))) 
				else:
					distance_matrix = self.SW_[i].transform([dgms[i] for dgms in X])
					np.save(open(SW_i_path, "wb"), distance_matrix)
		else:
			distance_matrices = [sw.transform([dgms[i] for dgms in X]) for i, sw in enumerate(self.SW_)]
		kernels = [np.exp(-distance_matrix / (2*self.bandwidth**2)) for distance_matrix in distance_matrices]
		return np.sum(kernels, axis=0) # TODO fix this, we may want to sum the distances instead of the kernels. 


class Dgms2SlicedWassersteinDistanceMatrices(BaseEstimator, TransformerMixin):
	def __init__(self, num_directions:int=10, n_jobs:int=1) -> None:
		super().__init__()
		self.num_directions:int=num_directions
		self.n_jobs=n_jobs
		self.SW_:list = []
	def fit(self, X:list[list[np.ndarray]], y=None):
		# Assumes that all x \in X have the same size
		self.SW_ = [
			SlicedWassersteinDistance(num_directions=self.num_directions, n_jobs = self.n_jobs) for _ in range(len(X[0]))
		]
		for i, sw in enumerate(self.SW_):
			self.SW_[i]=sw.fit([dgms[i] for dgms in X]) # TODO : check : Not sure copy is necessary here
		return self
	
	@staticmethod
	def _get_distance(diagrams, SWD):
		return SWD.transform(diagrams)
	def transform(self,X):
		distance_matrices = Parallel(n_jobs = self.n_jobs)(delayed(self._get_distance)([dgms[degree] for dgms in X], swd) for degree, swd in enumerate(self.SW_))		
		return np.asarray(distance_matrices)



# Gudhi simplexTree to list of diagrams
class SimplexTree2Dgm(BaseEstimator, TransformerMixin):
	def __init__(self, degrees:list[int]|None = None, extended:list[int]|bool=[], n_jobs=1, progress:bool=False, threshold:float=np.inf) -> None:
		super().__init__()
		self.extended:list[int]|bool = False if not extended else extended if type(extended) is list else [0,2,5,7] # extended persistence.
		# There are 4 diagrams per dimension then, the list of ints acts as a filter, on which to consider,
		#  eg., [0,2, 5,7] is Ord0, Ext+0, Rel1, Ext-1
		self.degrees:list[int] = degrees if degrees else list(range((max(self.extended) // 4)+1))  if self.extended else [0] # homological degrees
		self.n_jobs=n_jobs 
		self.progress = progress # progress bar
		self.threshold = threshold # Threshold value
		return
	def fit(self, X:list[gd.SimplexTree], y=None):
		if self.threshold <= 0:
			self.threshold = max( (abs(f) for simplextree in get_simplextrees(X) for s,f in simplextree.get_simplices()) )  ## MAX FILTRATION VALUE
			print(f"Setting threshold to {self.threshold}.")
		return self
	def transform(self,X:list[gd.SimplexTree]):
		# Todo computes the diagrams
		def reshape(dgm:np.ndarray|list)->np.ndarray:
			out = np.array(dgm) if len(dgm) > 0 else np.empty((0,2)) 
			if self.threshold != np.inf:
				out[out>self.threshold] = self.threshold
				out[out<-self.threshold] = -self.threshold
			return out
		def todo_standard(st):
			st.compute_persistence()
			return [reshape(st.persistence_intervals_in_dimension(d)) for d in self.degrees]
		def todo_extended(st):
			st.extend_filtration()
			dgms = st.extended_persistence()
#			print(dgms, self.degrees)
			return [reshape([bar for j,dgm in enumerate(dgms) for d, bar in dgm if d in self.degrees and j+4*d in self.extended])]
		todo = todo_extended if self.extended else todo_standard

		if isinstance(X[0],gd.SimplexTree): # simplextree aren't pickleable, no parallel
			# if self.n_jobs != 1:	warn("Cannot parallelize. Use dtype=None in previous pipe.")
			return Parallel(n_jobs=self.n_jobs, prefer="threads")(delayed(todo)(x) for x in tqdm(X, disable=not self.progress, desc="Computing diagrams"))
		else:
			to_st = X[0]# if to_st is None else to_st 
			dataset = X[1]# if to_st is None else X
			pickleable_todo = lambda x : todo(to_st(x))
			return Parallel(n_jobs=self.n_jobs, prefer="threads")(delayed(pickleable_todo)(x) for x in tqdm(dataset, disable=not self.progress, desc="Computing simplextrees and diagrams"))
		warn("Bad input.")
		return

# Shuffles a diagram shaped array. Input : list of (list of diagrams), output, list of (list of shuffled diagrams)
class DiagramShuffle(BaseEstimator, TransformerMixin):
	def __init__(self, ) -> None:
		super().__init__()
		return
	def fit(self, X:list[list[np.ndarray]], y=None):
		return self
	def transform(self,X:list[list[np.ndarray]]):
		def shuffle(dgm):
			shape = dgm.shape
			dgm = dgm.flatten()
			np.random.shuffle(dgm)
			dgm = dgm.reshape(shape)
			return dgm
		def todo(dgms):
			return [shuffle(dgm) for dgm in dgms]
		return [todo(dgm) for dgm in X]


class Dgms2Landscapes(BaseEstimator, TransformerMixin):
	def __init__(self, num:int=5, resolution:int=100,  n_jobs:int=1) -> None:
		super().__init__()
		self.degrees:list[int] = []
		self.num:int= num
		self.resolution:int = resolution
		self.landscapes:list[Landscape]= []
		self.n_jobs=n_jobs
		return
	def fit(self, X, y=None):
		if len(X) == 0:	return self
		self.degrees = list(range(len(X[0])))
		self.landscapes = []
		for dim in self.degrees:
			self.landscapes.append(Landscape(num_landscapes=self.num,resolution=self.resolution).fit([dgms[dim] for dgms in X]))
		return self
	def transform(self,X):
		if len(X) == 0:	return []
		return np.concatenate([landscape.transform([dgms[degree] for dgms in X]) for degree, landscape in enumerate(self.landscapes)], axis=1)

class Dgms2Image(BaseEstimator, TransformerMixin):
	def __init__(self, bandwidth:float=1, resolution:tuple[int,int]=(20,20),  n_jobs:int=1) -> None:
		super().__init__()
		self.degrees:list[int] = []
		self.bandwidth:float= bandwidth
		self.resolution = resolution
		self.PI:list[PersistenceImage]= []
		self.n_jobs=n_jobs
		return
	def fit(self, X, y=None):
		if len(X) == 0:	return self
		self.degrees = list(range(len(X[0])))
		self.PI = []
		for dim in self.degrees:
			self.PI.append(PersistenceImage(bandwidth=self.bandwidth,resolution=self.resolution).fit([dgms[dim] for dgms in X]))
		return self
	def transform(self,X):
		if len(X) == 0:	return []
		return np.concatenate([pers_image.transform([dgms[degree] for dgms in X]) for degree, pers_image in enumerate(self.PI)], axis=1)

############################################################################################### ACCURACIES HELPERS
def kfold_acc(cls,x,y, k:int=10, clsn=None):
	if clsn is None:
		clsn = range(len(cls))
	from sklearn.model_selection import StratifiedKFold as sKFold
	accuracies = np.zeros((len(cls), k))
	for i,(train_idx, test_idx) in enumerate(tqdm(sKFold(k, shuffle=True).split(x,y), total=k, desc="Computing kfold")):
		for j, cl in enumerate(cls):
			xtrain = [x[i] for i in train_idx]
			ytrain = [y[i] for i in train_idx]
			cl.fit(xtrain, ytrain)
			xtest = [x[i] for i in test_idx]
			ytest = [y[i] for i in test_idx] 
			accuracies[j][i] = cl.score(xtest, ytest)
	return [f"Classifier {cl_name} : {np.mean(acc*100).round(decimals=3)}% ±{np.std(acc*100).round(decimals=3)}" for cl_name,acc in zip(clsn, accuracies)]
	





def accuracy_to_csv(X,Y,cl, cln:str, k:float=10, dataset:str = "", filtration:str = "", shuffle=True,  verbose:bool=True, **kwargs):
	import pandas as pd
	assert k > 0, "k is either the number of kfold > 1 or the test size > 0."
	if k>1:
		k = int(k)
		from sklearn.model_selection import StratifiedKFold as KFold
		kfold = KFold(k, shuffle=shuffle).split(X,Y)
		accuracies = np.zeros(k)
		for i,(train_idx, test_idx) in enumerate(tqdm(kfold, total=k, desc="Computing kfold")):
			xtrain = [X[i] for i in train_idx]
			ytrain = [Y[i] for i in train_idx]
			cl.fit(xtrain, ytrain)
			xtest = [X[i] for i in test_idx]
			ytest = [Y[i] for i in test_idx] 
			accuracies[i] = cl.score(xtest, ytest)
			if verbose:
				print(f"step {i+1}, {dataset} : {accuracies[i]}", flush=True)
				try:
					print("Best classification parameters : ", cl.best_params_)
				except:
					None
			
	elif k > 0:
		from sklearn.model_selection import train_test_split
		print("Computing accuracy, with train test split", flush=True)
		xtrain, xtest, ytrain, ytest = train_test_split(X, Y, shuffle=shuffle, test_size=k)
		print("Fitting...", end="", flush=True)
		cl.fit(xtrain, ytrain)
		print("Computing score...", end="", flush=True)
		accuracies = cl.score(xtest, ytest)
		try:
			print("Best classification parameters : ", cl.best_params_)
		except:
			None
		print("Done.")
		if verbose:	print(f"Accuracy {dataset} : {accuracies} ")
	file_path:str = f"result_{dataset}.csv".replace("/", "_").replace(".off", "")
	columns:list[str] = ["dataset", "filtration", "pipeline", "cv", "mean", "std"]
	if exists(file_path):
		df:pd.DataFrame = pd.read_csv(file_path)
	else:
		df:pd.DataFrame = pd.DataFrame(columns= columns)
	more_names = []
	more_values = []
	for key, value in kwargs.items():
		if key not in columns:
			more_names.append(key)
			more_values.append(value)
		else:
			warn(f"Duplicate key {key} ! with values {cln} and {value}")
	new_line:pd.DataFrame = pd.DataFrame([[dataset, filtration, cln, k, np.mean(accuracies), np.std(accuracies)]+more_values], columns = columns+more_names)
	df = pd.concat([df, new_line])
	df.to_csv(file_path, index=False)





#################################################### TESTS

def test_dijkstra_graph_tool():
	dataset = "Airplane"
	data, _ = get_3dshape(dataset, num_sample=1)[0]
	node = choice(range(len(data.pos)))
	edges = FaceToEdge()(data).edge_index.numpy()
	nxgraph = nx.from_edgelist(edges.T)
	g = gt.Graph(directed=False)
	g.add_edge_list(edge_list=edges.T)
	distance_to_node = shortest_distance(g=g, source=node, target=np.arange(g.num_vertices()))
	temp = nx.shortest_path_length(nxgraph, node)
	distance_to_node2 = np.empty(len(distance_to_node))
	for key, value in temp.items():	distance_to_node2[key]=value
	assert (distance_to_node == distance_to_node2).all()


def get_torch_test_data():
	dataset = "Airplane"
	data, _ = get_3dshape(dataset)
	edges = FaceToEdge()(data).edge_index.numpy()
	nodes = np.random.choice(len(data.pos), replace=False, size=10)
	input = [[data, node] for node in nodes]
	print("data, edges, nodes, input")
	return data, edges, nodes, input
def test_delaying_dijkstra():
	dataset = "Airplane"
	data, _ = get_3dshape(dataset, num_sample=1)[0]
	nodes = np.random.choice(len(data.pos), replace=False, size=10)
	input = [[data, node] for node in nodes]
	simplextrees = TorchData2DijkstraSimplexTree().fit_transform(input)
	f, datas = TorchData2DijkstraSimplexTree(dtype=None).fit_transform(input)
	assert simplextrees == [f(data) for data in datas]


def test_dijkstra():
	dataset = "Airplane"
	data, _ = get_3dshape(dataset, num_sample=1)[0]
	node = choice(range(len(data.pos)))
	distances = get_dijkstra(data=data, node=node, backend="graph_tool")
	# assert (distances == get_dijkstra(data=data, node=node, backend="igraph")).all()
	assert (distances == get_dijkstra(data=data, node=node, backend="networkx")).all()

def test_torch2st_delayed():
	dataset = "Airplane"
	datas:list[tuple[Data, np.ndarray]] =  get_3dshape(dataset, num_sample=5) #[get_3dshape(dataset)[0] for _ in range(5)]
	nodess = [np.random.choice(range(len(data.pos)), replace=False, size=5) for data,_ in datas]
	input1 = []
	input2 = []
	for i,(data, labels) in enumerate(datas):
		for j,node in enumerate(nodess[i]):
			assert len(data.pos) > node
			input1.append([data, node])
			data2 = data if j == 0 else None
			input2.append([data2, node])
	simplextrees1 = TorchData2DijkstraSimplexTree().fit_transform(input1)
	simplextrees2 = TorchData2DijkstraSimplexTree().fit_transform(input2)
	assert simplextrees1 == simplextrees2

def test_smk():
	dataset = "Airplane"
	data, _ = get_3dshape(dataset, num_sample=1)[0]
	simplextrees = TorchData2DijkstraSimplexTree().fit_transform([[data,i] for i in range(10)])
	dgms = SimplexTree2Dgm(extended=True).fit_transform(simplextrees)
	kernel = SignedMeasureDistance().fit_transform(dgms)



if __name__ == "__main__":
	# TESTS HERE #TODO
	test_dijkstra_graph_tool()
	test_dijkstra()
	
	## DIJKSTRA
	test_delaying_dijkstra()

	## Signed measure kernel
	test_smk()
	test_torch2st_delayed()
	

	 
