import torch
from torch_geometric.data import Data


class BatchMasking(Data):
    """A plain old python object modeling a batch of graphs as one big
    (disconnected) graph. With :class:`torch_geometric.data.Data` being the
    base class, all its methods can also be used here.
    In addition, single graphs can be reconstructed via the assignment vector
    :obj:`batch`, which maps each node to its respective graph identifier."""

    def __init__(self, batch=None, **kwargs):
        super(BatchMasking, self).__init__(**kwargs)
        self.batch = batch

    @staticmethod
    def from_data_list(data_list):
        """Constructs a batch object from a python list holding
        :class:`torch_geometric.data.Data` objects.
        The assignment vector :obj:`batch` is created on the fly."""
        keys = [set(data.keys) for data in data_list]
        keys = list(set.union(*keys))
        assert 'batch' not in keys

        batch = BatchMasking()

        for key in keys:
            batch[key] = []
        batch.batch = []

        cumsum_node = 0
        cumsum_edge = 0

        for i, data in enumerate(data_list):
            num_nodes = data.num_nodes
            batch.batch.append(torch.full((num_nodes,), i, dtype=torch.long))
            for key in data.keys:
                item = data[key]
                if key in ['edge_index', 'masked_atom_indices']:
                    item = item + cumsum_node
                elif key == 'connected_edge_indices':
                    item = item + cumsum_edge
                batch[key].append(item)

            cumsum_node += num_nodes
            cumsum_edge += data.edge_index.shape[1]

        for key in keys:
            batch[key] = torch.cat(batch[key], dim=data_list[0].__cat_dim__(key, batch[key][0]))
        batch.batch = torch.cat(batch.batch, dim=-1)
        return batch.contiguous()

    def cumsum(self, key, item):
        """If :obj:`True`, the attribute :obj:`key` with content :obj:`item`
        should be added up cumulatively before concatenated together.
        .. note::
            This method is for internal use only, and should only be overridden
            if the batch concatenation process is corrupted for a specific data
            attribute."""
        return key in ['edge_index', 'face',
                       'masked_atom_indices',
                       'connected_edge_indices']

    @property
    def num_graphs(self):
        """Returns the number of graphs in the batch."""
        return self.batch[-1].item() + 1


class BatchAE(Data):
    """A plain old python object modeling a batch of graphs as one big
    (disconnected) graph. With :class:`torch_geometric.data.Data` being the
    base class, all its methods can also be used here.
    In addition, single graphs can be reconstructed via the assignment vector
    :obj:`batch`, which maps each node to its respective graph identifier. """

    def __init__(self, batch=None, **kwargs):
        super(BatchAE, self).__init__(**kwargs)
        self.batch = batch

    @staticmethod
    def from_data_list(data_list):
        """Constructs a batch object from a python list holding
        :class:`torch_geometric.data.Data` objects.
        The assignment vector :obj:`batch` is created on the fly."""
        keys = [set(data.keys) for data in data_list]
        keys = list(set.union(*keys))
        assert 'batch' not in keys

        batch = BatchAE()

        for key in keys:
            batch[key] = []
        batch.batch = []

        cumsum_node = 0

        for i, data in enumerate(data_list):
            num_nodes = data.num_nodes
            batch.batch.append(torch.full((num_nodes,), i, dtype=torch.long))
            for key in data.keys:
                item = data[key]
                if key in ['edge_index', 'negative_edge_index']:
                    item = item + cumsum_node
                batch[key].append(item)

            cumsum_node += num_nodes

        for key in keys:
            batch[key] = torch.cat(
                batch[key], dim=batch.__cat_dim__(key))
        batch.batch = torch.cat(batch.batch, dim=-1)
        return batch.contiguous()

    @property
    def num_graphs(self):
        '''Returns the number of graphs in the batch.'''
        return self.batch[-1].item() + 1

    def __cat_dim__(self, key):
        return -1 if key in ['edge_index', 'negative_edge_index'] else 0


