"""
QM9 (Quantum Mechanics 9)

A collection of molecules with up to nine heavy atoms (C, O, N, S)
used as a benchmark dataset for molecular property prediction and
graph-classification tasks.

This file is a loader for variations of the dataset.

"""
from typing import Optional

from rdkit import Chem
import numpy as np
import torch
from tqdm import tqdm

from e3nn.o3 import Irreps, spherical_harmonics
from pointgroup import PointGroup

from torch_geometric.data import InMemoryDataset
from torch_geometric.datasets import QM9
from torch_geometric.loader import DataLoader
import torch_geometric.transforms as T
from torch_geometric.utils import degree

from torch_scatter import scatter

from torch_canon.E3Global.CategoricalPointCloud import CatFrame as Frame

point_groups = {'C1':0, 'C1h':1, 'C1v':2, 'C2':3, 'C2d':4, 'C2h':5, 'C2v':6, 'C3':7, 'C3h':8, 'C3v':9, 'C4':10, 'Cs':11, 'Cinfv':12, 'Ci':13, 'Dinfh':14, 'D2':15, 'D2d':16, 'D2h':17, 'D3':18, 'D3d':19, 'D3h':20, 'D6h':21, 'Oh':22, 'Td':23, 'S2':24, 'S4':25}

targets = ['mu', 'alpha', 'homo', 'lumo', 'gap', 'r2', 'zpve', 'U0',
                      'U', 'H', 'G', 'Cv', 'U0_atom', 'U_atom', 'H_atom', 'G_atom', 'A', 'B', 'C']

atomic_number_to_symbol = {1: 'H', 6: 'C', 7: 'N', 8: 'O', 9: 'F'}
atomic_symbol_to_number = {val:key for key,val in atomic_number_to_symbol.items()}

portions = {
    'fixed':(100_000,1_000),
    }
"""
================
Generate Loaders
================
"""


def qm9_loaders(loader_cfg: dict) -> dict:

    featurization = 'atomic-nubmer'
    adjacency = 'bond'
    split = 'fixed'
    batch_size = loader_cfg['batch_size']
    pre_transform = []
    transform = []

    pre_transform.append(Degree())
    pre_transform.append(AddPG())

    if 'align' in loader_cfg:
        align_cfg = loader_cfg['align']
        pre_transform.append(Align(tol=align_cfg['tol']))
        transform.append(GraphToMol())
        transform.append(GetOrbitals())
        #transform.append(O3Attr(lmax_attr=3))
        #pth = './data/qm9'
        pth = './data/qm9_align'
    else:
        pth = './data/qm9'

    if loader_cfg['target']=='pointgroup':
        transform.append(GetPG())
        target_index = None
    elif loader_cfg['target']=='molecular_orbital':
        exit()
    else:
        transform.append(GetTarget(loader_cfg['target']))
        target_index = targets.index(loader_cfg['target'])

    #transform.append(AtomicNumber())
    #transform.append(PosPlusAtomicNumber())
    #transform.append(IntegerEdgeFeatures())
    #transform.append(T.RadiusGraph(r=10.0))
    
    dataset = QM9(root=pth, pre_transform=T.Compose(pre_transform), transform=T.Compose(transform))

    if loader_cfg['target']=='pointgroup':
        split = 'random_7/1/2'
        lst_dataset = []
        for i,data in tqdm(enumerate(dataset), total=len(dataset)):
            if (data.y != 0 or i<1000):
                lst_dataset.append(data)
        dataset = FilteredQM9('',lst_dataset)


    degrees = []
    for i,data in tqdm(enumerate(dataset), total=len(dataset)):
        degrees.append(data.degrees.tolist()[0])
    degrees_hist = torch.from_numpy(np.histogram(degrees, bins=range(10))[0])


    if target_index is None:
        mean = 0
        std = 1
    else:
        mean = dataset.mean(target_index)
        std = dataset.std(target_index)

    train_dataset, val_dataset, test_dataset = fixed_splits(dataset,split)

    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)
    test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

    return {'train': train_loader, 'val': val_loader, 'test': test_loader, 'mean': mean, 'std': std, 'degrees_hist': degrees_hist}

"""
===========
Data Splits
===========

    Split the QM9 dataset containing 130,831 molecules into
    training, validation and test sets:

    `fixed_splits`: This follows the splits of EGNN and Cormorant
     - trainset <- portions either 110_000, 100_000 or 50_000 molecules
     - testset <- 0.1 * (QM9 total number of molecules)
     - valset <- remaining molecules
    
"""
def fixed_splits(dataset, split):
    if split == 'random_7/1/2':
        shuffle = np.random.permutation(len(dataset))
        dataset = dataset[shuffle]
        return dataset[:1000], dataset[1000:2000], dataset[2000:]
    # Define the splits
    num_molecules = 130_831
    num_train_molecules, num_val_molecules = portions[split]
    if num_val_molecules is None:
        num_test_molecules = num_molecules // 10
        num_val_molecules = num_molecules - num_train_molecules - num_test_molecules
    elif num_val_molecules == 18_000:
        num_test_molecules = 13_000
    else:
        num_test_molecules = num_molecules - num_train_molecules - num_val_molecules


    # Randomize the order of the data
    rng = np.random.RandomState(seed=0)  # EGNN uses this seed
    permutation = rng.permutation(num_molecules)
    train_idxs, val_idxs, test_idxs = np.split(
        permutation,
        indices_or_sections=(num_train_molecules,
                             num_train_molecules + num_val_molecules)
    )
    # print(f'tr{num_train_molecules/num_molecules} val{num_val_molecules/num_molecules} ts{num_test_molecules/num_molecules}')
    return dataset[train_idxs], dataset[val_idxs], dataset[test_idxs]

