

import torch
import networkx as nx
from torch_geometric.data import Data
import torch_geometric.utils as pyg_utils
from torch_geometric.datasets import Planetoid, WikiCS

import torch_geometric.transforms as pyg_transforms


def CSBM(sizes, num_features, p_intra, p_inter, std_dev=1,
         unbalance=1, empty_graph=False):
    """
    Contextual SBM data with two communities.

    Parameters
    ----------
    sizes : list of ints of length 2
        Number of nodes per block.
    num_features : int
        Dimension of features.
    p_intra : float between 0 and 1.
        Probability of connection intra class.
    p_inter : float between 0 and 1.
        Probability of connection inter class.
    std_dev : float, optional
        Std of Gaussian distrib. The default is 1.
    empty_graph : bool, optional
        If True return a graph with no edges. The default is False.


    """

    W = [[p_intra, p_inter], [p_inter, p_intra]]
    
    if empty_graph:
        W = [[0,0], [0,0]]

    graph = nx.stochastic_block_model(sizes, W)

    # Initialize node features tensor
    num_nodes = len(graph.nodes())
    node_features = torch.zeros(num_nodes, num_features)

    # Define means and standard deviations for Gaussian distributions
    mean_class1 = torch.tensor([-1.0]*num_features)
    std_dev_class1 = torch.tensor([std_dev]*num_features)
    mean_class2 = torch.tensor([1.0]*num_features)
    std_dev_class2 = std_dev_class1

    # Assign features to nodes based on their class
    for node in graph.nodes():
        if graph.nodes[node]['block'] == 0:
            node_features[node] = torch.normal(mean_class1, std_dev_class1)
        else:
            node_features[node] = torch.normal(mean_class2, std_dev_class2)

    # Convert networkx graph to PyTorch Geometric Data object
    edge_index = torch.tensor(list(graph.edges)).t().contiguous()
    edge_index = pyg_utils.to_undirected(edge_index)
    data = Data(x=node_features, edge_index=edge_index)
    data.train_mask = [True for _ in range(num_nodes)]
    data.y = torch.tensor([graph.nodes[node]['block'] for node in graph.nodes()])
    data.num_features = num_features
    data.num_classes = 2

    return data

def load_data(data_name, sizes=[200,200], num_features=2,
            p_intra = .03, p_inter = .005, std_dev = .25):
    # data_name = 'WikiCS' # 'Cora', 'Pubmed', 'Citeseer', 'synthetic'
    real_data = (data_name != 'synthetic')
    if real_data:
        if data_name == 'WikiCS':
            dataset = WikiCS(root='.')
        else:
            dataset = Planetoid(root='.', name=data_name)
        num_classes = dataset.num_classes
        num_features = dataset.num_features
        data = dataset[0]
    else:
        data = CSBM(sizes=sizes, num_features=num_features,
                    p_intra = p_intra, p_inter = p_inter, std_dev = std_dev)
        num_classes=2
    # transform: undirected, largest connected component
    data.edge_index = pyg_utils.to_undirected(data.edge_index)
    data = pyg_transforms.LargestConnectedComponents()(data)
    return data, num_classes, num_features

