import warnings 
import collections 
import numpy as np
import scipy.linalg
import mgraph
import networkx as nx
import scipy
from math import exp
from scipy.linalg import block_diag
from grakel.graph import Graph
from grakel.kernels import Kernel
from sklearn.exceptions import NotFittedError
from sklearn.utils.validation import check_is_fitted
from numpy import squeeze
from numpy import einsum
from numpy import array
from scipy.sparse import csr_matrix, linalg

class TorelliWasserstein(Kernel):
    """
    The tropical Torelli--Wasserstein kernel

    Parameters
    ----------
    None.

    Attributes
    ----------
    None.

    """

    # Define the graph format that this kernel needs (if needed)
    # _graph_format = "auto" (default: "auto")

    def __init__(self,
                 n_jobs=None,
                 verbose=False,
                 normalize=False,
                 dimbound=100,
                 gamma=0.5,
                 ):
        """Initialise an TW kernel."""

        super(TorelliWasserstein, self).__init__(
            n_jobs=n_jobs, 
            verbose=verbose, 
            normalize=normalize)
        
        self.dimbound = dimbound
        self.gamma = gamma
        self._initialized.update({"dimbound": False, "gamma": False})


    def initialize_(self):
        """Initialize all transformer arguments, needing initialization."""

        super(TorelliWasserstein, self).initialize_()

        # for i=1 .. m
        #     if not self.initialized_["param_needing_initialization_i"]:
        #         # Apply checks (raise ValueError or TypeError accordingly)
        #         # calculate derived fields stored on self._derived_field_ia .. z
        #         self.initialized_["param_needing_initialization_i"] = True
        pass

    def parse_input(self, X):
        """The tropical Torelli feature map.

        Parameters
        ----------
        X : iterable
            For the input to pass the test, we must have:
            Each element must be an iterable with at most three features and at
            least one. The first that is obligatory is a valid graph structure
            (adjacency matrix or edge_dictionary) while the second is
            node_labels and the third edge_labels (that correspond to the given
            graph format). A valid input also consists of graph type objects.

        Returns
        -------
        out : list
            A list of positive definite matrices.

        """
        if not isinstance(X, collections.abc.Iterable):
            raise TypeError('input must be an iterable\n')
        else:
            out = list()
            # default block from grakel
            for (idx, x) in enumerate(iter(X)):
                is_iter = False
                if isinstance(x, collections.abc.Iterable):
                    is_iter, x = True, list(x)
                if is_iter and len(x) in [0, 2, 3]:
                    if len(x) == 0:
                        warnings.warn('Ignoring empty element ' +
                                      'on index: '+str(idx))
                        continue
                    else:
                        x = Graph(x[0], x[1], {}, self._graph_format)
                elif type(x) is Graph:
                    x.desired_format(self._graph_format)
                else:
                    raise TypeError('each element of X must be either a '
                                    'graph or an iterable with at least 1 '
                                    'and at most 3 elements\n')
                
                # Compute the tropical period matrix
                A = x.get_adjacency_matrix()
                temp = mgraph.MetricGraph()
                G = nx.from_numpy_array(A, create_using=temp)
                Q_list = list()
                for cc_nodes in nx.connected_components(G):
                    subgraph = G.subgraph(cc_nodes).copy()
                    MST = nx.minimum_spanning_tree(subgraph)
                    Q_st = subgraph.trop_period(MST)
                    if Q_st.shape[0] == 0:
                        continue
                    Q_list.append(Q_st)
                if Q_list == []:
                    Q = np.array([1e-6])
                else:
                    Q = block_diag(*Q_list)
                m = Q.shape[0]
                if m > self.dimbound:
                    # pick a submatrix with random indexes
                    idx = np.random.choice(m, self.dimbound, replace=False)
                    Q = Q[np.ix_(idx, idx)]
                    padded = Q
                else:
                    padded = np.zeros((self.dimbound, self.dimbound))
                    padded[:m,:m] = Q
                out.append(padded)

            return out
        
    def pairwise_operation(self, x, y):
            """Bures--Wasserstein kernel.

            Parameters
            ----------
            x, y : psd matrices

            Returns
            -------
            kernel : number
            """

            # product inside the square root
            Dx, Vx = scipy.linalg.eigh(x)
            if len(np.where(Dx<0)[0])>0:
                if np.min(Dx[np.where(Dx<0)[0]])<-1e-3:
                    warnings.warn("Matrix A is not positive definite.")
                Dx[np.where(Dx<0)[0]] = 0
            sqrtx = Vx @ np.diag(np.sqrt(Dx)) @ Vx.T
            W = sqrtx @ y @ sqrtx
            W = (W + W.T) / 2 # ensure symmetry for numerical reason
            D, V = scipy.linalg.eigh(W)
            if len(np.where(D<0)[0])>0:
                if np.min(D[np.where(D<0)[0]])<-1e-3:
                    warnings.warn("Matrix ABA is not positive definite.")
                D[np.where(D<0)[0]] = 0
            inner = V @ np.diag(np.sqrt(D)) @ V.T
            matrix_inner = np.trace(inner)

            return matrix_inner

    def _increment_diagonal_(A, value):
        """Increment the diagonal of an array by a value.

        Parameters
        ----------
        A : np.array
            The array whose diagonal will be extracted.

        value : number
            The value that will be incremented on the diagonal.


        Returns
        -------
        None.

        """
        d = A.diagonal()
        d.setflags(write=True)
        d += value


