import torch
import numpy as np


def preprocess_adj(edges: torch.Tensor, tot_nodes: int):
    """
    Input:
        edges: the indices of edges.
    Output:
        a sparse binary self-loop-augmented adj.
    """
    device = edges.device
    # add self-loops
    aug_edges = torch.LongTensor([range(tot_nodes), range(tot_nodes)])
    tot_edges = torch.cat([edges, aug_edges.to(device)], 1)
    # fill the entries with 1
    tot_elems = torch.ones(tot_edges.shape[1]).to(device)
    # create the sparse matrix
    aug_adj = torch.sparse.FloatTensor(tot_edges, tot_elems, torch.Size([tot_nodes, tot_nodes]))

    return aug_adj


def normalize_sym(bin_mat: torch.Tensor):
    """A -> D^{-.5}AD^{-.5}, A is a binary sparse torch tensor."""
    sparse = bin_mat.is_sparse
    if sparse:
        degrees = torch.sparse.sum(bin_mat, 1)._values()
    else:
        degrees = bin_mat.sum(1)
    dinv = degrees.pow(-.5).flatten()
    dinv[torch.isinf(dinv)] = 0.
    dinv_diag = torch.diag(dinv)
    if sparse:
        norm_mat = torch.sparse.mm(bin_mat.t(), dinv_diag).t()
        norm_mat = torch.mm(norm_mat, dinv_diag).to_sparse()
    else:
        norm_mat = torch.mm(bin_mat.t(), dinv_diag).t()
        norm_mat = torch.mm(norm_mat, dinv_diag)

    return norm_mat


def normalize_row(unnorm_mat: torch.Tensor):
    device = unnorm_mat.device
    sparse = unnorm_mat.is_sparse
    if sparse:
        degrees = torch.sparse.sum(unnorm_mat, 1)._values()
    else:
        degrees = unnorm_mat.sum(1)
    dinv = degrees.pow(-1.).flatten()
    dinv[torch.isinf(dinv)] = 0.
    dinv_diag = torch.diag(dinv)
    if sparse:
        norm_mat = torch.sparse.mm(unnorm_mat.t(), dinv_diag).t().to_sparse()
    else:
        norm_mat = torch.mm(dinv_diag, unnorm_mat)

    return norm_mat


def preprocess_feat(feat, phi, mu=3e-4, normalize=False):
    """create feat = cat(phi*mu', feat) for task-learning"""
    mag_div = feat.sum(-1) / phi.sum(-1)
    out_fts = torch.cat([mu * mag_div.unsqueeze(-1) * phi, feat], dim=-1)
    if normalize:
        out_fts = normalize_row(out_fts)
    return out_fts


