
import numpy as np
import gudhi as gd
import multipers as mp
from tqdm import tqdm
from itertools import product
from sklearn.neighbors import KernelDensity
from sklearn.base import BaseEstimator, TransformerMixin
from warnings import warn
from .signed_betti import *
from .invariants_with_persistable import *
from .sliced_wasserstein import *
from types import FunctionType
from joblib import Parallel, delayed, cpu_count
from os.path import exists
from typing import Iterable
from torch import Tensor
import pandas as pd
from warnings import warn
import matplotlib.pyplot as plt
import MDAnalysis as mda
from scipy.spatial import distance_matrix
from scipy.ndimage import gaussian_filter
from MDAnalysis.topology.guessers import guess_masses
from sklearn.preprocessing import LabelEncoder

def get_simplex_tree_from_delayed(x):
	f,args, kwargs = x
	return f(*args,**kwargs)

def get_simplextree(x)->mp.SimplexTreeMulti:
	if isinstance(x, mp.SimplexTreeMulti):
		return x
	if len(x) == 3 and isinstance(x[0],FunctionType):
		return get_simplex_tree_from_delayed(x)
	else:
		raise TypeError("Not a valid SimplexTree !")
	return

def infer_grid_from_points(pts:Iterable[np.ndarray], num:int, strategy:str):
	if strategy =="regular":
		min = np.min(pts, axis=0)
		max = np.max(pts, axis=0)
		return np.linspace(min, max, num=num).T
	if strategy =="quantile":
		return np.quantile(pts, q=np.linspace(0,1,num), axis=0).T

	raise Exception(f"Grid strategy {strategy} not implemented")

def get_filtration_weights_grid(num_parameters:int=2, resolution:int|Iterable[int]=3,*, min:float=0, max:float=20, dtype=float):
	"""
	Provides a grid of weights, for filtration rescaling.
	 - num parameter : the dimension of the grid tensor
	 - resolution :  the size of each coordinate
	 - min : minimum weight
	 - max : maximum weight
	 - dtype : the type of the grid values (usefull for int weights)
	"""
	from itertools import product
	if isinstance(resolution, int):
		resolution = [resolution]*num_parameters
	out = np.asarray(list(product(*([np.linspace(start=min,stop=max,num=r, dtype=dtype) for r in resolution]))))
	_, indices = np.unique([x / x.max() for x in out if x.max() != 0],axis=0, return_index=True)
	return list(out[indices])



################################################# Data2SimplexTree
class RipsDensity2SimplexTree(BaseEstimator, TransformerMixin):
	def __init__(self, bandwidth:float=-0.1, threshold:float=np.inf, 
			sparse:float|None=None, num_collapse:int=0, 
			num_parameters:int=2, kernel:str="gaussian", delayed=False, rescale_density:float=0,
			progress:bool=False, n_jobs:int=-1, rtol:float=1e-4, atol=1e-6
		) -> None:
		super().__init__()
		self.bandwidth=bandwidth
		self.threshold = threshold
		self.sparse=sparse
		self.num_collapse=num_collapse
		self.num_parameters = num_parameters
		self.kernel = kernel
		self.delayed=delayed
		self.rescale_density = rescale_density
		self.progress=progress
		self._bandwidth=None
		self._threshold=None
		self.n_jobs = n_jobs
		self.rtol=rtol
		self.atol=atol
		self._scale=None
		return
	def _get_scale(self, X):
		if self.progress: print("Estimating scale...", flush=True, end="")
		indices = np.random.choice(len(X),int(0.3*len(X))+1 ,replace=False)
		self._scale=np.max([distance_matrix(x,x).max() for x in (X[i] for i in indices)])
		if self.progress: print(f"Done. {diameters}", flush=True)
		return self._scale
	def _get_st(self,x, bandwidth=None)->mp.SimplexTreeMulti:
		bandwidth = self._bandwidth if bandwidth is None else bandwidth
		kde=KernelDensity(bandwidth=bandwidth, kernel=self.kernel, rtol=self.rtol, atol=self.atol)
		st = gd.RipsComplex(points = x, max_edge_length=self._threshold, sparse=self.sparse).create_simplex_tree(max_dimension=1)
		st = mp.SimplexTreeMulti(st, num_parameters = self.num_parameters)
		kde.fit(x)
		codensity = -kde.score_samples(x)
		if self.rescale_density != 0: # Not safe...
			codensity -= codensity.min()
			if codensity.max() != 0:	codensity /= codensity.max()
			codensity *= self.rescale_density
		st.fill_lowerstar(codensity, parameter = 1)
		st.collapse_edges(num=self.num_collapse)
		st.collapse_edges(num=self.num_collapse, strong = False, max_dimension = 1) 
		return st
	def fit(self, X:np.ndarray|list, y=None):
		## default value 0.1 * diameter # TODO rescale density
		if self.bandwidth < 0 or self.threshold < 0:
			self._get_scale(X)
		self._bandwidth = - self.bandwidth * self._scale if self.bandwidth < 0 else self.bandwidth
		self._threshold = - self.threshold * self._scale if self.threshold < 0 else self.threshold
		# self.bandwidth = "silverman" ## not good, as is can make bandwidth not constant
		return self

	
	def transform(self,X):
		with tqdm(X, desc="Computing simplextrees", disable= not self.progress and self.delayed) as data:
			if self.delayed:
				return [delayed(self._get_st)(x) for x in data] # delay the computation for the to_module pipe, as simplextrees are not pickle-able.
			return Parallel(n_jobs=self.n_jobs, prefer="threads")(delayed(self._get_st)(x) for x in data) # not picklable so prefer threads is necessary.

		
		
		
