"""
QM9 (Quantum Mechanics 9)

A collection of molecules with up to nine heavy atoms (C, O, N, S)
used as a benchmark dataset for molecular property prediction and
graph-classification tasks.

This file is a loader for variations of the dataset.

"""
from typing import Optional
import math

import numpy as np
import torch

from torch_geometric.data import Data
from torch_geometric.loader import DataLoader
from torch_geometric.utils import degree, to_undirected

from rdkit import Chem
from rdkit.Chem import AllChem
from rdkit.Geometry import Point3D

from torch_canon.E3Global.CategoricalPointCloud import CatFrame as Frame

from .bipartite_pair import BatchPos

"""
===============
Generate Splits
===============
"""

def box_loaders(loader_cfg: dict) -> dict:
    batch_size = loader_cfg['batch_size']
    dataset = create_box()

    degrees = []
    frame = Frame(tol=0.01, save='all')
    for i,data in enumerate(dataset):
        data.pos = data.pos - data.pos.mean(dim=0)
        align_pos, frame_R, frame_t = frame.get_frame(data.pos, data.z)
        data.align_pos = torch.from_numpy(align_pos)
        print(i,data.align_pos)
        data.frame_R = frame_R
        data.frame_t = frame_t
        symmetric_elements = frame.symmetric_elements
        symmetric_edge_index = [[i,j] for symmetry_element in symmetric_elements for i in symmetry_element for j in symmetry_element]
        projection_edge_index = [[symmetry_element[0],symmetry_element[j]] for symmetry_element in symmetric_elements for j in range(1,len(symmetry_element))]
        data.project_edge_index = torch.tensor(projection_edge_index, dtype=torch.long).T
        data.symmetric_edge_index = torch.tensor(symmetric_edge_index, dtype=torch.long).T
        asu = frame.simple_asu
        asu_edge_index = [[i,j] for i in asu for j in asu]
        data.asu_edge_index = torch.tensor(asu_edge_index, dtype=torch.long).T

        if loader_cfg['relax']:
            y = data.y
            y = y - y.mean(dim=0)
            y = y.to(torch.float32)
            symmetries = torch.tensor(frame.symmetries).to(torch.float32)
            for symmetry in symmetries[1:]: # the first symmetry is identity
                z = y[0]@symmetry.T
                y = torch.cat((y, z.unsqueeze(0)), dim=0)
            data.y = y.unsqueeze(0).to(torch.float32)

        degrees.append(degree(data.edge_index[0], num_nodes=data.z.shape[0]).tolist()[0])
    degrees_hist = torch.from_numpy(np.histogram(degrees, bins=range(10))[0])
    dataloader = DataLoader(dataset, batch_size=1, shuffle=True)
    val_loader = DataLoader(dataset, batch_size=batch_size, shuffle=False)
    test_loader = DataLoader(dataset, batch_size=batch_size, shuffle=False)

    return {'train': dataloader, 'val': val_loader, 'test': test_loader, 'degrees_hist': degrees_hist}

"""
================
Generate Dataset
================
"""

