"""
LJ13 (Lennard-Jones 13)

A collection of 13 small molecules, each composed of
up to 12 atoms, generated by Monte Carlo simulations.
The molecules are characterized by their Lennard-Jones
potential parameters, which describe the pairwise
interactions between the atoms in the molecule.

"""
import torch
import torch_geometric
import torch_geometric.transforms as T
from torch_geometric.loader import DataLoader
from torch_geometric.data import Data, InMemoryDataset

#-----------------------------------------------------------------------------------------------------------------------------------------------------
# Split and Load
#-----------------------------------------------------------------------------------------------------------------------------------------------------

"""
Dataset Loading
~~~~~~~~~~~~~~~

    `split`: (str) train validation and test data splits
        + fixed_1: 10^1 training nodes, 1,000 validation and 1,000 test nodes
        + fixed_2: 10^2 training nodes, 1,000 validation and 1,000 test nodes
        + fixed_3: 10^3 training nodes, 1,000 validation and 1,000 test nodes
        + fixed_4: 10^4 training nodes, 1,000 validation and 1,000 test nodes
    `batch_size`: (int) maximum batch size for graphs

"""

def lj13_dataloaders(
    split : str = 'fixed_1',
    batch_size : int = 128,
):

    assert(split in ['fixed_1','fixed_2','fixed_3','fixed_4']), f'Split not recognized: {split}'

    dataset = LJ13(root='/root/workspace/data/lj13/',
        transform=T.Compose([RemovePosMean(), Complete(), OnesForX()])
    )

    train_dataset, val_dataset, test_dataset = fixed_splits(dataset, split)

    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)
    test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

    return dataset, train_dataset, val_dataset, test_dataset, train_loader, val_loader, test_loader

"""
Data Splits
~~~~~~~~~~~~~~~

    Split the LJ13 dataset containing X molecules into
    training, validation and test sets:

    `fixed_splits`: This follows the splits of EN Flows
     - valset <- 1,000 molecules
     - testset <- 1,000 molecules
     - trainset <- specified
    
"""

def fixed_splits(dataset,split):

    size_train_set = fixed_split_lookup[split]

    dataset_train = dataset.get(dataset.processed_file_names[0])
    dataset_val_test = dataset.get(dataset.processed_file_names[1])

    dataset_train = dataset_train[:size_train_set]
    dataset_val = dataset_val_test[1_000:2_000]
    dataset_test = dataset_val_test[:1_000]

    return dataset_train, dataset_val, dataset_test


#-----------------------------------------------------------------------------------------------------------------------------------------------------
# Dataset
#-----------------------------------------------------------------------------------------------------------------------------------------------------

class DatasetNoIO(InMemoryDataset):
    def __init__(self, dataset_name, data, slices, transform=None):
        super().__init__(None, transform)
        self.dataset_name = dataset_name
        self.data, self.slices = data, slices


class LJ13(InMemoryDataset):
    pos_dim = 3
    num_nodes = 13

    def __init__(self, root, transform=None, pre_transform=None, pre_filter=None):
        super().__init__(root, transform, pre_transform, pre_filter)

    @property
    def raw_file_names(self):
        return ['lj13dim3num_particles13_train.pt', 'lj13dim3num_particles13_val_test.pt']

    @property
    def processed_file_names(self):
        return self.raw_file_names

    def process(self):
        for rfn, pfn in zip(self.raw_file_names, self.processed_file_names):
            data = torch.load(f'{self.root}/{rfn}')
            if isinstance(data, tuple):
                data = data[-1]  # train partition also stores idx
            data = data.view(-1, self.num_nodes, self.pos_dim)  # reshape to array of graphs
            data_list = [Data(pos=p) for p in data]

            if self.pre_filter is not None:
                data_list = [*map(self.pre_filter, data_list)]

            if self.pre_transform is not None:
                data_list = [*map(self.pre_transform, data_list)]

            data, slices = self.collate(data_list)
            torch.save((data, slices), f'{self.processed_dir}/{pfn}')

    def get(self, processed_file_name):
        return DatasetNoIO(self.__class__.__name__, *torch.load(f'{self.processed_dir}/{processed_file_name}'))

    @classmethod
    def potential_function(cls, pos, edge_index):
        # eps = 1
        r_m = 1
        # tau = 0.5

        batch_size = pos.size(0)

        pos = pos.view(batch_size * cls.num_nodes, cls.pos_dim)
        node_i_idxs, node_j_idxs = edge_index
        pos_diffs = pos[node_i_idxs] - pos[node_j_idxs]
        dij = torch.sum(pos_diffs**2, dim=1, keepdim=True)**(1/2)
        dij = dij.view(batch_size, cls.num_nodes, cls.num_nodes - 1)

        return ((r_m / dij)**12 - 2 * (r_m / dij)**6).view(batch_size, -1).sum(dim=1)