class RipsDensity2SimplexTrees(RipsDensity2SimplexTree):
	"""
	Same as the pipeline without the 's', but computes multiple bandwidths at once. 
	transform shape :
	(data) x (bandwidth) x (simplextree)
	"""
	def __init__(self, bandwidths:Iterable[float]=-0.1, **rips_density_arguments) -> None:
		super().__init__(**rips_density_arguments)
		self.bandwidths=bandwidths
		self._bandwidths=None
		return
	def fit(self, X:np.ndarray|list, y=None):
		## default value 0.1 * diameter # TODO rescale density
		if  np.any(np.array(self.bandwidths) < 0) or self.threshold < 0:
			self._get_scale(X)
		self._bandwidths = [- b * self._scale if b < 0 else b for b in self.bandwidths]
		self._threshold = - self.threshold * self._scale if self.threshold < 0 else self.threshold
		return self

	def _get_sts(self, x, bandwidths=None):
		bandwidths = self._bandwidths if bandwidths is None else bandwidths
		return [self._get_st(x, bandwidth=bandwidth) for bandwidth in bandwidths]
	def transform(self,X):
		with tqdm(X, desc="Computing simplextrees", disable= not self.progress and self.delayed) as data:
			if self.delayed:
				return [[delayed(self._get_st)(x, bandwidth=bandwidth) for bandwidth in self._bandwidths] for x in data] # delay the computation for the to_module pipe, as simplextrees are not pickle-able.
			return Parallel(n_jobs=self.n_jobs, prefer="threads")(delayed(self._get_sts)(x) for x in data) # not picklable so prefer threads is necessary.


		
		
class SimplexTreeEdgeCollapser(BaseEstimator, TransformerMixin):
	def __init__(self, num_collapses:int=0, full:bool=False, max_dimension:int|None=None) -> None:
		super().__init__()
		self.full=full
		self.num_collapses=num_collapses
		self.max_dimension=max_dimension
		return
	def fit(self, X:np.ndarray|list, y=None):
		return self
	def transform(self,X):
		for i in range(len(X)):
			X[i].collapse_edges(full=self.full, num=self.num_collapses)
		return X

#### MOLECULE DATA
# def _lines2bonds(_lines):
# 	out = []
# 	index = 0
# 	while index < len(_lines) and  _lines[index].strip() != "@<TRIPOS>BOND":
# 		index += 1
# 	index += 1
# 	while index < len(_lines) and  _lines[index].strip()[0] != "@":
# 		line = _lines[index].strip().split(" ")
# 		for j,truc in enumerate(line):
# 			line[j] = truc.strip()
# 		# try:
# 		out.append([int(stuff) for stuff in line if len(stuff) > 0])
# 		# except:
# 		# 	print_lin
# 		index +=1
# 	out = pd.DataFrame(out, columns=["bond_id","atom1", "atom2", "bond_type"])
# 	out.set_index(["bond_id"],inplace=True)
# 	return out
# def _get_mol2_file(path:str, num_cols:int=9, columns:dict|None=None):
# 	from biopandas.mol2 import split_multimol2,PandasMol2
# 	columns={
# 		0:('atom_id', int), 
# 		1:('atom_name', str),
# 		2:('x', float), 
# 		3:('y', float), 
# 		4:('z', float), 
# 		5:('atom_type', str), 
# 		6:('subst_id', int), 
# 		7:('subst_name', str), 
# 		8:('charge', float)
# 	} if columns is None else columns
# 	while len(columns) > num_cols:
# 		columns.pop(len(columns)-1)
# 	# try:
# 	molecules_dfs = []
# 	bonds_dfs = []
# 	for molecule in split_multimol2(path):
# 		_code, _lines = molecule
# 		try:
# 			bonds_dfs.append(_lines2bonds(_lines))
# 			molecule_df = PandasMol2().read_mol2_from_list(mol2_lines=_lines, mol2_code=_code, columns=columns).df
# 		except:
# 			print(_code)
# 			print(_lines)
# 		molecule_df.set_index(["atom_id"], inplace=True)
# 		molecules_dfs.append(molecule_df)        
# 	# except:
# 	#     return get_mol2_file(path=path, num_cols=num_cols-1)
# 	return molecules_dfs, bonds_dfs
# def _atom_to_mass(atom)->int:
# 	return ELEMENTS[atom].mass
# 	raise Exception(f" Atom {atom} has no registered mass.")
def lines2bonds(path:str):
	_lines = open(path, "r").readlines()
	out = []
	index = 0
	while index < len(_lines) and  _lines[index].strip() != "@<TRIPOS>BOND":
		index += 1
	index += 1
	while index < len(_lines) and  _lines[index].strip()[0] != "@":
		line = _lines[index].strip().split(" ")
		for j,truc in enumerate(line):
			line[j] = truc.strip()
		# try:
		out.append([stuff for stuff in line if len(stuff) > 0])
		# except:
		# 	print_lin
		index +=1
	out = pd.DataFrame(out, columns=["bond_id","atom1", "atom2", "bond_type"])
	out.set_index(["bond_id"],inplace=True)
	return out
