
from typing import Callable, Iterable,List, Optional
import multipers as mp
from multipers.ml.tools import filtration_grid_to_coordinates
import numpy as np
from joblib import Parallel,  delayed
from sklearn.base import BaseEstimator, TransformerMixin
from multipers.multiparameter_module_approximation import PyModule
from tqdm import tqdm

from multipers.simplex_tree_multi import SimplexTreeMulti
reduce_grid = mp.simplex_tree_multi.SimplexTreeMulti._reduce_grid



class SimplexTree2MMA(BaseEstimator, TransformerMixin):
	"""
	Turns a list of simplextrees to MMA approximations
	"""
	def __init__(self,n_jobs=-1, expand_dim:Optional[int]=None, prune_degrees_above:Optional[int]=None, progress=False, **persistence_kwargs) -> None:
		super().__init__()
		self.persistence_args = persistence_kwargs
		self.n_jobs=n_jobs
		self._has_axis=None
		self._num_axis=None
		self.prune_degrees_above=prune_degrees_above
		self.progress=progress
		self.expand_dim=expand_dim
		self._boxes=None
		return		
	def fit(self, X, y=None):
		if len(X) == 0:
			return self
		self._has_axis = not isinstance(X[0], mp.SimplexTreeMulti)
		if self._has_axis:
			try: 
				X[0][0]
			except IndexError:
				print(f"IndexError, {X[0]=}")
				if len(X[0]) == 0:
					print("No simplextree found, maybe you forgot to give a filtration parameter to the previous pipeline")
				raise IndexError
			assert isinstance(X[0][0], mp.SimplexTreeMulti), f"X[0] is not a simplextre, {X[0]=}, and X[0][0] neither."
			self._num_axis = len(X[0]) 
			filtration_values = np.asarray([[x[axis].filtration_bounds() for x in X] for axis in range(self._num_axis)]) 
			num_parameters = filtration_values.shape[-1]
			## Output : axis, data, min/max, num_parameters
			# print("TEST : NUM PARAMETERS ", num_parameters)
			m = np.asarray([[filtration_values[axis,:,0,parameter].min() for parameter in range(num_parameters)] for axis in range(self._num_axis)])
			M = np.asarray([[filtration_values[axis,:,1,parameter].max() for parameter in range(num_parameters)] for axis in range(self._num_axis)])
			## shape of m/M axis,num_parameters
			self._boxes = [[m_of_axis,M_of_axis] for m_of_axis,M_of_axis in zip(m,M)] 
		else:
			filtration_values = np.asarray([x.filtration_bounds() for x in X])
			num_parameters = filtration_values.shape[-1]
			# print("TEST : NUM PARAMETERS ", num_parameters)
			m = np.asarray([filtration_values[:,0,parameter].min() for parameter in range(num_parameters)])
			M = np.asarray([filtration_values[:,1,parameter].max() for parameter in range(num_parameters)])
			self._boxes = [m,M]
		return self
	def transform(self,X):
		if self.prune_degrees_above is not None:
			for x in X:
				if self._has_axis:
					for x_ in x:	
						x_.prune_above_dimension(self.prune_degrees_above) # we only do for H0 for computational ease
				else:
					x.prune_above_dimension(self.prune_degrees_above) # we only do for H0 for computational ease

		def todo1(x:mp.SimplexTreeMulti,box): 
			# print(x.get_filtration_grid(resolution=3, grid_strategy="regular"))
			# print("TEST BOX",box)
			if self.expand_dim is not None:
				x.expansion(self.expand_dim)
			return x.persistence_approximation(box=box,verbose=False,**self.persistence_args)
		# if self._has_axis:
		# 	def todo(sts:List[SimplexTreeMulti]): 
		# 		return [todo1(st,box) for st,box in zip(sts,self._boxes)]
		# else:
		# 	def todo(x:SimplexTreeMulti):
		# 		return todo1(x,self._boxes)
		def todo(sts:List[SimplexTreeMulti]|SimplexTreeMulti):
			if self._has_axis:
				assert not isinstance(sts,SimplexTreeMulti)
				return [todo1(st,box) for st,box in zip(sts,self._boxes)]
			assert isinstance(sts,SimplexTreeMulti)
			return todo1(sts,self._boxes)
		return Parallel(n_jobs=self.n_jobs, backend="threading")(delayed(todo)(x) for x in tqdm(X, desc="Computing modules", disable = not self.progress))


