"""RDKit-based molecular property computation.

This module provides functions for computing drug-likeness and synthetic accessibility
properties commonly used in molecular optimization benchmarks.
"""

from __future__ import annotations

from typing import Sequence

import numpy as np
from rdkit import Chem
from rdkit.Chem import Crippen, Descriptors
from rdkit.Chem.QED import qed

from moltenflow.utils.logging import get_logger

logger = get_logger(__name__)

# SA_Score is in RDKit Contrib - we need to handle import carefully
try:
    from rdkit.Contrib.SA_Score import sascorer

    _HAS_SASCORER = True
except ImportError:
    _HAS_SASCORER = False
    logger.warning(
        "SA_Score not available from rdkit.Contrib. "
        "Install RDKit with Contrib support or use alternative SAS computation."
    )


def compute_qed(mol: Chem.Mol) -> float:
    """Compute Quantitative Estimate of Drug-likeness (QED).

    QED ranges from 0 to 1, where higher values indicate more drug-like molecules.

    Args:
        mol: RDKit molecule object

    Returns:
        QED score in range [0, 1]

    Raises:
        ValueError: If mol is None or invalid
    """
    if mol is None:
        raise ValueError("Cannot compute QED for None molecule")
    return float(qed(mol))


def compute_sas(mol: Chem.Mol) -> float:
    """Compute Synthetic Accessibility Score (SAS).

    SAS ranges from 1 to 10, where lower values indicate easier synthesis.

    Args:
        mol: RDKit molecule object

    Returns:
        SAS score in range [1, 10]

    Raises:
        ValueError: If mol is None or SA_Score is not available
    """
    if mol is None:
        raise ValueError("Cannot compute SAS for None molecule")
    if not _HAS_SASCORER:
        raise ValueError("SA_Score not available. Install RDKit with Contrib support.")
    return float(sascorer.calculateScore(mol))


def compute_logp(mol: Chem.Mol) -> float:
    """Compute Crippen LogP (partition coefficient).

    Args:
        mol: RDKit molecule object

    Returns:
        LogP value (unbounded, typically -3 to 7 for drug-like molecules)

    Raises:
        ValueError: If mol is None
    """
    if mol is None:
        raise ValueError("Cannot compute LogP for None molecule")
    return float(Crippen.MolLogP(mol))


def compute_plogp(mol: Chem.Mol) -> float:
    """Compute Penalized LogP (pLogP) as used in JT-VAE/GCPN benchmarks.

    pLogP = LogP - SAS - ring_penalty

    where ring_penalty = max(0, largest_ring_size - 6)

    This metric penalizes molecules that are hard to synthesize or have
    unusually large rings.

    Args:
        mol: RDKit molecule object

    Returns:
        Penalized LogP value (unbounded, can be negative)

    Raises:
        ValueError: If mol is None or SA_Score is not available
    """
    if mol is None:
        raise ValueError("Cannot compute pLogP for None molecule")

    logp = compute_logp(mol)
    sas = compute_sas(mol)

    # Ring penalty for large rings
    ring_info = mol.GetRingInfo()
    atom_rings = ring_info.AtomRings()
    if atom_rings:
        largest_ring = max(len(ring) for ring in atom_rings)
    else:
        largest_ring = 0

    ring_penalty = max(0, largest_ring - 6)

    return logp - sas - ring_penalty


def compute_molecular_weight(mol: Chem.Mol) -> float:
    """Compute molecular weight.

    Args:
        mol: RDKit molecule object

    Returns:
        Molecular weight in Daltons

    Raises:
        ValueError: If mol is None
    """
    if mol is None:
        raise ValueError("Cannot compute molecular weight for None molecule")
    return float(Descriptors.MolWt(mol))


# Registry of available properties
PROPERTY_FUNCTIONS = {
    "qed": compute_qed,
    "sas": compute_sas,
    "logp": compute_logp,
    "plogp": compute_plogp,
    "mw": compute_molecular_weight,
}


def compute_property(mol: Chem.Mol, property_name: str) -> float:
    """Compute a single property for a molecule.

    Args:
        mol: RDKit molecule object
        property_name: Name of property to compute (qed, sas, logp, plogp, mw)

    Returns:
        Property value

    Raises:
        ValueError: If property_name is unknown or mol is invalid
    """
    if property_name not in PROPERTY_FUNCTIONS:
        raise ValueError(
            f"Unknown property '{property_name}'. Available: {list(PROPERTY_FUNCTIONS.keys())}"
        )
    return PROPERTY_FUNCTIONS[property_name](mol)


def compute_properties_batch(
    smiles_list: Sequence[str],
    property_names: Sequence[str],
    return_valid_mask: bool = False,
) -> np.ndarray | tuple[np.ndarray, np.ndarray]:
    """Compute multiple properties for a batch of SMILES strings.

    Invalid SMILES or failed computations result in NaN values.

    Args:
        smiles_list: List of SMILES strings
        property_names: List of property names to compute
        return_valid_mask: If True, also return boolean mask of valid molecules

    Returns:
        If return_valid_mask is False:
            Array of shape (n_molecules, n_properties) with NaN for failures
        If return_valid_mask is True:
            Tuple of (properties array, valid mask array)

    Example:
        >>> smiles = ["CCO", "c1ccccc1", "invalid"]
        >>> props = compute_properties_batch(smiles, ["qed", "sas"])
        >>> props.shape
        (3, 2)
    """
    n_mols = len(smiles_list)
    n_props = len(property_names)

    # Validate property names
    for prop in property_names:
        if prop not in PROPERTY_FUNCTIONS:
            raise ValueError(
                f"Unknown property '{prop}'. Available: {list(PROPERTY_FUNCTIONS.keys())}"
            )

    properties = np.full((n_mols, n_props), np.nan, dtype=np.float32)
    valid_mask = np.zeros(n_mols, dtype=bool)

    for i, smi in enumerate(smiles_list):
        mol = Chem.MolFromSmiles(smi)
        if mol is None:
            continue

        try:
            for j, prop_name in enumerate(property_names):
                properties[i, j] = PROPERTY_FUNCTIONS[prop_name](mol)
            valid_mask[i] = True
        except Exception as e:
            logger.debug(f"Failed to compute properties for {smi}: {e}")
            continue

    if return_valid_mask:
        return properties, valid_mask
    return properties


def has_sascorer() -> bool:
    """Check if SA_Score computation is available.

    Returns:
        True if sascorer is available, False otherwise
    """
    return _HAS_SASCORER
