import numpy as np
import pickle as pkl
import scipy.sparse as sp
import torch

def load_data(args):
    dataset = args.dataset
    relationships = args.relationships_list
    self_conv = args.self_conv

    data = pkl.load(open('data/{}.pkl'.format(dataset), "rb"))
    label = data['label']
    N = label.shape[0]

    features = data['feature'].astype(float)
    rownetworks = [data[relationship] + np.eye(N) * self_conv for relationship in relationships]

    rownetworks = [sp.csr_matrix(rownetwork) for rownetwork in rownetworks]

    features = sp.lil_matrix(features)

    idx_train = data['train_idx'].ravel()
    idx_test = data['test_idx'].ravel()

    features_li = []
    for _ in range(len(rownetworks)):
        features_li.append(features)

    return rownetworks, features_li, label, idx_train, idx_test


def preprocess_features(features):
    """Row-normalize feature matrix and convert to tuple representation"""
    rowsum = np.array(features.sum(1))
    r_inv = np.power(rowsum, -1).flatten()
    r_inv[np.isinf(r_inv)] = 0.
    r_mat_inv = sp.diags(r_inv)
    features = r_mat_inv.dot(features)
    return features.todense()

def normalize_adj(adj):
    """Symmetrically normalize adjacency matrix."""
    adj = sp.coo_matrix(adj)
    rowsum = np.array(adj.sum(1))
    d_inv_sqrt = np.power(rowsum, -0.5).flatten()
    d_inv_sqrt[np.isinf(d_inv_sqrt)] = 0.
    d_mat_inv_sqrt = sp.diags(d_inv_sqrt)
    return adj.dot(d_mat_inv_sqrt).transpose().dot(d_mat_inv_sqrt).tocoo()

def sparse_mx_to_torch_sparse_tensor(sparse_mx):
    """Convert a scipy sparse matrix to a torch sparse tensor."""
    sparse_mx = sparse_mx.tocoo().astype(np.float32)
    indices = torch.from_numpy(
        np.vstack((sparse_mx.row, sparse_mx.col)).astype(np.int64))
    values = torch.from_numpy(sparse_mx.data)
    shape = torch.Size(sparse_mx.shape)
    return torch.sparse.FloatTensor(indices, values, shape)