from graph_learning.dataset.graph import GLGraph, edge_batch

class DataTransform(object):
    def transform(self, graph):
        if isinstance(graph, GLGraph):
            return self._transform(graph)
        elif isinstance(graph, list):
            for i, g in enumerate(graph):
                graph[i] = self.transform(g)
            return graph

    def _transform(self, data):
        """ Do data transform.

        Parameters
        ----------
        data: object
            Data object to be transformed.

        Returns
        -------
        object
            Transformsed data.
        """
        raise NotImplementedError

class DataPipeline(object):
    """ Data transform pipeline.
    """
    def __init__(self):
        self.transformers = []

    def append(self, transformer):
        """ Add data transformer.
        """
        self.transformers.append(transformer)

    def transform(self, data):
        """ Do data transform pipeline.

        Parameters
        ----------
        data: object
            Data object to be transformed.

        Returns
        -------
        object
            Transformsed data.
        """

        for t in self.transformers:
            data = t.transform(data)
        return data