class MMAFormatter(BaseEstimator, TransformerMixin):

	def __init__(self, degrees:list=[0,1], axis=None, verbose:bool=False, normalize:bool=False,weights=None, quantiles=None, dump=False,from_dump=False):
		self._module_bounds=None
		self.verbose=verbose
		self.axis=axis
		self._axis=[]
		self._has_axis=None
		self._num_axis=0
		self.degrees=degrees
		self.normalize = normalize
		self._num_parameters = None
		self.weights = weights
		self.quantiles=quantiles
		self.dump=dump
		self.from_dump=from_dump

	@staticmethod
	def _maybe_from_dump(X_in):
		if len(X_in) == 0: 
			return X_in
		import pickle
		if isinstance(X_in[0], bytes):
			X = [pickle.loads(mods) for mods in X_in]
		else:
			X = X_in
		return X
		# return [[mp.multiparameter_module_approximation.from_dump(mod) for mod in mods] for mods in dumped_modules]

	@staticmethod
	def _get_module_bound(x,degree):
		"""
		Output format : (2,num_parameters)
		"""
		# l,L = x.get_box()
		filtration_values = x.get_module_of_degree(degree).get_filtration_values(unique=True)
		out = np.array([[f[0],f[-1]] for f in filtration_values if len(f)>0 ]).T
		if len(out) != 2:
			print(f"Missing degree {degree} here !")
			m = M = [np.nan for _ in range(x.num_parameters)]
		else:
			m,M = out
		# m = np.where(m<np.inf, m, l)
		# M = np.where(M>-np.inf, M,L)
		return m,M
	
	@staticmethod
	def _infer_axis(X):
		has_axis = not isinstance(X[0], PyModule)
		assert not has_axis or isinstance(X[0][0], PyModule)
		return has_axis
	
	@staticmethod
	def _infer_num_parameters(X,ax=slice(None)):
		return X[0][ax].num_parameters
	
	@staticmethod 
	def _infer_bounds(X, degrees=None, axis=[slice(None)], quantiles=None):
		"""
		Compute bounds of filtration values of a list of modules.

		Output Format
		-------------
		m,M of shape : (num_axis,num_degrees,2,num_parameters)
		"""
		if degrees is None:
			degrees = np.arange(X[0][axis[0]].max_degree+1)
		bounds = np.array([[[MMAFormatter._get_module_bound(x[ax],degree) for degree in degrees] for ax in axis] for x in X])
		if quantiles is not None:
			qm,qM = quantiles
			# TODO per axis, degree !!
			# m = np.quantile(bounds[:,:,:,0,:], q=qm,axis=0)
			# M = np.quantile(bounds[:,:,:,1,:], q=1-qM,axis=0)
			num_pts, num_axis,num_degrees,_,num_parameters = bounds.shape
			m = [[[np.nanquantile(bounds[:,ax,degree,0,parameter], axis=0, q=qm) for parameter in range(num_parameters)] for degree in range(num_degrees)] for ax in range(num_axis)]
			m = np.asarray(m)
			M = [[[np.nanquantile(bounds[:,ax,degree,1,parameter], axis=0, q=1-qM) for parameter in range(num_parameters)] for degree in range(num_degrees)] for ax in range(num_axis)]
			M = np.asarray(M)
		else:
			num_pts, num_axis,num_degrees,_,num_parameters = bounds.shape
			m = [[[np.nanmin(bounds[:,ax,degree,0,parameter], axis=0) for parameter in range(num_parameters)] for degree in range(num_degrees)] for ax in range(num_axis)]
			m = np.asarray(m)
			M = [[[np.nanmax(bounds[:,ax,degree,1,parameter], axis=0) for parameter in range(num_parameters)] for degree in range(num_degrees)] for ax in range(num_axis)]
			M = np.asarray(M)
			# m = bounds[:,:,:,0,:].min(axis=0)
			# M = bounds[:,:,:,1,:].max(axis=0)
		return (m,M)
	
	@staticmethod
	def _infer_grid(X:List[PyModule], strategy:str,resolution:int, degrees=None):
		"""
		Given a list of PyModules, computes a multiparameter discrete grid,
		with a given strategy, 
		from the filtration values of the summands of the modules.
		"""
		num_parameters = X[0].num_parameters
		if degrees is None:
			## Format here : ((filtration values of parameter) for parameter) 
			filtration_values = tuple(mod.get_filtration_values(unique=True) for mod in X)
		else:
			filtration_values = tuple(mod.get_module_of_degrees(degrees).get_filtration_values(unique=True) for mod in X)
		
		if "_mean" in strategy:
			substrategy = strategy.split("_")[0]
			processed_filtration_values = [reduce_grid(f, resolution, substrategy, unique=False) for f in filtration_values]
			reduced_grid = np.mean(processed_filtration_values, axis=0)
		# elif "_quantile" in strategy:
		#	substrategy = strategy.split("_")[0]
		#	processed_filtration_values = [reduce_grid(f, resolution, substrategy, unique=False) for f in filtration_values]
		#	reduced_grid = np.qu(processed_filtration_values, axis=0)
		else:
			filtration_values = [np.unique(np.concatenate([f[parameter] for f in filtration_values], axis=0)) for parameter in range(num_parameters)]
			reduced_grid = reduce_grid(filtration_values, resolution, strategy,unique=True)

		coordinates, new_resolution = filtration_grid_to_coordinates(reduced_grid, return_resolution=True)
		return coordinates,new_resolution
	
	def fit(self, X_in, y=None):
		X = self._maybe_from_dump(X_in)
		if len(X) == 0:
			return self
		self._has_axis = self._infer_axis(X)
		# assert not self._has_axis or isinstance(X[0][0], mp.PyModule)
		if self.axis is None and self._has_axis:
			self.axis = -1
		if self.axis is not None and not (self._has_axis):
			raise Exception(f"SMF didn't find an axis, but requested axis {self.axis}")
		if self._has_axis:
			self._num_axis = len(X[0])
		if self.verbose:
			print('-----------MMAFormatter-----------')
			print('---- Infered stats')
			print(f'Found axis : {self._has_axis}, num : {self._num_axis}')
			print(f'Number of parameters : {self._num_parameters}')
		self._axis = [slice(None)] if self.axis is None else range(self._num_axis) if self.axis == -1 else [self.axis]

		self._num_parameters = self._infer_num_parameters(X, ax=self._axis[0])
		if self.normalize:
			# print(self._axis)
			self._module_bounds = self._infer_bounds(X,self.degrees, self._axis, self.quantiles)
		else:
			m = np.zeros((self._num_axis,len(self.degrees),self._num_parameters))
			M = m+1
			self._module_bounds = (m,M)
		assert self._num_parameters == self._module_bounds[0].shape[-1]
		if self.verbose:
			print('---- Bounds (only computed if normalize):')
			if self._has_axis and self._num_axis>1:
				print('(axis) x (degree) x (parameter)')
			else:
				print('(degree) x (parameter)')
			m,M = self._module_bounds
			print('-- Lower bound : ', m.shape)
			print(m)
			print('-- Upper bound :', M.shape)
			print(M)
		w = 1 if self.weights is None else np.asarray(self.weights)
		m,M = self._module_bounds
		normalizer = M-m
		zero_normalizer = normalizer==0
		if np.any(zero_normalizer):
			from warnings import warn
			warn(f"Encountered empty bounds. Please fix me. \n M-m = {normalizer}")
		normalizer[zero_normalizer] = 1
		self._normalization_factors = w/normalizer
		if self.verbose:
			print('-- Normalization factors:', self._normalization_factors.shape)
			print(self._normalization_factors)

		if self.verbose:
			print('---- Module size :')
			for ax in self._axis:
				print(f'- Axis {ax}')
				for degree in self.degrees:
					sizes = [len(x[ax].get_module_of_degree(degree)) for x in X]
					print(f' - Degree {degree} size {np.mean(sizes).round(decimals=2)}±{np.std(sizes).round(decimals=2)}')
			print('----------------------------------')
		return self
	
	@staticmethod
	def copy_transform(mod, degrees, translation, rescale_factors, new_box):
		copy = mod.get_module_of_degrees(degrees) # and only returns the specific degrees
		for j,degree in enumerate(degrees): 
			copy.translate(translation[j], degree=degree)
			copy.rescale(rescale_factors[j], degree=degree)
		copy.set_box(new_box)
		return copy

	def transform(self, X_in):
		X = self._maybe_from_dump(X_in)
		if np.any(self._normalization_factors != 1):
			if self.verbose: print("Normalizing...", end="")
			w = [1]*self._num_parameters if self.weights is None else np.asarray(self.weights)
			standard_box = mp.multiparameter_module_approximation.PyBox([0]*self._num_parameters, w)
			
			X_copy = [[self.copy_transform(
						mod=x[ax],
						degrees=self.degrees, 
						translation=-self._module_bounds[0][i],
						rescale_factors = self._normalization_factors[i], 
						new_box=standard_box)
				for i,ax in enumerate(self._axis)]
			for x in X]
			if self.verbose:
				print("Done.")
			return X_copy
		if self.axis != -1:
			X = [x[self.axis] for x in X] 
		if self.dump:
			import pickle
			X = [pickle.dumps(mods) for mods in X]
		return X
		# return [todo(x) for x in X]