"""
Transformations 
~~~~~~~~~~~~~~~

"""

class Degree(T.BaseTransform):
    def __call__(self, data):
        data.degrees = degree(data.edge_index[0], num_nodes=data.x.shape[0])
        return data

class Align(T.BaseTransform):
    def __init__(self, tol: Optional[float] = 1e-3):
        self.tol = tol

    def __call__(self, data):
        frame = Frame(tol=self.tol, save=True)
        align_pos, frame_R, frame_t = frame.get_frame(data.pos, data.z)
        data.align_pos = torch.from_numpy(align_pos)
        data.frame_R = frame_R
        data.frame_t = frame_t
        symmetric_elements = frame.symmetric_elements
        symmetric_edge_index = [[i,j] for symmetry_element in symmetric_elements for i in symmetry_element for j in symmetry_element]
        projection_edge_index = [[symmetry_element[0],symmetry_element[j]] for symmetry_element in symmetric_elements for j in range(1,len(symmetry_element))]
        data.project_edge_index = torch.tensor(projection_edge_index, dtype=torch.long).T
        data.symmetric_edge_index = torch.tensor(symmetric_edge_index, dtype=torch.long).T
        asu = frame.simple_asu
        asu_edge_index = [[i,j] for i in asu for j in asu]
        data.asu_edge_index = torch.tensor(asu_edge_index, dtype=torch.long).T
        return data

class AtomicNumber(T.BaseTransform):
    " Retrieve Atomic Number "
    def __call__(self, data):
        data.x = data.x[:,5].to(torch.long).view(-1,1)

class GetTarget(T.BaseTransform):
    def __init__(self, target: str):
        self.index = targets.index(target)

    def __call__(self, data):
        data.y = data.y[:,self.index]
        return data

class IntegerEdgeFeatures(T.BaseTransform):
    def __call__(self, data):
        data.edge_attr = data.edge_attr[:, 0].to(torch.long)
        return data

class GraphToMol(T.BaseTransform):
    def __call__(self, data):
        mol = Chem.RWMol()

        # Add atoms to the molecule using data.z (atomic numbers)
        for atomic_num in data.z:
          atom = Chem.Atom(int(atomic_num.item()))  # Convert to RDKit atom
          mol.AddAtom(atom)

        # Add bond information based on distance thresholds or predefined bond data
        # Example: adding bonds based on distance threshold (simple nearest neighbor)
        threshold = 1.6  # Threshold distance for bond formation

        for i in range(len(data.pos)):
          for j in range(i + 1, len(data.pos)):
              dist = torch.norm(data.pos[i] - data.pos[j]).item()
              if dist < threshold:
                  mol.AddBond(i, j, Chem.BondType.SINGLE)  # Add single bond for simplicity

        # Convert to a Mol object
        data.mol = mol.GetMol()
        return data

class AddPG(T.BaseTransform):
    def __call__(self, data):
        try:
            symbols = [atomic_number_to_symbol[atomic_num.item()] for atomic_num in data.z]
            pg = PointGroup(data.pos, symbols).get_point_group()
            data.pg = pg
        except:
            data.pg = 'C1'
        return data

class GetPG(T.BaseTransform):
    def __call__(self, data):
        data.y = torch.tensor([point_groups[data.pg]],dtype=torch.long)
        return data
    

class FilteredQM9(InMemoryDataset):
    def __init__(self, root, data_list, transform=None, pre_transform=None):
        self.data_list = data_list
        super(FilteredQM9, self).__init__(root, transform, pre_transform)
        self.data, self.slices = self.collate(self.data_list)  # Collate the list into a usable format

    def len(self):
        return len(self.data_list)

    def get(self, idx):
        return self.data_list[idx]  # This allows you to access the data points as you would in the original QM9 dataset

class GetOrbitals(T.BaseTransform):
    def __call__(self, data):
        molecule = data.mol
        # Iterate over all atoms to get atomic orbital information
        orbitals = []
        for atom in molecule.GetAtoms():
            atom_index = atom.GetIdx()
            atom_symbol = atom.GetSymbol()
            hybridization = atom.GetHybridization().real
            assert data.z[atom_index] == atomic_symbol_to_number[atom_symbol], f"Atomic number mismatch: {data.z[atom_index]} vs {atomic_symbol_to_number[atom_symbol]}"
            orbitals.append(hybridization)
        data.x = torch.tensor(orbitals, dtype=torch.long)
        return data

class O3Attr(T.BaseTransform):
    " Generate spherical harmonic edge attributes and node attributes "
    def __init__(self, lmax_attr):
        super().__init__()
        self.attr_irreps = Irreps.spherical_harmonics(lmax_attr)
    def __call__(self, data):
        """ Creates spherical harmonic edge attributes and node attributes for the SEGNN """
        edge_index = data.edge_index
        pos = data.x[:,1:] #aligned pos
        rel_pos = pos[edge_index[0]] - pos[edge_index[1]]  # pos_j - pos_i (note in edge_index stores tuples like (j,i))
        edge_dist = rel_pos.pow(2).sum(-1, keepdims=True)
        edge_attr = spherical_harmonics(self.attr_irreps, rel_pos, normalize=True,
                                        normalization='component')  # Unnormalised for now
        node_attr = scatter(edge_attr, edge_index[1], dim=0, reduce="mean")
        data.edge_attr = edge_attr
        data.node_attr = node_attr
        data.edge_dist = edge_dist
        return data