class BatchSubstructContext(Data):
    """A plain old python object modeling a batch of graphs as one big
    (disconnected) graph. With :class:`torch_geometric.data.Data` being the
    base class, all its methods can also be used here.
    In addition, single graphs can be reconstructed via the assignment vector
    :obj:`batch`, which maps each node to its respective graph identifier. """

    ''' Specialized batching for substructure context pair! '''

    def __init__(self, batch=None, **kwargs):
        super(BatchSubstructContext, self).__init__(**kwargs)
        self.batch = batch

    @staticmethod
    def from_data_list(data_list):
        """Constructs a batch object from a python list holding
        :class:`torch_geometric.data.Data` objects.
        The assignment vector :obj:`batch` is created on the fly."""
        # keys = [set(data.keys) for data in data_list]
        # keys = list(set.union(*keys))
        # assert 'batch' not in keys

        # 'x', 'masked_atom_indices', 'edge_attr', 'KG_index', 'mask_node_label', 'edge_index'
        # keys = [set(data.keys) for data in data_list]
        # keys = list(set.union(*keys))
        # print('keys\t', keys)

        batch = BatchSubstructContext()
        keys = [
            'center_substruct_idx', 'edge_attr_substruct',
            'edge_index_substruct', 'x_substruct', 'overlap_context_substruct_idx',
            'edge_attr_context', 'edge_index_context', 'x_context'
        ]

        for key in keys:
            batch[key] = []

        batch.batch = []
        batch.batch_overlapped_context = []
        batch.overlapped_context_size = []

        cumsum_main = 0
        cumsum_substruct = 0
        cumsum_context = 0

        i = 0

        for data in data_list:
            if hasattr(data, 'x_context'):
                num_nodes = data.num_nodes
                num_nodes_substruct = len(data.x_substruct)
                num_nodes_context = len(data.x_context)
                # print(data.x.size(), '\t', data.x_substruct.size(), '\t', data.x_context.size(), '\t', num_nodes)
                # print(data)
                # print(data.keys)
                # print()

                batch.batch.append(torch.full((num_nodes,), i, dtype=torch.long))
                batch.batch_overlapped_context.append(
                    torch.full((len(data.overlap_context_substruct_idx),), i, dtype=torch.long))
                batch.overlapped_context_size.append(len(data.overlap_context_substruct_idx))

                # batching for the substructure graph
                for key in ['center_substruct_idx', 'edge_attr_substruct',
                            'edge_index_substruct', 'x_substruct']:
                    item = data[key]
                    item = item + cumsum_substruct if batch.cumsum(key, item) else item
                    batch[key].append(item)

                # batching for the context graph
                for key in ['overlap_context_substruct_idx', 'edge_attr_context',
                            'edge_index_context', 'x_context']:
                    item = data[key]
                    item = item + cumsum_context if batch.cumsum(key, item) else item
                    batch[key].append(item)

                cumsum_main += num_nodes
                cumsum_substruct += num_nodes_substruct
                cumsum_context += num_nodes_context
                i += 1

        for key in keys:
            batch[key] = torch.cat(batch[key], dim=batch.__cat_dim__(key))
        batch.batch = torch.cat(batch.batch, dim=-1)
        batch.batch_overlapped_context = torch.cat(batch.batch_overlapped_context, dim=-1)
        batch.overlapped_context_size = torch.LongTensor(batch.overlapped_context_size)

        return batch.contiguous()

    def __cat_dim__(self, key):
        return -1 if key in ['edge_index', 'edge_index_substruct', 'edge_index_context'] else 0

    def cumsum(self, key, item):
        """If :obj:`True`, the attribute :obj:`key` with content :obj:`item`
        should be added up cumulatively before concatenated together.
        .. note::
            This method is for internal use only, and should only be overridden
            if the batch concatenation process is corrupted for a specific data
            attribute. """
        return key in ['edge_index', 'edge_index_substruct',
                       'edge_index_context',
                       'overlap_context_substruct_idx',
                       'center_substruct_idx']

    @property
    def num_graphs(self):
        """Returns the number of graphs in the batch."""
        return self.batch[-1].item() + 1


