import numpy               as np
import tensorflow          as tf
import gudhi               as gd

############################
# Vietoris-Rips filtration #
############################

# The parameters of the model are the point coordinates.

def Rips(DX, mel, dim, card):
    # Parameters: DX (distance matrix), 
    #             mel (maximum edge length for Rips filtration), 
    #             dim (homological dimension), 
    #             card (number of persistence diagram points, sorted by distance-to-diagonal)

    # Compute the persistence pairs with Gudhi
    rc = gd.RipsComplex(distance_matrix=DX, max_edge_length=mel)
    st = rc.create_simplex_tree(max_dimension=dim+1)
    
    
    dgm = st.persistence()           # return: list of pairs(dimension,pair(birth,death))
    pairs = st.persistence_pairs()   # return: list of persistence simplices intervals, type: list of pair of list of int

    # Retrieve vertices v_a and v_b by picking the ones achieving the maximal
    # distance among all pairwise distances between the simplex vertices
    indices, pers = [], []
    
    for s1, s2 in pairs:
        # 1 dimensional homology (cycle) is always created by connecting two points (1 dimensional simplex, 2 vertices)
        if len(s1) == dim+1 and len(s2) > 0:
            l1, l2 = np.array(s1), np.array(s2)         # list of indices of simplex vertex
            # (v_a,v_b): pair of vertex indices within the simplex with maximum distance (order does not matter since DX is symmetric) 
#       
            i1 = [s1[v] for v in np.unravel_index(np.argmax(DX[l1,:][:,l1]),[len(s1), len(s1)])]   # list of length 2
            i2 = [s2[v] for v in np.unravel_index(np.argmax(DX[l2,:][:,l2]),[len(s2), len(s2)])]
            indices += i1
            indices += i2
            pers.append(st.filtration(s2) - st.filtration(s1))   # d_death - d_born
    
    print('number of 1-dim features:',len(pers))
    # Sort points with distance-to-diagonal (ascending order)
    perm = np.argsort(pers)                # ascending order
    # simplex vertices with maximum distance of each pair (4 of them), sorted and flattened in descending order of filtration value
    indices = list(np.reshape(indices, [-1,4])[perm][::-1,:].flatten())   # descending order
    
    # Output indices
    # if number of persistence pairs is less than 'card', attach zeros to make up for it
    indices = indices[:4*card] + [0 for _ in range(0,max(0,4*card-len(indices)))]
    return list(np.array(indices, dtype=np.int32))
    
class RipsModel2(tf.keras.Model):
    def __init__(self, mel=12, dim=1, card=50):
        super(RipsModel2, self).__init__()
        self.mel = mel      # maximum edge length
        self.dim = dim
        self.card = card
        
    def call(self,X):
        m, d, c = self.mel, self.dim, self.card
        
        # Compute distance matrix
        # expand_dims: insert a length 1 axis at index 'axis'
        # tf.expand_dims(self.X, 1): (300,1,2) - tf.expand_dims(self.X, 2): (1,300,2) --> (300,300,2)
        # reduce_sum: computes sum along a given dimension
        DX = tf.math.sqrt(tf.reduce_sum((tf.expand_dims(X, 1)-tf.expand_dims(X, 0))**2, 2))  # Euclidean distance matrix (300,300)    
        DXX = tf.reshape(DX, [1, DX.shape[0], DX.shape[1]])   # (1,300,300)
        
        # Turn numpy function into tensorflow function
        # func: must accepts np arrays as input and outputs np array
        # inp: a list of tf.Tensor objects, must match the inputs of 'func'
        # Tout: a list/tuple of tensorflow data types, indicating what 'func' returns 
        RipsTF = lambda DX: tf.numpy_function(Rips, [DX, m, d, c], [tf.int32 for _ in range(4*c)])
        
        # Compute vertices associated to positive and negative simplices 
        # Don't compute gradient for this operation
        # tf.nest.map_structure: applies 'func' to each entry in 'structure' and returns a new structure
        # tf.map_fn: transforms 'elems' by applying 'fn' to each element unstacked on axis 0
        ids = tf.nest.map_structure(tf.stop_gradient, tf.map_fn(RipsTF,DXX,dtype=[tf.int32 for _ in range(4*c)]))
        
        # Get persistence diagram by simply picking the corresponding entries in the distance matrix
        if d > 0:
            # find entries in distance matrix corresponding to the selected persistence pairs
            # obtain distance pairs array (c,2) for all the selected persistence pairs 
            # distance: the filtration value of the simplex
            dgm = tf.reshape(tf.gather_nd(DX, tf.reshape(ids, [2*c,2])), [c,2])
        else:
            ids = tf.reshape(ids, [2*c,2])[1::2,:]
            dgm = tf.concat([tf.zeros([c,1]), tf.reshape(tf.gather_nd(DX, ids), [c,1])], axis=1)
            
        return dgm
 