"""
QM9 (Quantum Mechanics 9)

A collection of molecules with up to nine heavy atoms (C, O, N, S)
used as a benchmark dataset for molecular property prediction and
graph-classification tasks.

This file is a loader for variations of the dataset.

"""
from typing import Optional

import numpy as np
import torch

from torch_geometric.loader import DataLoader
from torch_geometric.data import InMemoryDataset
from torch_geometric.datasets import QM9
import torch_geometric.transforms as T
from torch_geometric.utils import remove_self_loops
from torch_scatter import scatter

from e3nn.o3 import Irreps, spherical_harmonics #TODO: In House

#-----------------------------------------------------------------------------------------------------------------------------------------------------
# Split and Load
#-----------------------------------------------------------------------------------------------------------------------------------------------------

"""
Dataset Loading
~~~~~~~~~~~~~~~

"""

def qm9_gen_dataloaders(
    filter : bool = False,
    featurization : str = 'dict',
    adjacency : str = 'bond',
    radius : Optional[int] = None,
    o3_attr : bool = False,
    lmax_attr : Optional[int] = None,
    split : str = 'random_100/18/13',
    batch_size : int = 128,
):

    assert(featurization in ['atomic-number','dict','one-hot']), f'Featurization not recognized: {featurization}'
    assert(adjacency in ['bond','radial','full']), f'Adjacency not recognized: {adjacency}'
    assert(split in [ 'random_18/01/61', 'random_76/14/10', 'random_100/18/13', 'fixed']), f'Split not recognized: {split}'

    transform = [t_feats[featurization]]
    if adjacency=='radial': transform.append(T.RadiusGraph(r=radius))
    elif adjacency=='full': transform.append(Complete())
    if o3_attr: transform.append(O3Attr(lmax_attr))

    if filter:
        dataset = QM9Positional(root='/root/workspace/data/qm9-pos',
            pre_transform = T.Compose(transform)
        )
    else:
        dataset = QM9(root='/root/workspace/data/qm9',
            transform = T.Compose(transform)
        )

    #splits
    if split=='fixed':
        if filter:
            train_dataset = dataset.get(dataset.processed_file_names[0])
            val_dataset = dataset.get(dataset.processed_file_names[1])
            test_dataset = dataset.get(dataset.processed_file_names[2])
        else:
            train_dataset, val_dataset, test_dataset = _qm9_enf_splits(dataset)
    else:
        train_dataset, val_dataset, test_dataset = random_splits(dataset,split) if split in random_portions else None

    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)
    test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

    return dataset, train_dataset, val_dataset, test_dataset, train_loader, val_loader, test_loader

"""
Data Splits
~~~~~~~~~~~~~~~

    Split the QM9 dataset containing 130,831 molecules into
    training, validation and test sets:

    `random_splits`: This follows the splits of EGNN and Cormorant
     - trainset <- portions either 110_000, 100_000 or 50_000 molecules
     - testset <- 0.1 * (QM9 total number of molecules)
     - valset <- remaining molecules
    
"""
def random_splits(dataset, split):
    # Define the splits
    num_molecules = 130_831
    # num_molecules = 2_831
    num_train_molecules, num_val_molecules = random_portions[split]
    if num_val_molecules is None:
        num_test_molecules = num_molecules // 10
        num_val_molecules = num_molecules - num_train_molecules - num_test_molecules
    elif num_val_molecules == 18_000:
        num_test_molecules = 13_000
    else:
        num_test_molecules = num_molecules - num_train_molecules - num_val_molecules


    # Randomize the order of the data
    rng = np.random.RandomState(seed=0)  # EGNN uses this seed
    permutation = rng.permutation(num_molecules)
    train_idxs, val_idxs, test_idxs = np.split(
        permutation,
        indices_or_sections=(num_train_molecules,
                             num_train_molecules + num_val_molecules)
    )
    # print(f'tr{num_train_molecules/num_molecules} val{num_val_molecules/num_molecules} ts{num_test_molecules/num_molecules}')
    return dataset[train_idxs], dataset[val_idxs], dataset[test_idxs]


