from typing import Optional, Callable

import torch
from torch_geometric.data import InMemoryDataset, Data
from torch_geometric.utils import barabasi_albert_graph


def house():
    edge_index = torch.tensor([[0, 0, 0, 1, 1, 1, 2, 2, 3, 3, 4, 4],
                               [1, 3, 4, 4, 2, 0, 1, 3, 2, 0, 0, 1]])
    label = torch.tensor([1, 1, 2, 2, 3])
    return edge_index, label

class BAShapes(InMemoryDataset):
    r"""The BA-Shapes dataset from the `"GNNExplainer: Generating Explanations
    for Graph Neural Networks" <https://arxiv.org/pdf/1903.03894.pdf>`_ paper,
    containing a Barabasi-Albert (BA) graph with 300 nodes and a set of 80
    "house"-structured graphs connected to it.

    Args:
        connection_distribution (string, optional): Specifies how the houses
            and the BA graph get connected. Valid inputs are :obj:`"random"`
            (random BA graph nodes are selected for connection to the houses),
            and :obj:`"uniform"` (uniformly distributed BA graph nodes are
            selected for connection to the houses). (default: :obj:`"random"`)
        transform (callable, optional): A function/transform that takes in an
            :obj:`torch_geometric.data.Data` object and returns a transformed
            version. The data object will be transformed before every access.
            (default: :obj:`None`)
    """
    def __init__(self, connection_distribution: str = "random",
                 transform: Optional[Callable] = None):
        super().__init__('.', transform)
        assert connection_distribution in ['random', 'uniform']

        # Build the Barabasi-Albert graph:
        num_nodes = 300
        edge_index = barabasi_albert_graph(num_nodes, num_edges=5)
        edge_label = torch.zeros(edge_index.size(1), dtype=torch.int64)
        node_label = torch.zeros(num_nodes, dtype=torch.int64)

        # Select nodes to connect shapes:
        num_houses = 80
        if connection_distribution == 'random':
            connecting_nodes = torch.randperm(num_nodes)[:num_houses]
        else:
            step = num_nodes // num_houses
            connecting_nodes = torch.arange(0, num_nodes, step)

        # Connect houses to Barabasi-Albert graph:
        edge_indices = [edge_index]
        edge_labels = [edge_label]
        node_labels = [node_label]
        for i in range(num_houses):
            house_edge_index, house_label = house()

            edge_indices.append(house_edge_index + num_nodes)
            edge_indices.append(
                torch.tensor([[int(connecting_nodes[i]), num_nodes],
                              [num_nodes, int(connecting_nodes[i])]]))

            edge_labels.append(
                torch.ones(house_edge_index.size(1), dtype=torch.long))
            edge_labels.append(torch.zeros(2, dtype=torch.long))

            node_labels.append(house_label)

            num_nodes += 5

        edge_index = torch.cat(edge_indices, dim=1)
        edge_label = torch.cat(edge_labels, dim=0)
        node_label = torch.cat(node_labels, dim=0)

        x = torch.ones((num_nodes, 10), dtype=torch.float)
        expl_mask = torch.zeros(num_nodes, dtype=torch.bool)
        expl_mask[torch.arange(400, num_nodes, 5)] = True

        data = Data(x=x, edge_index=edge_index, y=node_label,
                    expl_mask=expl_mask, edge_label=edge_label)

        self.data, self.slices = self.collate([data])