"""
LJ13 (Lennard-Jones 13)

A collection of 13 small molecules, each composed of
up to 12 atoms, generated by Monte Carlo simulations.
The molecules are characterized by their Lennard-Jones
potential parameters, which describe the pairwise
interactions between the atoms in the molecule.

"""
import torch
import torch_geometric
import torch_geometric.transforms as T
from torch_geometric.loader import DataLoader
from torch_geometric.data import Data, InMemoryDataset

import torch.distributed as dist

#-----------------------------------------------------------------------------------------------------------------------------------------------------
# Split and Load
#-----------------------------------------------------------------------------------------------------------------------------------------------------

"""
Dataset Loading
~~~~~~~~~~~~~~~

    `split`: (str) train validation and test data splits
        + fixed_1: 10^1 training nodes, 1,000 validation and 1,000 test nodes
        + fixed_2: 10^2 training nodes, 1,000 validation and 1,000 test nodes
        + fixed_3: 10^3 training nodes, 1,000 validation and 1,000 test nodes
        + fixed_4: 10^4 training nodes, 1,000 validation and 1,000 test nodes
    `batch_size`: (int) maximum batch size for graphs

"""

def lj13_dataloaders(
    featurization : str = 'one-hot',
    adjacency : str = 'complete',
    split : str = 'fixed_1',
    batch_size : int = 128,
):

    assert(split in ['fixed_1','fixed_2','fixed_3','fixed_4']), f'Split not recognized: {split}'

    dataset = LJ13(root='/root/workspace/data/lj13/'+featurization+'/',
        pre_transform=T.Compose([RemovePosMean(), Complete(), OnesForX(), t_feats[featurization]])
    )

    train_dataset, val_dataset, test_dataset = fixed_splits(dataset, split)

    # train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    # val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)
    # test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)
    train_loader, val_loader, test_loader = create_dataloaders(train_dataset, val_dataset, test_dataset, batch_size)

    return dataset, train_dataset, val_dataset, test_dataset, train_loader, val_loader, test_loader

def create_dataloaders(trainset, valset, testset, batch_size):
    if dist.is_initialized():

        train_sampler = torch.utils.data.distributed.DistributedSampler(trainset)
        val_sampler = torch.utils.data.distributed.DistributedSampler(valset)
        test_sampler = torch.utils.data.distributed.DistributedSampler(testset)

        pin_memory = True
        persistent_workers = False
        num_workers = 4

        train_loader = DataLoader(
            trainset,
            batch_size=batch_size,
            shuffle=False,
            sampler=train_sampler,
            num_workers=num_workers,
            pin_memory=pin_memory,
            persistent_workers=persistent_workers,
        )
        val_loader = DataLoader(
            valset, batch_size=batch_size, shuffle=False, sampler=val_sampler
        )
        test_loader = DataLoader(
            testset, batch_size=batch_size, shuffle=False, sampler=test_sampler
        )

    else:

        train_loader = DataLoader(trainset, batch_size=batch_size, shuffle=True)
        val_loader = DataLoader(
            valset,
            batch_size=batch_size,
            shuffle=True,
        )
        test_loader = DataLoader(
            testset,
            batch_size=batch_size,
            shuffle=True,
        )

    return train_loader, val_loader, test_loader

"""
Data Splits
~~~~~~~~~~~~~~~

    Split the LJ13 dataset containing X molecules into
    training, validation and test sets:

    `fixed_splits`: This follows the splits of EN Flows
     - valset <- 1,000 molecules
     - testset <- 1,000 molecules
     - trainset <- specified
    
"""

def fixed_splits(dataset,split):

    size_train_set = fixed_split_lookup[split]

    dataset_train = dataset.get(dataset.processed_file_names[0])
    dataset_val_test = dataset.get(dataset.processed_file_names[1])

    dataset_train = dataset_train[:size_train_set]
    dataset_val = dataset_val_test[1_000:2_000]
    dataset_test = dataset_val_test[:1_000]

    return dataset_train, dataset_val, dataset_test


#-----------------------------------------------------------------------------------------------------------------------------------------------------
# Dataset
#-----------------------------------------------------------------------------------------------------------------------------------------------------

class DatasetNoIO(InMemoryDataset):
    def __init__(self, dataset_name, data, slices, transform=None):
        super().__init__(None, transform)
        self.dataset_name = dataset_name
        self.data, self.slices = data, slices