class TorelliEuclidean(Kernel):
    """
    The tropical Torelli--Euclidean kernel

    Parameters
    ----------
    None.

    Attributes
    ----------
    None.

    """

    # Define the graph format that this kernel needs (if needed)
    # _graph_format = "auto" (default: "auto")

    def __init__(self,
                 n_jobs=None,
                 verbose=False,
                 normalize=False,
                 dimbound=100
                 ):
        """Initialise an TE kernel."""

        super(TorelliEuclidean, self).__init__(
            n_jobs=n_jobs, 
            verbose=verbose, 
            normalize=normalize)
        
        self.dimbound = dimbound
        self._initialized.update({"dimbound": False})


    def initialize_(self):
        """Initialize all transformer arguments, needing initialization."""

        super(TorelliEuclidean, self).initialize_()

        # for i=1 .. m
        #     if not self.initialized_["param_needing_initialization_i"]:
        #         # Apply checks (raise ValueError or TypeError accordingly)
        #         # calculate derived fields stored on self._derived_field_ia .. z
        #         self.initialized_["param_needing_initialization_i"] = True
        pass

    def parse_input(self, X):
        """The tropical Torelli feature map.

        Parameters
        ----------
        X : iterable
            For the input to pass the test, we must have:
            Each element must be an iterable with at most three features and at
            least one. The first that is obligatory is a valid graph structure
            (adjacency matrix or edge_dictionary) while the second is
            node_labels and the third edge_labels (that correspond to the given
            graph format). A valid input also consists of graph type objects.

        Returns
        -------
        out : list
            A list of positive semi-definite matrices.

        """
        if not isinstance(X, collections.abc.Iterable):
            raise TypeError('input must be an iterable\n')
        else:
            features = list()
            # default block from grakel
            for (idx, x) in enumerate(iter(X)):
                is_iter = False
                if isinstance(x, collections.abc.Iterable):
                    is_iter, x = True, list(x)
                if is_iter and len(x) in [0, 2, 3]:
                    if len(x) == 0:
                        warnings.warn('Ignoring empty element ' +
                                      'on index: '+str(idx))
                        continue
                    else:
                        x = Graph(x[0], x[1], {}, self._graph_format)
                elif type(x) is Graph:
                    x.desired_format(self._graph_format)
                else:
                    raise TypeError('each element of X must be either a '
                                    'graph or an iterable with at least 1 '
                                    'and at most 3 elements\n')
                
                # Compute the tropical period matrix
                A = x.get_adjacency_matrix()
                temp = mgraph.MetricGraph()
                G = nx.from_numpy_array(A, create_using=temp)
                Q_list = list()
                for cc_nodes in nx.connected_components(G):
                    subgraph = G.subgraph(cc_nodes).copy()
                    MST = nx.minimum_spanning_tree(subgraph)
                    Q_st = subgraph.trop_period(MST)
                    if Q_st.shape[0] == 0:
                        continue
                    Q_list.append(Q_st)
                if Q_list == []:
                    Q = np.array([1e-6])
                else:
                    Q = block_diag(*Q_list)
                m = Q.shape[0]
                if m > self.dimbound:
                    # pick a submatrix with random indexes
                    idx = np.random.choice(m, self.dimbound, replace=False)
                    Q = Q[np.ix_(idx, idx)]
                    padded = Q
                else:
                    padded = np.zeros((self.dimbound, self.dimbound))
                    padded[:m,:m] = Q
                
                features.append(padded.flatten())

            self.sparse_, features =True, csr_matrix(features)

            return features
        
    def _calculate_kernel_matrix(self, Y=None):
        """Calculate the kernel matrix given a target_graph and a kernel.

        Each a matrix is calculated between all elements of Y on the rows and
        all elements of X on the columns.

        Parameters
        ----------
        Y : np.array, default=None
            The array between samples and features.

        Returns
        -------
        K : numpy array, shape = [n_targets, n_inputs]
            The kernel matrix: a calculation between all pairs of graphs
            between targets and inputs. If Y is None targets and inputs
            are the taken from self.X. Otherwise Y corresponds to targets
            and self.X to inputs.

        """
        if Y is None:
            K = self.X.dot(self.X.T)
        else:
            K = Y[:, :self.X.shape[1]].dot(self.X.T)

        if self.sparse_:
            return K.toarray()
        else:
            return K

    def diagonal(self):
        """Calculate the kernel matrix diagonal of the fitted data.

        Parameters
        ----------
        None.

        Returns
        -------
        X_diag : np.array
            The diagonal of the kernel matrix, of the fitted. This consists
            of each element calculated with itself.


        """
        # Check is fit had been called
        check_is_fitted(self, ['X', 'sparse_'])
        try:
            check_is_fitted(self, ['_X_diag'])
        except NotFittedError:
            # Calculate diagonal of X
            if self.sparse_:
                self._X_diag = squeeze(array(self.X.multiply(self.X).sum(axis=1)))
            else:
                self._X_diag = einsum('ij,ij->i', self.X, self.X)
        try:
            # If transform has happened return both diagonals
            check_is_fitted(self, ['_Y'])
            if self.sparse_:
                Y_diag = squeeze(array(self._Y.multiply(self._Y).sum(axis=1)))
            else:
                Y_diag = einsum('ij,ij->i', self._Y, self._Y)
            return self._X_diag, Y_diag
        except NotFittedError:
            # Else just return both X_diag
            return self._X_diag
    
    