
import numpy as np
import gudhi as gd
import multipers as mp
from numpy import ndarray
from tqdm import tqdm
from itertools import product
from sklearn.neighbors import KernelDensity
from sklearn.base import BaseEstimator, TransformerMixin, clone
from warnings import warn
from .signed_betti import *
from .invariants_with_persistable import *
from .sliced_wasserstein import *
from types import FunctionType
from typing import Callable
from joblib import Parallel, delayed, cpu_count
from os.path import exists
from typing import Iterable
from torch import Tensor
import pandas as pd
from warnings import warn
import matplotlib.pyplot as plt
import MDAnalysis as mda
from scipy.spatial import distance_matrix
from scipy.ndimage import gaussian_filter
from MDAnalysis.topology.guessers import guess_masses
from sklearn.preprocessing import LabelEncoder
import networkx as nx
from numba import njit, cfunc
import GraphRicciCurvature
from .convolutions.convolutions import convolution_signed_measures

def get_simplex_tree_from_delayed(x)->mp.SimplexTreeMulti:
	f,args, kwargs = x
	return f(*args,**kwargs)

def get_simplextree(x)->mp.SimplexTreeMulti:
	if isinstance(x, mp.SimplexTreeMulti):
		return x
	if len(x) == 3 and isinstance(x[0],FunctionType):
		return get_simplex_tree_from_delayed(x)
	else:
		raise TypeError("Not a valid SimplexTree !")
	return


def infer_grid_from_points(pts:np.ndarray, num:int, strategy:str):
	if strategy =="regular":
		min = np.min(pts, axis=0)
		max = np.max(pts, axis=0)
		return np.linspace(min, max, num=num).T
	if strategy =="quantile":
		return np.quantile(pts, q=np.linspace(0,1,num), axis=0).T
	if strategy == "exact":
		F = [np.unique(pts[:,i]) for i in range(pts.shape[1])]
		F = [np.linspace(f.min(), f.max(), num=num) if len(f) > num else f for f in F] # fallback to regular if too large
		return F

	raise Exception(f"Grid strategy {strategy} not implemented")

def get_filtration_weights_grid(num_parameters:int=2, resolution:int|Iterable[int]=3,*, min:float=0, max:float=20, dtype=float, remove_homothetie:bool=True, weights=None):
	"""
	Provides a grid of weights, for filtration rescaling.
	 - num parameter : the dimension of the grid tensor
	 - resolution :  the size of each coordinate
	 - min : minimum weight
	 - max : maximum weight
	 - weights : custom weights (instead of linspace between min and max)
	 - dtype : the type of the grid values (useful for int weights)
	"""
	from itertools import product
	# if isinstance(resolution, int):
	try:
		float(resolution)
		resolution = [resolution]*num_parameters
	except:
		None
	if weights is None:	weights = [np.linspace(start=min,stop=max,num=r, dtype=dtype) for r in resolution]
	try:
		float(weights[0]) # same weights for each filtrations
		weights = [weights] * num_parameters
	except:
		None
	out = np.asarray(list(product(*weights)))
	if remove_homothetie:
		_, indices = np.unique([x / x.max() for x in out if x.max() != 0],axis=0, return_index=True)
		out = out[indices]
	return list(out)



################################################# Data2SimplexTree
class RipsDensity2SimplexTree(BaseEstimator, TransformerMixin):
	def __init__(self, bandwidth:float=-0.1, threshold:float=np.inf, 
			sparse:float|None=None, num_collapse:int=0, 
			num_parameters:int=2, kernel:str="gaussian", delayed=False, rescale_density:float=0,
			progress:bool=False, n_jobs:int=-1, rtol:float=1e-4, atol=1e-6, fit_fraction:float=1,
		) -> None:
		super().__init__()
		self.bandwidth=bandwidth
		self.threshold = threshold
		self.sparse=sparse
		self.num_collapse=num_collapse
		self.num_parameters = num_parameters
		self.kernel = kernel
		self.delayed=delayed
		self.rescale_density = rescale_density
		self.progress=progress
		self._bandwidth=None
		self._threshold=None
		self.n_jobs = n_jobs
		self.rtol=rtol
		self.atol=atol
		self._scale=None
		self.fit_fraction=1
		return
	def _get_distance_quantiles(self, X, qs):
		if len(qs) == 0: 
			self._scale = []
			return []
		if self.progress: print("Estimating scale...", flush=True, end="")
		indices = np.random.choice(len(X),min(len(X), int(self.fit_fraction*len(X))+1) ,replace=False)
		distances = np.asarray([distance_matrix(x,x)[np.triu_indices(len(x),k=1)].flatten() for x in (X[i] for i in indices)]).flatten()
		# distances = distances[distances != 0]
		# qs = [q for q in [-self.bandwidth, -self.threshold] if 0 <= q <= 1]
		diameter = distances.max()
		if self.threshold > 0:	diameter = min(diameter, self.threshold)

		# if self.bandwidth < 0:	self._bandwidth = self._scale[0]
		# if self.threshold < 0:	self._threshold = self._scale[-1]

		# self._scale=np.quantile(distances, q=qs)
		self._scale = diameter * np.asarray(qs) 
		if self.progress: print(f"Done. Chosen scales {qs} are {self._scale}", flush=True)
		return self._scale
	def _get_st(self,x, bandwidth=None)->mp.SimplexTreeMulti:
		bandwidth = self._bandwidth if bandwidth is None else bandwidth
		kde=KernelDensity(bandwidth=bandwidth, kernel=self.kernel, rtol=self.rtol, atol=self.atol)
		st = gd.RipsComplex(points = x, max_edge_length=self._threshold, sparse=self.sparse).create_simplex_tree(max_dimension=1)
		st = mp.SimplexTreeMulti(st, num_parameters = self.num_parameters)
		kde.fit(x)
		codensity = -kde.score_samples(x)
		# if self.rescale_density != 0: # Not safe...
		# 	codensity -= codensity.min()
		# 	if codensity.max() != 0:	codensity /= codensity.max()
		# 	codensity *= self.rescale_density
		st.fill_lowerstar(codensity, parameter = 1)
		st.collapse_edges(num=self.num_collapse)
		st.collapse_edges(num=self.num_collapse, strong = False, max_dimension = 1) 
		return st
	def fit(self, X:np.ndarray|list, y=None):
		## default value 0.1 * diameter # TODO rescale density
		qs = [q for q in [-self.bandwidth, -self.threshold] if 0 <= q <= 1]
		self._get_distance_quantiles(X, qs=qs)
		self._bandwidth = self.bandwidth if self.bandwidth > 0 else self._scale[0]
		self._threshold = self.threshold if self.threshold > 0 else self._scale[-1]
		# self.bandwidth = "silverman" ## not good, as is can make bandwidth not constant
		return self

	
	def transform(self,X):
		with tqdm(X, desc="Computing simplextrees", disable = not self.progress or self.delayed) as data:
			if self.delayed:
				return [delayed(self._get_st)(x) for x in data] # delay the computation for the to_module pipe, as simplextrees are not pickle-able.
			return Parallel(n_jobs=self.n_jobs, prefer="threads")(delayed(self._get_st)(x) for x in data) # not picklable so prefer threads is necessary.

		
		
		
class RipsDensity2SimplexTrees(RipsDensity2SimplexTree):
	"""
	Same as the pipeline without the 's', but computes multiple bandwidths at once. 
	output shape :
	(data) x (bandwidth) x (simplextree)
	"""
	def __init__(self, bandwidths:Iterable[float]=-0.1, **rips_density_arguments) -> None:
		super().__init__(**rips_density_arguments)
		self.bandwidths=bandwidths
		self._bandwidths=None
		return
	def fit(self, X:np.ndarray|list, y=None):
		## default value 0.1 * diameter # TODO rescale density
		# if  np.any(np.array(self.bandwidths) < 0) or self.threshold < 0:
		# 	self._get_scale(X)
		# self._bandwidths = [- b * self._scale if b < 0 else b for b in self.bandwidths]
		# self._threshold = - self.threshold * self._scale if self.threshold < 0 else self.threshold
		
		qs = [q for q in [*-np.asarray(self.bandwidths), -self.threshold] if 0 <= q <= 1]
		self._get_distance_quantiles(X, qs=qs)
		self._bandwidths = np.asarray([b if b > 0 else s for b,s in zip(self.bandwidths, self._scale)])
		self._threshold = self.threshold if self.threshold > 0 else self._scale[-1]
		return self

	def _get_sts(self, x, bandwidths=None):
		bandwidths = self._bandwidths if bandwidths is None else bandwidths
		return [self._get_st(x, bandwidth=bandwidth) for bandwidth in bandwidths]
	def transform(self,X):
		with tqdm(X, desc="Computing simplextrees", disable= not self.progress and self.delayed) as data:
			if self.delayed:
				return [[delayed(self._get_st)(x, bandwidth=bandwidth) for bandwidth in self._bandwidths] for x in data] # delay the computation for the to_module pipe, as simplextrees are not pickle-able.
			return Parallel(n_jobs=self.n_jobs, prefer="threads")(delayed(self._get_sts)(x) for x in data) # not picklable so prefer threads is necessary.


		
		