class BatchSubstructContext3D(Data):
    """A plain old python object modeling a batch of graphs as one big
    (disconnected) graph. With :class:`torch_geometric.data.Data` being the
    base class, all its methods can also be used here.
    In addition, single graphs can be reconstructed via the assignment vector
    :obj:`batch`, which maps each node to its respective graph identifier. """

    ''' Specialized batching for substructure context pair! '''

    def __init__(self, batch=None, **kwargs):
        super(BatchSubstructContext, self).__init__(**kwargs)
        self.batch = batch

    @staticmethod
    def from_data_list(data_list):
        """Constructs a batch object from a python list holding
        :class:`torch_geometric.data.Data` objects.
        The assignment vector :obj:`batch` is created on the fly."""
        # keys = [set(data.keys) for data in data_list]
        # keys = list(set.union(*keys))
        # print('original keys\t', keys)
        # assert 'batch' not in keys

        batch = BatchSubstructContext()
        keys = [
            'center_substruct_idx', 'edge_attr_substruct',
            'edge_index_substruct', 'x_substruct', 'overlap_context_substruct_idx',
            'edge_attr_context', 'edge_index_context', 'x_context',
            'positions', 'x', 'edge_attr', 'edge_index'
        ]
        # print('neo keys\t', keys)

        for key in keys:
            batch[key] = []

        batch.batch = []
        batch.batch_overlapped_context = []
        batch.overlapped_context_size = []

        cumsum_main = 0
        cumsum_substruct = 0
        cumsum_context = 0

        i = 0

        for data in data_list:
            if hasattr(data, 'x_context'):
                num_nodes = data.num_nodes
                num_nodes_substruct = len(data.x_substruct)
                num_nodes_context = len(data.x_context)

                batch.batch.append(torch.full((num_nodes,), i, dtype=torch.long))
                batch.batch_overlapped_context.append(
                    torch.full((len(data.overlap_context_substruct_idx),), i, dtype=torch.long))
                batch.overlapped_context_size.append(len(data.overlap_context_substruct_idx))

                # batching for the main graph
                for key in ['x', 'edge_attr', 'edge_index', 'positions']:
                    item = data[key]
                    if key in ['edge_index']:
                       item = item + cumsum_main
                    batch[key].append(item)

                # batching for the substructure graph
                for key in ['center_substruct_idx', 'edge_attr_substruct',
                            'edge_index_substruct', 'x_substruct']:
                    item = data[key]
                    item = item + cumsum_substruct if batch.cumsum(key, item) else item
                    batch[key].append(item)

                # batching for the context graph
                for key in ['overlap_context_substruct_idx', 'edge_attr_context',
                            'edge_index_context', 'x_context']:
                    item = data[key]
                    item = item + cumsum_context if batch.cumsum(key, item) else item
                    batch[key].append(item)

                cumsum_main += num_nodes
                cumsum_substruct += num_nodes_substruct
                cumsum_context += num_nodes_context
                i += 1

        for key in keys:
            batch[key] = torch.cat(batch[key], dim=batch.__cat_dim__(key))
        batch.batch = torch.cat(batch.batch, dim=-1)
        batch.batch_overlapped_context = torch.cat(batch.batch_overlapped_context, dim=-1)
        batch.overlapped_context_size = torch.LongTensor(batch.overlapped_context_size)

        return batch.contiguous()

    def __cat_dim__(self, key):
        return -1 if key in ['edge_index', 'edge_index_substruct', 'edge_index_context'] else 0

    def cumsum(self, key, item):
        """If :obj:`True`, the attribute :obj:`key` with content :obj:`item`
        should be added up cumulatively before concatenated together.
        .. note::
            This method is for internal use only, and should only be overridden
            if the batch concatenation process is corrupted for a specific data
            attribute. """
        return key in ['edge_index', 'edge_index_substruct',
                       'edge_index_context',
                       'overlap_context_substruct_idx',
                       'center_substruct_idx']

    @property
    def num_graphs(self):
        """Returns the number of graphs in the batch."""
        return self.batch[-1].item() + 1
