"""
DW4 (Dahlke-Winkelmann 4)

An extended collection of 1000 conformations of a small
protein fragment in water, generated by molecular dynamics
simulations. It includes information on the positions and
velocities of the atoms in the molecule, as well as various
other properties such as the potential energy and temperature.

"""
import torch
import torch_geometric
import torch_geometric.transforms as T
from torch_geometric.loader import DataLoader
from torch_geometric.data import Data, InMemoryDataset

#-----------------------------------------------------------------------------------------------------------------------------------------------------
# Split and Load
#-----------------------------------------------------------------------------------------------------------------------------------------------------

"""
Dataset Loading
~~~~~~~~~~~~~~~

    `split`: (str) train validation and test data splits
        + fixed_2: 10^2 training nodes, 1,000 validation and 1,000 test nodes
        + fixed_3: 10^3 training nodes, 1,000 validation and 1,000 test nodes
        + fixed_4: 10^4 training nodes, 1,000 validation and 1,000 test nodes
        + fixed_5: 10^5 training nodes, 1,000 validation and 1,000 test nodes
    `batch_size`: (int) maximum batch size for graphs

"""

def dw4_dataloaders(
    split : str = 'fixed_2',
    batch_size : int = 128,
):

    assert(split in ['fixed_2','fixed_3','fixed_4','fixed_5']), f'Split not recognized: {split}'

    dataset = DW4(root='/root/workspace/data/dw4/',
        transform=T.Compose([RemovePosMean(), Complete(), OnesForX()])
    )

    train_dataset, val_dataset, test_dataset = fixed_splits(dataset, split)

    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)
    test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

    return dataset, train_dataset, val_dataset, test_dataset, train_loader, val_loader, test_loader

"""
Data Splits
~~~~~~~~~~~~~~~

    Split the DW4 dataset containing X molecules into
    training, validation and test sets:

    `fixed_splits`: This follows the splits of EN Flows
     - valset <- 1,000 molecules
     - testset <- 1,000 molecules
     - trainset <- specified
    
"""

def fixed_splits(dataset,split):
    num_val = 1_000
    num_test = 1_000
    num_train = len(dataset) - num_val - num_test
    size_train_set = fixed_split_lookup[split]

    dataset_train = dataset[:num_train][:size_train_set]
    dataset_val = dataset[num_train:len(dataset) - num_test]
    dataset_test = dataset[len(dataset)-num_test:]

    return dataset_train, dataset_val, dataset_test


#-----------------------------------------------------------------------------------------------------------------------------------------------------
# Dataset
#-----------------------------------------------------------------------------------------------------------------------------------------------------

class DW4(InMemoryDataset):
    pos_dim = 2
    num_nodes = 4

    def __init__(self, root, transform=None, pre_transform=None, pre_filter=None):
        super().__init__(root, transform, pre_transform, pre_filter)
        self.data, self.slices = torch.load(self.processed_paths[0])

    @property
    def raw_file_names(self):
        return ['dw4_a0b-4c0.9d4.pt']

    @property
    def processed_file_names(self):
        return self.raw_file_names

    def process(self):
        data, _ = torch.load(f'{self.root}/{self.raw_file_names[0]}')
        data = data.view(-1, self.num_nodes, self.pos_dim)  # reshape to array of graphs
        data_list = [Data(pos=p) for p in data]

        if self.pre_filter is not None:
            data_list = [*map(self.pre_filter, data_list)]

        if self.pre_transform is not None:
            data_list = [*map(self.pre_transform, data_list)]

        data, slices = self.collate(data_list)
        torch.save((data, slices), self.processed_paths[0])

    @classmethod
    def potential_function(cls, pos, edge_index):
        # a = 0
        b = -4
        c = 0.9
        d0 = 4
        # tau = 1

        batch_size = pos.size(0)

        pos = pos.view(batch_size * cls.num_nodes, cls.pos_dim)
        node_i_idxs, node_j_idxs = edge_index
        pos_diffs = pos[node_i_idxs] - pos[node_j_idxs]
        dij = torch.sum(pos_diffs**2, dim=1, keepdim=True)**(1/2)
        dij = dij.view(batch_size, cls.num_nodes, cls.num_nodes - 1)
        dij_offset = dij - d0

        return (b * dij_offset**2 + c * dij_offset**4).view(batch_size, -1).sum(dim=1) / 2


#-----------------------------------------------------------------------------------------------------------------------------------------------------
# Transforms
#-----------------------------------------------------------------------------------------------------------------------------------------------------

"""
Featurizations 
~~~~~~~~~~~~~~~

"""

class RemovePosMean(T.BaseTransform):
    def __call__(self, data):
        """
        Subtracts the mean vertex pos from all pos of the nodes.
        The vertex pos of the graph is now centered at zero.

        Parameters
        ----------
        data: torch_geometric.data.Data
            A Data object representing a single graph.

        Returns
        -------
        torch_geometric.data.Data
            Updated graph with pos centered at zero.
        """
        data.pos = data.pos - torch.mean(data.pos, dim=0, keepdim=True)
        return data


class OnesForX(T.BaseTransform):
    def __call__(self, data):
        """
        Sets the node features of a graph to ``torch.ones(num_nodes, 1)``.

        Parameters
        ----------
        data: torch_geometric.data.Data
            A Data object representing a single graph.

        Returns
        -------
        torch_geometric.data.Data
            A graph with Data.x set to ``torch.ones(num_nodes, 1)``.
        """
        data.x = torch.ones(data.num_nodes, 1)

        return data


"""
Adjacency
~~~~~~~~~

"""

class Complete(T.BaseTransform):
    def __call__(self, data):
        """
        Makes a graph fully-connected.

        Parameters
        ----------
        data: torch_geometric.data.Data
            A Data object representing a single graph.

        Returns
        -------
        torch_geometric.data.Data
            Updated graph with edge_index representing a fully-connected graph.

        References
        ----------
        .. [1] PyTorch Geometric NN Conv Example for QM9.
           https://github.com/pyg-team/pytorch_geometric/blob/66b17806b1f4a2008e8be766064d9ef9a883ff03/examples/qm9_nn_conv.py
        """
        row = torch.arange(data.num_nodes, dtype=torch.long)
        col = torch.arange(data.num_nodes, dtype=torch.long)

        row = row.view(-1, 1).repeat(1, data.num_nodes).view(-1)
        col = col.repeat(data.num_nodes)
        edge_index = torch.stack([row, col], dim=0)

        edge_attr = None
        if data.edge_attr is not None:
            idx = data.edge_index[0] * data.num_nodes + data.edge_index[1]
            size = list(data.edge_attr.size())
            size[0] = data.num_nodes * data.num_nodes
            edge_attr = data.edge_attr.new_zeros(size)
            edge_attr[idx] = data.edge_attr

        edge_index, edge_attr = torch_geometric.utils.remove_self_loops(edge_index, edge_attr)
        data.edge_attr = edge_attr
        data.edge_index = edge_index

        return data


#-----------------------------------------------------------------------------------------------------------------------------------------------------
# Look-Up Tables
#-----------------------------------------------------------------------------------------------------------------------------------------------------

fixed_split_lookup = {
    'fixed_2':10**2,
    'fixed_3':10**3,
    'fixed_4':10**4,
    'fixed_5':10**5,
}

if __name__=='__main__':
    data_tuple = dw4_dataloaders()
    for val in data_tuple:
        print(val)