

import gudhi as _gd
import matplotlib.pyplot as _plt
from matplotlib.cm import get_cmap as _get_cmap
import sys as _sys
from matplotlib.transforms import Bbox
import numpy as _np
from shapely.geometry import box as _rectangle_box
from shapely.geometry import MultiPolygon, Polygon
from shapely.ops import unary_union
from torch import threshold
import mma 

from matplotlib.patches import Rectangle as _Rectangle
"""
Defines a rectangle patch in the format {z | x  ≤ z ≤ y} with color and alpha
"""
def _rectangle(x,y,color, alpha):
	return _Rectangle(x, max(y[0]-x[0],0),max(y[1]-x[1],0), color=color, alpha=alpha)

def _d_inf(a,b):
	if type(a) != _np.ndarray or type(b) != _np.ndarray:
		a = _np.array(a)
		b = _np.array(b)
	return _np.min(_np.abs(b-a))

def plot_2d_summand(birth_list, death_list, box, min_interleaving = 0, save=False, dpi=200, xlabel=None, ylabel=None):
	cmap = _get_cmap("Spectral")
	trivial_summand = True
	list_of_rect = []
	for birth in birth_list:
		for death in death_list:
			if death[1]>birth[1] and death[0]>birth[0]:
				if trivial_summand and _d_inf(birth,death)>min_interleaving:
					trivial_summand = False
				list_of_rect.append(_rectangle_box(birth[0], birth[1], death[0],death[1]))
	_, ax = _plt.subplots()
	ax.set(xlim=[box[0][0],box[1][0]],ylim=[box[0][1],box[1][1]])
	summand_shape = unary_union(list_of_rect)
	if type(summand_shape) == Polygon:
		xs,ys=summand_shape.exterior.xy
		ax.fill(xs,ys, ec='None')
	else:
		for polygon in summand_shape.geoms:
			xs,ys=polygon.exterior.xy
			ax.fill(xs,ys, ec='None')
	_plt.show()

def approx(simplextree, filters, precision=0.01, box=mma.Box([0],[0]),*,verbose=False, threshold=True):
	if (type(simplextree) == _gd.simplex_tree.SimplexTree):
		boundary = mma.simplex_tree_to_boundary_matrix(simplextree.thisptr)
	else:
		boundary = simplextree
	if type(box) == list:
		bbox = mma.Box(box[0], box[1])
	elif type(box) == mma.Box:
		bbox = box
	else:
		print("Cannot interpret box.")
		return
	return mma._approx(boundary, filters, precision, bbox, verbose=verbose, threshold = threshold)


def vine_alt(B, filters, precision, box = [], dimension = -1, threshold=False, multithread = False, verbose = False):
	if box == [] and (type(filters) == _np.ndarray):
		box = [[min(filters[:,0]),min(filters[:,1])],[max(filters[:,0]),max(filters[:,1])]]
	if box == [] and (type(filters) == list):
		box = [[min(filters[0]), min(filters[1])],[max(filters[0]), max(filters[1])]]
	if(type(filters) == _np.ndarray):
		assert filters.shape[1] == 2
		filtration = [filters[:,0], filters[:,1]]
	else:
		filtration = filters
	if dimension ==-1: # if dimension is not specified we return every dimension
		if (type(B) == _gd.simplex_tree.SimplexTree):
			return compute_vineyard_barcode(simplextree_to_sparse_boundary(B), filtration, precision, Box(box), threshold, multithread, verbose)
		return compute_vineyard_barcode(B,filtration,precision, Box(box), threshold, multithread, verbose)
	if (type(B) == _gd.simplex_tree.SimplexTree):
		return compute_vineyard_barcode_in_dimension(simplextree_to_sparse_boundary(B), filtration, precision, Box(box), dimension, threshold, verbose)
	return compute_vineyard_barcode_in_dimension(B,filtration,precision, Box(box), dimension, threshold, verbose)