class SimplexTreeEdgeCollapser(BaseEstimator, TransformerMixin):
	def __init__(self, num_collapses:int=0, full:bool=False, max_dimension:int|None=None, n_jobs:int=1) -> None:
		super().__init__()
		self.full=full
		self.num_collapses=num_collapses
		self.max_dimension=max_dimension
		self.n_job
		return
	def fit(self, X:np.ndarray|list, y=None):
		return self
	def transform(self,X):
		edges_list = Parallel(n_jobs=-1, prefer="threads")(delayed(mp.SimplextreeMulti.get_edge_list)(x) for x in X)
		collapsed_edge_lists = Parallel(n_jobs=self.n_jobs)(delayed(mp._collapse_edge_list)(edges, full=self.full, num=self.num_collapses) for edges in edges_list) ## 
		collapsed_simplextrees = Parallel(n_jobs=-1, prefer="threads")(delayed(mp.SimplexTreeMulti._reconstruct_from_edge_list)(x, swap=True, expand_dim = max_dimension))
		# for i in range(len(X)):
		# 	X[i].collapse_edges(full=self.full, num=self.num_collapses)
		return collapsed_simplextrees

#### MOLECULE DATA
# def _lines2bonds(_lines):
# 	out = []
# 	index = 0
# 	while index < len(_lines) and  _lines[index].strip() != "@<TRIPOS>BOND":
# 		index += 1
# 	index += 1
# 	while index < len(_lines) and  _lines[index].strip()[0] != "@":
# 		line = _lines[index].strip().split(" ")
# 		for j,truc in enumerate(line):
# 			line[j] = truc.strip()
# 		# try:
# 		out.append([int(stuff) for stuff in line if len(stuff) > 0])
# 		# except:
# 		# 	print_lin
# 		index +=1
# 	out = pd.DataFrame(out, columns=["bond_id","atom1", "atom2", "bond_type"])
# 	out.set_index(["bond_id"],inplace=True)
# 	return out
# def _get_mol2_file(path:str, num_cols:int=9, columns:dict|None=None):
# 	from biopandas.mol2 import split_multimol2,PandasMol2
# 	columns={
# 		0:('atom_id', int), 
# 		1:('atom_name', str),
# 		2:('x', float), 
# 		3:('y', float), 
# 		4:('z', float), 
# 		5:('atom_type', str), 
# 		6:('subst_id', int), 
# 		7:('subst_name', str), 
# 		8:('charge', float)
# 	} if columns is None else columns
# 	while len(columns) > num_cols:
# 		columns.pop(len(columns)-1)
# 	# try:
# 	molecules_dfs = []
# 	bonds_dfs = []
# 	for molecule in split_multimol2(path):
# 		_code, _lines = molecule
# 		try:
# 			bonds_dfs.append(_lines2bonds(_lines))
# 			molecule_df = PandasMol2().read_mol2_from_list(mol2_lines=_lines, mol2_code=_code, columns=columns).df
# 		except:
# 			print(_code)
# 			print(_lines)
# 		molecule_df.set_index(["atom_id"], inplace=True)
# 		molecules_dfs.append(molecule_df)        
# 	# except:
# 	#     return get_mol2_file(path=path, num_cols=num_cols-1)
# 	return molecules_dfs, bonds_dfs
# def _atom_to_mass(atom)->int:
# 	return ELEMENTS[atom].mass
# 	raise Exception(f" Atom {atom} has no registered mass.")
def lines2bonds(path:str):
	_lines = open(path, "r").readlines()
	out = []
	index = 0
	while index < len(_lines) and  _lines[index].strip() != "@<TRIPOS>BOND":
		index += 1
	index += 1
	while index < len(_lines) and  _lines[index].strip()[0] != "@":
		line = _lines[index].strip().split(" ")
		for j,truc in enumerate(line):
			line[j] = truc.strip()
		# try:
		out.append([stuff for stuff in line if len(stuff) > 0])
		# except:
		# 	print_lin
		index +=1
	out = pd.DataFrame(out, columns=["bond_id","atom1", "atom2", "bond_type"])
	out.set_index(["bond_id"],inplace=True)
	return out
def _mol2st(path:str|mda.Universe, bond_length:bool = False, charge:bool=False, atomic_mass:bool=False, bond_type=False, **kwargs):
	molecule = path if isinstance(path, mda.Universe) else mda.Universe(path, format="MOL2") 
	# if isinstance(bonds_df, list):	
	# 	if len(bonds_df) > 1:	warn("Multiple molecule found in the same data ! Taking the first only.")
	# 	molecule_df = molecule_df[0]
	# 	bonds_df = bonds_df[0]
	num_filtrations = bond_length + charge + atomic_mass + bond_type
	nodes = molecule.atoms.indices.reshape(1,-1)
	edges = molecule.bonds.dump_contents().T
	num_vertices = nodes.shape[1]
	num_edges =edges.shape[1]
	
	st = mp.SimplexTreeMulti(num_parameters = num_filtrations)
	
	## Edges filtration
	# edges = np.array(bonds_df[["atom1", "atom2"]]).T
	edges_filtration = np.zeros((num_edges, num_filtrations), dtype=np.float32) - np.inf
	if bond_length:
		bond_lengths = molecule.bonds.bonds()
		edges_filtration[:,0] = bond_lengths
	if bond_type:
		if isinstance(path, mda.Universe):
			TypeError("Expected path as input to compute bounds type. MDA doesn't handle it.")
		bond_types = LabelEncoder().fit([0,1,2,3,"am","ar"]).transform(lines2bonds(path=path)["bond_type"])
		edges_filtration[:,int(bond_length)] = bond_types

	## Nodes filtration
	nodes_filtrations = np.zeros((num_vertices,num_filtrations), dtype=np.float32) + np.min(edges_filtration, axis=0) # better than - np.inf
	st.insert_batch(nodes, nodes_filtrations)

	st.insert_batch(edges, edges_filtration)
	if charge:
		charges = molecule.atoms.charges
		st.fill_lowerstar(charges, parameter=int(bond_length + bond_type))
		# raise Exception("TODO")
	if atomic_mass:
		masses = molecule.atoms.masses
		null_indices = masses == 0
		if np.any(null_indices): # guess if necessary
			masses[null_indices] = guess_masses(molecule.atoms.types)[null_indices]
		st.fill_lowerstar(-masses, parameter=int(bond_length+bond_type+charge))
	st.make_filtration_non_decreasing()
	return st

class Molecule2SimplexTree(BaseEstimator, TransformerMixin):
	"""
	Transforms a list of mol2 files into a list of mulitparameter simplextrees
	Input:
	
	 X: Iterable[path_to_files:str]
	Output:

	 Iterable[multipers.SimplexTreeMulti]
	"""
	def __init__(self, atom_columns:Iterable[str]|None=None, 
			atom_num_columns:int=9, max_dimension:int|None=None, delayed:bool=False, 
			progress:bool=False, 
			bond_length_filtration:bool=False,
			bond_type_filtration:bool=False,
			charge_filtration:bool=False, 
			atomic_mass_filtration:bool=False, 
			n_jobs:int=1) -> None:
		super().__init__()
		self.max_dimension=max_dimension
		self.delayed=delayed
		self.progress=progress
		self.bond_length_filtration = bond_length_filtration
		self.charge_filtration = charge_filtration
		self.atomic_mass_filtration = atomic_mass_filtration
		self.bond_type_filtration = bond_type_filtration
		self.n_jobs = n_jobs
		self.atom_columns = atom_columns
		self.atom_num_columns = atom_num_columns
		self.num_parameters = self.charge_filtration + self.bond_length_filtration + self.bond_type_filtration + self.atomic_mass_filtration
		return
	def fit(self, X:Iterable[str], y=None):
		if len(X) == 0:	return self
		return self
	def transform(self,X:Iterable[str]):
		def to_simplex_tree(path_to_mol2_file:str):
			simplex_tree = _mol2st(path=path_to_mol2_file, 
				bond_type=self.bond_type_filtration,
				bond_length=self.bond_length_filtration,
				charge=self.charge_filtration,
				atomic_mass=self.atomic_mass_filtration,
			)
			return simplex_tree
		if self.delayed:
			return [delayed(to_simplex_tree)(path) for path in X]
		return Parallel(n_jobs=self.n_jobs, prefer="threads")(delayed(to_simplex_tree)(path) for path in X)



class Graph2SimplexTree(BaseEstimator,TransformerMixin):
	"""
	Transforms a list of networkx graphs into a list of simplextree multi
	
	Usual Filtrations
	-----------------
	- "cc" closeness centrality
	- "geodesic" if the graph provides data to compute it, e.g., BZR, COX2, PROTEINS
	- "degree" 
	- "ricciCurvature" the ricci curvature
	- "fiedler" the square of the fiedler vector
	"""
	def __init__(self, filtrations:Iterable[str]=["ricciCurvature", "cc", "degree"], delayed=False, num_collapses=100):
		super().__init__()
		self.filtrations=filtrations # filtration to search in graph
		self.delayed = delayed # reverses the filtration #TODO
		self.num_collapses=num_collapses
	def fit(self, X, y=None):
		return self
	def transform(self,X:list[nx.Graph]):
		def todo(graph, filtrations=self.filtrations) -> mp.SimplexTreeMulti: 
			st = mp.SimplexTreeMulti(num_parameters=len(filtrations))
			nodes = np.asarray(graph.nodes, dtype=int).reshape(1,-1)
			nodes_filtrations = np.asarray([[graph.nodes[node][filtration] for filtration in filtrations] for node in graph.nodes], dtype=np.float32)
			st.insert_batch(nodes, nodes_filtrations)
			edges = np.asarray(graph.edges, dtype=int).T
			edges_filtrations = np.asarray([[graph[u][v][filtration] for filtration in filtrations] for u,v in graph.edges], dtype=np.float32)
			st.insert_batch(edges,edges_filtrations)
			if st.num_parameters == 2:	st.collapse_edges(num=self.num_collapses) # TODO : wait for a filtration domination update
			st.make_filtration_non_decreasing() ## Ricci is not safe ...
			return st
		return [delayed(todo)(graph) for graph in X] if self.delayed else Parallel(n_jobs=-1, prefer="threads")(delayed(todo)(graph) for graph in tqdm(X, desc="Computing simplextrees from graphs"))



