import json
from rdkit import Chem
import random
class Retriever:
    def __init__(self, dataset):
        with open(dataset, 'r', encoding='utf-8') as file:
            data = json.load(file)
        self.data = data

    
    def _has_element(self, smiles: str, element_symbol) -> bool:
        """
        Returns True if the SMILES string contains the specified element symbol, False otherwise.
        """
        mol = Chem.MolFromSmiles(smiles)
        if mol is None:
            return False  # Invalid SMILES

        for atom in mol.GetAtoms():
            if atom.GetSymbol() == element_symbol:
                return True

        return False

    def retrieve(self, constraints):
        """
        Return molecules/fragments from the dataset that 
        partially match the constraints (or are just relevant).
        """

        must = constraints.get('must_contain','')
        qed_min = constraints.get('qed_min',-1)
        qed_max = constraints.get('qed_max',100)
        logp_min = constraints.get('logp_min',-1)
        logp_max = constraints.get('logp_max',100)
        mw_min = constraints.get('mw_min',-1)
        mw_max = constraints.get('mw_max',10000)
        relevant_molecules = []
        # relevant_fragments = []

        # Very naive approach: filter molecules that are "close" to constraints
        for mol_data in self.data:
            props = mol_data.get("properties", {})
            mw = props.get("Molecular Weight", None)
            qed = props.get("QED", None)
            logp = props.get('LogP',None)
            smiles = mol_data.get("smiles", "")

            if must != ''  and not self._has_element(smiles,must):
                continue

            # Keep if QED in [0.5, 0.9] just for demonstration
            if (qed_min <= qed <= qed_max) and (logp_min <= logp <= logp_max) and (mw_min <= mw <= mw_max):
                relevant_molecules.append(mol_data)

        # # Similarly for fragments
        # for frag_data in self.dataset.get("fragments", []):
        #     # Possibly check if it has 'F', or approximate QED, etc.
        #     relevant_fragments.append(frag_data)

        # return {
        #     "molecules": relevant_molecules,
        #     "fragments": relevant_fragments
        # }
        return relevant_molecules

    def retrieve_exact(self, constraints, range_size):
        """
        Return molecules/fragments from the dataset that 
        partially match the constraints (or are just relevant).
        """

        must = constraints.get('must_contain','')
        qed_min = constraints.get('qed',-1) - range_size
        qed_max = constraints.get('qed',100) + range_size
        logp_min = constraints.get('logp',-1) - range_size*5
        logp_max = constraints.get('logp',100) + range_size*5
        mw_min = constraints.get('mw',-1) - 30
        mw_max = constraints.get('mw',10000) + 30
        relevant_molecules = []
        # relevant_fragments = []

        # Very naive approach: filter molecules that are "close" to constraints
        for mol_data in self.data:
            props = mol_data.get("properties", {})
            mw = props.get("Molecular Weight", None)
            qed = props.get("QED", None)
            logp = props.get('LogP',None)
            smiles = mol_data.get("smiles", "")

            if must != ''  and not self._has_element(smiles,must):
                continue

            # Keep if QED in [0.5, 0.9] just for demonstration
            if (qed_min <= qed <= qed_max) and (logp_min <= logp <= logp_max) and (mw_min <= mw <= mw_max):
                relevant_molecules.append(smiles)

        # # Similarly for fragments
        # for frag_data in self.dataset.get("fragments", []):
        #     # Possibly check if it has 'F', or approximate QED, etc.
        #     relevant_fragments.append(frag_data)

        # return {
        #     "molecules": relevant_molecules,
        #     "fragments": relevant_fragments
        # }
        return relevant_molecules
