"""
DW4 (Dahlke-Winkelmann 4)

An extended collection of 1000 conformations of a small
protein fragment in water, generated by molecular dynamics
simulations. It includes information on the positions and
velocities of the atoms in the molecule, as well as various
other properties such as the potential energy and temperature.

"""
import os

import torch
import torch_geometric
import torch_geometric.transforms as T
from torch_geometric.loader import DataLoader
from torch_geometric.data import Data, InMemoryDataset

import torch.distributed as dist

#-----------------------------------------------------------------------------------------------------------------------------------------------------
# Split and Load
#-----------------------------------------------------------------------------------------------------------------------------------------------------

"""
Dataset Loading
~~~~~~~~~~~~~~~

    `split`: (str) train validation and test data splits
        + fixed_2: 10^2 training nodes, 1,000 validation and 1,000 test nodes
        + fixed_3: 10^3 training nodes, 1,000 validation and 1,000 test nodes
        + fixed_4: 10^4 training nodes, 1,000 validation and 1,000 test nodes
        + fixed_5: 10^5 training nodes, 1,000 validation and 1,000 test nodes
    `batch_size`: (int) maximum batch size for graphs

"""

def dw4_dataloaders(
    featurization : str = 'one-hot',
    adjacency : str = 'complete',
    split : str = 'fixed_2',
    batch_size : int = 128,
):

    assert(split in ['fixed_2','fixed_3','fixed_4','fixed_5']), f'Split not recognized: {split}'
    assert(adjacency in ['complete','bond']), f'Split not recognized: {adjacency}'

    if adjacency == 'complete':
        transform = T.Compose([AugmentPosition(), RemovePosMean(), Complete(), OnesForX(), t_feats[featurization]])
    elif adjacency == 'bond':
        transform = T.Compose([AugmentPosition(), RemovePosMean(), OnesForX(), t_feats[featurization]])

    dataset = DW4(root='/root/workspace/data/dw4/',
        transform=transform
    )

    train_dataset, val_dataset, test_dataset = fixed_splits(dataset, split)

    # train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    # val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)
    # test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)
    train_loader, val_loader, test_loader = create_dataloaders(train_dataset, val_dataset, test_dataset, batch_size)

    return dataset, train_dataset, val_dataset, test_dataset, train_loader, val_loader, test_loader


def create_dataloaders(trainset, valset, testset, batch_size):
    if dist.is_initialized():

        train_sampler = torch.utils.data.distributed.DistributedSampler(trainset)
        val_sampler = torch.utils.data.distributed.DistributedSampler(valset)
        test_sampler = torch.utils.data.distributed.DistributedSampler(testset)

        pin_memory = True
        persistent_workers = False
        num_workers = 4

        train_loader = DataLoader(
            trainset,
            batch_size=batch_size,
            shuffle=False,
            sampler=train_sampler,
            num_workers=num_workers,
            pin_memory=pin_memory,
            persistent_workers=persistent_workers,
        )
        val_loader = DataLoader(
            valset, batch_size=batch_size, shuffle=False, sampler=val_sampler
        )
        test_loader = DataLoader(
            testset, batch_size=batch_size, shuffle=False, sampler=test_sampler
        )

    else:

        train_loader = DataLoader(trainset, batch_size=batch_size, shuffle=True)
        val_loader = DataLoader(
            valset,
            batch_size=batch_size,
            shuffle=True,
        )
        test_loader = DataLoader(
            testset,
            batch_size=batch_size,
            shuffle=True,
        )

    return train_loader, val_loader, test_loader

"""
Data Splits
~~~~~~~~~~~~~~~

    Split the DW4 dataset containing X molecules into
    training, validation and test sets:

    `fixed_splits`: This follows the splits of EN Flows
     - valset <- 1,000 molecules
     - testset <- 1,000 molecules
     - trainset <- specified
    
"""

def fixed_splits(dataset,split):
    num_val = 1_000
    num_test = 1_000
    num_train = len(dataset) - num_val - num_test
    size_train_set = fixed_split_lookup[split]

    dataset_train = dataset[:num_train][:size_train_set]
    dataset_val = dataset[num_train:len(dataset) - num_test]
    dataset_test = dataset[len(dataset)-num_test:]

    return dataset_train, dataset_val, dataset_test


#-----------------------------------------------------------------------------------------------------------------------------------------------------
# Dataset
#-----------------------------------------------------------------------------------------------------------------------------------------------------