class SimplexTree2MMA(BaseEstimator, TransformerMixin):
	"""
	Turns a list of simplextrees to MMA approximations
	"""
	def __init__(self,n_jobs=-1, **persistence_kwargs) -> None:
		super().__init__()
		self.persistence_args = persistence_kwargs
		self.n_jobs=n_jobs
		self._is_input_delayed=None
		return		
	def fit(self, X, y=None):
		self._is_input_delayed = not isinstance(X[0], mp.SimplexTreeMulti)
		return self
	def transform(self,X)->list[mp.PyModule]:
		if self._is_input_delayed:
			todo = lambda x : get_simplex_tree_from_delayed(x).persistence_approximation(**self.persistence_args)
		else:
			todo = lambda x : x.persistence_approximation(**self.persistence_args)
		return Parallel(n_jobs=self.n_jobs, prefer="threads")(delayed(todo)(x) for x in X)

class MMA2Landscape(BaseEstimator, TransformerMixin):
	"""
	Turns a list of MMA approximations into Landscapes vectorisations
	"""
	def __init__(self, resolution=[100,100], degrees:list[int]|None = [0,1], ks:Iterable[int]=range(5), phi:Callable = np.sum, box=None, plot:bool=False, n_jobs=-1, filtration_quantile:float=0.01) -> None:
		super().__init__()
		self.resolution:list[int]=resolution
		self.degrees = degrees
		self.ks=ks
		self.phi=phi # Has to have a axis=0 !
		self.box = box
		self.plot = plot
		self.n_jobs=n_jobs
		self.filtration_quantile = filtration_quantile
		return
	def fit(self, X, y=None):
		if len(X) <= 0:	return
		assert X[0].num_parameters == 2, f"Number of parameters {X[0].num_parameters} has to be 2."
		if self.box is None:
			_bottom = lambda mod : mod.get_bottom()
			_top = lambda mod : mod.get_top()
			m = np.quantile(Parallel(n_jobs=self.n_jobs, prefer="threads")(delayed(_bottom)(mod) for mod in X), q=self.filtration_quantile, axis=0)
			M = np.quantile(Parallel(n_jobs=self.n_jobs, prefer="threads")(delayed(_top)(mod) for mod in X), q=1-self.filtration_quantile, axis=0)
			self.box=[m,M]
		return self
	def transform(self,X)->list[np.ndarray]:
		if len(X) <= 0:	return
		todo = lambda mod : np.concatenate([
				self.phi(mod.landscapes(ks=self.ks, resolution = self.resolution, degree=degree, plot=self.plot), axis=0).flatten()
				for degree in self.degrees
			]).flatten()
		return Parallel(n_jobs=self.n_jobs, prefer="threads")(delayed(todo)(x) for x in X)

############################################### Data2Signedmeasure

def tensor_möbius_inversion(tensor:Tensor|np.ndarray, grid_conversion:Iterable[np.ndarray]|None = None, plot:bool=False, raw:bool=False, num_parameters:int|None=None):
	betti_sparse = Tensor(tensor.copy()).to_sparse() # Copy necessary in some cases :(
	num_indices, num_pts = betti_sparse.indices().shape
	num_parameters = num_indices if num_parameters is None else num_parameters
	if num_indices == num_parameters: # either hilbert or rank invariant
		rank_invariant = False
	elif 2*num_parameters == num_indices:
		rank_invariant = True
	else:
		raise TypeError(f"Unsupported betti shape. {num_indices} has to be either {num_parameters} or {2*num_parameters}.")
	points_filtration = np.asarray(betti_sparse.indices().T, dtype=int)
	weights = np.asarray(betti_sparse.values(), dtype=int)

	if grid_conversion is not None:
		coords = np.empty(shape=(num_pts,num_indices), dtype=float)
		for i in range(num_indices):
			coords[:,i] = grid_conversion[i%num_parameters][points_filtration[:,i]]
	else:
		coords = points_filtration
	if (not rank_invariant) and plot:
		plt.figure()
		plt.scatter(points_filtration[:,0],points_filtration[:,1], c=weights)
		plt.colorbar()
	if (not rank_invariant) or raw: return coords, weights
	def _is_trivial(rectangle:np.array):
		birth=rectangle[:num_parameters]
		death=rectangle[num_parameters:]
		return np.all(birth<=death) # and not np.array_equal(birth,death)
	correct_indices = np.array([_is_trivial(rectangle) for rectangle in coords])
	if len(correct_indices) == 0:	return np.empty((0, num_indices)), np.empty((0))
	signed_measure = np.asarray(coords[correct_indices])
	weights = weights[correct_indices]
	if plot:
		assert signed_measure.shape[1] == 4 # plot only the rank decompo for the moment
		from matplotlib.pyplot import plot
		def _plot_rectangle(rectangle:np.ndarray, weight:float):
			x_axis=rectangle[[0,2]]
			y_axis=rectangle[[1,3]]
			color = "blue" if weight > 0 else "red"
			plot(x_axis, y_axis, c=color)
		for rectangle, weight in zip(signed_measure, weights):
			_plot_rectangle(rectangle=rectangle, weight=weight)
	return signed_measure, weights


class DegreeRips2SignedMeasure(BaseEstimator, TransformerMixin):
	def __init__(self, degrees:Iterable[int], min_rips_value:float, 
	      max_rips_value,max_normalized_degree:float, min_normalized_degree:float, 
		  grid_granularity:int, progress:bool=False, n_jobs=1, sparse:bool=False, 
		  _möbius_inversion=True,
		  fit_fraction=1,
		  ) -> None:
		super().__init__()
		self.min_rips_value = min_rips_value
		self.max_rips_value = max_rips_value
		self.min_normalized_degree = min_normalized_degree
		self.max_normalized_degree = max_normalized_degree
		self._max_rips_value = None
		self.grid_granularity = grid_granularity
		self.progress=progress
		self.n_jobs = n_jobs
		self.degrees = degrees
		self.sparse=sparse
		self._möbius_inversion = _möbius_inversion
		self.fit_fraction=fit_fraction
		return
	def fit(self, X:np.ndarray|list, y=None):
		if self.max_rips_value < 0:
			print("Estimating scale...", flush=True, end="")
			indices = np.random.choice(len(X),min(len(X), int(self.fit_fraction*len(X))+1) ,replace=False)
			diameters =np.max([distance_matrix(x,x).max() for x in (X[i] for i in indices)])
			print(f"Done. {diameters}", flush=True)
		self._max_rips_value = - self.max_rips_value * diameters if self.max_rips_value < 0 else self.max_rips_value
		return self
	
	def _transform1(self, data:np.ndarray):
		_distance_matrix = distance_matrix(data, data)
		signed_measures = []
		rips_values, normalized_degree_values, hilbert_functions, minimal_presentations = hf_degree_rips(
			_distance_matrix,
			min_rips_value = self.min_rips_value,
			max_rips_value = self._max_rips_value,
			min_normalized_degree = self.min_normalized_degree,
			max_normalized_degree = self.max_normalized_degree,
			grid_granularity = self.grid_granularity,
			max_homological_dimension = np.max(self.degrees),
		)
		for degree in self.degrees:
			hilbert_function = hilbert_functions[degree]
			signed_measure = signed_betti(hilbert_function, threshold=True) if self._möbius_inversion else hilbert_function
			if self.sparse:
				signed_measure = tensor_möbius_inversion(
					tensor=signed_measure,num_parameters=2,
					grid_conversion=[rips_values, normalized_degree_values]
				)
			if not self._möbius_inversion: signed_measure = signed_measure.flatten()
			signed_measures.append(signed_measure)
		return signed_measures
	def transform(self,X):
		return Parallel(n_jobs=self.n_jobs)(delayed(self._transform1)(data) 
		for data in tqdm(X, desc=f"Computing DegreeRips, of degrees {self.degrees}, signed measures.", disable = not self.progress))






################################################# SimplexTree2...



