from typing import List, Dict, Optional, Tuple, Set
from rdkit import Chem
from rdkit.Chem import Descriptors, rdMolDescriptors, AllChem
import numpy as np


class MolecularUtils:
    @staticmethod
    def smiles_to_mol(smiles: str) -> Optional[Chem.Mol]:
        try:
            mol = Chem.MolFromSmiles(smiles)
            if mol is not None:
                Chem.SanitizeMol(mol)
            return mol
        except Exception:
            return None

    @staticmethod
    def mol_to_smiles(mol: Chem.Mol, canonical: bool = True) -> Optional[str]:
        try:
            if canonical:
                return Chem.MolToSmiles(mol, canonical=True)
            else:
                return Chem.MolToSmiles(mol)
        except Exception:
            return None

    @staticmethod
    def calculate_molecular_properties(mol: Chem.Mol) -> Dict[str, float]:
        if mol is None:
            return {}

        try:
            properties = {
                'molecular_weight': Descriptors.MolWt(mol),
                'logp': Descriptors.MolLogP(mol),
                'tpsa': Descriptors.TPSA(mol),
                'num_heavy_atoms': mol.GetNumHeavyAtoms(),
                'num_heteroatoms': Descriptors.NumHeteroatoms(mol),
                'num_rotatable_bonds': Descriptors.NumRotatableBonds(mol),
                'num_hbd': Descriptors.NumHDonors(mol),
                'num_hba': Descriptors.NumHAcceptors(mol),
                'num_rings': Descriptors.RingCount(mol),
                'num_aromatic_rings': Descriptors.NumAromaticRings(mol),
                'num_saturated_rings': Descriptors.NumSaturatedRings(mol),
                'num_aliphatic_rings': Descriptors.NumAliphaticRings(mol),
                'slogp': Descriptors.SlogP_VSA0(mol),
                'molar_refractivity': Descriptors.MolMR(mol),
                'fraction_sp3': Descriptors.FractionCsp3(mol),
                'qed': Descriptors.qed(mol)
            }

            # Add Lipinski's rule violations
            properties['lipinski_violations'] = (
                int(properties['molecular_weight'] > 500) +
                int(properties['logp'] > 5) +
                int(properties['num_hbd'] > 5) +
                int(properties['num_hba'] > 10)
            )

            return properties

        except Exception as e:
            return {'error': str(e)}

    @staticmethod
    def check_drug_likeness(mol: Chem.Mol) -> Dict[str, bool]:
        props = MolecularUtils.calculate_molecular_properties(mol)

        if 'error' in props:
            return {'error': True}

        # Lipinski's Rule of Five
        lipinski_pass = (
            150 <= props['molecular_weight'] <= 500 and
            props['logp'] <= 5 and
            props['num_hbd'] <= 5 and
            props['num_hba'] <= 10
        )

        # Veber's rule
        veber_pass = (
            props['num_rotatable_bonds'] <= 10 and
            props['tpsa'] <= 140
        )

        # Lead-like criteria
        lead_like = (
            150 <= props['molecular_weight'] <= 350 and
            -2 <= props['logp'] <= 3 and
            props['num_rotatable_bonds'] <= 7
        )

        return {
            'lipinski_compliant': lipinski_pass,
            'veber_compliant': veber_pass,
            'lead_like': lead_like,
            'drug_like': lipinski_pass and veber_pass
        }

    @staticmethod
    def identify_functional_groups(mol: Chem.Mol) -> List[str]:
        if mol is None:
            return []

        functional_groups = []

        # Define SMARTS patterns for common functional groups
        patterns = {
            'alcohol': '[OX2H]',
            'phenol': '[OX2H][cX3]',
            'ether': '[OX2]([CX4])[CX4]',
            'aldehyde': '[CX3H1](=O)[#6]',
            'ketone': '[CX3](=[OX1])([#6])[#6]',
            'carboxylic_acid': '[CX3](=O)[OX2H1]',
            'ester': '[CX3](=O)[OX2H0]',
            'amide': '[CX3](=[OX1])[NX3]',
            'amine_primary': '[NX3;H2]',
            'amine_secondary': '[NX3;H1]',
            'amine_tertiary': '[NX3;H0]',
            'nitrile': '[CX1]#[NX1]',
            'sulfide': '[SX2]',
            'sulfoxide': '[SX3](=O)',
            'sulfone': '[SX4](=O)(=O)',
            'phosphate': '[PX4](=O)([OH])([OH])[OH]',
            'aromatic_ring': 'c1ccccc1',
            'pyridine': 'c1ccncc1',
            'imidazole': 'c1c[nH]cn1',
            'thiazole': 'c1cscn1',
            'oxazole': 'c1cocn1',
            'furan': 'c1ccoc1',
            'thiophene': 'c1ccsc1',
            'pyrrole': 'c1cc[nH]c1',
            'indole': 'c1ccc2[nH]ccc2c1'
        }

        for group_name, smarts in patterns.items():
            try:
                pattern = Chem.MolFromSmarts(smarts)
                if pattern and mol.HasSubstructMatch(pattern):
                    functional_groups.append(group_name)
            except Exception:
                continue

        return functional_groups

    @staticmethod
    def calculate_fingerprint(mol: Chem.Mol, fp_type: str = 'morgan', **kwargs) -> Optional[np.ndarray]:
        if mol is None:
            return None

        try:
            if fp_type == 'morgan':
                radius = kwargs.get('radius', 2)
                nbits = kwargs.get('nbits', 2048)
                fp = AllChem.GetMorganFingerprintAsBitVect(mol, radius, nBits=nbits)
                return np.array(fp)

            elif fp_type == 'rdkit':
                nbits = kwargs.get('nbits', 2048)
                fp = Chem.RDKFingerprint(mol, fpSize=nbits)
                return np.array(fp)

            elif fp_type == 'maccs':
                fp = rdMolDescriptors.GetMACCSKeysFingerprint(mol)
                return np.array(fp)

            else:
                raise ValueError(f"Unknown fingerprint type: {fp_type}")

        except Exception:
            return None

    @staticmethod
    def calculate_tanimoto_similarity(fp1: np.ndarray, fp2: np.ndarray) -> float:
        if fp1 is None or fp2 is None or len(fp1) != len(fp2):
            return 0.0

        intersection = np.sum(fp1 & fp2)
        union = np.sum(fp1 | fp2)

        if union == 0:
            return 1.0  # Both fingerprints are all zeros

        return intersection / union

    @staticmethod
    def get_scaffold(mol: Chem.Mol) -> Optional[str]:
        try:
            from rdkit.Chem.Scaffolds import MurckoScaffold
            scaffold = MurckoScaffold.GetScaffoldForMol(mol)
            return Chem.MolToSmiles(scaffold) if scaffold else None
        except Exception:
            return None

    @staticmethod
    def calculate_synthetic_accessibility(mol: Chem.Mol) -> Optional[float]:
        try:
            from rdkit.Chem import rdMolDescriptors
            sa_score = rdMolDescriptors.BertzCT(mol)
            # Normalize to 0-1 range (very rough approximation)
            return max(0, min(1, (1000 - sa_score) / 1000))
        except Exception:
            return None

    @staticmethod
    def validate_molecule(mol: Chem.Mol) -> Dict[str, bool]:
        if mol is None:
            return {'valid': False, 'reason': 'Molecule is None'}

        validation = {'valid': True}

        try:
            # Check for valid valences
            Chem.SanitizeMol(mol)
        except Exception as e:
            validation['valid'] = False
            validation['reason'] = f'Sanitization failed: {str(e)}'
            return validation

        # Check for reasonable atom types
        allowed_atoms = {1, 6, 7, 8, 9, 15, 16, 17, 35, 53}  # H, C, N, O, F, P, S, Cl, Br, I
        for atom in mol.GetAtoms():
            if atom.GetAtomicNum() not in allowed_atoms:
                validation['valid'] = False
                validation['reason'] = f'Unusual atom type: {atom.GetSymbol()}'
                return validation

        # Check molecular size
        num_atoms = mol.GetNumHeavyAtoms()
        if num_atoms > 100:
            validation['valid'] = False
            validation['reason'] = 'Molecule too large'
        elif num_atoms < 2:
            validation['valid'] = False
            validation['reason'] = 'Molecule too small'

        return validation

    @staticmethod
    def neutralize_charges(mol: Chem.Mol) -> Optional[Chem.Mol]:
        if mol is None:
            return None

        try:
            # Simple neutralization - remove charges from common atoms
            mol_copy = Chem.RWMol(mol)

            for atom in mol_copy.GetAtoms():
                if atom.GetFormalCharge() != 0:
                    # Simple neutralization - set charge to 0
                    atom.SetFormalCharge(0)

            return mol_copy.GetMol()

        except Exception:
            return mol

    @staticmethod
    def add_hydrogens(mol: Chem.Mol) -> Optional[Chem.Mol]:
        if mol is None:
            return None

        try:
            return Chem.AddHs(mol)
        except Exception:
            return mol

    @staticmethod
    def remove_hydrogens(mol: Chem.Mol) -> Optional[Chem.Mol]:
        if mol is None:
            return None

        try:
            return Chem.RemoveHs(mol)
        except Exception:
            return mol