# Copyright (c) Microsoft Corporation.
# The file is modified based on the original Graphormer's source code.
# Copyright (c) 2022 Tianyu Wen
# Licensed under the MIT License.


import numpy as np
import pyximport
import torch
import torch_geometric.datasets
from torch_geometric.data import InMemoryDataset
from torch_geometric.datasets import MoleculeNet, QM9, TUDataset

pyximport.install(setup_args={'include_dirs': np.get_include()})
import algos


def convert_to_single_emb(x, offset=512):
    feature_num = x.size(1) if len(x.size()) > 1 else 1
    feature_offset = 1 + \
                     torch.arange(0, feature_num * offset, offset, dtype=torch.long)
    x = x + feature_offset
    return x


def preprocess_item(item):
    edge_attr, edge_index, x = item.edge_attr, item.edge_index, item.x
    N = x.size(0)
    x = convert_to_single_emb(x)

    # node adj matrix [N, N] bool
    adj = torch.zeros([N, N], dtype=torch.bool)
    adj[edge_index[0, :], edge_index[1, :]] = True

    # edge feature here
    if len(edge_attr.size()) == 1:
        edge_attr = edge_attr[:, None]
    attn_edge_type = torch.zeros([N, N, edge_attr.size(-1)], dtype=torch.long)
    attn_edge_type[edge_index[0, :], edge_index[1, :]] = (convert_to_single_emb(edge_attr) + 1).long()

    # structural encoding SPIS here
    shortest_path_result, path, V, E = algos.SPIS(adj.numpy(), np.eye(N, dtype=np.int64),
                                                  np.zeros((N, N), dtype=np.int64),
                                                  np.zeros(N, dtype=np.int64))
    max_dist = np.amax(shortest_path_result)
    edge_input = algos.gen_edge_input(max_dist, path, attn_edge_type.numpy())
    spatial_pos = torch.from_numpy(shortest_path_result).long()
    V_pos = torch.from_numpy(V).long()
    E_pos = torch.from_numpy(E).long()
    attn_bias = torch.zeros([N + 1, N + 1], dtype=torch.float)  # with graph token

    # combine
    item.x = x
    item.adj = adj
    item.attn_bias = attn_bias
    item.attn_edge_type = attn_edge_type
    item.spatial_pos = spatial_pos
    item.V_pos = V_pos
    item.E_pos = E_pos
    item.in_degree = adj.long().sum(dim=1).view(-1)
    item.out_degree = adj.long().sum(dim=0).view(-1)
    item.edge_input = torch.from_numpy(edge_input).long()

    return item


class MyMoleculeNetDataset(MoleculeNet):
    def download(self):
        super(MyMoleculeNetDataset, self).download()

    def process(self):
        super(MyMoleculeNetDataset, self).process()

    def __getitem__(self, idx):
        if isinstance(idx, int):
            item = self.get(self.indices()[idx])
            item.idx = idx
            return preprocess_item(item)
        else:
            return self.index_select(idx)


class MyZINCDataset(torch_geometric.datasets.ZINC):
    def download(self):
        super(MyZINCDataset, self).download()

    def process(self):
        super(MyZINCDataset, self).process()

    def __getitem__(self, idx):
        if isinstance(idx, int):
            item = self.get(self.indices()[idx])
            item.idx = idx
            return preprocess_item(item)
        else:
            return self.index_select(idx)


class MyQM9Dataset(QM9):
    def download(self):
        super(MyQM9Dataset, self).download()

    def process(self):
        super(MyQM9Dataset, self).process()

    def __getitem__(self, idx):
        if isinstance(idx, int):
            item = self.get(self.indices()[idx])
            item.idx = idx
            return preprocess_item(item)
        else:
            return self.index_select(idx)


class MyTUDataset(TUDataset):
    def download(self):
        super(MyTUDataset, self).download()

    def process(self):
        super(MyTUDataset, self).process()

    def __getitem__(self, idx):
        if isinstance(idx, int):
            item = self.get(self.indices()[idx])
            item.idx = idx
            return preprocess_item(item)
        else:
            return self.index_select(idx)


class MyQM8Dataset(InMemoryDataset):
    def __init__(self):
        super().__init__()
        self.dataset = torch.load('../../our_TF/processed_qm8')

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, idx):
        if isinstance(idx, int):
            item = self.dataset[idx]
            item.idx = idx
            return preprocess_item(item)
        else:
            return [self.__getitem__(int(idx[i])) for i in range(idx.shape[0])]
