import py3Dmol
from rdkit import Chem
import torch_geometric
import torch_geometric.datasets


def to_rdkit_molecule(data: torch_geometric.data.Data) -> Chem.Mol:
    """
    Convert a PyTorch Geometric graph to an RDKit molecule using only atom types
    and positions (no edge/bond information).

    Args:
        data: A PyTorch Geometric Data object containing:
            - species: Node features tensor where each row represents an atom type
            - pos: Node position coordinates (N x 3 tensor)

    Returns:
        mol: RDKit Molecule object
    """
    # Create empty editable mol object
    mol = Chem.RWMol()

    # Add atoms
    atomic_numbers = data["z"]
    atom_idxs = []
    for atomic_num in atomic_numbers:
        atom = Chem.Atom(int(atomic_num))
        idx = mol.AddAtom(atom)
        atom_idxs.append(idx)

    # Convert to non-editable molecule
    mol = mol.GetMol()

    # Create a conformer to store 3D positions
    conf = Chem.Conformer(mol.GetNumAtoms())
    for i, position in enumerate(data["pos"]):
        x, y, z = position.tolist()
        conf.SetAtomPosition(i, (float(x), float(y), float(z)))

    mol.AddConformer(conf)

    # add bonds as edges?
    mol = Chem.RWMol(mol)
    num_edges = int(data.edge_index.shape[1])
    #print(data.edge_index)
    for e in range(num_edges):
        a, b = data.edge_index[0,e], data.edge_index[1,e]
        if a<b:
            mol.AddBond(atom_idxs[a], atom_idxs[b], Chem.BondType.SINGLE)

    
    return mol

def visualize_qm9_data(data):
    mol = to_rdkit_molecule(data)
    view = py3Dmol.view(
        data=Chem.MolToMolBlock(mol),  # Convert the RDKit molecule for py3Dmol
        style={"stick": {}, "sphere": {"scale": 0.3}}
    )
    view.zoomTo()
    view.show()