def _mol2st(path:str|mda.Universe, bond_length:bool = False, charge:bool=False, atomic_mass:bool=False, bond_type=False, **kwargs):
	molecule = path if isinstance(path, mda.Universe) else mda.Universe(path, format="MOL2") 
	# if isinstance(bonds_df, list):	
	# 	if len(bonds_df) > 1:	warn("Multiple molecule found in the same data ! Taking the first only.")
	# 	molecule_df = molecule_df[0]
	# 	bonds_df = bonds_df[0]
	num_filtrations = bond_length + charge + atomic_mass + bond_type
	nodes = molecule.atoms.indices.reshape(1,-1)
	edges = molecule.bonds.dump_contents().T
	num_vertices = nodes.shape[1]
	num_edges =edges.shape[1]
	
	st = mp.SimplexTreeMulti(num_parameters = num_filtrations)
	
	## Edges filtration
	# edges = np.array(bonds_df[["atom1", "atom2"]]).T
	edges_filtration = np.zeros((num_edges, num_filtrations), dtype=np.float32) - np.inf
	if bond_length:
		bond_lengths = molecule.bonds.bonds()
		edges_filtration[:,0] = bond_lengths
	if bond_type:
		if isinstance(path, mda.Universe):
			TypeError("Expected path as input to compute bounds type. MDA doesn't handle it.")
		bond_types = LabelEncoder().fit([0,1,2,3,"am","ar"]).transform(lines2bonds(path=path)["bond_type"])
		edges_filtration[:,int(bond_length)] = bond_types

	## Nodes filtration
	nodes_filtrations = np.zeros((num_vertices,num_filtrations), dtype=np.float32) -np.min(edges_filtration, axis=0) # better than - np.inf
	st.insert_batch(nodes, nodes_filtrations)

	st.insert_batch(edges, edges_filtration)
	if charge:
		charges = molecule.atoms.charges
		st.fill_lowerstar(charges, parameter=int(bond_length + bond_type))
		# raise Exception("TODO")
	if atomic_mass:
		masses = molecule.atoms.masses
		null_indices = masses == 0
		if np.any(null_indices): # guess if necessary
			masses[null_indices] = guess_masses(molecule.atoms.types)[null_indices]
		st.fill_lowerstar(-masses, parameter=int(bond_length+bond_type+charge))
	st.make_filtration_non_decreasing()
	return st
class Molecule2SimplexTree(BaseEstimator, TransformerMixin):
	def __init__(self, atom_columns:Iterable[str]|None=None, 
			atom_num_columns:int=9, max_dimension:int|None=None, delayed:bool=False, 
			progress:bool=False, 
			bond_length_filtration:bool=False,
			bond_type_filtration:bool=False,
			charge_filtration:bool=False, 
			atomic_mass_filtration:bool=False, 
			n_jobs:int=1) -> None:
		super().__init__()
		self.max_dimension=max_dimension
		self.delayed=delayed
		self.progress=progress
		self.bond_length_filtration = bond_length_filtration
		self.charge_filtration = charge_filtration
		self.atomic_mass_filtration = atomic_mass_filtration
		self.bond_type_filtration = bond_type_filtration
		self.n_jobs = n_jobs
		self.atom_columns = atom_columns
		self.atom_num_columns = atom_num_columns
		self.num_parameters = self.charge_filtration + self.bond_length_filtration + self.bond_type_filtration + self.atomic_mass_filtration
		return
	def fit(self, X:Iterable[str], y=None):
		if len(X) == 0:	return self
		return self
	def transform(self,X:Iterable[str]):
		def to_simplex_tree(path_to_mol2_file:str):
			simplex_tree = _mol2st(path=path_to_mol2_file, 
				bond_type=self.bond_type_filtration,
				bond_length=self.bond_length_filtration,
				charge=self.charge_filtration,
				atomic_mass=self.atomic_mass_filtration,
			)
			return simplex_tree
		if self.delayed:
			return [delayed(to_simplex_tree)(path) for path in X]
		return Parallel(n_jobs=self.n_jobs, prefer="threads")(delayed(to_simplex_tree)(path) for path in X)
		
############################################### Data2Signedmeasure

def tensor_möbius_inversion(tensor:Tensor|np.ndarray, grid_conversion:Iterable[np.ndarray]|None = None, plot:bool=False, raw:bool=False, num_parameters:int|None=None):
	betti_sparse = Tensor(tensor.copy()).to_sparse() # Copy necessary in some cases :(
	num_indices, num_pts = betti_sparse.indices().shape
	num_parameters = num_indices if num_parameters is None else num_parameters
	if num_indices == num_parameters: # either hilbert or rank invariant
		rank_invariant = False
	elif 2*num_parameters == num_indices:
		rank_invariant = True
	else:
		raise TypeError(f"Unsupported betti shape. {num_indices} has to be either {num_parameters} or {2*num_parameters}.")
	points_filtration = np.asarray(betti_sparse.indices().T, dtype=int)
	weights = np.asarray(betti_sparse.values(), dtype=int)

	if grid_conversion is not None:
		coords = np.empty(shape=(num_pts,num_indices), dtype=float)
		for i in range(num_indices):
			coords[:,i] = grid_conversion[i%num_parameters][points_filtration[:,i]]
	else:
		coords = points_filtration
	if (not rank_invariant) and plot:
		plt.figure()
		plt.scatter(points_filtration[:,0],points_filtration[:,1], c=weights)
		plt.colorbar()
	if (not rank_invariant) or raw: return coords, weights
	def _is_trivial(rectangle:np.array):
		birth=rectangle[:num_parameters]
		death=rectangle[num_parameters:]
		return np.all(birth<=death) # and not np.array_equal(birth,death)
	correct_indices = np.array([_is_trivial(rectangle) for rectangle in coords])
	if len(correct_indices) == 0:	return np.empty((0, num_indices)), np.empty((0))
	signed_measure = np.asarray(coords[correct_indices])
	weights = weights[correct_indices]
	if plot:
		assert signed_measure.shape[1] == 4 # plot only the rank decompo for the moment
		from matplotlib.pyplot import plot
		def _plot_rectangle(rectangle:np.ndarray, weight:float):
			x_axis=rectangle[[0,2]]
			y_axis=rectangle[[1,3]]
			color = "blue" if weight > 0 else "red"
			plot(x_axis, y_axis, c=color)
		for rectangle, weight in zip(signed_measure, weights):
			_plot_rectangle(rectangle=rectangle, weight=weight)
	return signed_measure, weights


