import random
import numpy as np
import tensorflow as tf
import gudhi as gd
from sklearn.metrics import pairwise_distances

try:
    import sys
    from utils.pto import path_to_oineus
    sys.path.append(path_to_oineus)
    import oineus as oin
except:
    print("Oineus not found.")


def interp(K, A):
    return np.linalg.inv(K).dot(A)


def get_warps(K):
    _, v = np.linalg.eig(np.linalg.inv(K))
    return v


def extend(M, k=2):
    n, m = M.shape[0], M.shape[1]
    A = [[0 for _ in range(m)] for _ in range(n)]
    for i in range(n):
        for j in range(m):
            A[i][j] = M[i, j] * np.eye(k)
    K = np.block(A)
    return K


def get_deformation(X, gradients, idxs, sigma, normalized=False):
    """
    From a point cloud X and a set of gradients, return a global movement MV obtained by
    (roughly) convoluting gradients with a kernel of bandwidth sigma.

    :param X: a (n x d) array representing a point cloud (typically d=2 in our experiments).
    :param gradients: a (n x d) array representing a gradient (one arrow on each point in X, should it be 0)
    :param idxs: a list of q integers, indicating the indices in X where the gradient is non-zero.
    :param sigma: float, > 0, the bandwidth of the kernel (how much we "take the neighborhood with us").
    :param normalized: should the gaussian kernel we use be normalized?

    :returns: MV, a (n x d) array representing the movement to apply to each point in X.
    """
    d = X.numpy().shape[1]
    # The point in X with non-zero gradient,
    PM = X.numpy()[idxs]  # shape (q x d)
    # The corresponding descent direction (opposite of gradient, scaled by learning rate).
    A = gradients[idxs]  # shape (q x d)
    # The gaussian kernel of distances on the non-zero entries
    K = extend(np.exp(-pairwise_distances(PM) ** 2 / (2 * sigma ** 2)), k=d)  # shape (q x d) x (q x d)
    # Compute K_inv.dot(A)
    alpha = interp(K, A.reshape([-1, 1]))  # shape q x d
    # Compute the gaussian kernel between the whole point cloud X and the reference points PM
    KG = extend(np.exp(-pairwise_distances(X.numpy(), PM) ** 2 / (2 * sigma ** 2)), k=d)  # shape (n x d) x (q x d)
    if normalized:
        KG = KG / (2 * np.pi * sigma ** 2) ** (d / 2)
    # Apply this kernel to the alpha previously computed, which gives our global movement.
    MV = KG.dot(alpha).reshape([-1, d])  # shape (n x d)

    return MV


def auto_sigma(X):
    """
    Heuristic for automatic computation of the kernel bandwidth.
    The bandwidth is lower bounded to avoid numerical issues, but the way it is done is questionable: what matter is the
    ratio |x-y|^2 / sigma^2 , 0.1 may be too small if |x-y|^2 is large.

    :param X: (n x d) numpy array
    :returns: float, a bandwidth.
    """
    sigma = np.median(pairwise_distances(X))
    return sigma


