import os
import scipy.sparse as sp
import numpy as np
import torch_geometric
from torch_geometric.datasets import WebKB
from ppnp.data.sparsegraph import SparseGraph, remove_self_loops
from ppnp.data.io import load_dataset


def save_sparsegraph(filename: str, sg: SparseGraph) -> None:
    np.savez(filename, **sg.to_flat_dict())


def _webkb(root: str, name: str) -> SparseGraph:
    dataset = WebKB(root, name)
    nodes = dataset[0].y.size(-1)
    print(dataset[0])
    row_idx, col_idx = dataset[0].edge_index
    adj_matrix = sp.csr_matrix(
        (np.ones(len(row_idx)), (row_idx, col_idx)), shape=(nodes, nodes))
    attr_matrix = dataset[0].x.numpy().astype('float32')
    labels = dataset[0].y.numpy().astype('int64')
    sg = remove_self_loops(SparseGraph(adj_matrix, attr_matrix, labels))
    sg_path = os.path.join(root, f'{name.lower()}.npz')
    if not os.path.exists(sg_path):
        save_sparsegraph(sg_path, sg)
    return sg


def cornell(root: str) -> SparseGraph:
    return _webkb(root, 'Cornell')


def texas(root: str) -> SparseGraph:
    return _webkb(root, 'Texas')


def wisconsin(root: str) -> SparseGraph:
    return _webkb(root, 'Wisconsin')


if __name__ == '__main__':
    dataset = load_dataset('cora_ml', './ppnp/data')
    print(dataset, dataset.node_names[0])
    c = cornell('./ppnp/data')
    t = texas('./ppnp/data')
    w = wisconsin('./ppnp/data')
    print(c)
    print(t)
    print(w)