def _st2ranktensor(st:mp.SimplexTreeMulti, filtration_grid:np.ndarray, degree:int, plot:bool, reconvert_grid:bool, num_collapse:int|str=0):
	"""
	TODO
	"""
	## Copy (the squeeze change the filtration values)
	stcpy = mp.SimplexTreeMulti(st)
	# turns the simplextree into a coordinate simplex tree
	stcpy.grid_squeeze(
		filtration_grid = filtration_grid, 
		coordinate_values = True)
	# stcpy.collapse_edges(num=100, strong = True, ignore_warning=True)
	if num_collapse == "full":
		stcpy.collapse_edges(full=True, ignore_warning=True, max_dimension=degree+1)
	elif isinstance(num_collapse, int):
		stcpy.collapse_edges(num=num_collapse,ignore_warning=True, max_dimension=degree+1)
	else:
		raise TypeError(f"Invalid num_collapse={num_collapse} type. Either full, or an integer.")
	# computes the rank invariant tensor
	rank_tensor = mp.rank_invariant2d(stcpy, degree=degree, grid_shape=[len(f) for f in filtration_grid])
	# refactor this tensor into the rectangle decomposition of the signed betti
	grid_conversion = filtration_grid if reconvert_grid else None 
	rank_decomposition = rank_decomposition_by_rectangles(
		rank_tensor, threshold=True,
		)
	rectangle_decomposition = tensor_möbius_inversion(tensor = rank_decomposition, grid_conversion = grid_conversion, plot=plot, num_parameters=st.num_parameters)
	return rectangle_decomposition

class SimplexTree2RectangleDecomposition(BaseEstimator,TransformerMixin):
	"""
	Transformer. 2 parameter SimplexTrees to their respective rectangle decomposition. 
	"""
	def __init__(self, filtration_grid:np.ndarray, degrees:Iterable[int], plot=False, reconvert_grid=True, num_collapses:int=0):
		super().__init__()
		self.filtration_grid = filtration_grid
		self.degrees = degrees
		self.plot=plot
		self.reconvert_grid = reconvert_grid
		self.num_collapses=num_collapses
		return
	def fit(self, X, y=None):
		"""
		TODO : infer grid from multiple simplextrees
		"""
		return self
	def transform(self,X:Iterable[mp.SimplexTreeMulti]):
		rectangle_decompositions = [
			[_st2ranktensor(
				simplextree, filtration_grid=self.filtration_grid,
				degree=degree,
				plot=self.plot,
				reconvert_grid = self.reconvert_grid,
				num_collapse=self.num_collapses
			) for degree in self.degrees]
			for simplextree in X
		]
		## TODO : return iterator ?
		return rectangle_decompositions


def betti_matrix2signed_measure(betti:coo_array|np.ndarray, grid_conversion:Iterable[np.ndarray]|None = None):
	if isinstance(betti, np.ndarray):   betti = coo_array(betti)
	points_filtration = np.empty(shape=(betti.getnnz(),2), dtype=int) # coo matrix is only for matrices -> 2d
	points_filtration[:,0] = betti.row
	points_filtration[:,1] = betti.col
	weights = np.array(betti.data, dtype=int)
	if grid_conversion is not None:
		coords = np.empty(shape=(betti.getnnz(),2), dtype=float)
		for i in range(2):
			coords[:,i] = grid_conversion[i][points_filtration[:,i]]
	else:
		coords = points_filtration
	return coords, weights


