import torch
import numpy as np
import scipy.sparse as s
from bsgnn.utils.data import load_pickle
from torch.utils.data import Dataset


class InMemory(Dataset):
    def __init__(self, graphs, node_features, labels, sampler=None):
        """
        Args:
            graphs (list): List of graphs as networkx.Graph objects.
            node_feature (list): List of node features as np.array(s).
            labels (list): List of labels or regression targets.
            sampler (object): Random sample balls in graphs.
        """
        self.graphs = np.array(graphs, dtype=object)
        self.node_features = np.array(node_features, dtype=object)
        self.labels = np.array(labels)
        self.sampler = sampler 

    def __len__(self):
        return len(self.labels)

    def __getitem__(self, idx):
        if torch.is_tensor(idx):
            idx = idx.tolist()
        graph = self.graphs[idx]
        feature = self.node_features[idx]
        label = self.labels[idx]
        data = (graph, feature, label)
        if self.sampler:
            data = self.sampler(data)
        return data


def process_batch(adjs, features):
    """Convert list of adjacency matrices to one block diagonal one.

    Args:
        adjs: (list) A list of COO sparse adjacency matrices.
        features: (list) A list of Tensors for each graph's features.
    """
    graph_sizes = [a.shape[0] for a in adjs]
    blkadj_dim = sum(graph_sizes)
    blk_adj = s.block_diag(adjs)
    feat = torch.cat(features, axis=0)
    blk_adj = torch.sparse_coo_tensor(np.stack((blk_adj.row, blk_adj.col)), 
                                      blk_adj.data, (blkadj_dim, blkadj_dim))
    return blk_adj.float(), feat, graph_sizes


def collate_single_sample(data):
    """Collate function for a batch of simple sample data.
    Returns tuple of lists from input list of tuples.
    
    Args:
        data: (list) Contains (graph, feature, label) tuples.

    """
    return zip(*data)


def collate_multiple_sample(data):
    """Collate function for a batch of multi-samples data.
    Returns multiple tuples of lists from input list of tuples.
    
    Args:
        data: (list) Contains (graph, feature, label) tuples.

    """
    return zip(*data)
