"""
Implementes graph edits as basis for graph edit networks.

"""

# REVIEWER COPY; DO NOT DISTRIBUTE!

import abc
import copy
import numpy as np


class Edit(abc.ABC):

    @abc.abstractmethod
    def apply(self, A, X):
        """ Applies this edit to the given graph and returns a copy of the
        graph with the applied changes. The original graph remains unchanged.

        Parameters
        ----------
        A: class numpy.array
            An adjacency matrix.
        X: class numpy.array
            A node attribute list/matrix.

        Returns
        -------
        B: class numpy.array
            The output adjacency matrix.
        Y: class numpy.array
            The output attribute matrix.

        """
        pass

    @abc.abstractmethod
    def apply_in_place(self, A, X):
        """ Applies this edit to the given graph. Note that this changes the
        input arguments.

        Parameters
        ----------
        A: class numpy.array
            An adjacency matrix.
        X: class numpy.array
            A node attribute list/matrix.

        """
        pass

    @abc.abstractmethod
    def score(self, N):
        """ Transforms this edit to a N-dimensional score vector (for node operations)
        and a N x N score matrix (for edge operations) where an entry is +1 if the
        respective node spawns a new node/if the respective edge is inserted and
        -1 if the respective node/edge is deleted.

        Parameters
        ----------
        N: int
            The size of the graph.

        Returns
        -------
        delta: class numpy.array
            A N dimensional score vector with a +1 entry for new spawned nodes
            and a -1 entry for deleted nodes.
        Epsilon: class numpy.array
            a N x N score matrix with a +1 entry for new edges and a -1 entry
            for deleted edges.

        """
        pass

class NodeDeletion(Edit):

    def __init__(self, index):
        self._index = index

    def apply(self, A, X):
        """ Deletes node self._index.

        Parameters
        ----------
        A: class numpy.array
            An adjacency matrix.
        X: class numpy.array
            A node attribute list/matrix.

        Returns
        -------
        B: class numpy.array
            The output adjacency matrix.
        Y: class numpy.array
            The output attribute matrix.

        """
        A = np.delete(np.delete(A, (self._index), axis=0), (self._index), axis=1)
        X = np.delete(X, (self._index), axis=0)
        return A, X

    def apply_in_place(self, nodes, adj):
        """ Deletes node self._index.

        Parameters
        ----------
        A: class numpy.array
            An adjacency matrix.
        X: class numpy.array
            A node attribute list/matrix.

        """
        # This is unsupported for deletions
        raise ValueError('apply_in_place can not be supported for node deletions because the size of numpy matrices can not be changed in place')

    def score(self, N):
        """ Transforms this edit to a N-dimensional score vector, where
        entry Y[self._index] = -1 and every other entry is zero.

        Parameters
        ----------
        N: int
            The size of the graph.

        Returns
        -------
        delta: class numpy.array
            a N-dimensional score vector, where entry y[self._index] = -1 and
            every other entry is zero.
        Epsilon: class numpy.array
            a N x N zero matrix

        """
        y = np.zeros(N)
        y[self._index] = -1
        return y, np.zeros((N, N))

    def __repr__(self):
        return 'del(%d)' % (self._index)

    def __str__(self):
        return self.__repr__()

    def __eq__(self, other):
        return isinstance(other, NodeDeletion) and self._index == other._index

class NodeInsertion(Edit):

    def __init__(self, index, attribute, directed = True):
        self._index = index
        self._attribute = attribute
        self._directed = directed

    def apply(self, A, X):
        """ Inserts a new node into the graph with self._attribute and
        connectes it to node self._index.

        Parameters
        ----------
        A: class numpy.array
            An adjacency matrix.
        X: class numpy.array
            A node attribute list/matrix.

        Returns
        -------
        B: class numpy.array
            The output adjacency matrix.
        Y: class numpy.array
            The output attribute matrix.

        """
        N = len(A)
        F = X.shape[1]
        A_new = np.zeros((N + 1, N + 1))
        X_new = np.zeros((N + 1, F))

        A_new[:N, :][:, :N] = A
        A_new[self._index, N] = 1
        if(not self._directed):
            A_new[N, self._index] = 1

        X_new[:N, :] = X
        X_new[N, :] = self._attribute
        return A_new, X_new

    def apply_in_place(self, nodes, adj):
        """ Deletes node self._index.

        Parameters
        ----------
        A: class numpy.array
            An adjacency matrix.
        X: class numpy.array
            A node attribute list/matrix.
        """
        # This is unsupported for deletions
        raise ValueError('apply_in_place can not be supported for Node insertions because the size of numpy matrices can not be changed in place')

    def score(self, N):
        """ Transforms this edit to a N-dimensional score vector, where
        entry Y[self._index] = +1 and every other entry is zero.

        Parameters
        ----------
        N: int
            The size of the graph.

        Returns
        -------
        delta: class numpy.array
            a N-dimensional score vector, where entry y[self._index] = +1 and
            every other entry is zero.
        Epsilon: class numpy.array
            a N x N zero matrix

        """
        y = np.zeros(N)
        y[self._index] = +1
        return y, np.zeros((N, N))

    def __repr__(self):
        return 'ins(%d, %s)' % (self._index, self._attribute)

    def __str__(self):
        return self.__repr__()

    def __eq__(self, other):
        if not isinstance(other, NodeInsertion):
            return False
        if isinstance(self._attribute, np.ndarray):
            if not np.array_equal(self._attribute, other._attribute):
                return False
        else:
            if self._attribute != other._attribute:
                return False
        return self._index == other._index and self._directed == other._directed