class SimplexTree2SignedMeasure(BaseEstimator,TransformerMixin):
	"""
	Input
	-----
	Iterable[SimplexTreeMulti]

	Output
	------
	Iterable[ list[signed_measure for degree] ]

	signed measure is either (points : (n x num_parameters) array, weights : (n) int array ) if sparse, else an integer matrix.

	Parameters
	----------
	 - degrees : list of degrees to compute
	 - filtration grid : the grid on which to compute. If None, the fit will infer it from
	   - fit_fraction : the fraction of data to consider for the fit, seed is controlled by the seed parameter
	   - resolution : the resolution of this grid
	   - filtration_quantile : filtrations values quantile to ignore
	   - infer_filtration_strategy:str : 'regular' or 'quantile' or 'exact'
	   - normalize filtration : if sparse, will normailze all filtrations.
	 - expand : expands the simplextree to compute correctly the degree, for flag complexes
	 - invariant : the topological invariant to produce the signed measure. Choices are "hilbert" or "euler". Will add rank invariant later.
	 - num_collapse : Either an int or "full". Collapse the complex before doing computation.
	 - _möbius_inversion : if False, will not do the mobius inversion. output has to be a matrix then.
	 - enforce_null_mass : Returns a zero mass measure, by thresholding the module if True.
	"""
	def __init__(self, 
			degrees:list[int]|None=None, # homological degrees
			filtration_grid:Iterable[np.ndarray]|None=None, # filtration values to consider. Format : [ filtration values of Fi for Fi:filtration values of parameter i] 
			progress=False, # tqdm
			num_collapses:int|str=0, # edge collapses before computing 
			n_jobs=1, 
			resolution:Iterable[int]|int|None=None, # when filtration grid is not given, the resolution of the filtration grid to infer
			sparse=True, # sparse output
			plot:bool=False, 
			filtration_quantile:float=0., # quantile for inferring filtration grid
			_möbius_inversion:bool=True, # wether or not to do the möbius inversion (not recommended to touch)
			expand=True, # expand the simplextree befoe computing the homology
			normalize_filtrations:bool=False,
			# exact_computation:bool=False, # compute the exact signed measure.
			infer_filtration_strategy:str="regular",
			seed:int=0, # if fit_fraction is not 1, the seed sampling
			fit_fraction = 1,  # the fraction of the data on which to fit
			invariant="_", 
			out_resolution:Iterable[int]|int|None=None,
			# _true_exact:bool=False,
			enforce_null_mass:bool=True,
		  ):
		super().__init__()
		self.degrees = degrees
		self.filtration_grid = filtration_grid
		self.progress = progress
		self.num_collapses=num_collapses
		self.n_jobs = cpu_count() if n_jobs <= 0 else n_jobs
		self.resolution = resolution
		self.plot=plot
		self.sparse=sparse
		self.filtration_quantile=filtration_quantile
		self.normalize_filtrations = normalize_filtrations # Will only work for non sparse output. (discrete matrices cannot be "rescaled")
		self.infer_filtration_strategy = infer_filtration_strategy
		assert not self.normalize_filtrations or self.sparse, "Not able to normalize a matrix without losing information. Will not normalize."
		assert resolution is not None or filtration_grid is not None or infer_filtration_strategy == "exact"
		self.num_parameter = None
		self._is_input_delayed = None
		self._möbius_inversion = _möbius_inversion
		self._reconversion_grid = None
		self.expand = expand
		self._refit_grid = filtration_grid is None # will only refit the grid if filtration_grid has never been given.
		self.seed=seed
		self.fit_fraction = fit_fraction
		self.invariant = invariant
		self._transform_st = None
		self._to_simplex_tree = None
		self.out_resolution = out_resolution
		# self._true_exact=_true_exact
		self.enforce_null_mass = enforce_null_mass
		return	

	# def _fit_exact(self, X):
	# 	print("Inferring filtration grid from simplextrees...", end="", flush=True)
	# 	np.random.seed(self.seed)
	# 	# subsample if fit_fraction < 1
	# 	indices = np.random.choice(len(X), min(int(self.fit_fraction* len(X)) +1, len(X)), replace=False)
	# 	get_st_filtration = lambda x : self._to_simplex_tree(x).get_filtration_grid(grid_strategy="exact")
	# 	prefer = "processes" if self._is_input_delayed else "threads"
		
	# 	# gets all filtration values
	# 	filtrations =  Parallel(n_jobs=self.n_jobs, prefer=prefer)(delayed(get_st_filtration)(x) for x in (X[idx] for idx in indices))
	# 	num_parameters = len(filtrations[0])
		
	# 	# unique + sort
	# 	filtrations = [np.unique(np.concatenate([x[i] for x in filtrations])) for i in range(num_parameters)]

	# 	# If the numer of exact filtrations is too high, will replace the np.unique by a linspace
	# 	if self.resolution is not None:
	# 		for i,(f,r) in enumerate(zip(filtrations, self.resolution)):
	# 			if len(f) > r:
	# 				m,M = f[0], f[-1]
	# 				filtrations[i] = np.linspace(start=m, stop=M, num=r)
		
	# 	# Adds a last one, to take essensial summands into account
	# 	for i,f in enumerate(filtrations):
	# 		m,M = f[0], f[-1]
	# 		# filtrations[i] = np.insert(np.append(f, M + 0.1 * (M-m)), 0, m - 0.1 * (M-m))
	# 		filtrations[i] = np.append(f, M + 0.1 * (M-m))
		
	# 	self.filtration_grid = filtrations
	# 	print("Done.")

	# def _refit(self,X):
	# 	print("Inferring filtration grid from simplextrees...", end="", flush=True)
	# 	np.random.seed(self.seed)
	# 	indices = np.random.choice(len(X), min(int(self.fit_fraction* len(X)) +1, len(X)), replace=False)
	# 	get_filtration_bounds = lambda x : self._to_simplex_tree(x).filtration_bounds(q=self.filtration_quantile)
	# 	prefer = "processes" if self._is_input_delayed else "threads"
	# 	filtration_bounds =  Parallel(n_jobs=self.n_jobs, prefer=prefer)(delayed(get_filtration_bounds)(x) for x in (X[idx] for idx in indices))
	# 	box = [np.min(filtration_bounds, axis=(0,1)), np.max(filtration_bounds, axis=(0,1))]
	# 	diameter = np.max(box[1] - box[0])
	# 	box = np.array([box[0] -0.1*diameter, box[1] + 0.1 * diameter])
	# 	self.filtration_grid = [np.linspace(*np.asarray(box)[:,i], num=self.resolution[i]) for i in range(len(box[0]))]
	# 	print("Done.")
	# 	# self._reconversion_grid = [np.linspace(0,1, num=len(f)) for f in self.filtration_grid] if self.normalize_filtrations else self.filtration_grid
	
	def _infer_filtration(self,X):
		print(f"Inferring filtration grid from simplextrees, with strategy {self.infer_filtration_strategy}...", end="", flush=True)
		np.random.seed(self.seed)
		indices = np.random.choice(len(X), min(int(self.fit_fraction* len(X)) +1, len(X)), replace=False)
		prefer = "processes" if self._is_input_delayed else "threads"
		if self.infer_filtration_strategy == "regular":
			get_filtration_bounds = lambda x : self._to_simplex_tree(x).filtration_bounds(q=self.filtration_quantile)
			filtration_bounds =  Parallel(n_jobs=self.n_jobs, prefer=prefer)(delayed(get_filtration_bounds)(x) for x in (X[idx] for idx in indices))
			box = [np.min(filtration_bounds, axis=(0,1)), np.max(filtration_bounds, axis=(0,1))]
			diameter = np.max(box[1] - box[0])
			box = np.array([box[0] -0.1*diameter, box[1] + 0.1 * diameter])
			self.filtration_grid = [np.linspace(*np.asarray(box)[:,i], num=self.resolution[i]) for i in range(len(box[0]))]
			print("Done.")
			return
		get_st_filtration = lambda x : self._to_simplex_tree(x).get_filtration_grid(grid_strategy="exact")
		filtrations =  Parallel(n_jobs=self.n_jobs, prefer=prefer)(delayed(get_st_filtration)(x) for x in (X[idx] for idx in indices))
		num_parameters = len(filtrations[0])

		if self.infer_filtration_strategy == "exact":
			# unique + sort
			filtrations = [np.unique(np.concatenate([x[i] for x in filtrations])) for i in range(num_parameters)]
			# If the numer of exact filtrations is too high, will replace the np.unique by a linspace
			if self.resolution is not None:
				for i,(f,r) in enumerate(zip(filtrations, self.resolution)):
					if len(f) > r:
						m,M = f[0], f[-1]
						filtrations[i] = np.linspace(start=m, stop=M, num=r)
		elif self.infer_filtration_strategy == "quantile":
			filtrations = [np.unique(np.quantile(np.concatenate([x[i] for x in filtrations]), q=np.linspace(0,1,num=self.resolution[i]))) for i in range(num_parameters)]
		else:
			raise Exception(f"Strategy {self.infer_filtration_strategy} is not implemented. Available are regular, exact, quantile.")
		# Adds a last one, to take essensial summands into account
		for i,f in enumerate(filtrations):
			m,M = f[0], f[-1]
			filtrations[i] = np.unique(np.append(f, M + 0.1 * (M-m)))
		
		self.filtration_grid = filtrations
		print("Done.")
		return

	def fit(self, X, y=None): # Todo : infer filtration grid ? quantiles ?
		assert self.invariant != "_" or self._möbius_inversion
		self._is_input_delayed = not isinstance(X[0], mp.SimplexTreeMulti)
		if self._is_input_delayed:
			self._to_simplex_tree = get_simplex_tree_from_delayed
		else:
			self._to_simplex_tree = lambda x : x
		if isinstance(self.resolution, int):
			self.resolution = [self.resolution]*self._to_simplex_tree(X[0]).num_parameters
		self.num_parameter = len(self.filtration_grid) if self.resolution is None else len(self.resolution)
		# if self.exact_computation: 
		# 	self._fit_exact(X)
		# elif self._refit_grid:
		# 	self._refit(X)
		if self._refit_grid:
			self._infer_filtration(X=X)
		if self.out_resolution is None:
			self.out_resolution = self.resolution
		elif isinstance(self.out_resolution, int):
			self.out_resolution = [self.out_resolution]*len(self.resolution)
		
		if self.normalize_filtrations:
			# self._reconversion_grid = [np.linspace(0,1, num=len(f), dtype=float) for f in self.filtration_grid] ## This will not work for non-regular grids...
			self._reconversion_grid = [f/np.std(f) for f in self.filtration_grid] # not the best, but better than some weird magic
		# elif not self.sparse: # It actually renormalizes the filtration !!  
		# 	self._reconversion_grid = [np.linspace(0,r, num=r, dtype=int) for r in self.out_resolution] 
		else:
			self._reconversion_grid = self.filtration_grid
		# else: 
		# 	self._reconversion_grid = [np.linspace(0,1, num=,) for _ in range]
		
		if self.invariant == "hilbert":
			def transform_hilbert(simplextree:mp.SimplexTreeMulti, degree:int, grid_shape:Iterable[int], _reconversion_grid):
				hilbert = mp.hilbert(simplextree=simplextree, degree=degree, grid_shape=grid_shape)
				signed_measure = signed_betti(hilbert, threshold=self.enforce_null_mass, sparse=False) if self._möbius_inversion else hilbert
				if self.sparse:
					signed_measure = tensor_möbius_inversion(tensor = signed_measure, 
					grid_conversion=_reconversion_grid, plot = self.plot, num_parameters=len(grid_shape))
				return signed_measure
			self._transform_st = transform_hilbert
		elif self.invariant == "euler":
			assert self.degrees == [None], f"Invariant euler incompatible with degrees {self.degrees}"
			def transform_euler(simplextree:mp.SimplexTreeMulti, degree:int, grid_shape:Iterable[int], _reconversion_grid):
				euler = mp.euler(simplextree=simplextree, degree=degree, grid_shape=grid_shape)
				signed_measure = signed_betti(euler, threshold=self.enforce_null_mass, sparse=False) if self._möbius_inversion else euler
				if self.sparse:
					signed_measure = tensor_möbius_inversion(tensor = signed_measure, 
					grid_conversion=_reconversion_grid, plot = self.plot, num_parameters=len(grid_shape))
				return signed_measure
			self._transform_st = transform_euler
			# self.degrees = [1000] # For the expansion
		elif self.invariant == "_":
			assert self._möbius_inversion is True
			def transform_sm(simplextree:mp.SimplexTreeMulti, degree:int|None, grid_shape:Iterable[int], _reconversion_grid):
				signed_measure = mp.signed_measure(
					simplextree=simplextree,degree=degree, 
					grid_shape=grid_shape, zero_pad=self.enforce_null_mass, 
					grid_conversion=_reconversion_grid, 
					unsparse = False)

				if not self.sparse:
					# assert _reconversion_grid[0].dtype is int
					pts, weights = signed_measure
					bins = [[f.min(), f.max()] for f in _reconversion_grid]
					bins = [np.linspace(m-0.1*(M-m)/r, M+0.1*(M-m)/r, num=r+1) for (m,M),r in zip(bins, self.out_resolution)]
					signed_measure,_ = np.histogramdd(
						pts,bins=bins, 
						weights=weights
						)
					# print(signed_measure.shape)
				return signed_measure
			self._transform_st = transform_sm
		else:
			raise Exception(f"Bad invariant {self.invariant}. Pick either euler or hilbert.")
		return self
	# def _to_simplex_tree(self,x):
	# 	return get_simplex_tree_from_delayed(x) if self._is_input_delayed else  x
	
	# @staticmethod
	# def _transform_st(st,degree, grid_shape):
	# 	if 
	# 	return mp.hilbert(simplextree=st, degree=degree, grid_shape=grid_shape)
	def transform1(self, simplextree, filtration_grid=None, _reconversion_grid=None):
		if filtration_grid is None: filtration_grid = self.filtration_grid
		if _reconversion_grid is None: _reconversion_grid = self._reconversion_grid
		st = self._to_simplex_tree(simplextree)
		st = mp.SimplexTreeMulti(st, num_parameters = st.num_parameters) ## COPY
		st.grid_squeeze(filtration_grid = filtration_grid, coordinate_values = True)
		if st.num_parameters == 2:
			if self.num_collapses == "full":
				st.collapse_edges(full=True,max_dimension=1)
			elif isinstance(self.num_collapses, int):
				st.collapse_edges(num=self.num_collapses,max_dimension=1)
			else:
				raise Exception("Bad edge collapse type. either 'full' or an int.")
		signed_measures = []
		# print(st.num_simplices(),st.dimension(), self.degrees)
		if self.expand :
			max_degree = 1000 if self.degrees == [None] else np.max(self.degrees)+1
			st.expansion(max_degree)
		grid_shape = [len(f) for f in filtration_grid]
		for degree in self.degrees:
			# hilbert = self._transform_st(simplextree=st, degree=degree, grid_shape=grid_shape) # TODO : nd ?
			# signed_measure = signed_betti(hilbert, threshold=True, sparse=False) if self._möbius_inversion else hilbert.flatten() ## FOR BENCHMARK ONLY
			# if self.sparse:
			# 	signed_measure = tensor_möbius_inversion(tensor = signed_measure, 
			# 	grid_conversion=_reconversion_grid, plot = self.plot, num_parameters=len(filtration_grid))
			signed_measure = self._transform_st(
				simplextree=st,degree=degree,
				grid_shape=grid_shape,
				_reconversion_grid=_reconversion_grid
			)
			signed_measures.append(signed_measure)
		# return np.reshape(signed_measures, -1) if not self._möbius_inversion else signed_measures
		return signed_measures
	def transform(self,X):
		assert self.filtration_grid is not None and self._transform_st is not None
		prefer = "processes" if self._is_input_delayed else "threads"
		out = Parallel(n_jobs=self.n_jobs, prefer=prefer)(
			delayed(self.transform1)(to_st) for to_st in tqdm(X, disable = not self.progress, desc=f"Computing topological invariant {self.invariant}")
		)
		return out
		# return [self.transform1(x) for x in tqdm(X, disable = not self.progress, desc="Computing Hilbert function")]