class DegreeRips2SignedMeasure(BaseEstimator, TransformerMixin):
	def __init__(self, degrees:Iterable[int], min_rips_value:float, max_rips_value,max_normalized_degree:float, min_normalized_degree:float, grid_granularity:int, progress:bool=False, n_jobs=1, sparse:bool=False, _möbius_inversion=True) -> None:
		super().__init__()
		self.min_rips_value = min_rips_value
		self.max_rips_value = max_rips_value
		self.min_normalized_degree = min_normalized_degree
		self.max_normalized_degree = max_normalized_degree
		self._max_rips_value = None
		self.grid_granularity = grid_granularity
		self.progress=progress
		self.n_jobs = n_jobs
		self.degrees = degrees
		self.sparse=sparse
		self._möbius_inversion = _möbius_inversion
		return
	def fit(self, X:np.ndarray|list, y=None):
		if self.max_rips_value < 0:
			print("Estimating scale...", flush=True, end="")
			indices = np.random.choice(len(X),int(0.3*len(X))+1 ,replace=False)
			diameters =np.max([distance_matrix(x,x).max() for x in (X[i] for i in indices)])
			print(f"Done. {diameters}", flush=True)
		self._max_rips_value = - self.max_rips_value * diameters if self.max_rips_value < 0 else self.max_rips_value
		return self
	
	def _transform1(self, data:np.ndarray):
		_distance_matrix = distance_matrix(data, data)
		signed_measures = []
		rips_values, normalized_degree_values, hilbert_functions, minimal_presentations = hf_degree_rips(
			_distance_matrix,
			min_rips_value = self.min_rips_value,
			max_rips_value = self._max_rips_value,
			min_normalized_degree = self.min_normalized_degree,
			max_normalized_degree = self.max_normalized_degree,
			grid_granularity = self.grid_granularity,
			max_homological_dimension = np.max(self.degrees),
		)
		for degree in self.degrees:
			hilbert_function = hilbert_functions[degree]
			signed_measure = signed_betti(hilbert_function, threshold=True) if self._möbius_inversion else hilbert_function
			if self.sparse:
				signed_measure = tensor_möbius_inversion(
					tensor=signed_measure,num_parameters=2,
					grid_conversion=[rips_values, normalized_degree_values]
				)

			signed_measures.append(signed_measure)
		return signed_measures
	def transform(self,X):
		return Parallel(n_jobs=self.n_jobs)(delayed(self._transform1)(data) 
		for data in tqdm(X, desc=f"Computing DegreeRips, of degrees {self.degrees}, signed measures.", disable = not self.progress))






################################################# SimplexTree2...

def _pts_convolution_sparse(pts:np.ndarray, pts_weights:np.ndarray, filtration_grid:Iterable[np.ndarray], kernel="gaussian", bandwidth=0.1, **more_kde_args):
	grid_iterator = np.asarray(list(product(*filtration_grid)))
	grid_shape = [len(f) for f in filtration_grid]
	if len(pts) == 0:
		# warn("Found a trivial signed measure !")
		return np.zeros(shape=grid_shape)
	kde = KernelDensity(kernel=kernel, bandwidth=bandwidth, rtol = 1e-4, **more_kde_args) # TODO : check rtol
	
	pos_indices = pts_weights>0
	neg_indices = pts_weights<0
	img_pos = kde.fit(pts[pos_indices], sample_weight=pts_weights[pos_indices]).score_samples(grid_iterator).reshape(grid_shape)
	img_neg = kde.fit(pts[neg_indices], sample_weight=-pts_weights[neg_indices]).score_samples(grid_iterator).reshape(grid_shape)
	return np.exp(img_pos) - np.exp(img_neg)



def _st2ranktensor(st:mp.SimplexTreeMulti, filtration_grid:np.ndarray, degree:int, plot:bool, reconvert_grid:bool, num_collapse:int|str=0):
	"""
	TODO
	"""
	## Copy (the squeeze change the filtration values)
	stcpy = mp.SimplexTreeMulti(st)
	# turns the simplextree into a coordinate simplex tree
	stcpy.grid_squeeze(
		filtration_grid = filtration_grid, 
		coordinate_values = True)
	# stcpy.collapse_edges(num=100, strong = True, ignore_warning=True)
	if num_collapse == "full":
		stcpy.collapse_edges(full=True, ignore_warning=True, max_dimension=degree+1)
	elif isinstance(num_collapse, int):
		stcpy.collapse_edges(num=num_collapse,ignore_warning=True, max_dimension=degree+1)
	else:
		raise TypeError(f"Invalid num_collapse={num_collapse} type. Either full, or an integer.")
	# computes the rank invariant tensor
	rank_tensor = mp.rank_invariant2d(stcpy, degree=degree, grid_shape=[len(f) for f in filtration_grid])
	# refactor this tensor into the rectangle decomposition of the signed betti
	grid_conversion = filtration_grid if reconvert_grid else None 
	rank_decomposition = rank_decomposition_by_rectangles(
		rank_tensor, threshold=True,
		)
	rectangle_decomposition = tensor_möbius_inversion(tensor = rank_decomposition, grid_conversion = grid_conversion, plot=plot, num_parameters=st.num_parameters)
	return rectangle_decomposition