def create_box(version2=False):
    
    dataset = []

    # S > R
    #========
    atomic_numbers = [1,1,1,1]
    atoms = torch.tensor(atomic_numbers,dtype=torch.long)

    pos = torch.tensor([[1,1.0,0],[-1,-1,0],[1,-1,0], [-1,1,0]])
    v1 = torch.tensor([1,1+1,0],dtype=torch.float32)
    v2 = torch.tensor([-1,-1,0],dtype=torch.float32)
    v3 = torch.tensor([1,-1,0],dtype=torch.float32)
    v4 = torch.tensor([-1,1+1,0],dtype=torch.float32)

    b = pos.mean(dim=0)
    pos = pos - b
    y = torch.stack((v1,v2,v3,v4)).unsqueeze(0)
    y = y - y.mean(dim=1)

    if version2:
        _, _, Q = torch.svd(torch.randn(3, 3))
        b = torch.zeros(3)

        pos = pos @ Q.T + b
        y = y @ Q.T + b

    edge_index = [[i,j] for i in range(4) for j in range(4) if i != j]
    edge_index = torch.tensor(edge_index,dtype=torch.long).T
    print(y)
    data1 = Data(z=atoms,edge_index=edge_index,y=y,pos=pos)
    data1.edge_index = to_undirected(data1.edge_index)
    dataset.append(data1)

    # S > R2
    #========
    atomic_numbers = [1,1,1,1]
    atoms = torch.tensor(atomic_numbers,dtype=torch.long)

    pos = torch.tensor([[1,1.0,0],[-1,-1,0],[1,-1,0], [-1,1,0]])
    v1 = torch.tensor([1+1,1,0],dtype=torch.float32)
    v2 = torch.tensor([-1,-1,0],dtype=torch.float32)
    v3 = torch.tensor([1+1,-1,0],dtype=torch.float32)
    v4 = torch.tensor([-1,1,0],dtype=torch.float32)

    b = pos.mean(dim=0)
    pos = pos - b
    y = torch.stack((v1,v2,v3,v4)).unsqueeze(0)
    y = y - y.mean(dim=1)

    if version2:
        _, _, Q = torch.svd(torch.randn(3, 3))
        b = torch.zeros(3)

        pos = pos @ Q.T + b
        y = y @ Q.T + b

    edge_index = [[i,j] for i in range(4) for j in range(4) if i != j]
    edge_index = torch.tensor(edge_index,dtype=torch.long).T
    print(y)
    data1 = Data(z=atoms,edge_index=edge_index,y=y,pos=pos)
    data1.edge_index = to_undirected(data1.edge_index)
    dataset.append(data1)

    """
    #========
    _, _, Q = torch.svd(torch.randn(3, 3))
    b = torch.zeros(3)

    atomic_numbers = [1,1,1]
    atoms = torch.tensor(atomic_numbers,dtype=torch.long)

    pos = torch.tensor([[1,1.0,0],[-1,-1,0],[-1,1,0]])
    v1 = torch.tensor([1,1,0],dtype=torch.float32)
    v2 = torch.tensor([-1,-1,0],dtype=torch.float32)
    v3 = torch.tensor([-1,1+1,0],dtype=torch.float32)

    pos = pos @ Q.T + b
    v1 = v1 @ Q.T + b
    v2 = v2 @ Q.T + b
    v3 = v3 @ Q.T + b

    y = torch.cat((v1,v2,v3),dim=0)
    edge_index = [[i,j] for i in range(3) for j in range(3) if i != j]
    edge_index = torch.tensor(edge_index,dtype=torch.long).T
    print(y)
    data1 = Data(z=atoms,edge_index=edge_index,y=y,pos=pos)
    data1.edge_index = to_undirected(data1.edge_index)
    dataset.append(data1)

    #========
    _, _, Q = torch.svd(torch.randn(3, 3))
    b = torch.zeros(3)

    atomic_numbers = [1,1,1,1]
    atoms = torch.tensor(atomic_numbers,dtype=torch.long)

    pos = torch.tensor([[1,1.0+1,0],[-1,-1,0],[1,-1,0], [-1,1+1,0]])
    v1 = torch.tensor([1,1,0],dtype=torch.float32)
    v2 = torch.tensor([-1,-1,0],dtype=torch.float32)
    v3 = torch.tensor([1,-1,0],dtype=torch.float32)
    v4 = torch.tensor([-1,1,0],dtype=torch.float32)

    pos = pos @ Q.T + b
    v1 = v1 @ Q.T + b
    v2 = v2 @ Q.T + b
    v3 = v3 @ Q.T + b
    v4 = v4 @ Q.T + b

    y = torch.cat((v1,v2,v3,v4),dim=0)
    edge_index = [[i,j] for i in range(4) for j in range(4) if i != j]
    edge_index = torch.tensor(edge_index,dtype=torch.long).T
    print(y)
    data1 = Data(z=atoms,edge_index=edge_index,y=y,pos=pos)
    data1.edge_index = to_undirected(data1.edge_index)
    dataset.append(data1)

    #========
    _, _, Q = torch.svd(torch.randn(3, 3))
    b = torch.zeros(3)
    atomic_numbers = [1,1,1,1]
    atoms = torch.tensor(atomic_numbers,dtype=torch.long)

    pos = torch.tensor([[1+1,1.0,0],[-1,-1,0],[1+1,-1,0], [-1,1,0]])
    v1 = torch.tensor([1,1,0],dtype=torch.float32)
    v2 = torch.tensor([-1,-1,0],dtype=torch.float32)
    v3 = torch.tensor([1,-1,0],dtype=torch.float32)
    v4 = torch.tensor([-1,1,0],dtype=torch.float32)

    pos = pos @ Q.T + b
    v1 = v1 @ Q.T + b
    v2 = v2 @ Q.T + b
    v3 = v3 @ Q.T + b
    v4 = v4 @ Q.T + b

    y = torch.cat((v1,v2,v3,v4),dim=0)
    edge_index = [[i,j] for i in range(4) for j in range(4) if i != j]
    edge_index = torch.tensor(edge_index,dtype=torch.long).T
    print(y)
    data1 = Data(z=atoms,edge_index=edge_index,y=y,pos=pos)
    data1.edge_index = to_undirected(data1.edge_index)
    dataset.append(data1)
    """


    return dataset
