from rdkit.Chem import Descriptors
from rdkit import Chem
### ===================数据验证器/奖励函数===================
import re


def is_brackets_balanced(s):
    """检查字符串中的括号是否平衡"""
    stack = []
    for char in s:
        if char in '([':
            stack.append(char)
        elif char == ')':
            if not stack or stack.pop() != '(':
                return False
        elif char == ']':
            if not stack or stack.pop() != '[':
                return False
    return len(stack) == 0

def clean_smiles(smiles):
    """清理SMILES字符串，确保括号平衡，删除多余的右括号和标点符号"""
    if not smiles:
        return smiles
    
    # 去除结尾的标点符号和右括号
    while smiles and (smiles[-1] in '.,;:' or (smiles[-1] in ')]' and not is_brackets_balanced(smiles))):
        smiles = smiles[:-1]
    
    return smiles

def extract_smiles(text,len_pre):
    # 方法1：从"正确的产物SMILES应如下："后提取
    # pattern1 = r'正确的产物SMILES应如下：\s*\n\s*([A-Za-z0-9\(\)\[\]\+=\-#@\\\/\.\*]+)'
    # match1 = re.search(pattern1, text)
    # if match1:
    #     return match1.group(1)
    
    # 方法2：查找最可能的SMILES字符串
    # 寻找包含常见SMILES特征的长字符串
    # pattern = r'([A-Z][a-z]?(?:[0-9]+)?(?:[=\(\)\[\]\\\/\.]|[A-Z][a-z]?|[0-9])+)'
    pattern = r'([A-Za-z][a-z]?(?:[0-9]+)?(?:[=\(\)\[\]\+=\-#@\\\/\.\*\^\~\&]|[A-Za-z][a-z]?|[0-9])+)'
    
    matches = re.findall(pattern, text)
    
    # 过滤可能的SMILES字符串
    candidates = []
    for match in matches:
        carbon_count = len(re.findall(r'C(?![a-z])', match))
        oxygen_count = len(re.findall(r'O(?![a-z])', match))
        # SMILES通常包含特定模式如环闭合数字、化学键或括号  防止匹配上英文单词
        
            
        if (len(match) >= 6 and  # 降低长度限制
            (re.search(r'[A-Za-z][a-z]?[0-9]', match) or  # 环闭合
             re.search(r'[=#]', match) or  # 双键,三键
             carbon_count >= 3 or  # 降低碳原子数要求
             re.search(r'\([^)]+\)', match) or  # 带括号的分支
             # 添加新规则: 匹配含多元素的简单链状分子
             re.search(r'[CNOPS][A-Z]', match))):  # 常见有机元素连接
            
            match=match.split('.')
            if isinstance(match,list):
                for i in match:
                    if i:
                        candidates.append(i)
                # candidates.extend(match)
            else:
                candidates.append(match)
            
    #  (C₂H₅)₃SiH，  -C₆H₄Br  是分子式，不用提取
    
    
    smiles_tag_pattern = r'<SMILES>(.*?)</SMILES>'
    # 如果在<SMILES>标签内找到内容，优先返回
    tag_match = re.search(smiles_tag_pattern, text, re.DOTALL)  
    if tag_match:
        # 从标签中提取并清理SMILES字符串
        tagged_smiles = tag_match.group(1).strip()
        if tagged_smiles and tagged_smiles not in candidates:
            # 确保不重复添加
            candidates.append(tagged_smiles)
    
            
    if candidates:  
        # 返回所有候选项而不是只返回最长的
        final_candidates = []
        for i in candidates:
            i = clean_smiles(i)
            # final_candidates.append(i)
            
            if  i and i not in final_candidates and Chem.MolFromSmiles(i):  # 验证SMILES的准确性
                final_candidates.append(i)
        
        if not final_candidates:
            return []  # 或者返回默认值，如 [""]
        
        max_length = max([len(i) for i in final_candidates])
        final_candidates = [i for i in final_candidates if len(i) >= max_length * len_pre]
        return final_candidates
    # c1ccccc1),c2ccccc2)   N-C-S-C-S-S-C-c1ccccc1
    return "未找到SMILES字符串"

