import torch
import numpy as np
import networkx as nx
import scipy.sparse as sp


def get_ball(G, v, r):
    """Get a ball of radius r in G about v. 
    This is equivalent to r-hops neighborhood of v.

    Args:
        G: (nx.Graph) Graph.
        v: (int) Node index.
        r: (int) Radius.
    """
    neighbors = list(nx.single_source_shortest_path_length(G, v, r).keys())
    ball = nx.subgraph(G, neighbors)
    mapping = list(ball.nodes())
    ball = nx.relabel.convert_node_labels_to_integers(ball)
    return ball, mapping


def _get_random_ball(G, v, r, b):
    """Get a random ball of radius r in G about v with branching factor b.
    This function returns a set of vertices rather than subgraph.

    Args:
        G: (nx.Graph) Graph.
        v: (int) Node index.
        r: (int) Radius.
        b: (int) Branching factor.
    """
    layers = [[] for _ in range(r+1)]
    layers[0].append(v)
    for i in range(1, r+1):
        for u in layers[i-1]:
            layers[i].extend(np.random.choice([*G.neighbors(u)], b))
    return set(sum(layers, []))


def get_random_balls(G, r, b, k):
    """Get k random balls and return the induced subgraph.

    Args:
        G: (nx.Graph) Graph.
        r: (int) Radius.
        b: (int) Branching factor.
        k: (int) Number of random balls to sample.
    """
    root_vertices = np.random.choice(G.nodes(), k)
    union = set()
    for rv in root_vertices:
        union.update(_get_random_ball(G, rv, r, b))
    balls = nx.subgraph(G, union)
    mapping = list(balls.nodes())
    balls = nx.relabel.convert_node_labels_to_integers(balls)
    return balls, mapping


class AdjRandomBallsSampler(object):
    """Randomly sample a subgraph from random balls 
    and return a sparse adjacency matrix.

    Args:
        r: (int) Radius of a random ball.
        b: (int) Branching factor.
        k: (int) Number of random balls to sample.
    """
    def __init__(self, r, b, k):
        self.r = r
        self.b = b
        self.k = k

    def __call__(self, data):
        graph, feature, label = data
        ball, mapping = get_random_balls(graph, self.r, self.b, self.k)
        ball_feature = torch.Tensor(feature[mapping]).float()
        ball_adj = nx.to_scipy_sparse_matrix(ball, format='coo')
        return ball_adj, ball_feature, label


class BallSampler(object):
    """Randomly sample a subgraph.

    Args:
        r: (int) Maximum radius of subgraph.
    """
    def __init__(self, r):
        assert r >= 0, "Radius needs to be non-negative."
        self.r = r

    def __call__(self, data):
        """Process a single instance of data.

        Args:
            data: (tuple) - graph (nx.Graph), feature (np.array), label (int)
        """
        graph, feature, label = data 
        v = np.random.randint(0, graph.number_of_nodes())
        ball, mapping = get_ball(graph, v, self.r)
        ball_feature = torch.Tensor(feature[mapping]).float()
        label = torch.Tensor(label)
        return ball, ball_feature, label


class AdjSampler(object):
    """Randomly sample a subgraph and return a sparse adjacency matrix.

    Args:
        r: (int) Maximum radius of subgraph.
    """
    def __init__(self, r):
        assert r >= 0, "Radius needs to be non-negative."
        self.r = r

    def __call__(self, data):
        graph, feature, label = data
        v = np.random.randint(0, graph.number_of_nodes())
        ball, mapping = get_ball(graph, v, self.r)
        ball_feature = torch.Tensor(feature[mapping]).float()
        ball_adj = nx.to_scipy_sparse_matrix(ball, format='coo')
        return ball_adj, ball_feature, label


class MultiAdjSampler(object):
    """Sample multiple subgraphs and return a sparse block adj matrix.

    Args:
        r: (int) Maximum radius of subgraph.
        K: (int) Number of subgraph to sample.
    """
    def __init__(self, r, K):
        assert r >= 0, "Radius needs to be non-negative."
        assert K >= 1, "Sample at least one subgraph."
        self.r = r
        self.K = K 
    
    def __call__(self, data):
        graph, feature, label = data
        nodes = np.random.choice(range(graph.number_of_nodes()), self.K)
        balls, mappings = zip(*[get_ball(graph, v, self.r) for v in nodes])
        ball_features = [feature[m] for m in mappings]
        ball_adjs = [nx.to_scipy_sparse_matrix(b, format='coo') for b in balls]
        label = label
        return ball_adjs, ball_features, label
