from datasets import load_dataset

from rdkit import Chem
from rdkit.Chem import QED, Descriptors, Crippen
import re

# Load dataset

# Property functions
def calculate_qed(smiles):
    mol = Chem.MolFromSmiles(smiles)
    return QED.qed(mol) if mol else None

def calculate_logp(smiles):
    mol = Chem.MolFromSmiles(smiles)
    return Crippen.MolLogP(mol) if mol else None

def calculate_mw(smiles):
    mol = Chem.MolFromSmiles(smiles)
    return Descriptors.MolWt(mol) if mol else None


from rdkit.Chem import BRICS

import re

def cleanup_brics_fragment(fragment_smiles: str) -> str:
    """
    Removes atom-mapping indices (like :2) from bracketed atoms,
    preserving the BRICS attachment notation [n*] if present.
    Then returns a cleaned SMILES from RDKit.
    """
    # 1) Use regex to remove :<digits> inside bracketed atoms, e.g. [CH2:2] -> [CH2]
    #    This preserves [4*] because we're only removing :\d+ patterns.
    cleaned = re.sub(r":\d+", "", fragment_smiles)

    # 2) Convert cleaned SMILES back to an RDKit Mol (do minimal sanitize here)
    mol = Chem.MolFromSmiles(cleaned, sanitize=False)
    if mol is None:
        raise ValueError(f"Could not parse fragment after cleanup: {cleaned}")

    # 3) Now perform a full sanitize to get a valid RDKit molecule
    Chem.SanitizeMol(mol)

    # 4) Convert the sanitized molecule back to SMILES
    final_smiles = Chem.MolToSmiles(mol, isomericSmiles=True)
    return final_smiles

def cleanup_dummy_atoms(fragment_smiles: str) -> str:
    """
    Replace dummy attachments like [4*], [3*], etc. with [H].
    Example: [4*]CCC -> [H]CCC
    """
    # Regex:  \[\d*\*\]  matches [4*], [3*], [10*], etc.
    cleaned = re.sub(r"\[\d*\*\]", "[H]", fragment_smiles)
    return cleaned

# def brics_fragments(smiles: str):



def brics_decomposition_connectivity(smiles: str):
    """
    Perform BRICS decomposition on a molecule given by SMILES,
    and reconstruct which fragments are connected by BRICS bonds.

    Returns:
      fragment_smiles_list: list of SMILES for each fragment (with [*] attachments)
      connections_list: list of tuples (fragA, fragB, bondType),
                        indicating which fragments connect via which bond type
    """

    # 1) Convert SMILES to RDKit molecule
    mol = Chem.MolFromSmiles(smiles)
    if mol is None:
        raise ValueError(f"Invalid SMILES string: {smiles}")

    # 2) Assign original atom indices as atom map numbers
    for atom in mol.GetAtoms():
        atom.SetAtomMapNum(atom.GetIdx())

    # 3) Identify BRICS bonds BEFORE breaking
    #    Each entry is ((atomIdx1, atomIdx2), (BRICS_code1, BRICS_code2))
    brics_bonds = list(BRICS.FindBRICSBonds(mol))

    # Store them as a simple list of (atom1, atom2, (code1, code2))
    bond_info = []
    for (a1, a2), bond_type in brics_bonds:
        bond_info.append((a1, a2, bond_type))

    # 4) Break BRICS bonds to get a "fragmented" molecule
    mol_broken = BRICS.BreakBRICSBonds(mol)

    # 5) Get separate fragment molecules
    fragments = Chem.GetMolFrags(mol_broken, asMols=True, sanitizeFrags=True)

    # 6) Convert each fragment to SMILES (retain [*] attachment points)
    #    Also record which original atom maps ended up in which fragment
    fragment_smiles_list = []
    atom_map_to_fragment = {}

    for frag_idx, frag_mol in enumerate(fragments):
        # Convert this fragment to SMILES with bracket notation
        frag_smiles = Chem.MolToSmiles(frag_mol, isomericSmiles=True)
        fragment_smiles_list.append(frag_smiles)

        # Each atom in this fragment has an original index stored in atom map
        for atom in frag_mol.GetAtoms():
            original_idx = atom.GetAtomMapNum()
            atom_map_to_fragment[original_idx] = frag_idx

    # 7) Determine which fragments connect based on original bonds
    connections_list = []
    for (atomA, atomB, bond_type) in bond_info:
        # Which fragment has atomA?
        fragA = atom_map_to_fragment.get(atomA, None)
        # Which fragment has atomB?
        fragB = atom_map_to_fragment.get(atomB, None)

        # If both exist and are different, it's a valid fragment-connecting bond
        if fragA is not None and fragB is not None and fragA != fragB:
            connections_list.append(
                (cleanup_brics_fragment(fragment_smiles_list[fragA]), cleanup_brics_fragment(fragment_smiles_list[fragB]), bond_type)
            )
    clean_up_fragmet_smiles_list = []
    for fragment in fragment_smiles_list:
        clean_up_fragmet_smiles_list.append(cleanup_brics_fragment(fragment))

    cleaned_fragments = [re.sub(r'\[\d+\*\]', '', frag) for frag in clean_up_fragmet_smiles_list]


    return cleaned_fragments, connections_list