def _build_DeformRips(homology_dimension: int,
                      max_edge_length: float,
                      use_deformations: bool,
                      sigma: float,
                      input_dimension: int,
                      subsample_size: int,
                      use_oineus: bool,
                      n_preserved: int):
    """
    This private function is used as a way to define a RipsLayer (in Gudhi and Oineus's styles) that is
    both compatible with automatic differenciation and custom gradients.
    For some reason, with have to encapsulate it in a wrapper so that it
    understands that the parameters are non-derivable.

    :param homology_dimension: int, the homology dimension for the Rips.
    :param max_edge_length: float, the max edge to be used in the Rips filtration.
    :param use_deformations: bool, should we use std gradient (False) or deformation gradient (True)?
    :param sigma: float, the bandwidth of the Gaussian kernel.
    :param input_dimension: the dimension of the input point cloud.
    :param use_oineus: bool, should we use Oineus for making big steps (True) or not (False)?
                        If True, the output will be the subset of critical points forming
                            the critical edges of the Rips filtration + their target values.
                        Make sure you use an adequate loss afterward,
                            of the form sum(length(critical_edges) - target_values)
    :param n_preserved: int, only used if `use_oineus=True` and "simplify" loss of Oineus.
                             The first n-1 points in the diagram are preserved, the others are sent to the diagonal.

    :returns: a DeformRips type of Layer, i.e. a function that can be called in a `tape.gradient` environment.
    """

    @tf.custom_gradient
    def DeformRips(X_input):
        """
        Compute the diagram of a point cloud X_input (typically a tf.Variable)
        with respect to the parameters prescribed above.
        """
        N, D = X_input.shape
        if subsample_size is not None:
            # Pick indices between 1 and N, N being the size of X_input.
            ech = random.choices(range(N), k=subsample_size)
            # build a subsample point cloud Xech
            X = tf.gather(X_input, ech)
        else:
            ech = np.arange(N)  # Useful for consistency in gradient computation...
            X = X_input

        if not use_oineus:
            # Construction with gudhi's API
            DX = tf.norm(tf.expand_dims(X, 1) - tf.expand_dims(X, 0), axis=2)
            rc = gd.RipsComplex(distance_matrix=DX.numpy(), max_edge_length=max_edge_length)
            st = rc.create_simplex_tree(max_dimension=homology_dimension + 1)
            st.compute_persistence(homology_coeff_field=2)
            pairs = st.flag_persistence_generators()
            indices = []
            if homology_dimension == 0:
                finite_pairs = pairs[0]
                essential_pairs = pairs[2]
            else:
                finite_pairs = pairs[1][homology_dimension - 1] if len(pairs[1]) >= homology_dimension else np.empty(
                    shape=[0, 4])
                essential_pairs = pairs[3][homology_dimension - 1] if len(pairs[3]) >= homology_dimension else np.empty(
                    shape=[0, 2])
            finite_indices = np.array(finite_pairs.flatten(), dtype=np.int32)
            essential_indices = np.array(essential_pairs.flatten(), dtype=np.int32)
            indices = (finite_indices, essential_indices)

            if homology_dimension > 0:
                finite_dgm = tf.reshape(tf.gather_nd(DX, tf.reshape(indices[0], [-1, 2])), [-1, 2])
            else:
                reshaped_cur_idx = tf.reshape(indices[0], [-1, 3])
                finite_dgm = tf.concat([tf.zeros([reshaped_cur_idx.shape[0], 1]),
                                        tf.reshape(tf.gather_nd(DX, reshaped_cur_idx[:, 1:]), [-1, 1])], axis=1)

        else:
            # Construction with Oineus' API
            X = tf.cast(X, dtype=tf.float64)
            fil, longest_edges = oin.get_vr_filtration_and_critical_edges(np.array(X.numpy(), dtype=np.float64),
                                                                          max_dim=homology_dimension + 1,
                                                                          max_radius=max_edge_length,
                                                                          n_threads=1)
            top_opt = oin.TopologyOptimizer(fil)
            eps = top_opt.get_nth_persistence(homology_dimension, n_preserved)
            # We can potentially use other losses from Oineus
            # The birth-birth loss is equivalent to "ul.death_killer": all (but n_preserved-1 = 0) points with
            # coordinates (b,d) want to be matched to (b,b).
            # Below, indices: simplices we want to update, and values: values we want to assign to them.
            indices, values = top_opt.simplify(eps, oin.DenoiseStrategy.BirthBirth, homology_dimension)
            # Now, to each simplex to be updated, we assign a critical set (list of indices) to which we may want to assign
            # the same value. The following is thus a list of pairs [(values, indices) ...].
            critical_sets = top_opt.singletons(indices, values)
            # As a given index could appear twice (or more), we need to chose which value to actually assign.
            # The heuristic is to take the maximum. In the following, we eventually store the list of indices to be
            # updated, and their corresponding values.
            crit_indices, crit_values = top_opt.combine_loss(critical_sets, oin.ConflictStrategy.Max)

            # Convert filtration values of simplices into their Rips edges. Indeed, recall that indices correspond to
            # (critical) simplices, i.e. introduction of edges (creating or killing circles).
            crit_indices = np.array(crit_indices, dtype=np.int32)
            crit_edges = longest_edges[crit_indices, :]
            # Two list of pairs (x,y) inducing critical edges.
            crit_edges_x, crit_edges_y = crit_edges[:, 0], crit_edges[:, 1]

            # Now we store all the values we need in what we call the "finite dgm", but beware, it's not a dgm actually.
            # It's an array of size (n) x (2 D + 1) where D is the dimension of the point cloud and n the number of
            # critical points to move.
            # So the first :D coordinates correspond to a point x to move, D:2D correspond to its corresponding y,
            # and the final 2D coordinate is the value we want to assign to this pair.
            finite_dgm = tf.concat([tf.gather(X, crit_edges_x),
                                    tf.gather(X, crit_edges_y),
                                    tf.Variable(np.array(crit_values)[:, None], dtype=tf.float64, trainable=False),
                                    crit_edges_x[:,None],
                                    crit_edges_y[:,None]],
                                   axis=1)

        def grad(dd):

            gradient = np.zeros(shape=X_input.numpy().shape)

            if not use_oineus:
                I = indices[0]
                try:
                    d_dgm = dd.values.numpy()
                except:
                    d_dgm = dd.numpy()

                for idx_p in range(len(d_dgm)):

                    d_px, d_py = d_dgm[idx_p, 0], d_dgm[idx_p, 1]

                    # We need to distinguish between homology dimension == 0, in which case (Rips) birth is always 0,
                    # and thus the gradient only depends on death time (i3, i4 in the following),
                    # and homology dimension > 1, where both birth (i1, i2) and death (i3, i4)
                    # have to be taken into account
                    if homology_dimension > 0:
                        i1, i2, i3, i4 = I[4 * idx_p:4 * (idx_p + 1)]
                        v12 = X[i1, :] - X[i2, :]
                        n12 = np.linalg.norm(v12)
                        gradient[ech[i1], :] += v12 * (d_px / n12)
                        gradient[ech[i2], :] += -v12 * (d_px / n12)
                    else:
                        _, i3, i4 = I[3 * idx_p:3 * (idx_p + 1)]

                    v34 = X[i3, :] - X[i4, :]
                    n34 = np.linalg.norm(v34)
                    gradient[ech[i3], :] += v34 * (d_py / n34)
                    gradient[ech[i4], :] += -v34 * (d_py / n34)

            else:
                num_pts = dd.shape[0]
                gradient[crit_edges_x, :] += dd[:, 0:input_dimension]
                gradient[crit_edges_y, :] += dd[:, input_dimension:2 * input_dimension]

            MV = gradient
            if use_deformations:
                idxs = np.argwhere(np.linalg.norm(gradient, axis=1) > 1e-2).ravel()
                if len(idxs) > 0:
                    MV = get_deformation(X_input, gradient, idxs, sigma, normalized=False)

            return tf.constant(MV, dtype=tf.float64)

        return finite_dgm, grad

    return DeformRips


class DeformRipsLayer(tf.keras.layers.Layer):
    """
    This is the class to be imported by the user. See the doc of _build_DeformRips for parameters descriptions.

    It is simply used as (e.g.) DRL = DeformRipsLayer(blabla), then dgm = DRL(X).
    """

    def __init__(self,
                 homology_dimension: int,
                 input_dimension: int,
                 max_edge_length: float,
                 use_deformations: bool = False,
                 sigma: float = None,
                 subsample_size: int = None,
                 use_oineus: bool = False,
                 n_preserved: int = 1):
        super(DeformRipsLayer, self).__init__()
        self.layer = _build_DeformRips(homology_dimension=homology_dimension,
                                       max_edge_length=max_edge_length,
                                       use_deformations=use_deformations,
                                       sigma=sigma,
                                       input_dimension=input_dimension,
                                       subsample_size=subsample_size,
                                       use_oineus=use_oineus,
                                       n_preserved=n_preserved)

    def call(self, X):
        return self.layer(X)