class LJ13(InMemoryDataset):
    pos_dim = 3
    num_nodes = 13

    def __init__(self, root, transform=None, pre_transform=None, pre_filter=None):
        super().__init__(root, transform, pre_transform, pre_filter)

    @property
    def raw_file_names(self):
        return ['lj13dim3num_particles13_train.pt', 'lj13dim3num_particles13_val_test.pt']

    @property
    def processed_file_names(self):
        return self.raw_file_names

    def process(self):
        for rfn, pfn in zip(self.raw_file_names, self.processed_file_names):
            data = torch.load(f'{self.root}/{rfn}')
            if isinstance(data, tuple):
                data = data[-1]  # train partition also stores idx
            data = data.view(-1, self.num_nodes, self.pos_dim)  # reshape to array of graphs
            data_list = [Data(pos=p) for p in data]

            if self.pre_filter is not None:
                data_list = [*map(self.pre_filter, data_list)]

            if self.pre_transform is not None:
                data_list = [*map(self.pre_transform, data_list)]

            data, slices = self.collate(data_list)
            torch.save((data, slices), f'{self.processed_dir}/{pfn}')

    def get(self, processed_file_name):
        return DatasetNoIO(self.__class__.__name__, *torch.load(f'{self.processed_dir}/{processed_file_name}'))

    @classmethod
    def potential_function(cls, pos, edge_index):
        # eps = 1
        r_m = 1
        # tau = 0.5

        batch_size = pos.size(0)

        pos = pos.view(batch_size * cls.num_nodes, cls.pos_dim)
        node_i_idxs, node_j_idxs = edge_index
        pos_diffs = pos[node_i_idxs] - pos[node_j_idxs]
        dij = torch.sum(pos_diffs**2, dim=1, keepdim=True)**(1/2)
        dij = dij.view(batch_size, cls.num_nodes, cls.num_nodes - 1)

        return ((r_m / dij)**12 - 2 * (r_m / dij)**6).view(batch_size, -1).sum(dim=1)

#-----------------------------------------------------------------------------------------------------------------------------------------------------
# Transforms
#-----------------------------------------------------------------------------------------------------------------------------------------------------

"""
Featurizations 
~~~~~~~~~~~~~~~

"""

class RemovePosMean(T.BaseTransform):
    def __call__(self, data):
        """
        Subtracts the mean vertex pos from all pos of the nodes.
        The vertex pos of the graph is now centered at zero.

        Parameters
        ----------
        data: torch_geometric.data.Data
            A Data object representing a single graph.

        Returns
        -------
        torch_geometric.data.Data
            Updated graph with pos centered at zero.
        """
        data.pos = data.pos - torch.mean(data.pos, dim=0, keepdim=True)
        return data


class OnesForX(T.BaseTransform):
    def __call__(self, data):
        """
        Sets the node features of a graph to ``torch.ones(num_nodes, 1)``.

        Parameters
        ----------
        data: torch_geometric.data.Data
            A Data object representing a single graph.

        Returns
        -------
        torch_geometric.data.Data
            A graph with Data.x set to ``torch.ones(num_nodes, 1)``.
        """
        data.x = torch.ones(data.num_nodes, 1)

        return data


"""
Adjacency
~~~~~~~~~

"""

class Complete(T.BaseTransform):
    def __call__(self, data):
        """
        Makes a graph fully-connected.

        Parameters
        ----------
        data: torch_geometric.data.Data
            A Data object representing a single graph.

        Returns
        -------
        torch_geometric.data.Data
            Updated graph with edge_index representing a fully-connected graph.

        References
        ----------
        .. [1] PyTorch Geometric NN Conv Example for QM9.
           https://github.com/pyg-team/pytorch_geometric/blob/66b17806b1f4a2008e8be766064d9ef9a883ff03/examples/qm9_nn_conv.py
        """
        row = torch.arange(data.num_nodes, dtype=torch.long)
        col = torch.arange(data.num_nodes, dtype=torch.long)

        row = row.view(-1, 1).repeat(1, data.num_nodes).view(-1)
        col = col.repeat(data.num_nodes)
        edge_index = torch.stack([row, col], dim=0)

        edge_attr = None
        if data.edge_attr is not None:
            idx = data.edge_index[0] * data.num_nodes + data.edge_index[1]
            size = list(data.edge_attr.size())
            size[0] = data.num_nodes * data.num_nodes
            edge_attr = data.edge_attr.new_zeros(size)
            edge_attr[idx] = data.edge_attr

        edge_index, edge_attr = torch_geometric.utils.remove_self_loops(edge_index, edge_attr)
        data.edge_attr = edge_attr
        data.edge_index = edge_index

        return data


class AtomicNumber(T.BaseTransform):
    " Retrieve Atomic Number "
    def __call__(self, data):
        data.x = data.x[:,5]
        return data

class OneHot(T.BaseTransform):
    " Base is one-hot "
    def __call__(self, data):
        return data


class OneHot_Charge(T.BaseTransform):
    def __call__(self, data):
        one_hots = data.x[:]
        atomic_numbers = data.x[:]
        data_dict = {'categorical':one_hots, 'integer':atomic_numbers}
        data.x = data_dict
        return data

#-----------------------------------------------------------------------------------------------------------------------------------------------------
# Look-Up Tables
#-----------------------------------------------------------------------------------------------------------------------------------------------------

t_feats = {
    'atomic-number':AtomicNumber(),
    'dict':OneHot_Charge(),
    'one-hot':OneHot(),
    }


fixed_split_lookup = {
    'fixed_1':10**1,
    'fixed_2':10**2,
    'fixed_3':10**3,
    'fixed_4':10**4,
    'fixed_5':10**5,
}

if __name__=='__main__':
    data_tuple = lj13_dataloaders()
    for val in data_tuple:
        print(val)