import torch
#from torch_geometric.data import DataLoader
from trainable_scattering.models.scatter import Scatter, Scatter_Diffuse_Second

class FastScatterTransform():
    def __init__(self, device='cpu'):
        self.device = device

    def __call__(self, data):
        data.x = Scatter(data.x.shape[1])(data)
        return data

class FastScatterTransformSort():
    def __init__(self, device='cpu'):
        self.device = device

    def __call__(self, data):
        data.x = Scatter_Diffuse_Second(data.x.shape[1])(data)
        return data

