from rdkit import Chem

def mol_prop(mol, prop):
    try:
        mol = Chem.MolFromSmiles(mol)
    except Exception as e:
        return 'Chem.MolFromSmiles ERROR ：{}'.format(e)
    
    if mol is None:
        return 'SMILES ERROR'
    
    # Remove the "num_" prefix, if present
    if prop.startswith("num_"):
        prop = prop[4:]
    
    # Define a dictionary of all SMARTS patterns
    smarts_patterns = {
        "benzene_ring": '[cR1]1[cR1][cR1][cR1][cR1][cR1]1',
        "hydroxyl": '[OX2H]',
        "anhydride": '[CX3](=[OX1])[OX2][CX3](=[OX1])',
        "aldehyde": '[CX3H1](=O)[#6]',
        "ketone": '[#6][CX3](=O)[#6]',
        "carboxyl": '[CX3](=O)[OX2H1]',
        "ester": '[#6][CX3](=O)[OX2H0][#6]',
        "amide": '[NX3][CX3](=[OX1])[#6]',
        "amine": '[NX3;H2,H1;!$(NC=O)]',
        "nitro": '[$([NX3](=O)=O),$([NX3+](=O)[O-])][!#8]',
        "halo": '[F,Cl,Br,I]',
        "thioether": '[SX2][CX4]',
        "nitrile": '[NX1]#[CX2]',
        "thiol": '[#16X2H]',
        "sulfide": '[#16X2H0]',  # Require special handling
        "disulfide": '[#16X2H0][#16X2H0]',
        "sulfoxide": '[$([#16X3]=[OX1]),$([#16X3+][OX1-])]',
        "sulfone": '[$([#16X4](=[OX1])=[OX1]),$([#16X4+2]([OX1-])[OX1-])]',
        "borane": '[BX3]'
    }
    
    # Check if the attribute exists in the dictionary
    if prop not in smarts_patterns:
        raise ValueError(f'Property {prop} not supported')
        
    # Get SMARTS patterns
    smarts = smarts_patterns[prop]
    
    # Attributes requiring special handling
    if prop == "sulfide":
        matches = mol.GetSubstructMatches(Chem.MolFromSmarts(smarts))
        exception = mol.GetSubstructMatches(Chem.MolFromSmarts(smarts_patterns["disulfide"]))
        return len(matches) - len(exception)
    
    # Standard processing
    matches = mol.GetSubstructMatches(Chem.MolFromSmarts(smarts))
    return len(matches)

### ===================数据验证器/奖励函数===================
import re

def tool_catch_anwser(all_anwsers):
    r"""
    从模型回答中找到\boxed{}中的最终答案，若有多个，则使用', '分隔。
    
    参数：
    1. all_anwsers：待提取的答案，可能包含多个答案
    
    返回：
    1. extracted_value：提取后的答案，可能是一个字符串或列表
    """
    pattern = r"\\boxed\{([^{}]*(?:\{[^{}]*\}[^{}]*)*)\}"
    match = re.findall(pattern, all_anwsers)
    

    extracted_value = ''
    if match:
        extracted_value = match[-1]  
    else:
        print("没有找到\\boxed{}中的答案，返回整个回答")
        extracted_value = all_anwsers  # 返回最后一个答案
    
    if "<SMILES>" in extracted_value and "</SMILES>" in extracted_value:
        extracted_value = extracted_value.split("<SMILES>")[1].split("</SMILES>")[0]

    if isinstance(extracted_value, list):
        extracted_value = ', '.join(extracted_value)
    print("提取的答案:", extracted_value)
    return extracted_value     

def chemistry_tomgbench_functionalgroup_qa_reward(
    question: str, 
    answer_content: str, 
    ground_truth: str,
):
    '''
    奖励函数，用于检查模型生成的分子结构是否与参考答案在功能团数量上匹配。可以使用在functionalgroup、AddComponent、DelComponent、SubComponent四个任务上
    
    技术方案：
        从模型回答中提取出分子结构(SMILES格式)，然后比较该分子与参考答案中
        各种功能团(如苯环、羟基、酸酐、醛基、酮基等)的数量是否一致。如果
        所有功能团数量都匹配，则认为答案正确。
        
    参数：
        question: str, 问题内容（当前函数中未使用）
        answer_content: str, 模型输出的答案，应包含\boxed{}标记的SMILES字符串
        ground_truth: str, 参考答案，应为正确的SMILES字符串
        
    返回：
        bool，如果模型生成的分子与参考答案中各功能团的数量完全一致，则返回True，
        否则返回False。任何一种功能团数量不匹配或SMILES解析错误都会返回False。
    
    '''
    original_mol = tool_catch_anwser(answer_content)  # 等待检测的分子
    functional_group  = ['benzene rings', 'hydroxyl', 'anhydride', 'aldehyde', 'ketone', 'carboxyl', 'ester', 'amide', 'amine', 'nitro', 'halo', 'nitrile', 'thiol', 'sulfide', 'disulfide', 'sulfoxide', 'sulfone', 'borane']
    # 苯环（benzene rings），羟基（hydroxyl），酸酐（anhydride），醛基（aldehyde），酮基（ketone），羧基（carboxyl），酯基（ester），酰胺基（amide），氨基（amine），硝基（nitro），卤代基（halo），氰基（nitrile），巯基（thiol），硫醚（sulfide），二硫化物（disulfide），亚砜基（sulfoxide），砜基（sulfone），硼烷（borane）
    flag = True
    functional_group_dict = {}
    for group in functional_group:  # 统计参考答案分子中的基团
        if group == "benzene rings":
            org_nums = mol_prop(ground_truth, "num_benzene_ring")
        else:
            org_nums = mol_prop(ground_truth, "num_" + group)
        
        functional_group_dict[group] = org_nums
            
    print('FunctionalGroup_dict:', functional_group_dict)
    for group in functional_group:
        truth_nums = functional_group_dict.get(group, 0) # 如果没有这个基团，默认是0
        if 'ERROR' in str(truth_nums):
            flag = False
            
        if group == "benzene rings":
            org_nums = mol_prop(original_mol, "num_benzene_ring")
            print(f"{group} 待检测分子中的基团数目: {org_nums}, 参考答案中的基团数目: {truth_nums}")
            
        else:
            org_nums = mol_prop(original_mol, "num_" + group)
            print(f"{group} 待检测分子中的基团数目: {org_nums}, 参考答案中的基团数目: {truth_nums}")
            
        if org_nums != truth_nums:
            flag = False
                # break
    return flag

flag = chemistry_tomgbench_functionalgroup_qa_reward(
    question="""{question}""",
    answer_content={answer_content},  # 待检测的分子
    ground_truth={ground_truth}
)

print('\n==========>chemistry_tomgbench_functionalgroup_qa_reward:',flag,'\n\n')