def _qm9_enf_splits(dataset: QM9):
    """
    """
    # Define the splits
    num_molecules = 130_831
    num_train_molecules = 100_000
    # num_molecules = 2_831
    # num_train_molecules = 1_000
    num_test_molecules = num_molecules // 10
    num_val_molecules = num_molecules - num_train_molecules - num_test_molecules

    # Randomize the order of the data
    rng = np.random.RandomState(seed=0)  # EGNN uses this seed
    permutation = rng.permutation(num_molecules)
    train_idxs, val_idxs, test_idxs = np.split(
        permutation,
        indices_or_sections=(num_train_molecules,
                             num_train_molecules + num_val_molecules)
    )

    # These are the expected final counts
    # assert train_idxs.size == 100_000
    # assert val_idxs.size == 17_748
    # assert test_idxs.size == 13_083

    dataset_train = dataset[train_idxs]
    dataset_val = dataset[val_idxs]
    dataset_test = dataset[test_idxs]

    return dataset_train, dataset_val, dataset_test


#-----------------------------------------------------------------------------------------------------------------------------------------------------
# Dataset
#-----------------------------------------------------------------------------------------------------------------------------------------------------

class DatasetNoIO(InMemoryDataset):
    def __init__(self, dataset_name, data, slices, transform=None):
        super().__init__(None, transform)
        self.dataset_name = dataset_name
        self.data, self.slices = data, slices


class QM9Positional(InMemoryDataset):
    "Reduces QM9 to those containing 19 nodes."
    pos_dim = 3
    num_nodes = 19

    def __init__(self, root, transform=None, pre_transform=None, pre_filter=None):
        if pre_filter is not None:
            pre_filter = Compose([HomogeneousVertexCount(self.num_nodes), pre_filter])
        else:
            pre_filter = HomogeneousVertexCount(self.num_nodes)
        super().__init__(root, transform, pre_transform, pre_filter=pre_filter)

    @property
    def raw_file_names(self):
        prefix = f'data_v3_atoms{self.num_nodes}_'
        return [f'{prefix}{p}.pt' for p in ('train', 'val', 'test')]

    @property
    def processed_file_names(self):
        return self.raw_file_names

    def process(self):
        for pfn, data in zip(self.processed_file_names, _qm9_enf_splits(QM9(self.root))):
            data_list = [d for d in data if self.pre_filter(d)]

            if self.pre_transform is not None:
                data_list = [*map(self.pre_transform, data_list)]

            data, slices = self.collate(data_list)
            torch.save((data, slices), f'{self.processed_dir}/{pfn}')

    def get(self, processed_file_name):
        return DatasetNoIO(self.__class__.__name__, *torch.load(f'{self.processed_dir}/{processed_file_name}'))


#-----------------------------------------------------------------------------------------------------------------------------------------------------
# Tranforms
#-----------------------------------------------------------------------------------------------------------------------------------------------------


"""
Filters 
~~~~~~~~

"""

class Filter:
    " Datset filter "
    def __call__(self, data):
        raise NotImplementedError

    def __repr__(self):
        return f'{self.__class__.__name__}()'


class Compose(Filter):
    def __init__(self, filters):
        self.filters = filters

    def __call__(self, data):
        return all(f(data) for f in self.filters)


class HomogeneousVertexCount(Filter):
    def __init__(self, vertex_count):
        self.vertex_count = vertex_count

    def __call__(self, data):
        return data.pos.size(0) == self.vertex_count


"""
Featurizations 
~~~~~~~~~~~~~~~

"""

class AtomicNumber(T.BaseTransform):
    " Retrieve Atomic Number "
    def __call__(self, data):
        data.x = data.x[:,5]
        return data

class OneHot(T.BaseTransform):
    " Base is one-hot "
    def __call__(self, data):
        return data


class OneHot_Charge(T.BaseTransform):
    def __call__(self, data):
        num_atom_types = 5
        one_hots = data.x[:, :num_atom_types]
        atomic_numbers = data.x[:, num_atom_types].view(-1,1)
        data_dict = {'categorical':one_hots, 'integer':atomic_numbers}
        data.x = data_dict
        return data


"""
Adjacency
~~~~~~~~~

"""

