from rdkit import Chem
from rdkit.Chem import Descriptors

def mol_prop(mol, prop):
    try:
        mol = Chem.MolFromSmiles(mol)
    except Exception as e:
        return 'Chem.MolFromSmiles ERROR ：{}'.format(e)
    
    if mol is None:
        return 'SMILES ERROR'
    
    if prop == 'num_single_bonds':
        return sum([bond.GetBondType() == Chem.rdchem.BondType.SINGLE for bond in mol.GetBonds()])
    elif prop == 'num_double_bonds':
        return sum([bond.GetBondType() == Chem.rdchem.BondType.DOUBLE for bond in mol.GetBonds()])
    elif prop == 'num_triple_bonds':
        return sum([bond.GetBondType() == Chem.rdchem.BondType.TRIPLE for bond in mol.GetBonds()])
    elif prop == 'num_aromatic_bonds':
        return sum([bond.GetBondType() == Chem.rdchem.BondType.AROMATIC for bond in mol.GetBonds()])
    elif prop == 'num_rotatable_bonds': # rotatable bonds
        return Descriptors.NumRotatableBonds(mol)
    else:
        raise ValueError(f'Property {prop} not supported')

### ===================数据验证器/奖励函数===================
import re

def tool_catch_anwser(all_anwsers):
    r"""
    从模型回答中找到\boxed{}中的最终答案，若有多个，则使用', '分隔。
    
    参数：
    1. all_anwsers：待提取的答案，可能包含多个答案
    
    返回：
    1. extracted_value：提取后的答案，可能是一个字符串或列表
    """
    pattern = r"\\boxed\{([^{}]*(?:\{[^{}]*\}[^{}]*)*)\}"
    match = re.findall(pattern, all_anwsers)
    

    extracted_value = ''
    if match:
        extracted_value = match[-1] 
    else:
        print("没有找到\\boxed{}中的答案，返回整个回答")
        extracted_value = all_anwsers  # 返回最后一个答案
    
    if "<SMILES>" in extracted_value and "</SMILES>" in extracted_value:
        extracted_value = extracted_value.split("<SMILES>")[1].split("</SMILES>")[0]

    if isinstance(extracted_value, list):
        extracted_value = ', '.join(extracted_value)
    print("提取的答案:", extracted_value)
    return extracted_value       


def chemistry_tomgbench_bondnumgroup_qa_reward(
    question: str, 
    answer_content: str, 
    ground_truth: str,
):
    '''
    奖励函数，用于检查模型生成的分子结构是否与参考答案在化学键数量上匹配。
    
    技术方案：
        从模型回答中提取出分子结构(SMILES格式)，然后比较该分子与参考答案中
        各类化学键(单键、双键、三键、可旋转键、芳香键)的数量是否一致。如果
        所有化学键数量都匹配，则认为答案正确。
        
    参数：
        question: str, 问题内容（当前函数中未使用）
        answer_content: str, 模型输出的答案，应包含\boxed{}标记的SMILES字符串
        ground_truth: str, 参考答案，应为正确的SMILES字符串
        
    返回：
        bool，如果模型生成的分子与参考答案中各类化学键的数量完全一致，则返回True，
        否则返回False。任何一种化学键数量不匹配或SMILES解析错误都会返回False。
    
    '''
    original_mol = tool_catch_anwser(answer_content)  # 等待检测的分子
    bonds_type = ['single', 'double', 'triple', 'rotatable', 'aromatic']
    # 单键，双键，三键，可旋转键，芳香键
    flag = True
    bondnums_dict = {}
    for bond in bonds_type:  # 统计参考答案分子中的化学键数目
        bondnums_dict[bond] = mol_prop(ground_truth, "num_" + bond + "_bonds")
    
    for bond in bonds_type:
        truth_nums = bondnums_dict.get(bond, 0)  # 如果没有这个基团，默认是0
        if 'ERROR' in str(truth_nums):
            flag = False
        
        org_nums = mol_prop(original_mol, "num_" + bond + "_bonds")
        print(f"化学键种类：{bond}，待检测分子中的化学键数目: {org_nums}, 问题中要求的化学键数目: {truth_nums}")
        
        if  org_nums != truth_nums:
            flag = False
    return flag

flag = chemistry_tomgbench_bondnumgroup_qa_reward(
    question = """{question}""",
    answer_content = {answer_content},  # 待检测的分子
    ground_truth = {ground_truth}
)
print('\n==========>chemistry_tomgbench_bondnumgroup_qa_reward:',flag,'\n\n')
    
