"""
MD17 (Quantum Mechanics 9)

A collection of molecules with up to nine heavy atoms (C, O, N, S)
used as a benchmark dataset for molecular property prediction and
graph-classification tasks.

This file is a loader for variations of the dataset.

"""
from typing import Optional

from rdkit import Chem
import numpy as np
import torch
from tqdm import tqdm

from e3nn.o3 import Irreps, spherical_harmonics
from pointgroup import PointGroup

from sklearn.utils import shuffle

from torch_geometric.data import InMemoryDataset
from torch_geometric.datasets import MD17
from torch_geometric.loader import DataLoader
import torch_geometric.transforms as T
from torch_geometric.utils import degree

from torch_scatter import scatter

from torch_canon.E3Global.CategoricalPointCloud import CatFrame as Frame

point_groups = ['C1', 'C1v',
                'C2', 'C2v', 'C2d', 'C2h',
                'C3', 'C3v', 'C3d', 'C3h',
                'C6', 'C6v', 'C6d', 'C6h',
                'Ci', 'Cs',
                'D2', 'D2d', 'D2h',
                'D3', 'D3d', 'D3h',
                'D6', 'D6d', 'D6h',
                'S2', 'S3', 'S6',
                ]
point_groups = {val:key for key,val in enumerate(point_groups)}

atomic_number_to_symbol = {1: 'H', 6: 'C', 7: 'N', 8: 'O', 9: 'F'}
atomic_symbol_to_number = {val:key for key,val in atomic_number_to_symbol.items()}

"""
================
Generate Loaders
================
"""


def md17_loaders(loader_cfg: dict) -> dict:

    featurization = 'atomic-nubmer'
    adjacency = 'bond'
    name = loader_cfg['name']
    split = 'random_76/14/10'
    batch_size = loader_cfg['batch_size']
    pre_transform = []
    transform = []

    pre_transform.append(T.RadiusGraph(10.0))
    pre_transform.append(Degree())
    pre_transform.append(AddPG())

    if 'align' in loader_cfg:
        align_cfg = loader_cfg['align']
        pre_transform.append(Align(tol=align_cfg['tol']))
        transform.append(GraphToMol())
        transform.append(GetOrbitals())
        #transform.append(O3Attr(lmax_attr=3))
        #pth = './data/md17'
        pth = './data/md17_align'
    else:
        pth = './data/md17'

    if loader_cfg['target']=='pointgroup':
        transform.append(GetPG())

    #transform.append(AtomicNumber())
    #transform.append(PosPlusAtomicNumber())
    #transform.append(IntegerEdgeFeatures())
    #transform.append(T.RadiusGraph(r=10.0))
    
    dataset = MD17(root=pth, name=name, pre_transform=T.Compose(pre_transform), transform=T.Compose(transform))

    mean = dataset.mean()
    std = 1.0

    degrees = []
    for i,data in tqdm(enumerate(dataset), total=len(dataset)):
        if i>10:
            exit()
        degrees.append(data.degrees.tolist()[0])
    degrees_hist = torch.from_numpy(np.histogram(degrees, bins=range(10))[0])
    
    split_dict = random_splits(len(dataset), 200, 1000, seed=42)

    train_dataset = dataset[split_dict['train']]
    val_dataset = dataset[split_dict['valid']]
    test_dataset = dataset[split_dict['test']]

    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)
    test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

    return {'train': train_loader, 'val': val_loader, 'test': test_loader, 'mean': mean, 'std': std, 'degrees_hist': degrees_hist}

"""
===========
Data Splits
===========

    From DIG:

    `random_splits`: This follows the splits of EGNN and Cormorant
     - trainset <- portions either 110_000, 100_000 or 50_000 molecules
     - testset <- 0.1 * (QM9 total number of molecules)
     - valset <- remaining molecules
    
"""
def random_splits(data_size, train_size, valid_size, seed):
    ids = shuffle(range(data_size), random_state=seed)
    train_idx, val_idx, test_idx = torch.tensor(ids[:train_size]), torch.tensor(ids[train_size:train_size + valid_size]), torch.tensor(ids[train_size + valid_size:])
    split_dict = {'train':train_idx, 'valid':val_idx, 'test':test_idx}
    return split_dict

"""
Transformations 
~~~~~~~~~~~~~~~

"""

class Degree(T.BaseTransform):
    def __call__(self, data):
        data.degrees = degree(data.edge_index[0], num_nodes=data.z.shape[0])
        return data

