import torch
from torch_geometric.utils import to_networkx


def index_to_mask(index, size, device='cpu'):
    mask = torch.zeros(size, dtype=torch.bool, device=device)
    mask[index] = 1
    return mask


def random_planetoid_splits(data,
                            num_classes,
                            percls_trn=20,
                            val_lb=500,
                            Flag=0):

    indices = []
    for i in range(num_classes):
        index = (data.y == i).nonzero().view(-1)
        index = index[torch.randperm(index.size(0))]
        indices.append(index)

    train_index = torch.cat([i[:percls_trn] for i in indices], dim=0)

    if Flag == 0:
        rest_index = torch.cat([i[percls_trn:] for i in indices], dim=0)
        rest_index = rest_index[torch.randperm(rest_index.size(0))]

        data.train_mask = index_to_mask(train_index, size=data.num_nodes)
        data.val_mask = index_to_mask(rest_index[:val_lb], size=data.num_nodes)
        data.test_mask = index_to_mask(rest_index[val_lb:],
                                       size=data.num_nodes)
    else:
        val_index = torch.cat(
            [i[percls_trn:percls_trn + val_lb] for i in indices], dim=0)
        rest_index = torch.cat([i[percls_trn + val_lb:] for i in indices],
                               dim=0)
        rest_index = rest_index[torch.randperm(rest_index.size(0))]

        data.train_mask = index_to_mask(train_index, size=data.num_nodes)
        data.val_mask = index_to_mask(val_index, size=data.num_nodes)
        data.test_mask = index_to_mask(rest_index, size=data.num_nodes)
    return data


def get_maxDegree(graphs):
    maxdegree = 0
    for i, graph in enumerate(graphs):
        g = to_networkx(graph, to_undirected=True)
        gdegree = max(dict(g.degree).values())
        if gdegree > maxdegree:
            maxdegree = gdegree
    return maxdegree
