import os.path as osp
import numpy as np
import scipy.sparse as sp
from torch_geometric.datasets import \
    WikipediaNetwork, LINKXDataset, Actor, WebKB, Amazon
from data.io import StandardSplits
from data.sparsegraph import SparseGraph


def pyg_to_sparsegraph(dataset) -> SparseGraph:
    nodes = len(dataset.y)
    row_idx, col_idx = dataset.edge_index
    adj_csr = sp.coo_matrix(
        (np.ones_like(row_idx, dtype=np.float32), (row_idx, col_idx)),
        shape=(nodes, nodes)).tocsr()
    return SparseGraph(
        adj_matrix=adj_csr,
        attr_matrix=np.asarray(dataset.x),
        labels=np.asarray(dataset.y))


def prepare_pyg_dataset(
        dataset, save_path='./data', dataset_name='actor'
) -> None:
    graph = pyg_to_sparsegraph(dataset)
    np.savez(osp.join(save_path, dataset_name+'.npz'), **graph.to_flat_dict())


def prepare_pyg_dataset_with_splits(
        dataset, save_path='./data', dataset_name='actor', splits: int = 10
) -> None:
    prepare_pyg_dataset(dataset, save_path, dataset_name)
    splits = StandardSplits.from_masks(
        dataset.train_mask, dataset.val_mask, dataset.test_mask, splits)
    # save standard splits
    np.savez(
        osp.join(save_path, dataset_name+'_splits.npz'), **splits.to_dict())


if __name__ == '__main__':
    pyg_dataset_dir = './pyg_datasets'
    save_dir = './data'

    print('- processing cornell dataset (10 splits) ...')
    prepare_pyg_dataset_with_splits(
        WebKB(pyg_dataset_dir, 'Cornell'), save_dir, 'cornell')

    print('- processing texas dataset (10 splits) ...')
    prepare_pyg_dataset_with_splits(
        WebKB(pyg_dataset_dir, 'Texas'), save_dir, 'texas')

    print('- processing wisconsin dataset (10 splits) ...')
    prepare_pyg_dataset_with_splits(
        WebKB(pyg_dataset_dir, 'Wisconsin'), save_dir, 'wisconsin')

    print('- processing actor dataset (10 splits) ...')
    prepare_pyg_dataset_with_splits(Actor(pyg_dataset_dir), save_dir, 'actor')

    print('- processing squirrel dataset (10 splits) ...')
    prepare_pyg_dataset_with_splits(
        WikipediaNetwork(
            pyg_dataset_dir, 'squirrel', geom_gcn_preprocess=True),
        save_dir,
        'squirrel')

   

    print('- preparation complete.')
