import torch
import torch_geometric.transforms

from tqdm import tqdm
from joblib import Parallel, delayed

# %%


class DataStructure:
    train_ratio = 0.7
    val_ratio = 0.
    test_ratio = 1 - (train_ratio + val_ratio)
    batch_size = 2 ** 12  # 4096, does not include negative sampling

    def __init__(self, d, task=None, neighbourhood_aggregation=None, pos_enc_dim=15, batch_processing=False, ssl=True,
                 aggregation='nagphormer'):
        # TODO: batch processing for larger graphs
        self.datasets = d
        self.ssl = ssl
        self.task = task
        self.aggregation = aggregation

        if neighbourhood_aggregation is not None:
            print(f"constructing data structure with {len(d)} datasets and neighbourhood of {neighbourhood_aggregation}"
                  f" for {'ssl' if self.ssl else self.task}")
            self.num_hops = neighbourhood_aggregation
            self.pos_enc_dim = pos_enc_dim

        for i in tqdm(range(len(d))):
            # processing
            if not hasattr(self.datasets[i], 'name'):
                if isinstance(self.datasets[i], torch_geometric.datasets.PPI):
                    self.datasets[i] = MergedDataset(self.datasets[i])
                    self.datasets[i].name = 'ppi'
            if isinstance(self.datasets[i], torch_geometric.datasets.planetoid.Planetoid):
                self.datasets[i].name = self.datasets[i].name + '-planetoid'  # prevent overlapping names of datasets
            # create data
            if neighbourhood_aggregation is not None:
                
                if self.datasets[i].data.x.shape[0]<2e5:
                    torch_geometric.transforms.AddLaplacianEigenvectorPE(
                        k=self.pos_enc_dim,
                        attr_name=None,
                    )(self.datasets[i].data)

                self.datasets[i].data.x = neighbourhood_aggregation_fn(
                    dataset=self.datasets[i],
                    num_hops=self.num_hops,
                    aggregation=self.aggregation,
                )
            if not hasattr(self.datasets[i], 'num_features'):
                self.datasets[i].num_features = self.datasets[i].data.x.shape[-1]
            if self.ssl:
                if batch_processing:
                    pass
                else:
                    pass
            else:
                if self.task == 'node_classification':
                    torch_geometric.transforms.RandomNodeSplit()(self.datasets[i].data)
                    if batch_processing:
                        pass
                    else:
                        pass
                elif self.task == 'link_prediction':
                    train_data, val_data, test_data = torch_geometric.transforms.RandomLinkSplit(
                        num_val=0.1,
                        num_test=0.2,
                        is_undirected=torch_geometric.utils.is_undirected(self.datasets[i].data.edge_index),
                        disjoint_train_ratio=0.3,
                    )(self.datasets[i].data)
                    self.datasets[i].data.train_data = train_data
                    self.datasets[i].data.val_data = val_data
                    self.datasets[i].data.test_data = test_data
                    if batch_processing:
                        pass
                    else:
                        pass
                    # if batched:
                    #     self.datasets[i].train_loader = torch_geometric.loader.LinkNeighborLoader(
                    #         data=self.datasets[i].data,
                    #         num_neighbors=[30, 30],
                    #         neg_sampling_ratio=1.,
                    #         batch_size=2 ** 12,
                    #     )
                else:
                    raise(NotImplementedError(f"task type {self.task} not implemented"))