#-----------------------------------------------------------------------------------------------------------------------------------------------------
# Transforms
#-----------------------------------------------------------------------------------------------------------------------------------------------------

"""
Featurizations 
~~~~~~~~~~~~~~~

"""

class RemovePosMean(T.BaseTransform):
    def __call__(self, data):
        """
        Subtracts the mean vertex pos from all pos of the nodes.
        The vertex pos of the graph is now centered at zero.

        Parameters
        ----------
        data: torch_geometric.data.Data
            A Data object representing a single graph.

        Returns
        -------
        torch_geometric.data.Data
            Updated graph with pos centered at zero.
        """
        data.pos = data.pos - torch.mean(data.pos, dim=0, keepdim=True)
        return data


class OnesForX(T.BaseTransform):
    def __call__(self, data):
        """
        Sets the node features of a graph to ``torch.ones(num_nodes, 1)``.

        Parameters
        ----------
        data: torch_geometric.data.Data
            A Data object representing a single graph.

        Returns
        -------
        torch_geometric.data.Data
            A graph with Data.x set to ``torch.ones(num_nodes, 1)``.
        """
        data.x = torch.ones(data.num_nodes, 1)

        return data


"""
Adjacency
~~~~~~~~~

"""

class Complete(T.BaseTransform):
    def __call__(self, data):
        """
        Makes a graph fully-connected.

        Parameters
        ----------
        data: torch_geometric.data.Data
            A Data object representing a single graph.

        Returns
        -------
        torch_geometric.data.Data
            Updated graph with edge_index representing a fully-connected graph.

        References
        ----------
        .. [1] PyTorch Geometric NN Conv Example for QM9.
           https://github.com/pyg-team/pytorch_geometric/blob/66b17806b1f4a2008e8be766064d9ef9a883ff03/examples/qm9_nn_conv.py
        """
        row = torch.arange(data.num_nodes, dtype=torch.long)
        col = torch.arange(data.num_nodes, dtype=torch.long)

        row = row.view(-1, 1).repeat(1, data.num_nodes).view(-1)
        col = col.repeat(data.num_nodes)
        edge_index = torch.stack([row, col], dim=0)

        edge_attr = None
        if data.edge_attr is not None:
            idx = data.edge_index[0] * data.num_nodes + data.edge_index[1]
            size = list(data.edge_attr.size())
            size[0] = data.num_nodes * data.num_nodes
            edge_attr = data.edge_attr.new_zeros(size)
            edge_attr[idx] = data.edge_attr

        edge_index, edge_attr = torch_geometric.utils.remove_self_loops(edge_index, edge_attr)
        data.edge_attr = edge_attr
        data.edge_index = edge_index

        return data


#-----------------------------------------------------------------------------------------------------------------------------------------------------
# Look-Up Tables
#-----------------------------------------------------------------------------------------------------------------------------------------------------

fixed_split_lookup = {
    'fixed_1':10**1,
    'fixed_2':10**2,
    'fixed_3':10**3,
    'fixed_4':10**4,
    'fixed_5':10**5,
}

if __name__=='__main__':
    data_tuple = lj13_dataloaders()
    for val in data_tuple:
        print(val)