import torch
import networkx as nx

from torch import LongTensor
from torch_sparse import SparseTensor
from torch_geometric.data import Data
from torch_geometric.datasets import CitationFull
from torch_geometric.utils import is_undirected, subgraph, to_networkx
from typing import Optional, Tuple

from baselines.NetGAN.netgan.utils import train_val_test_split_adjacency


def load_data(path: str, dataset_name: str) -> Data:
    """
    Return a PyTorch Geometric Data object representing a single graph from a one-graph dataset.
    In case of 'Cora_ML', 'Cora', 'DBLP' and 'CiteSeer', the largest connected component is selected.

    Args:
        path: str
            Path to the folder that contains the data.
        dataset_name: str
            Either 'Cora_ML', 'Cora', 'CiteSeer', 'DBLP' or 'PubMed'.

    Returns:
        data: torch_geometric.data.Data
            A PyTorch Geometric Data object representing a graph.
    """
    # Initialize the dataset
    dataset = CitationFull(root=path, name=dataset_name)

    # Access the single graph inside the dataset
    data = dataset[0]

    if dataset_name in ['Cora_ML', 'Cora', 'CiteSeer', 'DBLP']:
        # Select the largest connected component of the original graph

        # Check if the graph is undirected
        undirected = is_undirected(data.edge_index, num_nodes=data.num_nodes)

        # Convert the graph to a nx.Graph or nx.DiGraph (if the graph is connected)
        graph = to_networkx(data, to_undirected=undirected)

        # Compute a generator holding sets with the node indices for all connected components
        connected_components = nx.connected_components(graph)

        # Compute the largest connected component
        largest_connected_component = torch.tensor(list(max(connected_components, key=len)))

        # Extract the induced subgraph by the nodes with the indices in largest_connected_component
        edge_index, _ = subgraph(largest_connected_component, data.edge_index, relabel_nodes=True,
                                 num_nodes=data.num_nodes)

        # Creat a new PyTorch Geometric Data object holding the edges of the largest connected component of the
        # original graph
        data = Data(x=data.x[largest_connected_component], edge_index=edge_index, y=data.y[largest_connected_component],
                    num_nodes=largest_connected_component.size(0))

    return data


def train_val_test_split_graph(data: Data, seed: Optional[int] = None) \
        -> Tuple[LongTensor, LongTensor, LongTensor, LongTensor, LongTensor]:
    """
    Split a graph into training, validation and test edges, and validation and test non-edges.

    Args:
        data: torch_geometric.data.Data
            Data object representing a graph.
        seed: int (optional, default: None)
            Seed for reproducible splitting.

    Returns:
        train_edge_index: torch.LongTensor, shape: (2, num_train_edges)
            Training edges.
        val_edge_index: torch.LongTensor, shape: (2, num_val_edges)
            Validation edges.
        val_non_edge_index: torch.LongTensor, shape: (2, num_val_non_edges)
            Validation non-edges.
        test_edge_index: torch.LongTensor, shape: (2, num_test_edges)
            Test edges.
        test_non_edge_index: torch.LongTensor, shape: (2, num_test_non_edges)
            Test non-edges.
    """
    # Convert data to a scipy.sparse.csr_matrix representing the adjacency matrix of the graph
    A = SparseTensor(row=data.edge_index[0], col=data.edge_index[1],
                     sparse_sizes=(data.num_nodes, data.num_nodes)).to_scipy(layout='csr')

    # Perform the split of the adjacency matrix into training, validation and test edges,
    # and validation and test non-edges
    train_ones, val_ones, val_zeros, test_ones, test_zeros = train_val_test_split_adjacency(A, p_val=0.10, p_test=0.05,
                                                                                            seed=seed, neg_mul=1,
                                                                                            every_node=True,
                                                                                            connected=True,
                                                                                            undirected=True,
                                                                                            use_edge_cover=True,
                                                                                            set_ops=True,
                                                                                            asserts=True)

    # Convert train_ones, val_ones, val_zeros, test_ones and test_zeros to torch.LongTensors
    train_edge_index = torch.from_numpy(train_ones.T.astype(int))
    val_edge_index = torch.from_numpy(val_ones.T.astype(int))
    val_non_edge_index = torch.from_numpy(val_zeros.T.astype(int))
    test_edge_index = torch.from_numpy(test_ones.T.astype(int))
    test_non_edge_index = torch.from_numpy(test_zeros.T.astype(int))

    return train_edge_index, val_edge_index, val_non_edge_index, test_edge_index, test_non_edge_index
