import numpy as np
import torch
import torch.nn.functional as F
from torch_geometric.data import Data


# 1. Convert StellarGraph data to PyG format
# def convert_stellargraph_to_pyg(stellargraph, node_features, train_indices, test_indices, train_targets, test_targets):
#     # Get edge list
#     edge_list = stellargraph.edges()
#     edge_index = torch.tensor([[stellargraph.node_ids[s], stellargraph.node_ids[t]]
#                                for s, t in edge_list], dtype=torch.long).t()
#
#     # Convert node features to tensor
#     x = torch.tensor(node_features.values, dtype=torch.float)
#
#     # Convert targets to tensors
#     y_train = torch.tensor(train_targets.values, dtype=torch.long)
#     y_test = torch.tensor(test_targets.values, dtype=torch.long)
#
#     # Create masks
#     train_mask = torch.zeros(len(stellargraph.nodes()), dtype=torch.bool)
#     test_mask = torch.zeros(len(stellargraph.nodes()), dtype=torch.bool)
#
#     # Convert node indices to numeric indices
#     train_idx = [stellargraph.node_ids[n] for n in train_indices]
#     test_idx = [stellargraph.node_ids[n] for n in test_indices]
#
#     train_mask[train_idx] = True
#     test_mask[test_idx] = True
#
#     # Create PyG data object
#     data = Data(x=x, edge_index=edge_index)
#     data.train_mask = train_mask
#     data.test_mask = test_mask
#     data.y_train = y_train
#     data.y_test = y_test
#     data.train_idx = torch.tensor(train_idx)
#     data.test_idx = torch.tensor(test_idx)
#
#     return data


def convert_stellargraph_to_pyg(stellargraph, train_subjects, test_subjects, train_targets, test_targets):
    # Get node features
    node_features = stellargraph.node_features()
    x = torch.tensor(node_features, dtype=torch.float)

    edges = stellargraph.edges(use_ilocs=True)
    edge_index_start = [pair[0] for pair in edges]
    edge_index_end = [pair[1] for pair in edges]
    edge_index = torch.tensor([edge_index_start, edge_index_end], dtype=torch.long)

    # Create PyG Data object
    data = Data(x=x, edge_index=edge_index)

    train_ids = train_subjects.keys()
    train_ilocs = stellargraph.node_ids_to_ilocs(train_ids).astype(np.int64)
    test_ids = test_subjects.keys()
    test_ilocs = stellargraph.node_ids_to_ilocs(test_ids).astype(np.int64)

    train_mask = torch.zeros(data.num_nodes, dtype=torch.bool)
    test_mask = torch.zeros(data.num_nodes, dtype=torch.bool)

    train_ilocs_torch = torch.from_numpy(train_ilocs)
    test_ilocs_torch = torch.from_numpy(test_ilocs)

    train_mask[train_ilocs_torch] = True
    test_mask[test_ilocs_torch] = True

    labels = torch.full((data.num_nodes, ), fill_value=-1, dtype=torch.long)

    train_labels = torch.tensor(train_targets).argmax(dim=1)
    test_labels = torch.tensor(test_targets).argmax(dim=1)

    labels[train_ilocs_torch] = train_labels
    labels[test_ilocs_torch] = test_labels

    data.y = labels
    data.train_mask = train_mask
    data.test_mask = test_mask

    return data


def glb_convert_stellargraph_to_pyg(stellargraph, test_ilocs, test_targets):
    # Get node features
    node_features = stellargraph.node_features()
    x = torch.tensor(node_features, dtype=torch.float)

    edges = stellargraph.edges(use_ilocs=True)
    edge_index_start = [pair[0] for pair in edges]
    edge_index_end = [pair[1] for pair in edges]
    edge_index = torch.tensor([edge_index_start, edge_index_end], dtype=torch.long)

    # Create PyG Data object
    data = Data(x=x, edge_index=edge_index)

    test_mask = torch.zeros(data.num_nodes, dtype=torch.bool)

    test_ilocs_torch = torch.from_numpy(test_ilocs.astype(np.int64))

    test_mask[test_ilocs_torch] = True

    labels = torch.full((data.num_nodes, ), fill_value=-1, dtype=torch.long)


    test_labels = torch.tensor(test_targets).argmax(dim=1)

    labels[test_ilocs_torch] = test_labels

    data.y = labels
    data.test_mask = test_mask

    return data
