"""
QM9 (Quantum Mechanics 9)

A collection of molecules with up to nine heavy atoms (C, O, N, S)
used as a benchmark dataset for molecular property prediction and
graph-classification tasks.

This file is a loader for variations of the dataset.

"""

from typing import Optional

import numpy as np
import torch

from torch_scatter import scatter
from torch_geometric.loader import DataLoader
from torch_geometric.datasets import QM9
import torch_geometric.transforms as T
from torch_geometric.utils import remove_self_loops

from e3nn.o3 import Irreps, spherical_harmonics #TODO: In House


#-----------------------------------------------------------------------------------------------------------------------------------------------------
# Split and Load
#-----------------------------------------------------------------------------------------------------------------------------------------------------

"""
Dataset Loading
~~~~~~~~~~~~~~~

    `target`: (str) label target for training
        + `mu`: Dipole moment
        + `alpha`: Isotropic polarizability
        + `homo`: Highest occupied molecular orbital energy 
        + `lumo`: Lowest unoccupied molecular orbital energy 
        + `gap`: Gap between homo and lumo
        + `r2`: Electronic spatial extent
        + `zpve`: Zero point vibrational energy
        + `U0`: Internal energy at 0K
        + `U`: Internal energy at 298.15K
        + `H`: Enthalpy at 298.15K
        + `G`: Free energy at 298.15K
        + `cv`: Heat capavity at 298.15K
        + `U0_atom`: Atomization energy at 0K
        + `U_atom`: Atomization energy at 298.15K
        + `H_atom`: Atomization enthalpy at 298.15K
        + `G_atom`: Atomization free energy at 298.15K
        + `A`: Rotational constant
        + `B`: Rotational constant
        + `C`: Rotational constant
    `featurization`: (str) encoding of atomic number information
        + atomic-number
        + one-hot
        + cormorant
        + valence
    `all_features`: (bool) allow use of additional features
    `adjacency`: (str) encoding of edge indices and/or edge weights
        + `base`: edge indices filled if there is an existing bond
        + `radial`: edge indices filled if atoms are within radial cutoff
        + `full`: fully connected edge indices for each molecule
    `o3_attr`: (bool) allow use spherical harmonic edge featurization
    `lmax_attr`: (int) maximum for `o3_attr`
    `split`: (str) data splits for common ML papers
        + `random_38/01/61`: random split of 50_000 training, 1_000 validation and rest test
        + `random_76/14/10`: random split of 100_000 training, 10% test and the rest validation
        + `random_76/01/23`: random split of 100_000 training, 1_000 validation and rest test
        + `random_84/08/08`: random split of 110_000 training, 10_000 validation and rest test
        + `random_84/01/15`: random split of 110_000 training, 1_000 validation and rest test
    `batch_size`: (int) maximum batch size for graphs

    The molecular properties are also normalized using the mean and
    mean average deviation. These are kept in dataset.stats and can
    be used to recompose the desired features.

"""
def qm9_dataloaders(
    target : str = 'homo',
    featurization : str = 'one-hot',
    all_features : bool = True,
    adjacency : str = 'bond',
    radius : Optional[int] = None,
    o3_attr : bool = False,
    lmax_attr : Optional[int] = None,
    split : str = 'random_76/14/10',
    batch_size : int = 128,
):

    assert(target in targets), f'Property not recognized: {target}'
    assert(featurization in ['atomic-number','dict','one-hot','cormorant','valence']), f'Featurization not recognized: {featurization}'
    assert(featurization != 'atomic-number' or not all_features), f'Featurization is incompatible: {featurization, all_features}'
    assert(featurization != 'dict' or all_features), f'Featurization is incompatible: {featurization, all_features}'
    assert(adjacency in ['bond','radial','full']), f'Adjacency not recognized: {adjacency}'
    assert((adjacency!='raidal') or (radius is not None)),f'Radial adjacency and radius do not match {adjacency,radius}'
    assert((not o3_attr) or (lmax_attr is not None)),f'O3 attributes and lmax do not match{o3_attr,lmax_attr}'
    assert(split in ['random_38/01/61','random_76/01/23', 'random_76/14/10', 'random_84/08/08', 'random_84/01/15']), f'Split not recognized: {split}'

    transform = [Target(target)]
    transform.append(t_feats[featurization])
    if not all_features and featurization != 'atomic-number': transform.append(ReduceFeatures())
    if adjacency=='radial': transform.append(T.RadiusGraph(r=radius))
    elif adjacency=='full': transform.append(Complete())
    if o3_attr: transform.append(O3Attr(lmax_attr))

    dataset = QM9(root='/root/workspace/data/qm9',
        transform = T.Compose(transform)
    )

    train_dataset, val_dataset, test_dataset = random_splits(dataset,split) if split in random_portions else None

    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)
    test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

    train_loader, val_loader = compute_mean_mad(train_loader, val_loader, target)

    return dataset, train_dataset, val_dataset, test_dataset, train_loader, val_loader, test_loader
    # return train_loader, val_loader, test_loader


