import torch
from torch_geometric.data import Data
from torch_geometric.loader import DataLoader as TGDataLoader

class BipartitePair(Data):
    def __inc__(self, key, value, *args, **kwargs):
        if key == 'edge_index_a':
            return self.x_a.size(0)
        if key == 'edge_index_f':
            return self.x_f.size(0)
        if key == 'edge_index_af':
            return torch.tensor([[self.x_a.size(0)], [self.x_f.size(0)]])
        else:
            return super(BipartitePair, self).__inc__(key, value)


class BatchPos(Data):
    def __cat_dim__(self, key, value, *args, **kwargs):
        if key == 'pos':
            return None
        else:
            return super(BatchPos, self).__cat_dim__(key, value)

def DataLoader(dataset, batch_size=1, shuffle=False, **kwargs):
    if hasattr(dataset[0], 'x_a'):
        return TGDataLoader(dataset, batch_size, shuffle, follow_batch = ['x_a', 'x_f'],**kwargs)
    else:
        return TGDataLoader(dataset, batch_size, shuffle, **kwargs)