class MMA2IMG(BaseEstimator, TransformerMixin):
	def __init__(self, 
			degrees:list, 
			bandwidth:float=0.1, 
			power:float=1, 
			normalize:bool=False, 
			resolution:list|int=50, 
			plot:bool=False, 
			box = None,
			n_jobs=1,
			flatten=False,
			progress=False,
			grid_strategy="regular",
		):
		self.bandwidth=bandwidth
		self.degrees = degrees
		self.resolution=resolution
		self.box=box
		self.plot = plot 
		self._box=None
		self.normalize = normalize
		self.power = power
		self._has_axis=None
		self._num_parameters=None
		self.n_jobs=n_jobs
		self.flatten=flatten
		self.progress=progress
		self.grid_strategy=grid_strategy
		self._num_axis=None
		self._coords_to_compute=None
		self._new_resolutions=None
	def fit(self, X, y=None):
		# TODO infer box
		# TODO rescale module
		self._has_axis = MMAFormatter._infer_axis(X)
		if self._has_axis:
			self._num_axis = len(X[0])
		if self.box is None:
			self._box = [[0],[1,1]]
		else:
			self._box = self.box
		if self._has_axis:
			its = (tuple(x[axis] for x in X) for axis in range(self._num_axis))
			crs = tuple(MMAFormatter._infer_grid(X_axis, self.grid_strategy,self.resolution, degrees=self.degrees) for X_axis in its)
			self._coords_to_compute = [c for c,_ in crs] ## not the same resolutions, so cannot be put in an array
			self._new_resolutions = np.asarray([r for _, r in crs])
		else:
			coords, new_resolution = MMAFormatter._infer_grid(X, self.grid_strategy,self.resolution, degrees=self.degrees)
			self._coords_to_compute = coords
			self._new_resolutions = new_resolution
		return self

	def transform(self, X):
		img_args = {
			"delta":self.bandwidth,
			"p":self.power,
			"normalize" : self.normalize,
			# "plot":self.plot,
			# "cb":1, # colorbar
			# "resolution" : self.resolution, # info in coordinates
			"box" : self.box,
			"degrees" : self.degrees,
			"n_jobs":self.n_jobs, # num_jobs is better for parallel over modules.
		}
		if self._has_axis:
			todo1 = lambda x, c : x._compute_pixels(c, **img_args)
		else:
			todo1 = lambda x : x._compute_pixels(self._coords_to_compute, **img_args)[None,:] # shape same as has_axis
		
		if self._has_axis:
			todo2 = lambda mods : [todo1(mod,c) for mod,c in zip(mods, self._coords_to_compute)]
		else:
			todo2 = todo1
		
		if self.flatten:
			todo = lambda mods : np.concatenate(todo2(mods),axis=1).flatten()
		else:
			todo = lambda mods : [img.reshape(len(img_args["degrees"]),*r) for img,r in zip(todo2(mods), self._new_resolutions)]

		return Parallel(n_jobs=self.n_jobs, backend="threading")(delayed(todo)(x) for x in tqdm(X, desc="Computing images", disable = not self.progress)) ## res depends on ax (infer_grid)






