"""
QM9 (Quantum Mechanics 9)

A collection of molecules with up to nine heavy atoms (C, O, N, S)
used as a benchmark dataset for molecular property prediction and
graph-classification tasks.

This file is a loader for variations of the dataset.

"""

import numpy as np
import torch

from torch_geometric.data import Data
from torch_geometric.utils import degree, to_undirected

from rdkit import Chem
from rdkit.Chem import AllChem

from torch_canon.E3Global.CategoricalPointCloud import CatFrame as Frame

from .bipartite_pair import DataLoader

"""
===============
Generate Splits
===============
"""

def polystyrene_loaders(loader_cfg: dict) -> dict:
    batch_size = loader_cfg['batch_size']
    dataset = create_polystyrene()
    degrees = []
    if 'align' in loader_cfg:
        align_cfg = loader_cfg['align']
        frame = Frame(tol=align_cfg['tol'])
        for i,data in enumerate(dataset):
            align_pos, frame_R, frame_t = frame.get_frame(data.pos, data.x)
            data.pos = torch.from_numpy(align_pos)
            data.x = torch.cat([data.x.view(data.x.shape[0],-1), data.pos], dim=1)
            data.frame_R = frame_R
            data.frame_t = frame_t
            dataset[i] = data
            degrees.append(degree(data.edge_index[0], num_nodes=data.x.shape[0]).tolist()[0])
    degrees_hist = torch.from_numpy(np.histogram(degrees, bins=range(10))[0])
    dataloader = DataLoader(dataset, batch_size=1, shuffle=True)
    val_loader = DataLoader(dataset, batch_size=batch_size, shuffle=False)
    test_loader = DataLoader(dataset, batch_size=batch_size, shuffle=False)

    return {'train': dataloader, 'val': val_loader, 'test': test_loader, 'degrees_hist': degrees_hist}

"""
================
Generate Dataset
================
"""
def create_polystyrene():
  dataset = []

  k = 8
  mol = generate_polystyrene(k,'syndio')
  data = mol_to_data(mol, y=0)
  dataset.append(data)

  mol = generate_polystyrene(k,'iso')
  data = mol_to_data(mol, y=1)
  dataset.append(data)

  mol = generate_polystyrene(k,'a')
  data = mol_to_data(mol, y=2)
  dataset.append(data)
  return dataset


def mol_to_data(mol, y=0):
  AllChem.Compute2DCoords(mol)
  conf = mol.GetConformer()
  z_offset_R = 0.5  # Positive z for R
  z_offset_S = -0.5  # Negative z for S

  # Loop through atoms and adjust z based on chirality
  for atom in mol.GetAtoms():
      pos = list(conf.GetAtomPosition(atom.GetIdx()))

      if atom.HasProp('_CIPCode'):
          chirality = atom.GetProp('_CIPCode')

          # Adjust z-coordinate based on chirality
          if chirality == 'R':
              pos[2] = z_offset_R
          elif chirality == 'S':
              pos[2] = z_offset_S

      # Set the new 3D position with adjusted z
      conf.SetAtomPosition(atom.GetIdx(), pos)
  pos = torch.tensor(conf.GetPositions(),dtype=torch.float32)

  x = mol.GetAtoms()
  y = torch.LongTensor([y])
  z = torch.tensor([0 for atom in mol.GetAtoms()])
  edge_index = torch.tensor(list_bonds(mol))[:,:2].T.to(torch.long)
  edge_index = to_undirected(edge_index)
  edge_attr = torch.zeros(edge_index.shape[1]).to(torch.int)
  data = Data(x=z,edge_index=edge_index,y=y,pos=pos,edge_attr=edge_attr, mol=mol)
  return data

def list_bonds(molecule):
    bonds = []
    for bond in molecule.GetBonds():
        atom1_idx = bond.GetBeginAtomIdx()  # Index of the first atom in the bond
        atom2_idx = bond.GetEndAtomIdx()    # Index of the second atom in the bond
        bond_type = bond.GetBondType()      # Bond type (single, double, etc.)

        # Append bond information to the list
        bonds.append((atom1_idx, atom2_idx, bond_type))

    return bonds

# Create the three rings: benzene rings
def generate_polystyrene(n, tactic='syndio'):

  total = ["C[C@H](C1=CC=CC=C1)","C[C@@H](C1=CC=CC=C1)"]
  if tactic=='iso':
    total = (n//2) * total
    smiles = total[0] + ''.join(total[:n-1])
  elif tactic=='syndio':
    total = n * total[:1]
    smiles =  ''.join(total[:n])
  elif tactic=='a':
    total = total[:1] + total[1:] + total[1:] + total[:1] + total[1:] + total[:1] + total[:1] + total[1:]
    smiles =  ''.join(total[:n])
  elif tactic=='a2':
    total = total[:1] + total[:1] + total[1:] + total[:1] + total[1:] + total[:1] + total[:1] + total[1:]
    smiles =  ''.join(total[:n])
  else:
    raise NotImplementedError (f"unrecognized tactic-ness, {tactic}")

  mol = Chem.MolFromSmiles(smiles[1:])
  # AllChem.EmbedMolecule(mol)
  return mol
