import os
import pickle
import re

import numpy as np
import torch
from rdkit import Chem
from rdkit.Chem import rdchem
from rdkit.Geometry import Point3D



# Open & Save Data
def open_pickle(mol_path):
    with open(mol_path, "rb") as f:
        dic = pickle.load(f)
    return dic

def save_pyg_data_to_pkl(data, smi, args, task = 'local'):
    """
    Save a PyG Data object to a .pkl file under
    <args.root>/data/<args.data_type>_local/<idx>.pkl
    """
    # e.g. subfolder = "qm9_local"
    subfolder = f"{args.data_type}_{args.mode}"
    out_dir = os.path.join(args.root, "data", subfolder)
    os.makedirs(out_dir, exist_ok=True)

    # Construct the file path and write
    filepath = os.path.join(out_dir, f"{smi}.pkl")
    with open(filepath, "wb") as f:
        pickle.dump(data, f)

# Clean data
def get_full_smiles(mol):
    """
    Assign atom map numbers starting from 0 and generate a SMILES string
    where all atoms (including atom 0) have visible atom map annotations.
    
    RDKit by default suppresses atoms with atomMapNum = 0 in SMILES,
    so we shift all map numbers by +1 temporarily, then subtract 1 in the SMILES string.
    """
    # Temporarily assign map numbers as idx + 1
    for idx, atom in enumerate(mol.GetAtoms()):
        atom.SetAtomMapNum(idx + 1)

    # Generate SMILES (disable canonical to preserve atom order)
    smiles = Chem.MolToSmiles(mol, isomericSmiles=True, canonical=False)

    # Use regex to subtract 1 from all atom map numbers
    def replace_mapnum(match):
        symbol = match.group(1)
        num = int(match.group(2))
        return f"[{symbol}:{num - 1}]"

    mapped_smiles = re.sub(r"\[([A-Za-z@+\-\d#]+):(\d+)\]", replace_mapnum, smiles)

    return mapped_smiles

def clean_data(ref_conformer, conformers, test = False):
    """
    Input: 
    ref_conformer: the 1st mol in conformers list
    conformers: the conformers list
    Return:
    full smiles, list of cleaned conformers
    """
    full_smiles = get_full_smiles(ref_conformer)
    

    # basic molecule check
    if '.' in full_smiles :
        print(f'{full_smiles }, conformers with fragments')
        return [], full_smiles
    # skip mols with atoms with more than 4 neighbors for now
    n_neighbors = [len(a.GetNeighbors()) for a in ref_conformer.GetAtoms()]
    if np.max(n_neighbors) > 4: 
        print(f'{full_smiles }, at least one atom with more than 4 neighbors')
        return [], full_smiles
    
    if test:
        cleaned_conformers = clean_confs_test(full_smiles, conformers)
    else:
        cleaned_conformers = clean_confs(full_smiles, conformers)
    if cleaned_conformers == []:
        print(f'{full_smiles}, has no conformers')
    return cleaned_conformers, full_smiles

def clean_confs(full_smiles, confs):
    cleaned_conformers = []
    for conf in confs:
        conf_smiles = get_full_smiles(conf['rd_mol'])
        if conf_smiles == full_smiles: cleaned_conformers.append(conf['rd_mol'])
    return cleaned_conformers

def clean_confs_test(full_smiles, confs):
    cleaned_conformers = []
    for conf in confs:
        conf_smiles = get_full_smiles(conf)
        if conf_smiles == full_smiles: cleaned_conformers.append(conf)
    return cleaned_conformers

# Save pos 
def update_conformer_positions(mol: Chem.Mol, new_pos) -> Chem.Mol:
    """
    Replace the only conformer in a molecule with new 3D coordinates.

    Args:
        mol (Chem.Mol): An RDKit molecule with exactly one conformer.
        new_pos (Union[np.ndarray, torch.Tensor]): A (N_atoms, 3) array/tensor of new positions.

    Returns:
        Chem.Mol: A copy of the molecule with updated conformer.
    """

    mol = Chem.Mol(mol)  # copy

    if isinstance(new_pos, torch.Tensor):
        new_pos = new_pos.detach().cpu().numpy()

    assert new_pos.shape == (mol.GetNumAtoms(), 3)

    # Create new conformer
    new_conf = rdchem.Conformer(mol.GetNumAtoms())
    for i in range(mol.GetNumAtoms()):
        new_conf.SetAtomPosition(i, Point3D(*map(float, new_pos[i])))

    # Remove all conformers (assume only one)
    mol.RemoveAllConformers()

    # Add new one
    mol.AddConformer(new_conf, assignId=True)

    return mol
