"""
Empty Sample Graphs

A method to sample graphs with given properties.

"""
from typing import Optional

import torch
import torch_geometric
import torch_geometric.transforms as T
from torch_geometric.data import Data, InMemoryDataset

#-----------------------------------------------------------------------------------------------------------------------------------------------------
# Split and Load
#-----------------------------------------------------------------------------------------------------------------------------------------------------

"""
Dataset Loading
~~~~~~~~~~~~~~~

    `split`: (str) train validation and test data splits
        + fixed_2: 10^2 training nodes, 1,000 validation and 1,000 test nodes
        + fixed_3: 10^3 training nodes, 1,000 validation and 1,000 test nodes
        + fixed_4: 10^4 training nodes, 1,000 validation and 1,000 test nodes
        + fixed_5: 10^5 training nodes, 1,000 validation and 1,000 test nodes
    `batch_size`: (int) maximum batch size for graphs

"""

def sample_graphs(
    x : Optional[list] = None,
    pos : Optional[list] = None,
    adjacency : str = 'full',
    split : str = 'fixed_2',
    batch_size : int = 128,
):

    x = x if (x is not None) else [None for p in pos]
    pos = pos if (pos is not None) else [None for x_ in x]
    data_list = [Data(x=x,pos=p) for (x,p) in zip(x,pos)]
    transform = []

    if adjacency=='full': transform.append(Complete())

    data, slices = InMemoryDataset.collate(data_list)

    dataset = DatasetNoIO(
        f'{InMemoryDataset.__name__}: Latent Sample',
        data, slices,
        transform=T.Compose(transform)
    )

    return dataset

#-----------------------------------------------------------------------------------------------------------------------------------------------------
# Dataset
#-----------------------------------------------------------------------------------------------------------------------------------------------------

class DatasetNoIO(InMemoryDataset):
    def __init__(self, dataset_name, data, slices, transform=None):
        super().__init__(None, transform)
        self.dataset_name = dataset_name
        self.data, self.slices = data, slices

#-----------------------------------------------------------------------------------------------------------------------------------------------------
# Transforms
#-----------------------------------------------------------------------------------------------------------------------------------------------------


"""
Featurizations 
~~~~~~~~~~~~~~~

"""

class RemovePosMean(T.BaseTransform):
    def __call__(self, data):
        """
        Subtracts the mean vertex pos from all pos of the nodes.
        The vertex pos of the graph is now centered at zero.

        Parameters
        ----------
        data: torch_geometric.data.Data
            A Data object representing a single graph.

        Returns
        -------
        torch_geometric.data.Data
            Updated graph with pos centered at zero.
        """
        data.pos = data.pos - torch.mean(data.pos, dim=0, keepdim=True)
        return data


class OnesForX(T.BaseTransform):
    def __call__(self, data):
        """
        Sets the node features of a graph to ``torch.ones(num_nodes, 1)``.

        Parameters
        ----------
        data: torch_geometric.data.Data
            A Data object representing a single graph.

        Returns
        -------
        torch_geometric.data.Data
            A graph with Data.x set to ``torch.ones(num_nodes, 1)``.
        """
        data.x = torch.ones(data.num_nodes, 1)

        return data


"""
Adjacency
~~~~~~~~~

"""

class Complete(T.BaseTransform):
    def __call__(self, data):
        """
        Makes a graph fully-connected.

        Parameters
        ----------
        data: torch_geometric.data.Data
            A Data object representing a single graph.

        Returns
        -------
        torch_geometric.data.Data
            Updated graph with edge_index representing a fully-connected graph.

        References
        ----------
        .. [1] PyTorch Geometric NN Conv Example for QM9.
           https://github.com/pyg-team/pytorch_geometric/blob/66b17806b1f4a2008e8be766064d9ef9a883ff03/examples/qm9_nn_conv.py
        """
        row = torch.arange(data.num_nodes, dtype=torch.long)
        col = torch.arange(data.num_nodes, dtype=torch.long)

        row = row.view(-1, 1).repeat(1, data.num_nodes).view(-1)
        col = col.repeat(data.num_nodes)
        edge_index = torch.stack([row, col], dim=0)

        edge_attr = None
        if data.edge_attr is not None:
            idx = data.edge_index[0] * data.num_nodes + data.edge_index[1]
            size = list(data.edge_attr.size())
            size[0] = data.num_nodes * data.num_nodes
            edge_attr = data.edge_attr.new_zeros(size)
            edge_attr[idx] = data.edge_attr

        edge_index, edge_attr = torch_geometric.utils.remove_self_loops(edge_index, edge_attr)
        data.edge_attr = edge_attr
        data.edge_index = edge_index

        return data