"""
Data Splits
~~~~~~~~~~~~~~~

    Split the QM9 dataset containing 130,831 molecules into
    training, validation and test sets:

    `random_splits`: This follows the splits of EGNN and Cormorant
     - trainset <- portions either 110_000, 100_000 or 50_000 molecules
     - testset <- 0.1 * (QM9 total number of molecules)
     - valset <- remaining molecules
    
"""
def random_splits(dataset, split):
    # Define the splits
    num_molecules = 130_831
    num_train_molecules, num_val_molecules = random_portions[split]
    if num_val_molecules is None:
        num_test_molecules = num_molecules // 10
        num_val_molecules = num_molecules - num_train_molecules - num_test_molecules
    else:
        num_test_molecules = num_molecules - num_train_molecules - num_val_molecules


    # Randomize the order of the data
    rng = np.random.RandomState(seed=0)  # EGNN uses this seed
    permutation = rng.permutation(num_molecules)
    train_idxs, val_idxs, test_idxs = np.split(
        permutation,
        indices_or_sections=(num_train_molecules,
                             num_train_molecules + num_val_molecules)
    )
    # print(f'tr{num_train_molecules/num_molecules} val{num_val_molecules/num_molecules} ts{num_test_molecules/num_molecules}')
    return dataset[train_idxs], dataset[val_idxs], dataset[test_idxs]



#-----------------------------------------------------------------------------------------------------------------------------------------------------
# Transforms
#-----------------------------------------------------------------------------------------------------------------------------------------------------

class Target(T.BaseTransform):
    " Target molecular property for labels "
    def __init__(self, target):
        self.target = target
        self.target_idx = targets.index(target)
    def __call__(self, data):
        data.y = data.y[0, self.target_idx]
        return data

"""
Featurizations 
~~~~~~~~~~~~~~~

"""

class AtomicNumber(T.BaseTransform):
    " Retrieve Atomic Number "
    def __call__(self, data):
        data.x = data.x[:,5]
        return data

class OneHot(T.BaseTransform):
    " Base is one-hot "
    def __call__(self, data):
        return data

class CormorantFeatures(T.BaseTransform):
    " Featurization as described in section 7.3 of https://arxiv.org/pdf/1906.04015.pdf "
    def __call__(self, data):
        num_atom_types = 5
        one_hots = data.x[:, :num_atom_types]
        atomic_numbers = data.x[:, num_atom_types]    
        data.x = self.get_cormorant_features(one_hots, atomic_numbers, 2, 9) #torch.max(atomic_numbers))
        return data
    def get_cormorant_features(self, one_hot, charges, charge_power, charge_scale):
        charge_tensor = (charges.unsqueeze(-1) / charge_scale).pow(
            torch.arange(charge_power + 1., dtype=torch.float32)
        )
        charge_tensor = charge_tensor.view(charges.shape + (1, charge_power + 1))
        atom_scalars = (one_hot.unsqueeze(-1) * charge_tensor).view(charges.shape[:2] + (-1,))
        return atom_scalars