class MMA2Landscape(BaseEstimator, TransformerMixin):
	"""
	Turns a list of MMA approximations into Landscapes vectorisations
	"""
	def __init__(self, resolution=[100,100], degrees:list[int]|None = [0,1], ks:Iterable[int]=range(5), phi:Callable = np.sum, box=None, plot:bool=False, n_jobs=-1, filtration_quantile:float=0.01) -> None:
		super().__init__()
		self.resolution:list[int]=resolution
		self.degrees = degrees
		self.ks=ks
		self.phi=phi # Has to have a axis=0 !
		self.box = box
		self.plot = plot
		self.n_jobs=n_jobs
		self.filtration_quantile = filtration_quantile
		return
	def fit(self, X, y=None):
		if len(X) <= 0:	return
		assert X[0].num_parameters == 2, f"Number of parameters {X[0].num_parameters} has to be 2."
		if self.box is None:
			_bottom = lambda mod : mod.get_bottom()
			_top = lambda mod : mod.get_top()
			m = np.quantile(Parallel(n_jobs=self.n_jobs, backend="threading")(delayed(_bottom)(mod) for mod in X), q=self.filtration_quantile, axis=0)
			M = np.quantile(Parallel(n_jobs=self.n_jobs, backend="threading")(delayed(_top)(mod) for mod in X), q=1-self.filtration_quantile, axis=0)
			self.box=[m,M]
		return self
	def transform(self,X)->list[np.ndarray]:
		if len(X) <= 0:	return
		todo = lambda mod : np.concatenate([
				self.phi(mod.landscapes(ks=self.ks, resolution = self.resolution, degree=degree, plot=self.plot), axis=0).flatten()
				for degree in self.degrees
			]).flatten()
		return Parallel(n_jobs=self.n_jobs, backend="threading")(delayed(todo)(x) for x in X)