class Align(T.BaseTransform):
    def __init__(self, tol: Optional[float] = 1e-3):
        self.tol = tol

    def __call__(self, data):
        frame = Frame(tol=self.tol, save='all')
        align_pos, frame_R, frame_t = frame.get_frame(data.pos.numpy(), data.z.numpy())
        data.align_pos = torch.from_numpy(align_pos)
        data.frame_R = frame_R
        data.frame_t = frame_t
        symmetric_elements = frame.symmetric_elements
        symmetric_edge_index = [[i,j] for symmetry_element in symmetric_elements for i in symmetry_element for j in symmetry_element]
        projection_edge_index = [[symmetry_element[0],symmetry_element[j]] for symmetry_element in symmetric_elements for j in range(1,len(symmetry_element))]
        data.project_edge_index = torch.tensor(projection_edge_index, dtype=torch.long).T
        data.symmetric_edge_index = torch.tensor(symmetric_edge_index, dtype=torch.long).T
        asu = frame.simple_asu
        asu_edge_index = [[i,j] for i in asu for j in asu]
        data.asu_edge_index = torch.tensor(asu_edge_index, dtype=torch.long).T
        return data

class AtomicNumber(T.BaseTransform):
    " Retrieve Atomic Number "
    def __call__(self, data):
        data.x = data.x[:,5].to(torch.long).view(-1,1)

class GetTarget(T.BaseTransform):
    def __init__(self, target: str):
        self.index = targets.index(target)

    def __call__(self, data):
        data.y = data.y[:,self.index]
        return data

class IntegerEdgeFeatures(T.BaseTransform):
    def __call__(self, data):
        data.edge_attr = data.edge_attr[:, 0].to(torch.long)
        return data

class GraphToMol(T.BaseTransform):
    def __call__(self, data):
        mol = Chem.RWMol()

        # Add atoms to the molecule using data.z (atomic numbers)
        for atomic_num in data.z:
          atom = Chem.Atom(int(atomic_num.item()))  # Convert to RDKit atom
          mol.AddAtom(atom)

        # Add bond information based on distance thresholds or predefined bond data
        # Example: adding bonds based on distance threshold (simple nearest neighbor)
        threshold = 1.6  # Threshold distance for bond formation

        for i in range(len(data.pos)):
          for j in range(i + 1, len(data.pos)):
              dist = torch.norm(data.pos[i] - data.pos[j]).item()
              if dist < threshold:
                  mol.AddBond(i, j, Chem.BondType.SINGLE)  # Add single bond for simplicity

        # Convert to a Mol object
        data.mol = mol.GetMol()
        return data

class AddPG(T.BaseTransform):
    def __call__(self, data):
        try:
            symbols = [atomic_number_to_symbol[atomic_num.item()] for atomic_num in data.z]
            pg = PointGroup(data.pos, symbols).get_point_group()
            data.pg = pg
        except:
            data.pg = 'C1'
        return data

class GetPG(T.BaseTransform):
    def __call__(self, data):
        data.y = torch.tensor([point_groups[data.pg]],dtype=torch.long)
        return data
    

class FilteredQM9(InMemoryDataset):
    def __init__(self, root, data_list, transform=None, pre_transform=None):
        self.data_list = data_list
        super(FilteredQM9, self).__init__(root, transform, pre_transform)
        self.data, self.slices = self.collate(self.data_list)  # Collate the list into a usable format

    def len(self):
        return len(self.data_list)

    def get(self, idx):
        return self.data_list[idx]  # This allows you to access the data points as you would in the original QM9 dataset

class GetOrbitals(T.BaseTransform):
    def __call__(self, data):
        molecule = data.mol
        # Iterate over all atoms to get atomic orbital information
        orbitals = []
        for atom in molecule.GetAtoms():
            atom_index = atom.GetIdx()
            atom_symbol = atom.GetSymbol()
            hybridization = atom.GetHybridization().real
            assert data.z[atom_index] == atomic_symbol_to_number[atom_symbol], f"Atomic number mismatch: {data.z[atom_index]} vs {atomic_symbol_to_number[atom_symbol]}"
            orbitals.append(hybridization)
        data.x = torch.tensor(orbitals, dtype=torch.long)
        return data

class O3Attr(T.BaseTransform):
    " Generate spherical harmonic edge attributes and node attributes "
    def __init__(self, lmax_attr):
        super().__init__()
        self.attr_irreps = Irreps.spherical_harmonics(lmax_attr)
    def __call__(self, data):
        """ Creates spherical harmonic edge attributes and node attributes for the SEGNN """
        edge_index = data.edge_index
        pos = data.x[:,1:] #aligned pos
        rel_pos = pos[edge_index[0]] - pos[edge_index[1]]  # pos_j - pos_i (note in edge_index stores tuples like (j,i))
        edge_dist = rel_pos.pow(2).sum(-1, keepdims=True)
        edge_attr = spherical_harmonics(self.attr_irreps, rel_pos, normalize=True,
                                        normalization='component')  # Unnormalised for now
        node_attr = scatter(edge_attr, edge_index[1], dim=0, reduce="mean")
        data.edge_attr = edge_attr
        data.node_attr = node_attr
        data.edge_dist = edge_dist
        return data