class OneHot_Cormorant(T.BaseTransform):
    def __call__(self, data):
        num_atom_types = 5
        one_hots = data.x[:, :num_atom_types]
        atomic_numbers = data.x[:, num_atom_types]    
        data_dict = {'categorical':None, 'integer':data.x[:,num_atom_types].view(-1,1)}
        data_dict['categorical'] = self.get_cormorant_features(one_hots, atomic_numbers, 2, torch.max(atomic_numbers))[:,-5:]
        data.x = data_dict
        return data
    def get_cormorant_features(self, one_hot, charges, charge_power, charge_scale):
        charge_tensor = (charges.unsqueeze(-1) / charge_scale).pow(
            torch.arange(charge_power + 1., dtype=torch.float32)
        )
        charge_tensor = charge_tensor.view(charges.shape + (1, charge_power + 1))
        atom_scalars = (one_hot.unsqueeze(-1) * charge_tensor).view(charges.shape[:2] + (-1,))
        return atom_scalars

class ValenceFeatures(T.BaseTransform):
    " Featurization based on valence electrons "
    def __call__(self, data):
        num_atom_types = 5
        one_hots = data.x[:, :num_atom_types]
        atomic_numbers = data.x[:, num_atom_types]    
        valence_electrons = torch.clone(atomic_numbers)
        for a, n in atomic_number_to_valence:
            valence_electrons[atomic_numbers == int(a)] = n

        data.x = self.get_cormorant_features(one_hots, valence_electrons, 2, torch.max(valence_electrons))
        return data

    def get_cormorant_features(self, one_hot, charges, charge_power, charge_scale):
        charge_tensor = (charges.unsqueeze(-1) / charge_scale).pow(
            torch.arange(charge_power + 1., dtype=torch.float32)
        )
        charge_tensor = charge_tensor.view(charges.shape + (1, charge_power + 1))
        atom_scalars = (one_hot.unsqueeze(-1) * charge_tensor).view(charges.shape[:2] + (-1,))
        return atom_scalars

class ReduceFeatures(T.BaseTransform):
    " Reduce to only atomic features "
    def __call__(self, data):
        data.x = data.x[:, :-6]
        return data


"""
Adjacency
~~~~~~~~~

"""

class Complete(T.BaseTransform):
    " Generate fully connected graph "
    def __call__(self, data):
        device = data.edge_index.device

        row = torch.arange(data.num_nodes, dtype=torch.long, device=device)
        col = torch.arange(data.num_nodes, dtype=torch.long, device=device)

        row = row.view(-1, 1).repeat(1, data.num_nodes).view(-1)
        col = col.repeat(data.num_nodes)
        edge_index = torch.stack([row, col], dim=0)

        edge_attr = None
        if data.edge_attr is not None:
            idx = data.edge_index[0] * data.num_nodes + data.edge_index[1]
            size = list(data.edge_attr.size())
            size[0] = data.num_nodes * data.num_nodes
            edge_attr = data.edge_attr.new_zeros(size)
            edge_attr[idx] = data.edge_attr

        edge_index, edge_attr = remove_self_loops(edge_index, edge_attr)
        data.edge_attr = edge_attr
        data.edge_index = edge_index

        return data

class O3Attr(T.BaseTransform):
    " Generate spherical harmonic edge attributes and node attributes "
    def __init__(self, lmax_attr):
        super().__init__()
        self.attr_irreps = Irreps.spherical_harmonics(lmax_attr)
    def __call__(self, data):
        """ Creates spherical harmonic edge attributes and node attributes for the SEGNN """
        edge_index = data.edge_index
        pos = data.pos
        rel_pos = pos[edge_index[0]] - pos[edge_index[1]]  # pos_j - pos_i (note in edge_index stores tuples like (j,i))
        edge_dist = rel_pos.pow(2).sum(-1, keepdims=True)
        edge_attr = spherical_harmonics(self.attr_irreps, rel_pos, normalize=True,
                                        normalization='component')  # Unnormalised for now
        node_attr = scatter(edge_attr, edge_index[1], dim=0, reduce="mean")
        data.edge_attr = edge_attr
        data.node_attr = node_attr
        data.edge_dist = edge_dist
        data.additional_message_features = edge_dist
        return data