def tool_catch_anwser(all_anwsers):
    r"""
    从模型回答中找到\boxed{}中的最终答案，若有多个，则使用', '分隔。
    
    参数：
    1. all_anwsers：待提取的答案，可能包含多个答案
    
    返回：
    1. extracted_value：提取后的答案，可能是一个字符串或列表
    """
    pattern = r"\\boxed\{([^{}]*(?:\{[^{}]*\}[^{}]*)*)\}"
    match = re.findall(pattern, all_anwsers)
    

    extracted_value = ''
    if match:
        extracted_value = match[-1]   
    else:
        print("没有找到\\boxed{}中的答案，返回整个回答")
        extracted_value = all_anwsers # 返回最后一个答案
    
    if "<SMILES>" in extracted_value and "</SMILES>" in extracted_value:
        extracted_value = extracted_value.split("<SMILES>")[1].split("</SMILES>")[0]

    if isinstance(extracted_value, list):
        extracted_value = ', '.join(extracted_value)
    print("提取的答案:", extracted_value)
    return extracted_value 

def chemistry_tomgbench_logp_qa_reward(
    question: str, 
    answer_content: str, 
    ground_truth: str,
):
    '''
    用于检查模型生成的分子是否与参考答案在LogP值的变化方向上一致。
    
    技术方案：
        从问题中提取出原始分子结构(SMILES格式)，从模型回答中提取生成的分子结构，
        然后计算这两个分子与参考答案分子的LogP值。比较模型生成的分子与原始分子的
        LogP值变化方向（增加或减少）是否与参考答案分子和原始分子的LogP值变化方向一致。
        
    参数：
        question: str, 问题内容，包含原始分子的SMILES字符串
        answer_content: str, 模型输出的答案，应包含\boxed{}标记的SMILES字符串
        ground_truth: str, 参考答案，应为正确的SMILES字符串
        
    返回：
        bool，如果模型生成的分子与原始分子的LogP值变化方向与参考答案分子与原始分子的
        LogP值变化方向一致，则返回True，否则返回False。任何SMILES解析错误或LogP计算
        错误都会返回False。
    '''
    try:
        original_mol = extract_smiles(question,1)[0]  # 提取问题中的分子（原始分子）  SMILES最短长度设置为1
    except Exception as e:
        print(f"提取原始分子({original_mol})时发生错误: {e}")
        return False
    
    try:
        model_gen_mol = tool_catch_anwser(answer_content)  # 提取模型生成的分子
    except Exception as e:
        print(f"提取模型生成分子({model_gen_mol})时发生错误: {e}")
        return False
    try:
        original_mol = Chem.MolFromSmiles(original_mol)
        original_mol_attribute = Descriptors.MolLogP(original_mol)
    except Exception as e:
        print(f"原始分子({original_mol})转为标准式，并且计算LogP时发生错误: {e}")
        return False
    
    try:
        model_gen_mol = Chem.MolFromSmiles(model_gen_mol)
        model_gen_mol_attribute = Descriptors.MolLogP(model_gen_mol)
    except Exception as e:
        print(f"模型生成分子转为标准式，并且计算LogP时发生错误: {e}")
        return False
    
    try:
        ground_truth_mol = Chem.MolFromSmiles(ground_truth) # 参考答案分子
        ground_truth_mol_attribute = Descriptors.MolLogP(ground_truth_mol)
    except Exception as e:
        print(f"参考答案分子转为标准式，并且计算LogP时发生错误: {e}")
        return False
    
    print(f"原始分子LogP值:{original_mol_attribute}, 修改后的分子LogP值:{model_gen_mol_attribute}, 参考答案分子LogP值:{ground_truth_mol_attribute}")

    return (ground_truth_mol_attribute>=original_mol_attribute) == (model_gen_mol_attribute>=original_mol_attribute)

result = chemistry_tomgbench_logp_qa_reward(
    question = """{question}""",
    answer_content = {answer_content},
    ground_truth = {ground_truth}
)
print(f"chemistry_tomgbench_logp_qa_reward检测结果: {result}")
   
   