class EdgeDeletion(Edit):

    def __init__(self, i, j, directed = True):
        self._i = i
        self._j = j
        self._directed = directed

    def apply(self, A, X):
        """ Deletes the edge from node self._i to node self._j.

        Parameters
        ----------
        A: class numpy.array
            An adjacency matrix.
        X: class numpy.array
            A node attribute list/matrix.

        Returns
        -------
        B: class numpy.array
            The output adjacency matrix.
        Y: class numpy.array
            The output attribute matrix.

        """
        A = np.copy(A)
        X = np.copy(X)
        self.apply_in_place(A, X)
        return A, X

    def apply_in_place(self, A, X):
        """ Deletes the edge from node self._i to node self._j.

        Parameters
        ----------
        A: class numpy.array
            An adjacency matrix.
        X: class numpy.array
            A node attribute list/matrix.

        """
        A[self._i, self._j] = 0
        if(not self._directed):
            A[self._j, self._i] = 0

    def score(self, N):
        """ Transforms this edit to a score matrix where the (i,j)th entry
        is set to -1 and everything else is zero.

        Parameters
        ----------
        N: int
            The size of the graph.

        Returns
        -------
        delta: class numpy.array
            A N dimensional zero vector.
        Epsilon: class numpy.array
            A N x N matrix where entry Y[i,j] = -1 and everything else is zero.

        """
        y = np.zeros(N)
        Y = np.zeros((N, N))
        Y[self._i, self._j] = -1
        if(not self._directed):
            Y[self._j, self._i] = -1
        return y, Y

    def __repr__(self):
        return 'del_edge(%d, %d)' % (self._i, self._j)

    def __str__(self):
        return self.__repr__()

    def __eq__(self, other):
        return isinstance(other, EdgeDeletion) and self._i == other._i and self._j == other._j and self._directed == other._directed

class EdgeInsertion(Edit):

    def __init__(self, i, j, directed = True):
        self._i = i
        self._j = j
        self._directed = directed

    def apply(self, A, X):
        """ Inserts a new edge from node self._i to node self._j.

        Parameters
        ----------
        A: class numpy.array
            An adjacency matrix.
        X: class numpy.array
            A node attribute list/matrix.

        Returns
        -------
        B: class numpy.array
            The output adjacency matrix.
        Y: class numpy.array
            The output attribute matrix.

        """
        A = np.copy(A)
        X = np.copy(X)
        self.apply_in_place(A, X)
        return A, X

    def apply_in_place(self, A, X):
        """ Inserts a new edge from node self._i to node self._j.

        Parameters
        ----------
        A: class numpy.array
            An adjacency matrix.
        X: class numpy.array
            A node attribute list/matrix.

        """
        A[self._i, self._j] = 1
        if(not self._directed):
            A[self._j, self._i] = 1

    def score(self, N):
        """ Transforms this edit to a score matrix where the (i,j)th entry
        is set to +1 and everything else is zero.

        Parameters
        ----------
        N: int
            The size of the graph.

        Returns
        -------
        delta: class numpy.array
            A N dimensional zero vector.
        Epsilon: class numpy.array
            A N x N matrix where entry Y[i,j] = +1 and everything else is zero.

        """
        y = np.zeros(N)
        Y = np.zeros((N, N))
        Y[self._i, self._j] = +1
        if(not self._directed):
            Y[self._j, self._i] = +1
        return y, Y

    def __repr__(self):
        return 'ins_edge(%d, %d)' % (self._i, self._j)

    def __str__(self):
        return self.__repr__()

    def __eq__(self, other):
        return isinstance(other, EdgeInsertion) and self._i == other._i and self._j == other._j and self._directed == other._directed
