"""
QM9 (Quantum Mechanics 9)

A collection of molecules with up to nine heavy atoms (C, O, N, S)
used as a benchmark dataset for molecular property prediction and
graph-classification tasks.

This file is a loader for variations of the dataset.

"""
from typing import Optional

import numpy as np
import torch

from torch_geometric.data import Data
from torch_geometric.loader import DataLoader
from torch_geometric.utils import degree, to_undirected

from rdkit import Chem
from rdkit.Chem import AllChem
from rdkit.Geometry import Point3D

from torch_canon.E3Global.CategoricalPointCloud import CatFrame as Frame

from .bipartite_pair import BatchPos

"""
===============
Generate Splits
===============
"""

def h2o_loaders(loader_cfg: dict) -> dict:
    orientation = loader_cfg['orientation']
    batch_size = loader_cfg['batch_size']
    dataset = create_h2o(orientation=orientation)

    degrees = []
    if 'align' in loader_cfg:
        align_cfg = loader_cfg['align']
        frame = Frame(tol=align_cfg['tol'], save='all')
        for i,data in enumerate(dataset):
            align_pos, frame_R, frame_t = frame.get_frame(data.pos, data.z)
            data.align_pos = torch.from_numpy(align_pos)
            data.frame_R = frame_R
            data.frame_t = frame_t
            symmetric_elements = frame.symmetric_elements
            symmetric_edge_index = [[i,j] for symmetry_element in symmetric_elements for i in symmetry_element for j in symmetry_element]
            projection_edge_index = [[symmetry_element[0],symmetry_element[j]] for symmetry_element in symmetric_elements for j in range(1,len(symmetry_element))]
            data.project_edge_index = torch.tensor(projection_edge_index, dtype=torch.long).T
            data.symmetric_edge_index = torch.tensor(symmetric_edge_index, dtype=torch.long).T
            asu = frame.simple_asu
            asu_edge_index = [[i,j] for i in asu for j in asu]
            data.asu_edge_index = torch.tensor(asu_edge_index, dtype=torch.long).T

            y = data.y
            y = y - y.mean(dim=0)
            y = y.unsqueeze(0).to(torch.float32)
            symmetries = torch.tensor(frame.symmetries).to(torch.float32)
            for symmetry in symmetries[1:]: # the first symmetry is identity
                z = y[0]@symmetry.T
                y = torch.cat((y, z.unsqueeze(0)), dim=0)
            data.y = y.unsqueeze(0).to(torch.float32)

            degrees.append(degree(data.edge_index[0], num_nodes=data.z.shape[0]).tolist()[0])

    degrees_hist = torch.from_numpy(np.histogram(degrees, bins=range(10))[0])
    dataloader = DataLoader(dataset, batch_size=1, shuffle=True)
    val_loader = DataLoader(dataset, batch_size=batch_size, shuffle=False)
    test_loader = DataLoader(dataset, batch_size=batch_size, shuffle=False)

    return {'train': dataloader, 'val': val_loader, 'test': test_loader, 'degrees_hist': degrees_hist}

"""
================
Generate Dataset
================
"""

def create_h2o(orientation):
    assert orientation in ['symmetric-stretch', 'asymmetric-stretch', 'bending-motion']
    
    dataset = []
    for _ in range(1):
        _, _, Q = torch.svd(torch.randn(3, 3))
        Q = torch.eye(3)
        b = torch.zeros(3)

        if orientation == 'symmetric-stretch':
            atomic_numbers = [8,1,1]
            atoms = torch.tensor(atomic_numbers,dtype=torch.long)

            pos = torch.tensor([[0,0.4,0],[-.78,-.24,0],[0.78,-0.24,0]])

            aos = torch.tensor([0,0.8,0],dtype=torch.float32)
            anti_aos = torch.tensor([-1.56,-.48,0],dtype=torch.float32)
            refl_anti_aos = torch.tensor([1.56,-.48,0],dtype=torch.float32)
            y = torch.cat((aos,anti_aos,refl_anti_aos),dim=0).view(3,3)

            pos = pos @ Q.T + b
            y = y @ Q.T + b

            edge_index = torch.tensor([[0,1],[1,0],[0,2],[2,0],[1,2],[2,1]],dtype=torch.long).T
            print(y)
            data1 = Data(z=atoms,edge_index=edge_index,y=y.unsqueeze(0),pos=pos)
            data1.edge_index = to_undirected(data1.edge_index)
            dataset.append(data1)

        elif orientation == 'asymmetric-stretch':
            atomic_numbers = [8,1,1]
            atoms = torch.tensor(atomic_numbers,dtype=torch.long)

            pos = torch.tensor([[0,0.4,0],[-.78,-.24,0],[0.78,-0.24,0]])

            aos = torch.tensor([0,0.8,0],dtype=torch.float32)
            anti_aos = torch.tensor([.78,-0.48,0],dtype=torch.float32)
            refl_anti_aos = torch.tensor([-.78,-0.48,0],dtype=torch.float32)
            y = torch.cat((aos,anti_aos,refl_anti_aos),dim=0).view(3,3)

            pos = pos @ Q.T + b
            y = y @ Q.T + b

            edge_index = torch.tensor([[0,1],[1,0],[0,2],[2,0],[1,2],[2,1]],dtype=torch.long).T
            print(y)
            data1 = Data(z=atoms,edge_index=edge_index,y=y.unsqueeze(0),pos=pos)
            data1.edge_index = to_undirected(data1.edge_index)
            dataset.append(data1)

        elif orientation == 'bending-motion':
            atomic_numbers = [8,1,1]
            atoms = torch.tensor(atomic_numbers,dtype=torch.long)

            pos = torch.tensor([[0,0.4,0],[-.78,-.24,0],[0.78,-0.24,0]])

            aos = torch.tensor([.78,0,0],dtype=torch.float32)
            anti_aos = torch.tensor([-1.56,-0.44,0],dtype=torch.float32)
            refl_anti_aos = torch.tensor([-.78,.24,0],dtype=torch.float32)
            y = torch.cat((aos,anti_aos,refl_anti_aos),dim=0).view(3,3)

            pos = pos @ Q.T + b
            y = y @ Q.T + b

            edge_index = torch.tensor([[0,1],[1,0],[0,2],[2,0],[1,2],[2,1]],dtype=torch.long).T
            print(y)
            data1 = Data(z=atoms,edge_index=edge_index,y=y.unsqueeze(0),pos=pos)
            data1.edge_index = to_undirected(data1.edge_index)
            dataset.append(data1)


    return dataset