class graph_batch(object):
    def __init__(self, batch_graph: list, device='cuda:0'):
        """
        Input:
            batch_graph [list]: a list of graphs
        """
        super(graph_batch, self).__init__()
        self.device = device
        self._graph_batch = batch_graph

        self._get_batch_labels()
        self._get_batch_fts()
        self._count_nodes_n_edges()
        self._get_whole_batch_graph()
        self._get_normalized_batch_graph()
        

    def _get_batch_labels(self):
        """
        Concatenate all node labels on the node dimension: [N1, N2, ...]
        get tensor attributes: self.labels.
        """
        labels = torch.LongTensor([graph.label for graph in self._graph_batch])
        setattr(self, 'labels', labels.to(self.device))

    def _get_batch_fts(self):
        """
        Concatenate all node features on the node dimension: [N1*K, N2*K, ...] -> [(N1+N2+...), K],
        get tensor attributes: self.all_fts.
        """
        all_fts = torch.cat([graph.node_features for graph in self._graph_batch], 0)
        setattr(self, 'all_fts', all_fts.to(self.device))

    def _count_nodes_n_edges(self):
        """
        Count nodes & edges in each graph in the batch, also get some summary statistics,
        get non-tensor attributes: self.start_idx_, self.num_nodes_, self.tot_nodes (total nodes), self.pos_weight, self.neg_weight,
        get tensor attribute: self.kl_weights.
        """
        # nodes statistics
        num_nodes_ = np.array([len(graph.g) for graph in self._graph_batch])
        start_idx_ = np.array([0] + np.cumsum(num_nodes_).tolist())
        setattr(self, 'num_nodes_', num_nodes_)
        setattr(self, 'start_idx_', start_idx_)
        setattr(self, 'tot_nodes', start_idx_[-1])

        # kl weights: 1-D tensor, the average weights for each row of KL mat
        # n1 * [1/n1] + n2 * [2/n2] + ... + nB * [1/nB]
        kl_weights = torch.cat([torch.tensor(n * [1 / n]) for n in num_nodes_])
        setattr(self, 'kl_weights', kl_weights.to(self.device))

    def _get_whole_batch_graph(self):
        """
        Create (sparse) block diagonal adj, also store the original adjs in the batch in a list
        get tensor attributes: self.bin_adj, self.bin_adj_
        """
        tot_edge_list = []
        for i, graph in enumerate(self._graph_batch):
            block_edges = graph.edge_mat + self.start_idx_[i]
            tot_edge_list.append(block_edges)
        tot_edge_mat = torch.cat(tot_edge_list, 1)

        bin_adj = preprocess_adj(tot_edge_mat, self.tot_nodes)
        bin_adj_ = [preprocess_adj(graph.edge_mat, len(graph.g)).to(self.device) for graph in self._graph_batch]
        setattr(self, 'bin_adj', bin_adj.to(self.device))
        setattr(self, 'bin_adj_', bin_adj_)

    def _get_normalized_batch_graph(self):
        """
        Create (sparse) normalized block diagonal adj,
        get tensor attributes: self.norm_adj.
        """
        if self.tot_nodes <= 4000:
            norm_adj = normalize_sym(self.bin_adj)
        else:
            i = self.bin_adj.coalesce().indices()   # indices sorted in row-major
            v = torch.cat([normalize_sym(bin_adj).coalesce().values() for bin_adj in self.bin_adj_], dim=0)
            norm_adj = torch.sparse_coo_tensor(i, v, [self.tot_nodes, self.tot_nodes])

        setattr(self, 'norm_adj', norm_adj)     # self.norm_adj is on the same device as self.bin_adj and self.bin_adj_



class node_batch(object):
    def __init__(self, adj, features, labels, ind_dict, device='cuda:0'):
        """
        Input:
            adj     : [torch.Tensor][sparse] binary sparse adj,
            features: [torch.Tensor] node features,
            labels  : [torch.Tensor] node labels,
            ind_dict: [dict] a dictionary of tensors, keys() = ('train', 'val', 'test'),
            device  : [torch.device] set to 'cuda:0' by default.
        """
        # attr related with node classification
        self.bin_adj = adj
        self.norm_adj = normalize_sym(adj)
        self.features = features
        self.labels = labels
        self.idx_train = ind_dict["train"]
        self.idx_val = ind_dict["val"]
        self.idx_test = ind_dict["test"]

        # attr related with graph reconstruction
        num_edges = torch.sparse.sum(adj) if adj.is_sparse else adj.sum()
        num_negs = adj.shape[0] ** 2 - num_edges
        self.pos_weight = num_negs / num_edges
        self.norm = adj.shape[0] ** 2 / (2. * num_negs)

        # attr adapts to EPM_VAE_Loss
        self.pos_lab_multiple = .5 * num_edges.pow(-1.)
        self.neg_lab_multiple = .5 * num_negs.pow(-1.)
        self.kl_weights = torch.ones(adj.shape[0], device=adj.device) * (1 / adj.shape[0])

        self.device = device
        self._device_shuttle()

    def _device_shuttle(self):
        for attr in self.__dict__:
            if torch.is_tensor(self.__dict__[attr]):
                self.__dict__[attr] = self.__dict__[attr].to(self.device)


class S2VGraph(object):
    def __init__(self, g, label, node_tags=None, node_features=None):
        '''
            g: a networkx graph
            label: an integer graph label
            node_tags: a list of integer node tags
            node_features: a torch float tensor, one-hot representation of the tag that is used as input to neural nets
            edge_mat: a torch long tensor, contain edge list, will be used to create torch sparse tensor
            neighbors: list of neighbors (without self-loop)
        '''
        self.label = label
        self.g = g
        self.node_tags = node_tags
        self.neighbors = []
        self.node_features = 0
        self.edge_mat = 0

        self.max_neighbor = 0









