from rdkit import Chem

ATOMIC_NUMBERS = {
    'carbon': 6, 'nitrogen': 7, 'oxygen': 8, 'fluorine': 9,
    'phosphorus': 15, 'sulfur': 16, 'chlorine': 17, 'bromine': 35,
    'iodine': 53, 'boron': 5, 'silicon': 14, 'selenium': 34,
    'tellurium': 52, 'arsenic': 33, 'antimony': 51, 'bismuth': 83,
    'polonium': 84
}

def mol_prop(mol, prop):
    try:
        mol = Chem.MolFromSmiles(mol)
    except Exception as e:
        return 'Chem.MolFromSmiles ERROR ：{}'.format(e)
    
    if mol is None:
        return 'SMILES ERROR'
    
    if prop.startswith('num_'):
        element = prop[4:]  
        if element in ATOMIC_NUMBERS:
            return sum(atom.GetAtomicNum() == ATOMIC_NUMBERS[element] for atom in mol.GetAtoms())
    
    raise ValueError(f'Property {prop} not supported')

### ===================数据验证器/奖励函数===================
import re

def tool_catch_anwser(all_anwsers):
    r"""
    从模型回答中找到\boxed{}中的最终答案，若有多个，则使用', '分隔。
    
    参数：
    1. all_anwsers：待提取的答案，可能包含多个答案
    
    返回：
    1. extracted_value：提取后的答案，可能是一个字符串或列表
    """
    pattern = r"\\boxed\{([^{}]*(?:\{[^{}]*\}[^{}]*)*)\}"
    match = re.findall(pattern, all_anwsers)
    

    extracted_value = ''
    if match:
        extracted_value = match[-1]
    else:
        print("没有找到\\boxed{}中的答案，返回整个回答")
        extracted_value = all_anwsers  # 返回最后一个答案
    
    if "<SMILES>" in extracted_value and "</SMILES>" in extracted_value:
        extracted_value = extracted_value.split("<SMILES>")[1].split("</SMILES>")[0]

    if isinstance(extracted_value, list):
        extracted_value = ', '.join(extracted_value)
    print("提取的答案:", extracted_value)
    return extracted_value       


def chemistry_tomgbench_atomnumgroup_qa_reward(
    question: str, 
    answer_content: str, 
    ground_truth: str,
):
    '''
    奖励函数，用于检查模型生成的分子结构是否与参考答案在原子数量上匹配。
    
    技术方案：
        从模型回答中提取出分子结构(SMILES格式)，然后比较该分子与参考答案中
        各类原子(碳、氧、氮等)的数量是否一致。如果所有原子数量都匹配，则认为
        答案正确。
        
    参数：
        question: str, 问题内容（当前函数中未使用）
        answer_content: str, 模型输出的答案，应包含\boxed{}标记的SMILES字符串
        ground_truth: str, 参考答案，应为正确的SMILES字符串
        
    返回：
        bool，如果模型生成的分子与参考答案中各类原子的数量完全一致，则返回True，
        否则返回False。任何一种原子数量不匹配或SMILES解析错误都会返回False。
    
    '''
    original_mol = tool_catch_anwser(answer_content)  # 等待检测的分子
    atom_type = ['carbon', 'oxygen', 'nitrogen', 'sulfur', 'fluorine', 'chlorine', 'bromine', 'iodine', 'phosphorus', 'boron', 'silicon', 'selenium', 'tellurium', 'arsenic', 'antimony', 'bismuth', 'polonium']
    # 碳（carbon），氧（oxygen），氮（nitrogen），硫（sulfur），氟（fluorine），氯（chlorine），溴（bromine），碘（iodine），磷（phosphorus），硼（boron），硅（silicon），硒（selenium），碲（tellurium），砷（arsenic），锑（antimony），铋（bismuth），钋（polonium）
    
    flag = True
    atomnums_dict = {}
    for atom in atom_type:  # 统计参考答案分子中的原子个数
        org_nums = mol_prop(ground_truth, "num_" + atom)
        atomnums_dict[atom] = org_nums
        
    
    for atom in atom_type:
        truth_nums = atomnums_dict.get(atom, 0)  # 如果没有这个原子，默认是0
        if 'ERROR' in str(truth_nums):
            flag = False
            
        org_nums = mol_prop(original_mol, "num_" + atom)
        print(f"{atom}:", org_nums, truth_nums)
        if org_nums != truth_nums:
            flag = False
                
    return flag
       
flag = chemistry_tomgbench_atomnumgroup_qa_reward(
    question = """{question}""",
    answer_content = {answer_content},
    ground_truth = {ground_truth}
)
print('\n==========>chemistry_tomgbench_functchemistry_tomgbench_atomnumgroup_qa_rewardionalgroup_qa_reward:',flag,'\n\n')
    