class Complete(T.BaseTransform):
    " Generate fully connected graph "
    def __call__(self, data):
        device = data.edge_index.device

        row = torch.arange(data.num_nodes, dtype=torch.long, device=device)
        col = torch.arange(data.num_nodes, dtype=torch.long, device=device)

        row = row.view(-1, 1).repeat(1, data.num_nodes).view(-1)
        col = col.repeat(data.num_nodes)
        edge_index = torch.stack([row, col], dim=0)

        edge_attr = None
        if data.edge_attr is not None:
            idx = data.edge_index[0] * data.num_nodes + data.edge_index[1]
            size = list(data.edge_attr.size())
            size[0] = data.num_nodes * data.num_nodes
            edge_attr = data.edge_attr.new_zeros(size)
            edge_attr[idx] = data.edge_attr

        edge_index, edge_attr = remove_self_loops(edge_index, edge_attr)
        data.edge_attr = edge_attr
        data.edge_index = edge_index

        return data

class O3Attr(T.BaseTransform):
    " Generate spherical harmonic edge attributes and node attributes "
    def __init__(self, lmax_attr):
        super().__init__()
        self.attr_irreps = Irreps.spherical_harmonics(lmax_attr)
    def __call__(self, data):
        """ Creates spherical harmonic edge attributes and node attributes for the SEGNN """
        edge_index = data.edge_index
        pos = data.pos
        rel_pos = pos[edge_index[0]] - pos[edge_index[1]]  # pos_j - pos_i (note in edge_index stores tuples like (j,i))
        edge_dist = rel_pos.pow(2).sum(-1, keepdims=True)
        edge_attr = spherical_harmonics(self.attr_irreps, rel_pos, normalize=True,
                                        normalization='component')  # Unnormalised for now
        node_attr = scatter(edge_attr, edge_index[1], dim=0, reduce="mean")
        data.edge_attr = edge_attr
        data.node_attr = node_attr
        data.edge_dist = edge_dist
        data.additional_message_features = edge_dist
        return data

#-----------------------------------------------------------------------------------------------------------------------------------------------------
# Look-Up Tables
#-----------------------------------------------------------------------------------------------------------------------------------------------------

"""
Function Dictionary 
~~~~~~~~~~~~~~~~~~~

.. caution::
    Initailizes one instance of each class on file import

"""

t_feats = {
    'atomic-number':AtomicNumber(),
    'dict':OneHot_Charge(),
    'one-hot':OneHot(),
    }

"""
Conversions
~~~~~~~~~~~

"""

HAR2EV = 27.211386246
KCALMOL2EV = 0.04336414

conversion = torch.tensor([
    1., 1., HAR2EV, HAR2EV, HAR2EV, 1., HAR2EV, HAR2EV, HAR2EV, HAR2EV, HAR2EV,
    1., KCALMOL2EV, KCALMOL2EV, KCALMOL2EV, KCALMOL2EV, 1., 1., 1.
])

atomrefs = {
    6: [0., 0., 0., 0., 0.],
    7: [
        -13.61312172, -1029.86312267, -1485.30251237, -2042.61123593,
        -2713.48485589
    ],
    8: [
        -13.5745904, -1029.82456413, -1485.26398105, -2042.5727046,
        -2713.44632457
    ],
    9: [
        -13.54887564, -1029.79887659, -1485.2382935, -2042.54701705,
        -2713.42063702
    ],
    10: [
        -13.90303183, -1030.25891228, -1485.71166277, -2043.01812778,
        -2713.88796536
    ],
    11: [0., 0., 0., 0., 0.],
}

targets = ['mu', 'alpha', 'homo', 'lumo', 'gap', 'r2', 'zpve', 'U0',
           'U', 'H', 'G', 'Cv', 'U0_atom', 'U_atom', 'H_atom', 'G_atom', 'A', 'B', 'C']

thermo_targets = ['U', 'U0', 'H', 'G']

atomic_number_to_valence = (
    (1, 1),  # H
    (6, 4),  # C
    (7, 5),  # N
    (8, 6),  # O
    (9, 7)   # F
)

random_portions = {
    'random_18/01/61':(20_000,10_000),
    'random_38/01/61':(50_000,1_000),
    'random_76/01/23':(100_000,1_000),
    'random_76/14/10':(100_000, None),
    'random_100/18/13':(100_000, 18_000),
    'random_84/08/08':(110_000, 10_000),
    'random_84/01/15':(110_000,1_000)
}

#-----------------------------------------------------------------------------------------------------------------------------------------------------
# Main
#-----------------------------------------------------------------------------------------------------------------------------------------------------
    
if __name__ == "__main__":
    _, train_dset, val_dset, test_dset, train_dl, _, _ = qm9_gen_dataloaders()
    print(train_dset)
    print(val_dset)
    print(test_dset)