class SimplexTrees2SignedMeasures(SimplexTree2SignedMeasure):
	"""
	Input
	-----
	
	(data) x (axis, e.g. different bandwidths for simplextrees) x (simplextree)
	
	Output
	------ 
	(data) x (axis) x (degree) x (signed measure)
	"""
	def __init__(self,**kwargs):
		super().__init__(**kwargs)
		self._num_st_per_data=None
		# self._super_model=SimplexTree2SignedMeasure(**kwargs)
		self._filtration_grids = None
		return
	def fit(self, X, y=None):
		from sklearn.base import clone
		if len(X[0]) == 0: return self
		self._num_st_per_data = len(X[0])
		self._filtration_grids=[]
		for axis in range(self._num_st_per_data):
			self._filtration_grids.append(super().fit([x[axis] for x in X]).filtration_grid)
			# self._super_fits.append(truc)
		# self._super_fits_params = [super().fit([x[axis] for x in X]).get_params() for axis in range(self._num_st_per_data)]
		return self
	def transform(self, X):
		if self.normalize_filtrations:
			_reconversion_grids = [[np.linspace(0,1, num=len(f), dtype=float) for f in F] for F in self._filtration_grids]
		else:
			_reconversion_grids = self._filtration_grids
		def todo(x):
			# return [SimplexTree2SignedMeasure().set_params(**transformer_params).transform1(x[axis]) for axis,transformer_params in enumerate(self._super_fits_params)]
			return [
				self.transform1(x[axis],filtration_grid=filtration_grid, _reconversion_grid=_reconversion_grid) 
				for axis, filtration_grid, _reconversion_grid in zip(range(self._num_st_per_data), self._filtration_grids, _reconversion_grids)]
		return Parallel(n_jobs=-1, prefer="threads")(delayed(todo)(x) for x in X)
	# def _refit(self,X):
	# 	print("Inferring filtration grid from simplextrees...", end="", flush=True)
	# 	indices = np.random.choice(len(X), int(0.3 * len(X)) +1, replace=False)
	# 	get_filtration_bounds_of_st = lambda x : self._to_simplex_tree(x).filtration_bounds(q=self.filtration_quantile)

	# 	prefer = "processes" if self._is_input_delayed else "threads"

	# 	get_filtration_bounds_of_axis = lambda axis : Parallel(n_jobs=self.n_jobs, prefer=prefer)(delayed(get_filtration_bounds_of_st)(x[axis]) for x in (X[idx] for idx in indices))
	# 	filtration_bounds_of_axes = Parallel(n_jobs=self.n_jobs, prefer=prefer)(
	# 		delayed(get_filtration_bounds_of_axis)(axis) for axis in range(self._num_st_per_data)
	# 	)

	# 	self.filtration_grids = []
	# 	for axis, filtration_bounds in enumerate(filtration_bounds_of_axes):
	# 		box = [np.min(filtration_bounds, axis=(0,1)), np.max(filtration_bounds, axis=(0,1))]
	# 		diameter = np.max(box[1] - box[0])
	# 		box = np.array([box[0] - 0.1* diameter, box[1] + 0.1 * diameter])
	# 		self.filtration_grids.append([np.linspace(*np.asarray(box)[:,i], num=self.resolution[i]) for i in range(len(box[0]))])
	# 	print("Done.")
	# def fit(self, X, y=None): # Todo : infer filtration grid ? quantiles ?
	# 	if len(X) == 0: return self

	# 	self._num_st_per_data = len(X[0])
	# 	if self.filtration_grids is None and self.filtration_grid is not None:
	# 		self.filtration_grids = [self.filtration_grid]*self._num_st_per_data
	# 	self._is_input_delayed = not isinstance(X[0][0], mp.SimplexTreeMulti)
	# 	if isinstance(self.resolution, int):
	# 		self.resolution = [self.resolution]*self._to_simplex_tree(X[0][0]).num_parameters
	# 	if self._refit_grid:
	# 		self._refit(X)
	# 	self._reconversion_grid = [[np.linspace(0,1, num=len(f)) for f in self.filtration_grid]]*self._num_st_per_data if self.normalize_filtrations else self.filtration_grids
	# 	if self.invariant == "hilbert":
	# 		def transform_hilbert(simplextree:mp.SimplexTreeMulti, degree:int, grid_shape:Iterable[int], _reconversion_grid):
	# 			hilbert = mp.hilbert(simplextree=simplextree, degree=degree, grid_shape=grid_shape)
	# 			signed_measure = signed_betti(hilbert, threshold=True, sparse=False) if self._möbius_inversion else hilbert.flatten()
	# 			if self.sparse:
	# 				signed_measure = tensor_möbius_inversion(tensor = signed_measure, 
	# 				grid_conversion=_reconversion_grid, plot = self.plot, num_parameters=len(grid_shape))
	# 			return signed_measure
	# 		self._transform_st = transform_hilbert
	# 	elif self.invariant == "euler":
	# 		def transform_euler(simplextree:mp.SimplexTreeMulti, degree:int, grid_shape:Iterable[int], _reconversion_grid):
	# 			euler = mp.euler(simplextree=simplextree, degree=degree, grid_shape=grid_shape)
	# 			signed_measure = signed_betti(euler, threshold=True, sparse=False) if self._möbius_inversion else euler.flatten()
	# 			if self.sparse:
	# 				signed_measure = tensor_möbius_inversion(tensor = signed_measure, 
	# 				grid_conversion=_reconversion_grid, plot = self.plot, num_parameters=len(grid_shape))
	# 			return signed_measure
	# 		self._transform_st = transform_euler
	# 		self.degrees = [1000] # For the expansion
	# 	elif self.invariant == "_":
	# 		def transform_sm(simplextree:mp.SimplexTreeMulti, degree:int, grid_shape:Iterable[int], _reconversion_grid):
	# 			signed_measure = mp.signed_measure(simplextree=simplextree,degree=degree, grid_shape=grid_shape, zero_pad=True, grid_conversion=_reconversion_grid)
	# 			return signed_measure
	# 		self._transform_st = transform_sm
	# 	else:
	# 		raise Exception(f"Bad invariant {self.invariant}. Pick either euler or hilbert.")
	# 	return self
	# def transformk(self, simplextrees):
	# 	return Parallel(n_jobs=self.n_jobs // 2 + 1, prefer="threads")(
	# 		delayed(self.transform1)(st, filtration_grid, renormalization_grid) for st, filtration_grid, renormalization_grid in zip(simplextrees, self.filtration_grids, self._reconversion_grid)
	# 		)
	# def transform(self, X,y=None):
	# 	assert self.filtration_grids is not None
	# 	prefer = "processes" if self._is_input_delayed else "threads" #simplextrees are not plicklable
	# 	return Parallel(n_jobs=self.n_jobs // 2 + 1, prefer=prefer)(
	# 		delayed(self.transformk)(to_st) for to_st in tqdm(X, disable = not self.progress, desc=f"Computing topological invariant {self.invariant}")
	# 	)


def rescale_sparse_signed_measure(signed_measure, filtration_weights, normalize_scales=None):
	from copy import deepcopy
	out = deepcopy(signed_measure)
	if normalize_scales is None:
		for degree in range(len(out)): # degree
			for parameter in range(len(filtration_weights)):
				out[degree][0][:,parameter] *= filtration_weights[parameter]
	else:
		for degree in range(len(out)):
			for parameter in range(len(filtration_weights)):
				out[degree][0][:,parameter] *= filtration_weights[parameter] / normalize_scales[degree][parameter]
	return out

