from ogb.nodeproppred import PygNodePropPredDataset
from ogb.nodeproppred import NodePropPredDataset
import torch_geometric.transforms as T
import torch
import numpy as np
import pickle


class NCDataset(object):
    def __init__(self, name):
        self.name = name
        self.graph = {}
        self.label = None

    def __getitem__(self, idx):
        assert idx == 0, 'This dataset has only one graph'
        return self.graph, self.label

    def __len__(self):
        return 1

    def __repr__(self):
        return '{}({})'.format(self.__class__.__name__, len(self))


debug = 0
if debug:
    dataset = NodePropPredDataset(name='ogbn-products', root="../data")
    split_idx = dataset.get_idx_split()
    train_idx = np.arange(200000, 510000)    # 310000
    valid_idx = np.arange(580000, 650000)   # 70000
    test_idx = np.arange(510000, 580000)    # 70000
    node_feat = dataset.graph['node_feat']
    edge_index = dataset.graph['edge_index']    # (2, 123718280)
    labels = dataset.labels
    row = []
    col = []
    for i in range(edge_index.shape[1]):
        if edge_index[0, i] < edge_index[1, i]:
            row.append(edge_index[0, i])
            col.append(edge_index[1, i])
    edge_index = np.vstack((np.array(row), np.array(col)))  # (2, 61859012)
    train_dataset = NCDataset(-1)
    node_feat_train = node_feat[train_idx, :]
    labels_train = labels[train_idx]

    row_train = []
    col_train = []
    row_valid = []
    col_valid = []
    row_test = []
    col_test = []
    for i in range(edge_index.shape[1]):
        if (i % 1000 == 0) or (i == edge_index.shape[1]-1):
            print("Rebuild edge_index, now in {}.".format(i))
        if (edge_index[0, i] >= train_idx[0] and edge_index[0, i] <= train_idx[-1]) and \
            (edge_index[1, i] >= train_idx[0] and edge_index[1, i] <= train_idx[-1]):
            row_train.append(edge_index[0, i]-train_idx[0])
            col_train.append(edge_index[1, i]-train_idx[0])
        elif (edge_index[0, i] >= valid_idx[0] and edge_index[0, i] <= valid_idx[-1]) and \
            (edge_index[1, i] >= valid_idx[0] and edge_index[1, i] <= valid_idx[-1]):
            row_valid.append(edge_index[0, i]-valid_idx[0])
            col_valid.append(edge_index[1, i]-valid_idx[0])
        elif (edge_index[0, i] >= test_idx[0] and edge_index[0, i] <= test_idx[-1]) and \
            (edge_index[1, i] >= test_idx[0] and edge_index[1, i] <= test_idx[-1]):
            row_test.append(edge_index[0, i] - test_idx[0])
            col_test.append(edge_index[1, i] - test_idx[0])

    edge_index_train = np.vstack((np.array(row_train), np.array(col_train)))
    edge_index_train = torch.tensor(edge_index_train, dtype=torch.long)
    node_feat_train = torch.tensor(node_feat_train, dtype=torch.float)
    num_nodes_train = node_feat_train.shape[0]
    train_dataset.graph = {'edge_index': edge_index_train,
                           'edge_feat': None,
                           'node_feat': node_feat_train,
                           'num_nodes': num_nodes_train}
    train_dataset.label = torch.tensor(labels_train)

    if len(train_dataset.label.shape) == 1:
        train_dataset.label = train_dataset.label.unsqueeze(1)
    train_dataset.n = train_dataset.graph['num_nodes']
    train_dataset.c = max(train_dataset.label.max().item() + 1, train_dataset.label.shape[1])
    train_dataset.d = train_dataset.graph['node_feat'].shape[1]
    print(edge_index_train.shape)
    print(f"Train num nodes {train_dataset.n} | num classes {train_dataset.c} | num node feats {train_dataset.d}")
    with open("train_dataset.pkl", 'wb') as f:
        f.write(pickle.dumps(train_dataset))
    print("Save train_dataset.pkl Successful.")

    valid_dataset = NCDataset(-1)
    node_feat_valid = node_feat[valid_idx, :]
    labels_valid = labels[valid_idx]
    edge_index_valid = np.vstack((np.array(row_valid), np.array(col_valid)))
    edge_index_valid = torch.tensor(edge_index_valid, dtype=torch.long)
    node_feat_valid = torch.tensor(node_feat_valid, dtype=torch.float)
    num_nodes_valid = node_feat_valid.shape[0]
    valid_dataset.graph = {'edge_index': edge_index_valid,
                           'edge_feat': None,
                           'node_feat': node_feat_valid,
                           'num_nodes': num_nodes_valid}
    valid_dataset.label = torch.tensor(labels_valid)

    if len(valid_dataset.label.shape) == 1:
        valid_dataset.label = valid_dataset.label.unsqueeze(1)
    valid_dataset.n = valid_dataset.graph['num_nodes']
    valid_dataset.c = max(valid_dataset.label.max().item() + 1, valid_dataset.label.shape[1])
    valid_dataset.d = valid_dataset.graph['node_feat'].shape[1]
    print(edge_index_valid.shape)
    print(f"Valid num nodes {valid_dataset.n} | num classes {valid_dataset.c} | num node feats {valid_dataset.d}")
    with open("valid_dataset.pkl", 'wb') as f:
        f.write(pickle.dumps(valid_dataset))
    print("Save valid_dataset.pkl Successful.")

    test_dataset = NCDataset(-1)
    node_feat_test = node_feat[test_idx, :]
    labels_test = labels[test_idx]
    edge_index_test = np.vstack((np.array(row_test), np.array(col_test)))
    edge_index_test = torch.tensor(edge_index_test, dtype=torch.long)
    node_feat_test = torch.tensor(node_feat_test, dtype=torch.float)
    num_nodes_test = node_feat_test.shape[0]
    test_dataset.graph = {'edge_index': edge_index_test,
                          'edge_feat': None,
                          'node_feat': node_feat_test,
                          'num_nodes': num_nodes_test}
    test_dataset.label = torch.tensor(labels_test)

    if len(test_dataset.label.shape) == 1:
        test_dataset.label = test_dataset.label.unsqueeze(1)
    test_dataset.n = test_dataset.graph['num_nodes']
    test_dataset.c = max(test_dataset.label.max().item() + 1, test_dataset.label.shape[1])
    test_dataset.d = test_dataset.graph['node_feat'].shape[1]
    print(edge_index_test.shape)
    print(f"Test num nodes {test_dataset.n} | num classes {test_dataset.c} | num node feats {test_dataset.d}")
    with open("test_dataset.pkl", 'wb') as f:
        f.write(pickle.dumps(test_dataset))
    print("Save test_dataset.pkl Successful.")

else:
    with open("GraphOOD/sales_products/train_dataset.pkl", 'rb') as f:
        dataset_train = pickle.loads(f.read())
    dataset_train.c = 47
    with open("GraphOOD/sales_products/valid_dataset.pkl", 'rb') as f:
        dataset_valid = pickle.loads(f.read())
    dataset_valid.c = 47
    with open("GraphOOD/sales_products/test_dataset.pkl", 'rb') as f:
        dataset_test = pickle.loads(f.read())
    dataset_test.c = 47
    datasets_train = [dataset_train]
    datasets_valid = [dataset_valid]
    datasets_test = [dataset_test]