def neighbourhood_aggregation_fn(dataset, num_hops, aggregation='nagphormer'):
    node_features = torch.zeros((dataset.data.x.shape[0], num_hops + 1, dataset.data.x.shape[1]))

    # node features at hop 0
    for i in range(dataset.data.x.shape[0]):
        node_features[i, 0, :] = dataset.data.x[i, :] + torch.zeros_like(dataset.data.x[i, :])

    # node features at hop 1 onwards
    source_nodes, neighbour_nodes = dataset.data.edge_index
    node_mask = torch.empty(dataset.data.x.shape[0], dtype=torch.bool)
    edge_mask = torch.empty(dataset.data.edge_index.shape[1], dtype=torch.bool)
    neighbours_hop = []
    if aggregation == 'new':
        for hop in range(num_hops):
            for i in range(dataset.data.x.shape[0]):
                node_mask.fill_(False)
                if hop == 0:
                    neighbours_hop.append([i])
                node_mask[neighbours_hop[i][-1]] = True
                torch.index_select(node_mask, 0, neighbour_nodes, out=edge_mask)
                neighbours = source_nodes[edge_mask]
                neighbours = torch.unique(neighbours)  # find unique neighbours
                neighbours = neighbours[neighbours != i]  # remove self from neighbours
                neighbours_hop[i].append([neighbours])
                node_features[i, hop + 1, :] = torch.sum(node_features[neighbours, hop, :], dim=0)
                # node_features[i, hop + 1, :] = torch.sum(dataset.data.x[neighbours, :], dim=0)
    elif aggregation == 'nagphormer':
        for i in range(dataset.data.x.shape[0]):
            node_mask.fill_(False)
            neighbours_hop.append([i])
            node_mask[neighbours_hop[i][-1]] = True
            torch.index_select(node_mask, 0, neighbour_nodes, out=edge_mask)
            neighbours = source_nodes[edge_mask]
            neighbours = torch.unique(neighbours)  # find unique neighbours
            neighbours = neighbours[neighbours != i]  # remove self from neighbours
            neighbours_hop[i].append([neighbours])
        # neighbours_hop = Parallel(n_jobs=2)(delayed(aggregation_parallel)(dataset.data, i) for i in tqdm(range(dataset.data.x.shape[0])))
    
        for hop in range(num_hops):
            for i in range(dataset.data.x.shape[0]):
                node_features[i, hop + 1, :] = torch.sum(node_features[neighbours_hop[i][-1][0], hop, :], dim=0)
    else:
        raise(NotImplementedError(f'aggregation type \'{aggregation}\' not implemented.'))

    return node_features

def aggregation_parallel(data, i):
    source_nodes, neighbour_nodes = data.edge_index
    node_mask = torch.empty(data.x.shape[0], dtype=torch.bool)
    edge_mask = torch.empty(data.edge_index.shape[1], dtype=torch.bool)
    # neighbours_hop = []
    # for i in range(data.x.shape[0]):
    node_mask.fill_(False)
    # neighbours_hop.append([i])
    # node_mask[neighbours_hop[i][-1]] = True
    node_mask[i] = True
    torch.index_select(node_mask, 0, neighbour_nodes, out=edge_mask)
    neighbours = source_nodes[edge_mask]
    neighbours = torch.unique(neighbours)  # find unique neighbours
    neighbours = neighbours[neighbours != i]  # remove self from neighbours
    # neighbours_hop[i].append([neighbours])

    return [[i],[neighbours]]

class MergedDataset:
    def __init__(self, d):
        edge_index = torch.zeros_like(d.data.edge_index)
        edge_index[:, :d[0].num_edges] = d[0].edge_index
        edge_index_filled = 0
        num_nodes_covered = 0
        for i in range(1, len(d)):
            edge_index_filled = edge_index_filled + d[i - 1].num_edges
            num_nodes_covered = num_nodes_covered + d[i - 1].num_nodes
            edge_index[:, edge_index_filled:edge_index_filled + d[i].num_edges] = d[i].edge_index + num_nodes_covered
        x = torch.cat([d[i].x for i in range(len(d))])
        y = torch.cat([d[i].y for i in range(len(d))])
        d1 = torch_geometric.data.Data(x=x, edge_index=edge_index, y=y)

        self.data = d1
        self.num_nodes = x.shape[0]
        self.num_edges = edge_index.shape[1]

class Dummydatasets():

    def __init__(self, d, name):
        self.data = d
        self.num_nodes = d.x.shape[0]
        self.num_node_features = d.x.shape[1]
        self.name = name
        self.num_classes = torch.unique(d.y).shape[0]