import re
from typing import List, Dict, Tuple, Set, Optional
from rdkit import Chem
from rdkit.Chem import BRICS, rdmolops
from ..representation import Motif, ConnectionSite


class BRICSSegmentation:
    CHEMICAL_ENVIRONMENTS = {
        1: "L1: C-C (sp3-sp3)",
        2: "L2: C-C (sp3-sp2)",
        3: "L3: C-C (sp2-sp2)",
        4: "L4: C-N (sp3-sp3)",
        5: "L5: C-N (sp3-sp2)",
        6: "L6: C-O (sp3-sp3)",
        7: "L7: C-O (sp3-sp2)",
        8: "L8: C-S (sp3-sp3)",
        9: "L9: C-S (sp3-sp2)",
        10: "L10: N-N",
        11: "L11: N-O",
        12: "L12: N-S",
        13: "L13: O-O",
        14: "L14: O-S",
        15: "L15: S-S",
        16: "L16: Aromatic-Aromatic"
    }

    BOND_TYPE_MAPPING = {
        Chem.BondType.SINGLE: "SINGLE",
        Chem.BondType.DOUBLE: "DOUBLE",
        Chem.BondType.TRIPLE: "TRIPLE",
        Chem.BondType.AROMATIC: "AROMATIC"
    }

    def __init__(self, min_fragment_size: int = 2):
        self.min_fragment_size = min_fragment_size

    def segment_molecule(self, mol: Chem.Mol, mol_id: str = "mol") -> List[Motif]:
        if mol is None:
            return []

        break_bonds = list(BRICS.FindBRICSBonds(mol))
        if not break_bonds:
            return [self._create_single_motif(mol, f"{mol_id}_0")]

        fragments = BRICS.BRICSDecompose(mol, returnMols=True)
        motifs = []

        for i, fragment in enumerate(fragments):
            if fragment.GetNumHeavyAtoms() >= self.min_fragment_size:
                motif_id = f"{mol_id}_motif_{i}"
                motif = self._create_motif_from_fragment(fragment, motif_id)
                if motif:
                    motifs.append(motif)

        if not motifs:
            motifs = [self._create_single_motif(mol, f"{mol_id}_0")]

        return motifs

    def _create_single_motif(self, mol: Chem.Mol, motif_id: str) -> Motif:
        connection_sites = self._extract_connection_sites(mol)
        properties = self._calculate_properties(mol)
        is_aromatic = self._has_aromatic_atoms(mol)
        ring_info = self._get_ring_info(mol)
        functional_groups = self._identify_functional_groups(mol)

        return Motif(
            motif_id=motif_id,
            smiles=Chem.MolToSmiles(mol),
            mol=mol,
            connection_sites=connection_sites,
            properties=properties,
            is_aromatic=is_aromatic,
            ring_info=ring_info,
            functional_groups=functional_groups
        )

    def _create_motif_from_fragment(self, fragment: Chem.Mol, motif_id: str) -> Optional[Motif]:
        if fragment is None:
            return None

        try:
            sanitized_mol = Chem.MolFromSmiles(Chem.MolToSmiles(fragment))
            if sanitized_mol is None:
                return None

            return self._create_single_motif(sanitized_mol, motif_id)
        except Exception:
            return None

    def _extract_connection_sites(self, mol: Chem.Mol) -> List[ConnectionSite]:
        connection_sites = []
        site_id = 0

        for atom in mol.GetAtoms():
            if self._is_potential_connection_site(atom):
                site = ConnectionSite(
                    site_id=site_id,
                    atom_idx=atom.GetIdx(),
                    site_type=self._get_site_type(atom),
                    chemical_environment=self._get_chemical_environment(atom),
                    allowed_bond_types=self._get_allowed_bond_types(atom),
                    is_aromatic=atom.GetIsAromatic()
                )
                connection_sites.append(site)
                site_id += 1

        return connection_sites

    def _is_potential_connection_site(self, atom: Chem.Atom) -> bool:
        if atom.GetAtomicNum() == 1:  # Hydrogen
            return False

        valence = atom.GetTotalValence()
        max_valence = self._get_max_valence(atom.GetAtomicNum())

        return valence < max_valence

    def _get_max_valence(self, atomic_num: int) -> int:
        valence_dict = {6: 4, 7: 3, 8: 2, 16: 6, 15: 5, 9: 1, 17: 1, 35: 1, 53: 1}
        return valence_dict.get(atomic_num, 4)

    def _get_site_type(self, atom: Chem.Atom) -> str:
        atomic_num = atom.GetAtomicNum()
        hybridization = atom.GetHybridization()

        atom_symbol = atom.GetSymbol()
        hybrid_map = {
            Chem.HybridizationType.SP: "sp",
            Chem.HybridizationType.SP2: "sp2",
            Chem.HybridizationType.SP3: "sp3"
        }

        hybrid_str = hybrid_map.get(hybridization, "unknown")
        return f"{atom_symbol}_{hybrid_str}"

    def _get_chemical_environment(self, atom: Chem.Atom) -> str:
        neighbors = [neighbor.GetSymbol() for neighbor in atom.GetNeighbors()]
        return f"{atom.GetSymbol()}({''.join(sorted(neighbors))})"

    def _get_allowed_bond_types(self, atom: Chem.Atom) -> Set[str]:
        atomic_num = atom.GetAtomicNum()
        current_valence = atom.GetTotalValence()
        max_valence = self._get_max_valence(atomic_num)

        allowed_bonds = set()

        if current_valence < max_valence:
            allowed_bonds.add("SINGLE")

        if current_valence + 1 < max_valence:
            if atomic_num in [6, 7, 8, 16]:  # C, N, O, S can form double bonds
                allowed_bonds.add("DOUBLE")

        if current_valence + 2 < max_valence:
            if atomic_num in [6, 7]:  # C, N can form triple bonds
                allowed_bonds.add("TRIPLE")

        if atom.GetIsAromatic():
            allowed_bonds.add("AROMATIC")

        return allowed_bonds

    def _calculate_properties(self, mol: Chem.Mol) -> Dict[str, float]:
        from rdkit.Chem import Descriptors

        return {
            'molecular_weight': Descriptors.MolWt(mol),
            'num_heavy_atoms': mol.GetNumHeavyAtoms(),
            'num_heteroatoms': Descriptors.NumHeteroatoms(mol),
            'num_rotatable_bonds': Descriptors.NumRotatableBonds(mol),
            'tpsa': Descriptors.TPSA(mol),
            'logp': Descriptors.MolLogP(mol),
            'num_rings': Descriptors.RingCount(mol),
            'num_aromatic_rings': Descriptors.NumAromaticRings(mol)
        }

    def _has_aromatic_atoms(self, mol: Chem.Mol) -> bool:
        return any(atom.GetIsAromatic() for atom in mol.GetAtoms())

    def _get_ring_info(self, mol: Chem.Mol) -> Dict:
        ring_info = mol.GetRingInfo()
        return {
            'rings': [list(ring) for ring in ring_info.AtomRings()],
            'num_rings': ring_info.NumRings(),
            'ring_sizes': [len(ring) for ring in ring_info.AtomRings()]
        }

    def _identify_functional_groups(self, mol: Chem.Mol) -> List[str]:
        functional_groups = []

        patterns = {
            'carbonyl': '[CX3]=[OX1]',
            'alcohol': '[OX2H]',
            'amine': '[NX3;H2,H1;!$(NC=O)]',
            'carboxyl': '[CX3](=O)[OX2H1]',
            'ester': '[CX3](=O)[OX2H0]',
            'amide': '[CX3](=[OX1])[NX3]',
            'nitrile': '[CX1]#[NX1]',
            'aromatic_ring': 'c1ccccc1'
        }

        for group_name, smarts in patterns.items():
            if mol.HasSubstructMatch(Chem.MolFromSmarts(smarts)):
                functional_groups.append(group_name)

        return functional_groups