from rdkit.Chem import QED, Crippen
from rdkit.Chem import Descriptors
from rdkit import Chem

from rdkit import RDLogger
import io
import contextlib

class RDKitEvaluator:
    def _has_element(self, smiles: str, element_symbol: str = "F") -> bool:
        """
        Returns True if the SMILES string contains the specified element symbol, False otherwise.
        """
        mol = Chem.MolFromSmiles(smiles)
        if mol is None:
            return False  # Invalid SMILES

        for atom in mol.GetAtoms():
            if atom.GetSymbol() == element_symbol:
                return True

        return False

    def _get_smiles_parse_error(self, smiles: str) -> str:
        # Create a string buffer to capture RDKit warnings
        buffer = io.StringIO()

        # Save original stderr state
        logger = RDLogger.logger()
        logger.setLevel(RDLogger.ERROR)

        # Temporarily redirect stderr
        with contextlib.redirect_stderr(buffer):
            mol = Chem.MolFromSmiles(smiles)
        
        # Get captured warnings
        error_message = buffer.getvalue().strip()
        return error_message if not mol else None

    def evaluate(self, smiles: str, constraints: dict):
        """
        Returns (is_valid, feedback_message) for the given SMILES
        against the constraints.
        """

        feedback = ''

        # 1) Check validity
        mol = Chem.MolFromSmiles(smiles)
        if mol is None:
            parse_error = self._get_smiles_parse_error(smiles)
            return f"Invalid SMILES (RDKit parse failed). Reason: {parse_error or 'Unknown error.'}"


        # 2) Compute properties
        qed_val = QED.qed(mol)
        logp_val = Crippen.MolLogP(mol)
        mw_val = Descriptors.MolWt(mol)

        # 3) Check constraints
        # QED
        qed_min = constraints.get('qed_min', None)
        qed_max = constraints.get('qed_max', None)

        if qed_min and qed_val < qed_min:
            feedback = feedback + f"The generated molecule's QED is {qed_val}, lower than the constraint."
        if qed_max and qed_val > qed_max:
            feedback = feedback + f"The generated molecule's QED is {qed_val}, higher than the constraint."

        logp_min = constraints.get('logp_min', None)
        logp_max = constraints.get('logp_max', None)

        if logp_min and logp_val < logp_min:
            feedback = feedback + f"The generated molecule's LogP is {logp_val}, lower than the constraint."
        if logp_max and logp_val > logp_max:
            feedback = feedback + f"The generated molecule's LogP is {logp_val}, higher than the constraint."

            

        # Must contain 'F'
        must_contain = constraints.get('must_contain', None)
        if must_contain and not self._has_element(smiles,must_contain):
            feedback = feedback +  f"Molecule does not contain {must_contain}."

        # MW
        mw_max = constraints.get('mw_max', None)
        mw_min = constraints.get('mw_min', None)
        if mw_max and mw_val > mw_max:
            feedback = feedback + f"Molecular Weight {mw_val} is larger than the constraint."
        
        if mw_min and mw_val < mw_min:
            feedback = feedback + f"Molecular Weight {mw_val} is less than the constraint."

        if feedback == '':
            return 'All constraints are satisfied.'
        return feedback

    def evaluate_exact(self, smiles: str, constraints: dict, range_size=0.02):
        """
        Returns (is_valid, feedback_message) for the given SMILES
        against the constraints.
        """

        feedback = ''

        # 1) Check validity
        mol = Chem.MolFromSmiles(smiles)
        if mol is None:
            parse_error = self._get_smiles_parse_error(smiles)
            return f"Invalid SMILES (RDKit parse failed). Reason: {parse_error or 'Unknown error.'}"

        # 2) Compute properties
        qed_val = QED.qed(mol)
        logp_val = Crippen.MolLogP(mol)
        mw_val = Descriptors.MolWt(mol)

        # 3) Check constraints
        # QED
        qed = constraints.get('qed', None)
        #qed_max = constraints.get('qed_max', None)

        if qed and qed_val < qed:
            feedback = feedback + f"The generated molecule's QED is {qed_val}, {qed - qed_val} lower than the constraint."
        if qed and qed_val > qed:
            feedback = feedback + f"The generated molecule's QED is {qed_val}, {qed_val - qed} higher than the constraint."

        logp = constraints.get('logp', None)

        if logp and logp_val < logp:
            feedback = feedback + f"The generated molecule's LogP is {logp_val}, {logp - logp_val} lower than the constraint."
        if logp and logp_val > logp:
            feedback = feedback + f"The generated molecule's LogP is {logp_val}, {logp_val - logp} higher than the constraint."

            

        # Must contain 'F'
        must_contain = constraints.get('must_contain', None)
        if must_contain and not self._has_element(smiles,must_contain):
            feedback = feedback +  f"Molecule does not contain {must_contain}."

        # MW
        mw = constraints.get('mw', None)
        if mw and mw_val > mw:
            feedback = feedback + f"Molecular Weight {mw_val} is {mw_val - mw} larger than the constraint."
        
        if mw and mw_val < mw:
            feedback = feedback + f"Molecular Weight {mw_val} is {mw - mw_val} less than the constraint."

        

        within_mw = mw is not None and abs(mw - mw_val) <= range_size * 100
        within_qed = qed is not None and abs(qed - qed_val) <= range_size
        within_logp = logp is not None and abs(logp_val - logp) <= range_size * 6

        if feedback == '' or (within_mw and within_qed and within_logp):
            return 'All constraints are satisfied.'
        return feedback