def plot_vine_2d(matrix, filters, precision, box=[], dimension=0, return_barcodes=False, separated = False, multithread = True, save=False, dpi=50):
	if box == [] and (type(filters) == _np.ndarray):
		box = [[min(filters[:,0]),min(filters[:,1])],[max(filters[:,0]),max(filters[:,1])]]
	if box == [] and (type(filters) == list):
		box = [[min(filters[0]), min(filters[1])],[max(filters[0]), max(filters[1])]]
	temp = vine_alt(matrix, filters, precision, box, dimension = dimension, threshold = True, multithread = False)
	#barcodes = _np.array([_np.array([ _np.array([z for z in y]) for y in x]) for x in temp])
	barcodes = temp
	cmap = _get_cmap("Spectral")
	n=len(barcodes)
	#number_of_trivial_features=0
	for matching in range(n):
		trivial = True
		for line in range(len(barcodes[matching])):
			birth = barcodes[matching][line][0]
			death = barcodes[matching][line][1]
			if((birth ==[]) or (death == []) or (death == birth) or (birth[0] == _sys.float_info.max)):	continue
			trivial = False
			if(death[0] != _sys.float_info.max and death[1] != _sys.float_info.max  and birth[0] != _sys.float_info.max):
				_plt.plot([birth[0], death[0]], [birth[1],death[1]], c=cmap((matching)/(n)))
		if(not(trivial)):
			_plt.xlim(box[0][0], box[1][0])
			_plt.ylim(box[0][1], box[1][1])
		#if trivial:
			#number_of_trivial_features+=1
		if separated and not(trivial) :
			_plt.show()
	if(save):	_plt.savefig(save, dpi=dpi)
	_plt.show()
	if(return_barcodes):
		return barcodes



def plot_approx_2d(B, filters, precision=0.1, box = [], dimension=-1, return_corners=False, separated=False, min_interleaving = 0, multithread = False, complete=True, alpha=1, verbose = False, save=False, dpi=200, shapely = True):
	if alpha >= 1:
		shapely = False # Not sure which one is quicker in that case.
	if box == [] and (type(filters) == _np.ndarray):
		box = [[min(filters[:,0]),min(filters[:,1])],[max(filters[:,0]),max(filters[:,1])]]
	if box == [] and (type(filters) == list):
		box = [[min(filters[0]), min(filters[1])],[max(filters[0]), max(filters[1])]]

	module = approx(B,filters,precision,box=box,threshold=1, multithread = multithread, complete = complete, verbose = verbose)
	if dimension < 0:
		maxDim = module[-1].get_dimension()
		for i in range(maxDim + 1):
			plot_approx(module.get_module_of_dimension(i), precision=precision, box=box, return_corners=return_corners, separated = separated, min_interleaving=min_interleaving, alpha=alpha, dpi=dpi, shapely=shapely, save = save)
	else:
		plot_approx(module.get_module_of_dimension(dimension), precision=precision, box=box, return_corners=return_corners, separated = separated, min_interleaving=min_interleaving, alpha=alpha, dpi=dpi, shapely=shapely, save = save)



def plot_module(corners, xlim, ylim, separated=False, min_interleaving = 0, alpha=1, verbose = False, save=False, dpi=200, shapely = True, xlabel=None, ylabel=None):
	cmap = _get_cmap("Spectral")
	if not(separated):
		fig, ax = _plt.subplots()
		ax.set(xlim=xlim,ylim=ylim)
	n_summands = len(corners)
	for i in range(n_summands):
		trivial_summand = True
		list_of_rect = []
		for birth in corners[i][0]:
			for death in corners[i][1]:
				if death[1]>birth[1] and death[0]>birth[0]:
					if trivial_summand and _d_inf(birth,death)>min_interleaving:
						trivial_summand = False
					if shapely:
						list_of_rect.append(_rectangle_box(birth[0], birth[1], death[0],death[1]))
					else:
						list_of_rect.append(_rectangle(birth,death,cmap(i/n_summands),alpha))
		if not(trivial_summand):
			if separated:
				fig,ax= _plt.subplots()
				ax.set(xlim=xlim,ylim=ylim)
			# if shapely:
			summand_shape = unary_union(list_of_rect)
			if type(summand_shape) == Polygon:
				xs,ys=summand_shape.exterior.xy
				ax.fill(xs,ys,alpha=alpha, fc=cmap(i/n_summands), ec='None')
			else:
				for polygon in summand_shape.geoms:
					xs,ys=polygon.exterior.xy
					ax.fill(xs,ys,alpha=alpha, fc=cmap(i/n_summands), ec='None')
			# else:
			# 	for rectangle in list_of_rect:
			# 		ax.add_patch(rectangle)
			# if separated:
			# 	if xlabel:
			# 		_plt.xlabel(xlabel)
			# 	if ylabel:
			# 		_plt.ylabel(ylabel)
			# 	_plt.show()
	if not(separated):
		if xlabel != None:
			_plt.xlabel(xlabel)
		if ylabel != None:
			_plt.ylabel(ylabel)
	_plt.show()
		# return fig, ax
	