class SignedMeasureFormatter(BaseEstimator,TransformerMixin):
	"""
	Input
	-----
	
	(data) x (degree) x (signed measure) or (data) x (axis) x (degree) x (signed measure)
	
	Iterable[list[signed_measure_matrix of degree]] or Iterable[previous].
	
	The second is meant to use multiple choices for signed measure input. An example of usage : they come from a Rips + Density with different bandwidth. 
	It is controlled by the axis parameter.

	Output
	------
	
	Iterable[list[(reweighted)_sparse_signed_measure of degree]]
	"""
	def __init__(self, 
			filtrations_weights:Iterable[float]=None,
			normalize=False,
			num_parameters:int|None=None,
			plot:bool=False,
			n_jobs:int=1, 
			unsparse:bool=False,
			axis:int=None,
			resolution:int|Iterable[int]=50,
			flatten:bool=False,
		):
		super().__init__()
		self.filtrations_weights = filtrations_weights
		self.num_parameters = num_parameters
		self.plot=plot
		self._grid =None
		self._old_shape = None
		self.n_jobs = n_jobs
		self.unsparse = unsparse
		self.axis=axis
		self._is_input_sparse=None
		self.resolution:int=resolution
		self._filtrations_bounds=None
		self.flatten=flatten
		self.normalize=normalize
		self._normalization_factors=None
		return
	def fit(self, X, y=None):
		## Gets a grid. This will be the max in each coord+1
		if len(X) == 0 or len(X[0]) == 0 or (self.axis is not None and len(X[0][0][0]) == 0):	return self
		
		self._is_input_sparse = (isinstance(X[0][0], tuple) and self.axis is None) or (isinstance(X[0][0][0], tuple) and self.axis is not None)
		# print("Sparse input : ", self._is_input_sparse)
		if self.axis is None:
			self.num_parameters = X[0][0][0].shape[1] if self._is_input_sparse else X[0][0].ndim
		else:
			#  (data) x (axis) x (degree) x (signed measure)
			self.num_parameters = X[0][0][0][0].shape[1] if self._is_input_sparse else X[0][0][0].ndim
		# Sets weights to 1 if None
		if self.filtrations_weights is None:
			self.filtrations_weights = np.array([1]*self.num_parameters)
		
		# resolution is iterable over the parameters
		try:
			float(self.resolution)
			self.resolution = [self.resolution]*self.num_parameters
		except:
			None
		assert len(self.filtrations_weights) == self.num_parameters == len(self.resolution), f"Number of parameter is not consistent. Inferred : {self.num_parameters}, Filtration weigths : {len(self.filtrations_weights)}, Resolutions : {len(self.resolution)}."
		# if not sparse : not recommended. 
		assert np.all(1 == np.asarray(self.filtrations_weights)) or self._is_input_sparse, f"Use sparse signed measure to rescale. Recieved weights {self.filtrations_weights}"
		# if not self._is_input_sparse:
			# self._old_shape = X[0][0].shape if self.axis is None else X[0][self.axis][0].shape # assume that every degree has a similarly shaped matrix
			# if self.unsparse:
			# 	self._grid = [np.linspace(start = 0, stop=w*s, num=s, dtype=int) for w,s in zip(self.filtrations_weights, self._old_shape)] # Enforces weights to be greater than the old shape
			# else:
			# 	self._grid = [np.linspace(start = 0, stop=w, num=s) for w,s in zip(self.filtrations_weights, self._old_shape)]
			# return self
		
		if self.unsparse and self._is_input_sparse or self.normalize:
			if self.axis is None:
				stuff = [np.concatenate([sm[d][0] for sm in X], axis=0) for d in range(len(X[0]))]
				sizes_ = np.array([len(x)>0 for x in stuff])
				assert np.all(sizes_), f"Axis {not np.where(sizes_)} are trivial !"
				self._filtrations_bounds = np.asarray([[f.min(axis=0), f.max(axis=0)] for f in stuff])
			else:
				stuff = [np.concatenate([sm[self.axis][d][0] for sm in X], axis=0) for d in range(len(X[0][0]))]
				self._filtrations_bounds = np.asarray([[f.min(axis=0), f.max(axis=0)] for f in stuff])
			self._normalization_factors = self._filtrations_bounds[:,1] - self._filtrations_bounds[:,0] if self.normalize else None
			# print("Normalization factors : ",self._normalization_factors)
			if np.any(self._normalization_factors == 0 ):
				indices = np.where(self._normalization_factors == 0)
				# warn(f"Constant filtration encountered, at degree, parameter {indices} and axis {self.axis}.")
				self._normalization_factors[indices] = 1
		# assert self._is_input_sparse or not self.unsparse, "Cannot unsparse an already sparse matrix."
		
		# print(X[0])
		return self
	
	# def unsparse_signed_measure(self, sparse_signed_measure:Iterable[tuple[np.ndarray, np.ndarray]]):
	# 	from torch import sparse_coo_tensor
	# 	out = []
	# 	for pts, weights in sparse_signed_measure:
	# 		tensor = sparse_coo_tensor(indices=pts.T, values=weights, size=self.resolution).to_dense()
	# 		out.append(np.asarray(tensor, dtype=weights.dtype))
	# 	return out
	def unsparse_signed_measure(self, sparse_signed_measure:Iterable[tuple[np.ndarray, np.ndarray]]):
		filtrations = [np.linspace(start=a, stop=b, num=r) for (a,b),r in zip(self._filtrations_bounds, self.resolution)]
		# print(filtrations) #####
		out = []
		# print(sparse_signed_measure)
		for (pts, weights), filtration in zip(sparse_signed_measure, filtrations): # over degree
			signed_measure,_ = np.histogramdd(
				pts,bins=filtration.T, 
				weights=weights
				)
			if self.flatten:	signed_measure = signed_measure.flatten()
			out.append(signed_measure)
		if self.flatten:	out = np.concatenate(out).flatten()
		return out

	def transform(self,X):
		def todo_from_not_sparse(signed_measure:Iterable[np.ndarray]):
			if not self.flatten:
				return signed_measure
			return np.asarray([sm.flatten() for sm in signed_measure]).flatten()

		def todo_from_sparse(sparse_signed_measure:Iterable[tuple[np.ndarray, np.ndarray]]):
			out = rescale_sparse_signed_measure(sparse_signed_measure, filtration_weights=self.filtrations_weights, normalize_scales = self._normalization_factors)
			return out
			
		if self._is_input_sparse:
			todo = todo_from_sparse
		else:
			todo = todo_from_not_sparse
		
		if self.axis is None:
			it = X
		else:
			it = (x[self.axis] for x in X)
		out = Parallel(n_jobs=self.n_jobs, prefer="threads")(delayed(todo)(x) for x in it)

		if self.unsparse and self._is_input_sparse:
			# assert out[0][0][0].dtype is int, f"Can only unsparse coordinate values of signed measure ! Found {out[0][0][0].dtype}"
			out = [self.unsparse_signed_measure(x) for x in out]
			# print("Unsparse")
		# print(out[0][0].shape,np.abs(out[0][0]).max())
		return out










class SignedMeasure2Img(BaseEstimator,TransformerMixin):
	"""
	Turns a signed measure into an image

	Input
	-----
	
	(data) x (degree) x (signed measure)

	Parameters
	----------
	 - filtration_grid : Iterable[array] For each filtration, the filtration values on which to evaluate the grid
	 - resolution : int or (num_parameter) : If filtration grid is not given, will infer a grid, with this resolution
	 - infer_grid_strategy : the strategy to generate the grid. Available ones are regular, quantile, exact
	 - flatten : if true, the output will be flattened
	 -

	Output
	------
	
	(data) x (concatenation of imgs of degree)
	"""
	def __init__(self, 
	      filtration_grid:Iterable[np.ndarray]=None, 
		  kernel="gaussian", 
	      bandwidth:float|Iterable[float]=1., 
		  flatten:bool=False, n_jobs:int=1,
		  resolution:int=None, 
		  infer_grid_strategy:str="exact",
		  progress:bool=False, 
		  _old_implementation=True,
		  **kwargs):
		super().__init__()
		self.kernel=kernel
		self.bandwidth=bandwidth
		self.more_kde_kwargs=kwargs
		self.filtration_grid=filtration_grid
		self.flatten=flatten
		self.progress=progress
		self.n_jobs = n_jobs
		self.resolution = resolution
		self.infer_grid_strategy = infer_grid_strategy
		self._is_input_sparse = None
		self._refit = filtration_grid is None
		self._input_resolution=None
		self._bandwidths=None
		self.diameter=None
		self._old_implementation=_old_implementation
		return
	def fit(self, X, y=None):
		## Infers if the input is sparse given X 
		if len(X) == 0: return self
		if isinstance(X[0][0], tuple):	self._is_input_sparse = True 
		else: self._is_input_sparse = False
		# print(f"IMG output is set to {'sparse' if self.sparse else 'matrix'}")
		if not self._is_input_sparse:
			self._input_resolution = X[0][0].shape
			try:
				float(self.bandwidth)
				b = float(self.bandwidth)
				self._bandwidths = [b if b > 0 else -b * s for s in self._input_resolution]
			except:
				self._bandwidths = [b if b > 0 else -b * s for s,b in zip(self._input_resolution, self.bandwidth)]
			return self # in that case, singed measures are matrices, and the grid is already given
		
		if self.filtration_grid is None and self.resolution is None:
			raise Exception("Cannot infer filtration grid. Provide either a filtration grid or a resolution.")
		## If not sparse : a grid has to be defined
		if self._refit:
			# print("Fitting a grid...", end="")
			pts = np.concatenate([
				sm[0] for signed_measures in X for sm in signed_measures
			])
			self.filtration_grid = infer_grid_from_points(pts, strategy=self.infer_grid_strategy, num=self.resolution)
			# print('Done.')
		if self.filtration_grid is not None: self.diameter=np.linalg.norm([f.max() - f.min() for f in self.filtration_grid])
		return self
	
	def _sparsify(self,sm):
		return tensor_möbius_inversion(input=sm,grid_conversion=self.filtration_grid)

	def _sm2smi(self, signed_measures:Iterable[np.ndarray]):
			# print(self._input_resolution, self.bandwidths, _bandwidths)
		return np.concatenate([
				gaussian_filter(input=signed_measure, sigma=self._bandwidths,mode="constant", cval=0)
			for signed_measure in signed_measures], axis=0)
	# def _sm2smi_sparse(self, signed_measures:Iterable[tuple[np.ndarray]]):
	# 	return np.concatenate([
	# 			_pts_convolution_sparse(
	# 				pts = signed_measure_pts, pts_weights = signed_measure_weights,
	# 				filtration_grid = self.filtration_grid, 
	# 				kernel=self.kernel,
	# 				bandwidth=self.bandwidths,
	# 				**self.more_kde_kwargs
	# 			)
	# 		for signed_measure_pts, signed_measure_weights  in signed_measures], axis=0)
	def _transform_from_sparse(self,X):
		bandwidth = self.bandwidth if self.bandwidth > 0 else -self.bandwidth * self.diameter
		return convolution_signed_measures(X, filtrations=self.filtration_grid, bandwidth=bandwidth, flatten=self.flatten, n_jobs=self.n_jobs, old_implementation=self._old_implementation)
	def transform(self,X):
		if self._is_input_sparse is None:	raise Exception("Fit first")
		if self._is_input_sparse:
			return self._transform_from_sparse(X)
		# print("Image from non-sparse")
		todo = SignedMeasure2Img._sm2smi
		out =  Parallel(n_jobs=self.n_jobs)(delayed(todo)(self, signed_measures) for signed_measures in tqdm(X, desc="Computing images", disable = not self.progress))
		if self.flatten:	out = [x.flatten() for x in out]
		# if not self._is_input_sparse and self.flatten:
		# 	out = [x.flatten() for x in out]
		# elif self._is_input_sparse and not self.flatten:
		# 	grid_shape = [len(f) for f in self.filtration_grid]
		# 	out = [x.reshape(grid_shape) for x in out]
		return out