class DW4(InMemoryDataset):
    pos_dim = 2
    num_nodes = 4

    def __init__(self, root, transform=None, pre_transform=None, pre_filter=None):
        super().__init__(root, transform, pre_transform, pre_filter)
        self.data, self.slices = torch.load(self.processed_paths[0])

    @property
    def raw_file_names(self):
        return ['dw4_a0b-4c0.9d4.pt']

    @property
    def processed_file_names(self):
        return self.raw_file_names

    def process(self):
        data, _ = torch.load(f'{self.root}/{self.raw_file_names[0]}')
        data = data.view(-1, self.num_nodes, self.pos_dim)  # reshape to array of graphs
        data_list = [Data(pos=p) for p in data]

        if self.pre_filter is not None:
            data_list = [*map(self.pre_filter, data_list)]

        if self.pre_transform is not None:
            data_list = [*map(self.pre_transform, data_list)]

        data, slices = self.collate(data_list)
        torch.save((data, slices), self.processed_paths[0])

    @classmethod
    def potential_function(cls, pos, edge_index):
        # a = 0
        b = -4
        c = 0.9
        d0 = 4
        # tau = 1

        batch_size = pos.size(0)

        pos = pos.view(batch_size * cls.num_nodes, cls.pos_dim)
        node_i_idxs, node_j_idxs = edge_index
        pos_diffs = pos[node_i_idxs] - pos[node_j_idxs]
        dij = torch.sum(pos_diffs**2, dim=1, keepdim=True)**(1/2)
        dij = dij.view(batch_size, cls.num_nodes, cls.num_nodes - 1)
        dij_offset = dij - d0

        return (b * dij_offset**2 + c * dij_offset**4).view(batch_size, -1).sum(dim=1) / 2


#-----------------------------------------------------------------------------------------------------------------------------------------------------
# Transforms
#-----------------------------------------------------------------------------------------------------------------------------------------------------

"""
Featurizations 
~~~~~~~~~~~~~~~

"""

class RemovePosMean(T.BaseTransform):
    def __call__(self, data):
        """
        Subtracts the mean vertex pos from all pos of the nodes.
        The vertex pos of the graph is now centered at zero.

        Parameters
        ----------
        data: torch_geometric.data.Data
            A Data object representing a single graph.

        Returns
        -------
        torch_geometric.data.Data
            Updated graph with pos centered at zero.
        """
        data.pos = data.pos - torch.mean(data.pos, dim=0, keepdim=True)
        return data


class OnesForX(T.BaseTransform):
    def __call__(self, data):
        """
        Sets the node features of a graph to ``torch.ones(num_nodes, 1)``.

        Parameters
        ----------
        data: torch_geometric.data.Data
            A Data object representing a single graph.

        Returns
        -------
        torch_geometric.data.Data
            A graph with Data.x set to ``torch.ones(num_nodes, 1)``.
        """
        data.x = torch.ones(data.num_nodes, 1)

        return data


class AtomicNumber(T.BaseTransform):
    " Retrieve Atomic Number "
    def __call__(self, data):
        data.x = data.x[:,5]
        return data

class OneHot(T.BaseTransform):
    " Base is one-hot "
    def __call__(self, data):
        return data


class OneHot_Charge(T.BaseTransform):
    def __call__(self, data):
        one_hots = torch.ones_like(data.x[:])
        atomic_numbers = torch.ones_like(data.x[:])
        data_dict = {'categorical':one_hots, 'integer':atomic_numbers}
        data.x = data_dict
        # data.pos = torch.cat([data.pos, data.pos.mean(dim=0)*torch.ones_like(data.pos)[:,:1]], dim=1)
        return data

class AugmentPosition(T.BaseTransform):
    def __call__(self, data):
        data.pos = torch.cat([data.pos, torch.zeros_like(data.pos)[:,:1]], dim=1)
        data.pos = data.pos - torch.mean(data.pos, dim=1, keepdim=True)
        return data

"""
Adjacency
~~~~~~~~~

"""

class Complete(T.BaseTransform):
    def __call__(self, data):
        """
        Makes a graph fully-connected.

        Parameters
        ----------
        data: torch_geometric.data.Data
            A Data object representing a single graph.

        Returns
        -------
        torch_geometric.data.Data
            Updated graph with edge_index representing a fully-connected graph.

        References
        ----------
        .. [1] PyTorch Geometric NN Conv Example for QM9.
           https://github.com/pyg-team/pytorch_geometric/blob/66b17806b1f4a2008e8be766064d9ef9a883ff03/examples/qm9_nn_conv.py
        """
        row = torch.arange(data.num_nodes, dtype=torch.long)
        col = torch.arange(data.num_nodes, dtype=torch.long)

        row = row.view(-1, 1).repeat(1, data.num_nodes).view(-1)
        col = col.repeat(data.num_nodes)
        edge_index = torch.stack([row, col], dim=0)

        edge_attr = None
        if data.edge_attr is not None:
            idx = data.edge_index[0] * data.num_nodes + data.edge_index[1]
            size = list(data.edge_attr.size())
            size[0] = data.num_nodes * data.num_nodes
            edge_attr = data.edge_attr.new_zeros(size)
            edge_attr[idx] = data.edge_attr

        edge_index, edge_attr = torch_geometric.utils.remove_self_loops(edge_index, edge_attr)
        data.edge_attr = edge_attr
        data.edge_index = edge_index

        return data


#-----------------------------------------------------------------------------------------------------------------------------------------------------
# Look-Up Tables
#-----------------------------------------------------------------------------------------------------------------------------------------------------

t_feats = {
    'atomic-number':AtomicNumber(),
    'dict':OneHot_Charge(),
    'one-hot':OneHot(),
    }

fixed_split_lookup = {
    'fixed_2':10**2,
    'fixed_3':10**3,
    'fixed_4':10**4,
    'fixed_5':10**5,
}

if __name__=='__main__':
    data_tuple = dw4_dataloaders()
    for val in data_tuple:
        print(val)