"""
QM9 (Quantum Mechanics 9)

A collection of molecules with up to nine heavy atoms (C, O, N, S)
used as a benchmark dataset for molecular property prediction and
graph-classification tasks.

This file is a loader for variations of the dataset.

"""
from typing import Optional

import numpy as np
import torch

from torch_geometric.data import Data
from torch_geometric.loader import DataLoader
from torch_geometric.utils import degree, to_undirected

from rdkit import Chem
from rdkit.Chem import AllChem
from rdkit.Geometry import Point3D

from torch_canon.E3Global.CategoricalPointCloud import CatFrame as Frame

from .bipartite_pair import BatchPos

"""
===============
Generate Splits
===============
"""

def kchain_loaders(loader_cfg: dict) -> dict:
    k = loader_cfg['k']
    batch_size = loader_cfg['batch_size']
    dataset = create_kchains(k=k)
    degrees = []
    if 'align' in loader_cfg:
        align_cfg = loader_cfg['align']
        frame = Frame(tol=align_cfg['tol'], save='all')
        for i,data in enumerate(dataset):
            align_pos, frame_R, frame_t = frame.get_frame(data.pos, data.x)
            data.align_pos = torch.from_numpy(align_pos)
            data.frame_R = frame_R
            data.frame_t = frame_t
            symmetric_elements = frame.symmetric_elements
            symmetric_edge_index = [[i,j] for symmetry_element in symmetric_elements for i in symmetry_element for j in symmetry_element]
            projection_edge_index = [[symmetry_element[0],symmetry_element[j]] for symmetry_element in symmetric_elements for j in range(1,len(symmetry_element))]
            data.project_edge_index = torch.tensor(projection_edge_index, dtype=torch.long).T
            data.symmetric_edge_index = torch.tensor(symmetric_edge_index, dtype=torch.long).T
            asu = frame.simple_asu
            asu_edge_index = [[i,j] for i in asu for j in asu]
            data.asu_edge_index = torch.tensor(asu_edge_index, dtype=torch.long).T
            degrees.append(degree(data.edge_index[0], num_nodes=data.x.shape[0]).tolist()[0])
    degrees_hist = torch.from_numpy(np.histogram(degrees, bins=range(10))[0])
    dataloader = DataLoader(dataset, batch_size=1, shuffle=True)
    val_loader = DataLoader(dataset, batch_size=batch_size, shuffle=False)
    test_loader = DataLoader(dataset, batch_size=batch_size, shuffle=False)

    return {'train': dataloader, 'val': val_loader, 'test': test_loader, 'degrees_hist': degrees_hist}

"""
================
Generate Dataset
================
"""

def create_kchains(k):
    assert k >= 2
    
    dataset = []

    # Graph 0
    atoms = torch.LongTensor( [0] + [0] + [0]*(k-1) + [0] )
    edge_index = torch.LongTensor( [ [i for i in range((k+2) - 1)], [i for i in range(1, k+2)] ] )
    pos = torch.FloatTensor(
        [[-4, -3, 0]] + 
        [[0, 5*i , 0] for i in range(k)] + 
        [[4, 5*(k-1) + 3, 0]]
    )
    center_of_mass = torch.mean(pos, dim=0)
    pos = pos - center_of_mass
    y = torch.LongTensor([0])  # Label 0
    data1 = Data(x=atoms, atoms=atoms, z=atoms, edge_index=edge_index, pos=pos, y=y, edge_attr=torch.zeros(edge_index.shape[1]).to(torch.int))
    data1.edge_index = to_undirected(data1.edge_index)
    data1.edge_attr = torch.zeros(data1.edge_index.shape[1]).to(torch.int)
    dataset.append(data1)
    
    # Graph 1
    atoms = torch.LongTensor( [0] + [0] + [0]*(k-1) + [0] )
    edge_index = torch.LongTensor( [ [i for i in range((k+2) - 1)], [i for i in range(1, k+2)] ] )
    pos = torch.FloatTensor(
        [[4, -3, 0]] + 
        [[0, 5*i , 0] for i in range(k)] + 
        [[4, 5*(k-1) + 3, 0]]
    )
    center_of_mass = torch.mean(pos, dim=0)
    pos = pos - center_of_mass
    y = torch.LongTensor([1])  # Label 1
    data2 = Data(x=atoms, atoms=atoms, z=atoms, edge_index=edge_index, pos=pos, y=y, edge_attr=torch.zeros(edge_index.shape[1]).to(torch.int))
    data2.edge_index = to_undirected(data2.edge_index)
    data2.edge_attr = torch.zeros(data2.edge_index.shape[1]).to(torch.int)
    dataset.append(data2)
    
    return dataset

def kchain_to_mol(data):
    smiles = 'C' * data.x.shape[0]
    mol = Chem.MolFromSmiles(smiles)
    mol = Chem.AddHs(mol)
    AllChem.EmbedMolecule(mol)
    # set the positions
    conf = mol.GetConformer()
    for i in range(data.pos.shape[0]):
        pos = Point3D(*data.pos[i].to(torch.double).numpy())
        conf.SetAtomPosition(i, pos)
    data.mol = mol
    return data

