import sys, os
import pickle
import torch
import torch_geometric
import torch.nn.functional as F

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

def load_from_pickle(filename):
    with open(filename, 'rb') as f:
        data = pickle.load(f)
    return data

def only_train_nodes(nodes, train_mask):
    train_nodes_mask = train_mask[nodes]
    return nodes[train_nodes_mask]
 
def get_test_neighborhood_labels(data, 
                                 combination_type='mode', 
                                 ignore_self_loop=False,
                                 ignore_test_nodes=False):

    test_node_indices = torch.arange(data.x.shape[0]).to(device)[data.test_mask]
    test_label_val = []
    for i in range(len(test_node_indices)):
        node_index = test_node_indices[i]
        neighbors, _, _, _ = torch_geometric.utils.k_hop_subgraph(
            node_index.item(), num_hops=1, edge_index=data.edge_index)
        if ignore_self_loop:
            neighbors = neighbors[neighbors!=node_index]
        if ignore_test_nodes:
            neighbors = only_train_nodes(neighbors, data.train_mask)

        if combination_type == 'mean':
            neighbors_labels = data.y[neighbors].type(torch.FloatTensor).mean()
            if neighbors_labels > 0.5:
                test_label_val.append(torch.tensor([1]).to(device))
            else:
                test_label_val.append(torch.tensor([0]).to(device))

        elif combination_type == 'mode':
            val, _ = data.y[neighbors].mode()
            test_label_val.append(val)

        elif combination_type == 'mean_prob':
            neighbor_labels = data.y[neighbors]
            unique, label_count = torch.unique(neighbor_labels, sorted=True, return_counts=True)
            if len(unique) == 1:
                concat_tensor = torch.tensor([0]).to(device)
                if unique.item() == 0:
                    label_count = torch.concatenate([label_count, concat_tensor])
                else:
                    label_count = torch.concatenate([concat_tensor, label_count]) 
            val = label_count / label_count.sum()
            test_label_val.append(val)
    return torch.stack(test_label_val)

def get_test_nodes_w_test_neighbors(data):
    test_node_indices = torch.arange(data.x.shape[0]).to(device)[data.test_mask]
    test_label_val = torch.zeros_like(test_node_indices)
    empty_indices = []
    indices_w_neighbors = []
    for i in range(len(test_node_indices)):
        node_index = test_node_indices[i]
        neighbors, _, _, _ = torch_geometric.utils.k_hop_subgraph(
            node_index.item(), num_hops=1, edge_index=data.edge_index)
        neighbors = neighbors[neighbors!=node_index]
        neighbors = only_train_nodes(neighbors, data.train_mask)
        if len(neighbors) == 0:
            empty_indices.append(node_index)
        else:
            indices_w_neighbors.append(node_index)
    return empty_indices, indices_w_neighbors

def get_kld(out, data, indices):
    logsoftmax_out = F.log_softmax(out)[indices]
    data.marginals = torch.tensor(data.marginals).type(torch.FloatTensor).to(device)
    true_marginals = data.marginals[indices]
    kl_loss = true_marginals * (torch.log(true_marginals) - logsoftmax_out)
    kl_loss = kl_loss.sum(dim=1).mean()
    return kl_loss