#-----------------------------------------------------------------------------------------------------------------------------------------------------
# DataLoader Statistics
#-----------------------------------------------------------------------------------------------------------------------------------------------------

def compute_mean_mad(train_loader, val_loader, target):
    target_idx = targets.index(target)
    values = train_loader.dataset.data['y'][:,target_idx]
    values = torch.hstack([values,val_loader.dataset.data['y'][:,target_idx]])
    meann = torch.mean(values)
    ma = torch.abs(values - meann)
    mad = torch.mean(ma)

    train_loader.dataset.data['meann'] = meann
    train_loader.dataset.data['mad'] = mad
    val_loader.dataset.data['meann'] = meann
    val_loader.dataset.data['mad'] = mad
    return train_loader, val_loader


#-----------------------------------------------------------------------------------------------------------------------------------------------------
# Look-Up Tables
#-----------------------------------------------------------------------------------------------------------------------------------------------------

"""
Function Dictionary 
~~~~~~~~~~~~~~~~~~~

.. caution::
    Initailizes one instance of each class on file import

"""

t_feats = {
    'atomic-number':AtomicNumber(),
    'dict':OneHot_Cormorant(),
    'one-hot':OneHot(),
    'cormorant':CormorantFeatures(),
    'valence':ValenceFeatures()
    }

def baseline_kwargs():
    dicts = { 
        'egnn': {
            'featurization' : 'cormorant',
            'all_features' : True,
            'adjacency' : 'full',
            'radius' : None,
            'o3_attr' : False,
            'lmax_attr' :  None,
            'split' : 'random_76/14/10',
        },
        'schnet' : {
            'featurization' : 'one-hot',
            'all_features' : False,
            'adjacency' : 'bond',
            'radius' : None,
            'o3_attr' : False,
            'lmax_attr' :  None,
        }
    }

"""
Conversions
~~~~~~~~~~~

"""

HAR2EV = 27.211386246
KCALMOL2EV = 0.04336414

conversion = torch.tensor([
    1., 1., HAR2EV, HAR2EV, HAR2EV, 1., HAR2EV, HAR2EV, HAR2EV, HAR2EV, HAR2EV,
    1., KCALMOL2EV, KCALMOL2EV, KCALMOL2EV, KCALMOL2EV, 1., 1., 1.
])

atomrefs = {
    6: [0., 0., 0., 0., 0.],
    7: [
        -13.61312172, -1029.86312267, -1485.30251237, -2042.61123593,
        -2713.48485589
    ],
    8: [
        -13.5745904, -1029.82456413, -1485.26398105, -2042.5727046,
        -2713.44632457
    ],
    9: [
        -13.54887564, -1029.79887659, -1485.2382935, -2042.54701705,
        -2713.42063702
    ],
    10: [
        -13.90303183, -1030.25891228, -1485.71166277, -2043.01812778,
        -2713.88796536
    ],
    11: [0., 0., 0., 0., 0.],
}

targets = ['mu', 'alpha', 'homo', 'lumo', 'gap', 'r2', 'zpve', 'U0',
           'U', 'H', 'G', 'Cv', 'U0_atom', 'U_atom', 'H_atom', 'G_atom', 'A', 'B', 'C']

thermo_targets = ['U', 'U0', 'H', 'G']

atomic_number_to_valence = (
    (1, 1),  # H
    (6, 4),  # C
    (7, 5),  # N
    (8, 6),  # O
    (9, 7)   # F
)

random_portions = {
    'random_38/01/61':(50_000,1_000),
    'random_76/01/23':(100_000,1_000),
    'random_76/14/10':(100_000, None),
    'random_84/08/08':(110_000, 10_000),
    'random_84/01/15':(110_000,1_000)
}

#-----------------------------------------------------------------------------------------------------------------------------------------------------
# Main
#-----------------------------------------------------------------------------------------------------------------------------------------------------
    
if __name__ == "__main__":
    _, train_dset, val_dset, test_dset, _, _, _ = qm9_dataloaders()
    print(train_dset)
    print(val_dset)
    print(test_dset)
