from collections import defaultdict
import numpy as np
import torch
import torch.nn.functional as F
import scipy
import scipy.io
from sklearn.preprocessing import label_binarize
from ogb.nodeproppred import NodePropPredDataset
from os import path
import pickle as pkl


class NCDataset(object):
    def __init__(self, name):
        self.name = name
        self.graph = {}
        self.label = None

    def __getitem__(self, idx):
        assert idx == 0, 'This dataset has only one graph'
        return self.graph, self.label

    def __len__(self):
        return 1

    def __repr__(self):
        return '{}({})'.format(self.__class__.__name__, len(self))


def load_nc_dataset(data_dir, dataname, sub_dataname=''):
    if dataname == 'elliptic':
        if sub_dataname not in range(0, 49):
            print('Invalid sub_dataname, deferring to graph1')
            sub_dataname = 0
        dataset = load_elliptic_dataset(data_dir, sub_dataname)
    else:
        raise ValueError('Invalid dataname')
    return dataset


def load_elliptic_dataset(data_dir, lang):
    assert lang in range(0, 49), 'Invalid dataset'
    result = pkl.load(open('{}/elliptic/{}.pkl'.format(data_dir, lang), 'rb'))
    A, label, features = result
    dataset = NCDataset(lang)
    edge_index = torch.tensor(A.nonzero(), dtype=torch.long)
    node_feat = torch.tensor(features, dtype=torch.float)
    num_nodes = node_feat.shape[0]
    dataset.graph = {'edge_index': edge_index,
                     'edge_feat': None,
                     'node_feat': node_feat,
                     'num_nodes': num_nodes}
    dataset.label = torch.tensor(label)
    dataset.mask = (dataset.label >= 0)
    return dataset