class SimplexTree2RectangleDecomposition(BaseEstimator,TransformerMixin):
	"""
	Transformer. 2 parameter SimplexTrees to their respective rectangle decomposition. 
	"""
	def __init__(self, filtration_grid:np.ndarray, degrees:Iterable[int], plot=False, reconvert_grid=True, num_collapses:int=0):
		super().__init__()
		self.filtration_grid = filtration_grid
		self.degrees = degrees
		self.plot=plot
		self.reconvert_grid = reconvert_grid
		self.num_collapses=num_collapses
		return
	def fit(self, X, y=None):
		"""
		TODO : infer grid from multiple simplextrees
		"""
		return self
	def transform(self,X:Iterable[mp.SimplexTreeMulti]):
		rectangle_decompositions = [
			[_st2ranktensor(
				simplextree, filtration_grid=self.filtration_grid,
				degree=degree,
				plot=self.plot,
				reconvert_grid = self.reconvert_grid,
				num_collapse=self.num_collapses
			) for degree in self.degrees]
			for simplextree in X
		]
		## TODO : return iterator ?
		return rectangle_decompositions


def betti_matrix2signed_measure(betti:coo_array|np.ndarray, grid_conversion:Iterable[np.ndarray]|None = None):
	if isinstance(betti, np.ndarray):   betti = coo_array(betti)
	points_filtration = np.empty(shape=(betti.getnnz(),2), dtype=int) # coo matrix is only for matrices -> 2d
	points_filtration[:,0] = betti.row
	points_filtration[:,1] = betti.col
	weights = np.array(betti.data, dtype=int)
	if grid_conversion is not None:
		coords = np.empty(shape=(betti.getnnz(),2), dtype=float)
		for i in range(2):
			coords[:,i] = grid_conversion[i][points_filtration[:,i]]
	else:
		coords = points_filtration
	return coords, weights


class SimplexTree2SignedMeasure(BaseEstimator,TransformerMixin):
	"""
	Input : Iterable[SimplexTreeMulti]

	Output: Iterable[ list[signed_measure for degree] ]

	signed measure is either (points : (n x num_parameters) array, weights : (n) int array ) if sparse, else an integer matrix.
	"""
	def __init__(self, 
	      degrees:list[int], # homological degrees
		  filtration_grid:Iterable[np.ndarray]=None, # filtration values to consider. Format : [ filtration values of Fi for Fi:filtration values of parameter i] 
		  progress=False, # tqdm
		  num_collapses="full", # edge collapses before computing 
		  n_jobs=1, 
		  resolution:Iterable[int]=None, # when filtration grid is not given, the resolution of the filtration grid to infer
		  sparse=False, # sparse output
	      plot:bool=False, 
		  filtration_quantile:float=0, # quantile for inferring filtration grid
		  _möbius_inversion:bool=True, # wether or not to do the möbius inversion (not recommended to touch)
		  expand=True, # expand the simplextree befoe computing the homology
		  normalize_filtrations:bool=False,
		  ):
		super().__init__()
		self.degrees = degrees
		self.filtration_grid = filtration_grid
		self.progress = progress
		self.num_collapses=num_collapses
		self.n_jobs = cpu_count() if n_jobs <= 0 else n_jobs
		self.resolution = resolution
		self.plot=plot
		self.sparse=sparse
		self.filtration_quantile=filtration_quantile
		self.normalize_filtrations = normalize_filtrations # Will only work for non sparse output. (discrete matrices cannot be "rescaled")
		assert not self.normalize_filtrations or self.sparse, "Not able to normalize a matrix without losing information. Will not normalize."
		assert resolution is not None or filtration_grid is not None
		self._is_input_delayed = None
		self._möbius_inversion = _möbius_inversion
		self._reconversion_grid = None
		self.expand = expand
		self._refit_grid = filtration_grid is None # will only refit the grid if filtration_grid has never been given.
		return
	def fit(self, X, y=None): # Todo : infer filtration grid ? quantiles ?
		self._is_input_delayed = not isinstance(X[0], mp.SimplexTreeMulti)
		if isinstance(self.resolution, int):
			self.resolution = [self.resolution]*self._to_simplex_tree(X[0]).num_parameters
		
	def _refit(self,X):
		print("Inferring filtration grid from simplextrees...", end="", flush=True)
		indices = np.random.choice(len(X), int(0.3 * len(X)) +1, replace=False)
		get_filtration_bounds = lambda x : self._to_simplex_tree(x).filtration_bounds(q=self.filtration_quantile, remove_inf=True)
		prefer = "processes" if self._is_input_delayed else "threads"
		filtration_bounds =  Parallel(n_jobs=self.n_jobs, prefer=prefer)(delayed(get_filtration_bounds)(x) for x in (X[idx] for idx in indices))
		box = [np.min(filtration_bounds, axis=(0,1)), np.max(filtration_bounds, axis=(0,1))]
		diameter = np.max(box[1] - box[0])
		box = np.array([box[0] - 0.1* diameter, box[1] + 0.1 * diameter])
		self.filtration_grid = [np.linspace(*np.asarray(box)[:,i], num=self.resolution[i]) for i in range(len(box[0]))]
		print("Done.")
		# self._reconversion_grid = [np.linspace(0,1, num=len(f)) for f in self.filtration_grid] if self.normalize_filtrations else self.filtration_grid
		
	def fit(self, X, y=None): # Todo : infer filtration grid ? quantiles ?
		self._is_input_delayed = not isinstance(X[0], mp.SimplexTreeMulti)
		if isinstance(self.resolution, int):
			self.resolution = [self.resolution]*self._to_simplex_tree(X[0]).num_parameters
		if self._refit_grid:
			self._refit(X)
		self._reconversion_grid = [np.linspace(0,1, num=len(f)) for f in self.filtration_grid] if self.normalize_filtrations else self.filtration_grid
		return self
	def _to_simplex_tree(self,x):
		return get_simplex_tree_from_delayed(x) if self._is_input_delayed else  x
	def transform1(self, simplextree, filtration_grid=None, _reconversion_grid=None):
		if filtration_grid is None: filration_grid = self.filtration_grid
		if _reconversion_grid is None: filration_grid = self._reconversion_grid
		st = self._to_simplex_tree(simplextree)
		st = mp.SimplexTreeMulti(st, num_parameters = st.num_parameters) ## COPY
		st.grid_squeeze(filtration_grid = filtration_grid, coordinate_values = True)
		if st.num_parameters == 2:
			if self.num_collapses == "full":
				st.collapse_edges(full=True,max_dimension=1)
			elif isinstance(self.num_collapses, int):
				st.collapse_edges(num=self.num_collapses,max_dimension=1)
			else:
				raise Exception("Bad edge collapse type. either 'full' or an int.")
		signed_measures = []
		if self.expand :	st.expansion(np.max(self.degrees)+1)
		grid_shape = [len(f) for f in filtration_grid]
		for degree in self.degrees:
			hilbert = mp.hilbert(simplextree=st, degree=degree, grid_shape=grid_shape) # TODO : nd ?
			signed_measure = signed_betti(hilbert, threshold=True, sparse=False) if self._möbius_inversion else hilbert ## FOR BENCHMARK ONLY
			if self.sparse:
				signed_measure = tensor_möbius_inversion(tensor = signed_measure, 
				grid_conversion=_reconversion_grid, plot = self.plot, num_parameters=len(filtration_grid))
			signed_measures.append(signed_measure)
		return signed_measures
	def transform(self,X):
		assert self.filtration_grid is not None
		prefer = "processes" if self._is_input_delayed else "threads" #simplextrees are not plicklable
		return Parallel(n_jobs=self.n_jobs // 2 + 1, prefer=prefer)(
			delayed(self.transform1)(to_st) for to_st in tqdm(X, disable = not self.progress, desc="Computing Hilbert function")
		)
		# return [self.transform1(x) for x in tqdm(X, disable = not self.progress, desc="Computing Hilbert function")]

