import numpy as np
import torch

class GraphBatchCollator(object):
    def __init__(self, sparse_batching, triplets=False, gen_match_label=False):
        if gen_match_label:
            assert not sparse_batching and not triplets
        self.sparse_batching = sparse_batching
        self.triplets = triplets
        self.gen_match_label = gen_match_label

    def __call__(self, data_batch):
        graphs1 = []
        graphs2 = []
        max_nnodes = 0
        if self.triplets:
            graphs3 = []
            for i, (graph1, graph2, graph3) in enumerate(data_batch):
                graphs1.append(graph1)
                graphs2.append(graph2)
                graphs3.append(graph3)
                max_nnodes = max(max_nnodes, graph1['attr_matrix'].shape[0], graph2['attr_matrix'].shape[0])
            return self._collate_graphs(graphs1, max_nnodes), self._collate_graphs(graphs2, max_nnodes), self._collate_graphs(graphs3, max_nnodes)
        elif self.gen_match_label:
            for i, (graph1, graph2, _) in enumerate(data_batch):
                graphs1.append(graph1)
                graphs2.append(graph2)
                max_nnodes = max(max_nnodes, graph1['attr_matrix'].shape[0], graph2['attr_matrix'].shape[0])
            labels = torch.zeros((len(data_batch), max_nnodes, max_nnodes))
            for i, (_, _, (l1, l2)) in enumerate(data_batch):
                for idx1, value in enumerate(l1):
                    idx2 = (l2 == value).nonzero()
                    if len(idx2) > 0:
                        idx2 = idx2.item()
                        labels[i, idx1, idx2] = 1.
            return self._collate_graphs(graphs1, max_nnodes), self._collate_graphs(graphs2, max_nnodes), labels
        else:
            dists = torch.empty(len(data_batch))
            for i, (graph1, graph2, dist) in enumerate(data_batch):
                graphs1.append(graph1)
                graphs2.append(graph2)
                dists[i] = dist
                max_nnodes = max(max_nnodes, graph1['attr_matrix'].shape[0], graph2['attr_matrix'].shape[0])
            return self._collate_graphs(graphs1, max_nnodes), self._collate_graphs(graphs2, max_nnodes), dists

    def _collate_graphs(self, graph_batch, max_nnodes):
        num_nodes = torch.LongTensor([graph['attr_matrix'].shape[0] for graph in graph_batch])
        if graph_batch[0]['edge_attr_matrix'] is None:
            edge_attr_matrix = torch.zeros(0)
        else:
            edge_attr_matrix = torch.cat([graph['edge_attr_matrix'] for graph in graph_batch], dim=0)
        graphsdict = {
            'adj_idx': self._collate_adjidx([graph['adj_idx'] for graph in graph_batch], num_nodes, max_nnodes),
            'attr_matrix': self._pad_matrices([graph['attr_matrix'] for graph in graph_batch], max_nnodes),
            'edge_attr_matrix': edge_attr_matrix,
            'num_nodes': num_nodes,
        }
        return graphsdict

    def _pad_matrices(self, mats, max_len):
        if self.sparse_batching:
            return torch.cat(mats, dim=0)
        trailing_dims = mats[0].shape[1:]
        out_dims = (len(mats), max_len) + trailing_dims
        out_tensor = mats[0].new_zeros(*out_dims)
        for i, tensor in enumerate(mats):
            length = tensor.size(0)
            out_tensor[i, :length, ...] = tensor
        return out_tensor

    def _collate_adjidx(self, adj_idxs, num_nodes, max_nnodes):
        if self.sparse_batching:
            return self._collate_adjidx_sparse(adj_idxs, num_nodes)
        else:
            return self._collate_adjidx_dense(adj_idxs, max_nnodes)

    def _collate_adjidx_dense(self, adj_idxs, max_nnodes):
        batch_size = len(adj_idxs)
        nedges = [adj_idx.size(1) for adj_idx in adj_idxs]
        batch_idx_offset = torch.LongTensor(np.repeat(np.arange(batch_size) * max_nnodes, nedges))
        # b*e
        batch_edge_idx = torch.cat(adj_idxs, dim=1)
        # 2 x b*e
        batch_adj_idx = batch_edge_idx + batch_idx_offset
        # 2 x b*e
        return batch_adj_idx

    def _collate_adjidx_sparse(self, adj_idxs, num_nodes):
        nedges = [adj_idx.size(1) for adj_idx in adj_idxs]

        num_nodes = torch.LongTensor(num_nodes.long())
        cumsum = torch.cumsum(num_nodes, dim=0) - num_nodes

        batch_idx_offset = torch.LongTensor(np.repeat(cumsum, nedges))
        # b*e
        batch_edge_idx = torch.cat(adj_idxs, dim=1)
        # 2 x b*e
        batch_adj_idx = batch_edge_idx + batch_idx_offset
        # 2 x b*e
        return batch_adj_idx
