import torch
import torch_geometric.transforms as T
from torch_geometric.datasets import Amazon, Planetoid
from ogb.nodeproppred import NodePropPredDataset

import numpy as np
from os import path


class NCDataset(object):
    def __init__(self, name):
        """
        based off of ogb NodePropPredDataset
        https://github.com/snap-stanford/ogb/blob/master/ogb/nodeproppred/dataset.py
        Gives torch tensors instead of numpy arrays
            - name (str): name of the dataset
            - root (str): root directory to store the dataset folder
            - meta_dict: dictionary that stores all the meta-information about data. Default is None,
                    but when something is passed, it uses its information. Useful for debugging for external contributers.

        Usage after construction:

        split_idx = dataset.get_idx_split()
        train_idx, valid_idx, test_idx = split_idx["train"], split_idx["valid"], split_idx["test"]
        graph, label = dataset[0]

        Where the graph is a dictionary of the following form:
        dataset.graph = {'edge_index': edge_index,
                         'edge_feat': None,
                         'node_feat': node_feat,
                         'num_nodes': num_nodes}
        For additional documentation, see OGB Library-Agnostic Loader https://ogb.stanford.edu/docs/nodeprop/

        """

        self.name = name  # original name, e.g., ogbn-proteins
        self.graph = {}
        self.label = None

    def get_idx_split(self, split_type='random', train_prop=.6, valid_prop=.2, label_num_per_class=20):
        """
        train_prop: The proportion of dataset for train split. Between 0 and 1.
        valid_prop: The proportion of dataset for validation split. Between 0 and 1.
        """
        split_idx = None
        return split_idx

    def __getitem__(self, idx):
        assert idx == 0, 'This dataset has only one graph'
        return self.graph, self.label

    def __len__(self):
        return 1

    def __repr__(self):
        return '{}({})'.format(self.__class__.__name__, len(self))


def load_dataset(data_dir, dataname, sub_dataname=''):
    """ Loader for NCDataset
        Returns NCDataset
    """
    # print(dataname)
    if dataname in ('cora', 'citeseer', 'pubmed'):
        dataset = load_planetoid_dataset(data_dir, dataname)
    elif dataname in  ('amazon-photo', 'amazon-computer'):
        dataset = load_amazon_dataset(data_dir, dataname)
    elif dataname in ('roman-empire'):
        dataset = load_hetero_dataset(data_dir, dataname)
    elif dataname in ('ogbn-arxiv', 'ogbn-products'):
        dataset = load_ogb_dataset(data_dir, dataname)
    elif dataname in ('chameleon', 'squirrel'):
        dataset = load_wiki_new(data_dir, dataname)
    else:
        raise ValueError('Invalid dataname')
    return dataset

def load_planetoid_dataset(data_dir, name):
    p = path.join(data_dir, 'Planetoid')
    data = Planetoid(root=p, name=name)[0]

    dataset = NCDataset(name)
    dataset.graph = {'edge_index': data.edge_index,
                     'node_feat': data.x,
                     'edge_feat': None,
                     'num_nodes': data.num_nodes}
    dataset.label = data.y
    dataset.split = {}
    dataset.split["train"] = torch.nonzero(data.train_mask).squeeze(-1)
    dataset.split["valid"] = torch.nonzero(data.val_mask).squeeze(-1)
    dataset.split["test"] = torch.nonzero(data.test_mask).squeeze(-1)
    
    
    return dataset

def load_wiki_new(data_dir, name):
    path= f'{data_dir}/Hetero/{name}_filtered.npz'
    data=np.load(path)
    # lst=data.files
    # for item in lst:
    #     print(item)
    node_feat=data['node_features'] # unnormalized
    labels=data['node_labels']
    edges=data['edges'] #(E, 2)
    edge_index=edges.T

    dataset = NCDataset(name)

    edge_index=torch.as_tensor(edge_index)
    node_feat=torch.as_tensor(node_feat)
    labels=torch.as_tensor(labels)

    dataset.graph = {'edge_index': edge_index,
                     'node_feat': node_feat,
                     'edge_feat': None,
                     'num_nodes': node_feat.shape[0]}
    dataset.label = labels

    return dataset

def load_hetero_dataset(data_dir, name):
    
    path= f'{data_dir}/Hetero/{name}.npz'
    data=np.load(path)
    # lst=data.files
    # for item in lst:
    #     print(item)
    node_feat=data['node_features'] # unnormalized
    labels=data['node_labels']
    edges=data['edges'] #(E, 2)
    edge_index=edges.T

    dataset = NCDataset(name)

    edge_index=torch.as_tensor(edge_index)
    node_feat=torch.as_tensor(node_feat)
    labels=torch.as_tensor(labels)

    dataset.graph = {'edge_index': edge_index,
                     'node_feat': node_feat,
                     'edge_feat': None,
                     'num_nodes': node_feat.shape[0]}
    dataset.label = labels

    return dataset


def load_amazon_dataset(data_dir, name):
    p = path.join(data_dir, 'Amazon')
    transform = T.NormalizeFeatures()
    if name == 'amazon-photo':
        torch_dataset = Amazon(root=p,
                                 name='Photo', transform=transform)
    elif name == 'amazon-computer':
        torch_dataset = Amazon(root=p,
                                 name='Computers', transform=transform)
    data = torch_dataset[0]

    edge_index = data.edge_index
    node_feat = data.x
    label = data.y
    num_nodes = data.num_nodes

    dataset = NCDataset(name)

    dataset.graph = {'edge_index': edge_index,
                     'node_feat': node_feat,
                     'edge_feat': None,
                     'num_nodes': num_nodes}
    dataset.label = label
    

    return dataset

def load_ogb_dataset(data_dir, name):
    dataset = NCDataset(name)
    ogb_dataset = NodePropPredDataset(name=name, root=path.join(data_dir, 'ogb'))
    dataset.graph = ogb_dataset.graph
    dataset.graph['edge_index'] = torch.as_tensor(dataset.graph['edge_index'])
    dataset.graph['node_feat'] = torch.as_tensor(dataset.graph['node_feat'])

    split_idx = ogb_dataset.get_idx_split()
    dataset.split = {}
    dataset.split["train"] = torch.tensor(split_idx['train'])
    dataset.split["valid"] = torch.tensor(split_idx['valid'])
    dataset.split["test"] = torch.tensor(split_idx['test'])
    
    dataset.label = torch.as_tensor(ogb_dataset.labels).reshape(-1, 1)
    return dataset