class SignedMeasure2SlicedWassersteinDistance(BaseEstimator,TransformerMixin):
	"""
	Transformer from signed measure to distance matrix.
	
	Input
	-----
	
	(data) x (degree) x (signed measure)

	Format
	------
	- a signed measure : tuple of array. (point position) : npts x (num_paramters) and weigths : npts
	- each data is a list of signed measure (for e.g. multiple degrees)

	Output
	------
	- (degree) x (distance matrix)
	"""
	def __init__(self, n_jobs:int=1, num_directions:int=10, _sliced:bool=True, epsilon=-1, ground_norm=1, progress = False, grid_reconversion=None, scales=None):
		super().__init__()
		self.n_jobs=n_jobs
		self._SWD_list = None
		self._sliced=_sliced
		self.epsilon = epsilon
		self.ground_norm = ground_norm
		self.num_directions = num_directions
		self.progress = progress
		self.grid_reconversion=grid_reconversion
		self.scales=scales
		return
		
	def fit(self, X, y=None):
		# _DISTANCE = lambda : SlicedWassersteinDistance(num_directions=self.num_directions) if self._sliced else WassersteinDistance(epsilon=self.epsilon, ground_norm=self.ground_norm) # WARNING if _sliced is false, this distance is not CNSD
		if len(X) == 0:	return self
		self.sparse = isinstance(X[0][0], tuple)
		num_degrees = len(X[0])
		self._SWD_list = [
			SlicedWassersteinDistance(num_directions=self.num_directions, n_jobs=self.n_jobs, scales=self.scales) 
			if self._sliced else 
			WassersteinDistance(epsilon=self.epsilon, ground_norm=self.ground_norm, n_jobs=self.n_jobs) 
			for _ in range(num_degrees)
		]
		for degree, swd in enumerate(self._SWD_list):
			signed_measures_of_degree = [x[degree] for x in X]
			if not self.sparse:	signed_measures_of_degree = [tensor_möbius_inversion(tensor=sm, grid_conversion=self.grid_reconversion) for sm in signed_measures_of_degree]
			swd.fit(signed_measures_of_degree)
		return self
	def transform(self,X):
		assert self._SWD_list is not None, "Fit first"
		out = []
		for degree, swd in tqdm(enumerate(self._SWD_list), desc="Computing distance matrices", total=len(self._SWD_list), disable= not self.progress):
			signed_measures_of_degree = [x[degree] for x in X]
			if not self.sparse:	signed_measures_of_degree = [tensor_möbius_inversion(tensor=sm, grid_conversion=self.grid_reconversion) for sm in signed_measures_of_degree]
			out.append(swd.transform(signed_measures_of_degree))
		return np.asarray(out)
	def predict(self, X): 
		return self.transform(X)


class SignedMeasures2SlicedWassersteinDistances(BaseEstimator,TransformerMixin):
	"""
	Transformer from signed measure to distance matrix.
	Input
	-----
	(data) x opt (axis) x (degree) x (signed measure)
	
	Format
	------
	- a signed measure : tuple of array. (point position) : npts x (num_paramters) and weigths : npts
	- each data is a list of signed measure (for e.g. multiple degrees)

	Output
	------
	- (axis) x (degree) x (distance matrix)
	"""
	def __init__(self, progress=False, n_jobs:int=1, scales:Iterable[Iterable[float]]|None = None, **kwargs): # same init
		self._init_child = SignedMeasure2SlicedWassersteinDistance(progress=False, scales=None,n_jobs=-1, **kwargs)
		self._axe_iterator=None
		self._childs_to_fit=None
		self.scales = scales
		self.progress = progress
		self.n_jobs=n_jobs
		return
		
	def fit(self, X, y=None):
		from sklearn.base import clone
		if len(X) == 0:	 return self
		if isinstance(X[0][0],tuple): # Meaning that there are no axes
			self._axe_iterator = [slice(None)]
		else:
			self._axe_iterator = range(len(X[0]))
		if self.scales is None: 
			self.scales = [None]
		else:
			self.scales = np.asarray(self.scales)
			if self.scales.ndim == 1:	
				self.scales = np.asarray([self.scales])
		assert self.scales[0] is None or self.scales.ndim==2, "Scales have to be either None or a list of scales !"
		self._childs_to_fit = [
			clone(self._init_child).set_params(scales=scales).fit(
				[x[axis] for x in X]) 
				for axis, scales in product(self._axe_iterator, self.scales)
			]
		print("New axes : ", list(product(self._axe_iterator, self.scales)))
		return self
	def transform(self,X):
		return Parallel(n_jobs=self.n_jobs//2 +1,)(
			delayed(self._childs_to_fit[child_id].transform)([x[axis] for x in X])
				for child_id, (axis, _) in tqdm(enumerate(product(self._axe_iterator, self.scales)), 
					desc=f"Computing distances matrices of axis, and scales", disable=not self.progress, total=len(self._childs_to_fit)
				) 
		)
		# [
		# 		child.transform([x[axis // len(self.scales)] for x in X]) 
		# 		for axis, child in tqdm(enumerate(self._childs_to_fit), 
		# 			desc=f"Computing distances of axis", disable=not self.progress, total=len(self._childs_to_fit)
		# 		)
		# 	]




def accuracy_to_csv(X,Y,cl, cln:str, k:float=10, dataset:str = "", filtration:str = "", shuffle=True,  verbose:bool=True, **kwargs):
	import pandas as pd
	assert k > 0, "k is either the number of kfold > 1 or the test size > 0."
	if k>1:
		k = int(k)
		from sklearn.model_selection import StratifiedKFold as KFold
		kfold = KFold(k, shuffle=shuffle).split(X,Y)
		accuracies = np.zeros(k)
		for i,(train_idx, test_idx) in enumerate(tqdm(kfold, total=k, desc="Computing kfold")):
			xtrain = [X[i] for i in train_idx]
			ytrain = [Y[i] for i in train_idx]
			cl.fit(xtrain, ytrain)
			xtest = [X[i] for i in test_idx]
			ytest = [Y[i] for i in test_idx] 
			accuracies[i] = cl.score(xtest, ytest)
			if verbose:
				print(f"step {i+1}, {dataset} : {accuracies[i]}", flush=True)
				try:
					print("Best classification parameters : ", cl.best_params_)
				except:
					None
			
	elif k > 0:
		from sklearn.model_selection import train_test_split
		print("Computing accuracy, with train test split", flush=True)
		xtrain, xtest, ytrain, ytest = train_test_split(X, Y, shuffle=shuffle, test_size=k)
		print("Fitting...", end="", flush=True)
		cl.fit(xtrain, ytrain)
		print("Computing score...", end="", flush=True)
		accuracies = cl.score(xtest, ytest)
		try:
			print("Best classification parameters : ", cl.best_params_)
		except:
			None
		print("Done.")
		if verbose:	print(f"Accuracy {dataset} : {accuracies} ")
	file_path:str = f"result_{dataset}.csv".replace("/", "_").replace(".off", "")
	columns:list[str] = ["dataset", "filtration", "pipeline", "cv", "mean", "std"]
	if exists(file_path):
		df:pd.DataFrame = pd.read_csv(file_path)
	else:
		df:pd.DataFrame = pd.DataFrame(columns= columns)
	more_names = []
	more_values = []
	for key, value in kwargs.items():
		if key not in columns:
			more_names.append(key)
			more_values.append(value)
		else:
			warn(f"Duplicate key {key} ! with values {cln} and {value}")
	new_line:pd.DataFrame = pd.DataFrame([[dataset, filtration, cln, k, np.mean(accuracies), np.std(accuracies)]+more_values], columns = columns+more_names)
	df = pd.concat([df, new_line])
	df.to_csv(file_path, index=False)