class SimplexTrees2SignedMeasures(SimplexTree2SignedMeasure):
	"""
	Input : (data) x (axis, e.g. different bandwidths for simplextrees) x (simplextree)
	Output : (data) x (axis) x (signed measure)
	"""
	def __init__(self,filtration_grids=None,**kwargs):
		super().__init__(**kwargs)
		self._refit_grid = self.filtration_grid is None  and filtration_grids is None# will only refit the grid if filtration_grid has never been given.
		self.filtration_grids = filtration_grids
		self._num_st_per_data=None
		return
	def _refit(self,X):
		print("Inferring filtration grid from simplextrees...", end="", flush=True)
		indices = np.random.choice(len(X), int(0.3 * len(X)) +1, replace=False)
		get_filtration_bounds_of_st = lambda x : self._to_simplex_tree(x).filtration_bounds(q=self.filtration_quantile, remove_inf=True)

		prefer = "processes" if self._is_input_delayed else "threads"

		get_filtration_bounds_of_axis = lambda axis : Parallel(n_jobs=self.n_jobs, prefer=prefer)(delayed(get_filtration_bounds_of_st)(x[axis]) for x in (X[idx] for idx in indices))
		filtration_bounds_of_axes = Parallel(n_jobs=self.n_jobs, prefer=prefer)(
			delayed(get_filtration_bounds_of_axis)(axis) for axis in range(self._num_st_per_data)
		)

		self.filtration_grids = []
		for axis, filtration_bounds in enumerate(filtration_bounds_of_axes):
			box = [np.min(filtration_bounds, axis=(0,1)), np.max(filtration_bounds, axis=(0,1))]
			diameter = np.max(box[1] - box[0])
			box = np.array([box[0] - 0.1* diameter, box[1] + 0.1 * diameter])
			self.filtration_grids.append([np.linspace(*np.asarray(box)[:,i], num=self.resolution[i]) for i in range(len(box[0]))])
		print("Done.")
	def fit(self, X, y=None): # Todo : infer filtration grid ? quantiles ?
		if len(X) == 0: return self
		self._num_st_per_data = len(X[0])
		if self.filtration_grids is None and self.filtration_grid is not None:
			self.filtration_grids = [self.filtration_grid]*self._num_st_per_data
		self._is_input_delayed = not isinstance(X[0][0], mp.SimplexTreeMulti)
		if isinstance(self.resolution, int):
			self.resolution = [self.resolution]*self._to_simplex_tree(X[0][0]).num_parameters
		if self._refit_grid:
			self._refit(X)
		self._reconversion_grid = [[np.linspace(0,1, num=len(f)) for f in self.filtration_grid]]*self._num_st_per_data if self.normalize_filtrations else self.filtration_grids
		return self
	def transformk(self, simplextrees):
		return Parallel(n_jobs=self.n_jobs // 2 + 1, prefer="threads")(delayed(self.transform1)(st, filtration_grid, renormalization_grid) for st, filtration_grid, renormalization_grid in zip(simplextrees, self.filtration_grids, self._reconversion_grid))
	def transform(self, X,y=None):
		assert self.filtration_grids is not None
		prefer = "processes" if self._is_input_delayed else "threads" #simplextrees are not plicklable
		return Parallel(n_jobs=self.n_jobs // 2 + 1, prefer=prefer)(
			delayed(self.transformk)(to_st) for to_st in tqdm(X, disable = not self.progress, desc="Computing Hilbert function")
		)
	

class SignedMeasureFormatter(BaseEstimator,TransformerMixin):
	"""
	Input :  (data) x (degree) x (signed measure) or (data) x (axis) x (degree) x (signed measure)
	Iterable[list[signed_measure_matrix of degree]] or Iterable[previous].
	
	The second is meant to use multiple choices for signed measure input. An example of usage : they come from a Rips + Density with different bandwidth. 
	It is controlled by the axis parameter.

	Output : Iterable[list[(reweighted)_sparse_signed_measure of degree]]
	"""
	def __init__(self, 
			filtrations_weights:Iterable[float]=None,
			num_parameters:int|None=None,
			plot:bool=False,
			n_jobs:int=1, 
			unsparse:bool=False,
			axis:int=None,
		):
		super().__init__()
		self.filtrations_weights = filtrations_weights
		self.num_parameters = num_parameters
		self.plot=plot
		self._grid =None
		self._old_shape = None
		self.n_jobs = n_jobs
		self.unsparse = unsparse
		self.axis=axis
		return
	def fit(self, X, y=None):
		## Gets a grid. This will be the max in each coord+1
		if len(X) == 0 or len(X[0]) == 0 or (self.axis is not None and len(X[0][0][0]) == 0):	return self
		self._old_shape = X[0][0].shape if self.axis is None else X[0][self.axis][0].shape # assume that every degree has a similarly shaped matrix
		if self.filtrations_weights is None:
			self.filtrations_weights = [1]*len(self._old_shape)
		assert len(self.filtrations_weights) == len(self._old_shape), "Number of parameter is not consistent."
		
		if self.unsparse:
			self._grid = [np.linspace(start = 0, stop=w, num=s, dtype=int) for w,s in zip(self.filtrations_weights, self._old_shape)] # Enforces weights to be greater than the old shape
		else:
			self._grid = [np.linspace(start = 0, stop=w, num=s) for w,s in zip(self.filtrations_weights, self._old_shape)]
		return self

	def transform(self,X):
		def todo_sparse(x):
			if self.axis is not None:
				x=x[self.axis]
			return [
				tensor_möbius_inversion(
					tensor=signed_measure_matrix,
					grid_conversion=self._grid,
					num_parameters=self.num_parameters
				)
				for signed_measure_matrix in x
			]
		def todo_unsparse(x):
			from torch import sparse_coo_tensor
			if self.axis is not None:
				x=x[self.axis]
			out = []
			for signed_measure_matrix in x:
				indices, values = tensor_möbius_inversion(
					tensor=signed_measure_matrix,
					grid_conversion=self._grid,
					num_parameters=self.num_parameters
			    )
				size = [np.max(f)+1 for f in self._grid]
				# print(indices, size)
				tensor = sparse_coo_tensor(indices=indices.T, values=values, size=size).to_dense()
				out.append(np.asarray(tensor))
			return out
		todo = todo_unsparse if self.unsparse else todo_sparse
		
		return Parallel(n_jobs=self.n_jobs)(delayed(todo)(x) for x in X)











class SignedMeasure2Img(BaseEstimator,TransformerMixin):
	"""
	TODO
	"""
	def __init__(self, filtration_grid:Iterable[np.ndarray]=None, kernel="gaussian", 
	      bandwidth=1., flatten:bool=False, n_jobs:int=1,resolution:int=None, grid_strategy:str="regular",
		  sparse:bool|None = None,progress:bool=False, **kwargs):
		super().__init__()
		self.kernel=kernel
		self.bandwidth=bandwidth
		self.more_kde_kwargs=kwargs
		self.filtration_grid=filtration_grid
		self.flatten=flatten
		self.progress=progress
		self.n_jobs = n_jobs
		self.resolution = resolution
		self.grid_strategy = grid_strategy
		self.sparse=sparse # input is either sparse or not.
		self._is_input_sparse = None
		return
	def fit(self, X, y=None):
		## Infers if the input is sparse given X 
		if len(X) == 0: return self

		if isinstance(X[0][0], tuple):	self._is_input_sparse = True 
		else: self._is_input_sparse = False
		if self.sparse is None:	self.sparse = self._is_input_sparse

		if not self.sparse:	return self # in that case, singed measures are matrices, and the grid is already given
		
		## If not sparse : a grid has to be defined
		if self.filtration_grid is None and self.resolution is None:
			raise Exception("Cannot infer filtration grid. Provide either a filtration grid or a resolution.")
		if self.filtration_grid is None:
			pts = np.concatenate([
				sm[0] for signed_measures in X for sm in signed_measures
			])
			self.filtration_grid = infer_grid_from_points(pts, strategy=self.grid_strategy, num=self.resolution)
		
		return self
	
	def _sparsify(self,sm):
		return tensor_möbius_inversion(input=sm,grid_conversion=self.filtration_grid)

	def _sm2smi(self, signed_measures:Iterable[np.ndarray]):
		return np.concatenate([
				gaussian_filter(input=signed_measure, sigma=self.bandwidth, **self.more_kde_kwargs)
			for signed_measure in signed_measures], axis=0)
	def _sm2smi_sparse(self, signed_measures:Iterable[np.ndarray]):
		return np.concatenate([
				_pts_convolution_sparse(
					pts = signed_measure_pts, pts_weights = signed_measure_weights,
					filtration_grid = self.filtration_grid, 
					kernel=self.kernel,
					bandwidth=self.bandwidth,
					**self.more_kde_kwargs
				)
			for signed_measure_pts, signed_measure_weights  in signed_measures], axis=0)
	def transform(self,X):
		if self.sparse is None :	raise Exception("Fit first")
		todo = SignedMeasure2Img._sm2smi_sparse if self.sparse else  SignedMeasure2Img._sm2smi
		out =  Parallel(n_jobs=self.n_jobs)(delayed(todo)(self, signed_measures) for signed_measures in tqdm(X, desc="Computing images", disable = not self.progress))
		if self.flatten:
			return [x.flatten() for x in out]
		return out




class SignedMeasure2SlicedWassersteinDistance(BaseEstimator,TransformerMixin):
	"""
	Transformer from signed measure to distance matrix.

	Format :
	 - a signed measure : tuple of array. (point position) : npts x (num_paramters) and weigths : npts
	 - each data is a list of signed measure (for e.g. multiple degrees)

	out:
	 - list of distance matrices
	"""
	def __init__(self, n_jobs:int=1, num_directions:int=10, _sliced:bool=True, epsilon=-1, ground_norm=1, progress = False):
		super().__init__()
		self.n_jobs=n_jobs
		self._SWD_list = None
		self._sliced=_sliced
		self.epsilon = epsilon
		self.ground_norm = ground_norm
		self.num_directions = num_directions
		self.progress = False
		return
		
	def fit(self, X, y=None):
		# _DISTANCE = lambda : SlicedWassersteinDistance(num_directions=self.num_directions) if self._sliced else WassersteinDistance(epsilon=self.epsilon, ground_norm=self.ground_norm) # WARNING if _sliced is false, this distance is not CNSD
		if len(X) == 0:	return self
		self.sparse = isinstance(X[0][0], tuple)
		num_degrees = len(X[0])
		self._SWD_list = [
			SlicedWassersteinDistance(num_directions=self.num_directions) 
			if self._sliced else 
			WassersteinDistance(epsilon=self.epsilon, ground_norm=self.ground_norm) 
			for _ in range(num_degrees)
		]
		for degree, swd in enumerate(self._SWD_list):
			signed_measures_of_degree = [x[degree] for x in X]
			if not self.sparse:	signed_measures_of_degree = [tensor_möbius_inversion(tensor=sm) for sm in signed_measures_of_degree]
			swd.fit(signed_measures_of_degree)
		return self
	def transform(self,X):
		assert self._SWD_list is not None, "Fit first"
		out = []
		for degree, swd in tqdm(enumerate(self._SWD_list), desc="Computing distance matrices", total=len(self._SWD_list), disable= not self.progress):
			signed_measures_of_degree = [x[degree] for x in X]
			if not self.sparse:	signed_measures_of_degree = [tensor_möbius_inversion(tensor=sm) for sm in signed_measures_of_degree]
			out.append(swd.transform(signed_measures_of_degree))
		return out
	def predict(self, X): 
		return self.transform(X)

def accuracy_to_csv(X,Y,cl, cln:str, k:float=10, dataset:str = "", filtration:str = "", shuffle=True,  verbose:bool=True, **kwargs):
	import pandas as pd
	assert k > 0, "k is either the number of kfold > 1 or the test size > 0."
	if k>1:
		k = int(k)
		from sklearn.model_selection import StratifiedKFold as KFold
		kfold = KFold(k, shuffle=shuffle).split(X,Y)
		accuracies = np.zeros(k)
		for i,(train_idx, test_idx) in enumerate(tqdm(kfold, total=k, desc="Computing kfold")):
			xtrain = [X[i] for i in train_idx]
			ytrain = [Y[i] for i in train_idx]
			cl.fit(xtrain, ytrain)
			xtest = [X[i] for i in test_idx]
			ytest = [Y[i] for i in test_idx] 
			accuracies[i] = cl.score(xtest, ytest)
			if verbose:
				print(f"step {i+1}, {dataset} : {accuracies[i]}", flush=True)
				try:
					print("Best classification parameters : ", cl.best_params_)
				except:
					None
			
	elif k > 0:
		from sklearn.model_selection import train_test_split
		print("Computing accuracy, with train test split", flush=True)
		xtrain, xtest, ytrain, ytest = train_test_split(X, Y, shuffle=shuffle, test_size=k)
		print("Fitting...", end="", flush=True)
		cl.fit(xtrain, ytrain)
		print("Computing score...", end="", flush=True)
		accuracies = cl.score(xtest, ytest)
		try:
			print("Best classification parameters : ", cl.best_params_)
		except:
			None
		print("Done.")
		if verbose:	print(f"Accuracy {dataset} : {accuracies} ")
	file_path:str = f"result_{dataset}.csv".replace("/", "_").replace(".off", "")
	columns:list[str] = ["dataset", "filtration", "pipeline", "cv", "mean", "std"]
	if exists(file_path):
		df:pd.DataFrame = pd.read_csv(file_path)
	else:
		df:pd.DataFrame = pd.DataFrame(columns= columns)
	more_names = []
	more_values = []
	for key, value in kwargs.items():
		if key not in columns:
			more_names.append(key)
			more_values.append(value)
		else:
			warn(f"Duplicate key {key} ! with values {cln} and {value}")
	new_line:pd.DataFrame = pd.DataFrame([[dataset, filtration, cln, k, np.mean(accuracies), np.std(accuracies)]+more_values], columns = columns+more_names)
	df = pd.concat([df, new_line])
	df.to_csv(file_path, index=False)
