import json
import re
import sys
import ast
from itertools import combinations
from sklearn.metrics import brier_score_loss
import random
import numpy as np
from tqdm import tqdm
from nltk.translate.bleu_score import sentence_bleu, SmoothingFunction
from chem_metrics import mol_opt_evaluater, smiles_similarity, calculate_kendall_tau, compute_classification_metrics
from chem_metrics.eval_mol_edit import check_edit_add_valid, check_edit_del_valid, check_edit_sub_valid
from rdkit import RDLogger
RDLogger.DisableLog('rdApp.*')


def parse_list(s):
    # 提取数字
    nums = re.findall(r'-?\d+', s)
    return [int(n) for n in nums]

def preprocess_text(text):
    
    text = re.sub(r"'([^']*)':", r'"\1":', text)  # 替换键的单引号
    text = re.sub(r":\s*'([^']*)'", r': "\1"', text)  # 替换值的单引号

    # 4. 处理没有引号的字符串值
    text = re.sub(
        r':\s*([^\[{"][^,\}\n]*)',
        lambda m: (
            f': "{m.group(1).strip()}"'  # 加引号
            if not re.match(r'^["\'].*["\']$', m.group(1).strip()) and  # 不已经是引号字符串
            not re.match(r'^(true|false|null|\d+(\.\d+)?)(\s*[,}\n])?$', m.group(1).strip())  # 也不是合法原始值
            else f': {m.group(1).strip()}'
        ),
        text
    )

    # 移除trailing comma（对象中）
    text = re.sub(r',(\s*})', r'\1', text)
    # 2. 移除trailing comma（数组中）
    text = re.sub(r',(\s*])', r'\1', text)
    
    # 3. 处理没有引号的键名（简单情况）
    text = re.sub(r'(\w+):', r'"\1":', text)

    # 4. 处理多行和额外的空白字符
    text = re.sub(r'\s+', ' ', text)
    
    # 5. 确保布尔值和null值是小写
    text = re.sub(r'\bTrue\b', 'true', text)
    text = re.sub(r'\bFalse\b', 'false', text)
    text = re.sub(r'\bNone\b', 'null', text)
    return text

def AnswerExtract(answer_text, dataset_name, stop_reasoning_tokens=["</think>", 'assistantfinal']):
    # 为不同的数据集提取答案
    for stop_reasoning_token in stop_reasoning_tokens:
        if stop_reasoning_token in answer_text:
            answer_text = answer_text.split(stop_reasoning_token)[-1].strip()
    
    if dataset_name == 'ChemCoTDataset':
        # 提取JSON格式代码
        pattern = r'\{(?:[^{}"]|"(?:[^"\\]|\\.)*")*\}'
        matches = re.findall(pattern, answer_text)

        if not matches:
            return answer_text.strip().split('\n')[-1].strip()
        result = []
        for match in matches:
            # print(match)
            try:
                match = preprocess_text(match)
                parsed = json.loads(match)
                result.append(parsed)
            except json.JSONDecodeError:
                continue
        return result[-1] if result else ""
    
    elif dataset_name == 'BioProBench':
        pattern = r"\[ANSWER_START\](.*?)\[ANSWER_END\]"
        match = re.search(pattern, answer_text, re.DOTALL)

        if match:
            answer = match.group(1).strip()
        else:
            # Fall back to last line heuristics
            answer_text = answer_text.split("[ANSWER_START]")[-1].split("[ANSWER_END]")[0]
            answer_text = answer_text.strip()
            # last_line = answer_text.split('\n')[-1]
            last_line = answer_text
            answer = last_line.strip()
        return answer
    
    else:
        raise NotImplementedError(f"Unknown dataset: {dataset_name}")


class BioProBenchEval:
    def __init__(self, data=None):
        self.dataset_name = 'BioProBench'
        self.data = data
    
    def classify(self, data):     
        self.data = data
        for i in tqdm(range(len(data)), total=len(data), desc="Classifying data in BioProBench"):
            if self.data[i]['task'] == 'ERR':
                result = self.ERR(i)
            elif self.data[i]['task'] == 'PQA':
                result = self.PQA(i)
            elif self.data[i]['task'] == 'ORD':
                result = self.ORD(i)
            elif self.data[i]['task'] == 'GEN':
                result = self.GEN(i)
            else:
                raise NotImplementedError(f"Unknown task: {self.data[i]['task']}")
            
            if result >= 0.6:
                self.data[i]['difficulty'] = 'Easy'
            elif result >= 0.2:
                self.data[i]['difficulty'] = 'Medium'
            else:
                self.data[i]['difficulty'] = 'Hard'

        return self.data

    def evaluate(self, data, setting="train"):
        self.data = data # new data
        PQA_result = {'acc': [], 'conf': []}
        ORD_result = {'preds': [], 'gts': [], 'acc': [], 'fail': []}
        ERR_result = {'preds': [], 'gts': [], 'acc': []}
        GEN_result = []
        for i, d in tqdm(enumerate(self.data), total=len(self.data)):
            if self.data[i]['task'] == 'ERR':
                result = self.ERR(i, self.data[i]['output'], setting=setting)
                ERR_result['acc'].append(self.data[i]['eval_result']['acc'])
                ERR_result['preds'].append(self.data[i]['eval_result']['pred'])
                ERR_result['gts'].append(self.data[i]['ground_truth'])
                
            elif self.data[i]['task'] == 'PQA':
                if setting == "train":
                    result = self.PQA(i, self.data[i]['output'])
                    PQA_result['acc'].append(self.data[i]['eval_result']['acc'])
                else:
                    result = self.PQA(i, self.data[i]['output'], setting=setting)
                    PQA_result['acc'].append(self.data[i]['eval_result']['acc'])
                    PQA_result['conf'].append(self.data[i]['eval_result']['conf'])
            elif self.data[i]['task'] == 'ORD':
                if setting == "train":
                    result = self.ORD(i, self.data[i]['output'])
                    ORD_result['acc'].append(self.data[i]['eval_result']['acc'])
                else:
                    result = self.ORD(i, self.data[i]['output'], setting=setting)
                    if self.data[i]['eval_result']['fail'] > 0:
                        ORD_result['fail'].append(self.data[i]['eval_result']['fail'])
                        continue
                    assert isinstance(self.data[i]['eval_result']['indices'], list)
                    ORD_result['preds'].append(self.data[i]['eval_result']['indices'])
                    ORD_result['gts'].append(self.data[i]['ground_truth'])
                    ORD_result['acc'].append(self.data[i]['eval_result']['acc'])
                    ORD_result['fail'].append(self.data[i]['eval_result']['fail'])

            elif self.data[i]['task'] == 'GEN':
                result = self.GEN(i, self.data[i]['output'])
                GEN_result.append(result)
            else:
                raise NotImplementedError(f"Unknown task: {self.data[i]['task']}")
        
        if setting != "train":
            ERR_metrics = compute_classification_metrics(ERR_result['preds'], ERR_result['gts'])
            print("ERR: ", ERR_metrics)

            # ORD
            Acc = sum(ORD_result['acc']) / len(ORD_result['acc'])
            Fail = sum(ORD_result['fail']) / len(ORD_result['fail'])
            Bs = calculate_kendall_tau(ORD_result['gts'], ORD_result['preds'])
            print("ORD: ", f"Acc {Acc}", f"Fail {Fail}", f"Bs {Bs}")

            # PQA
            acc = sum(PQA_result['acc']) / len(PQA_result['acc'])
            bs = brier_score_loss(PQA_result['acc'], np.array([min(conf, 100) for conf in PQA_result['conf']])/100)
            print("PQA: ", acc, bs)

            # GEN
            bleu = np.mean(GEN_result)
            print("GEN: ", bleu)
        else:
            ERR_metrics = compute_classification_metrics(ERR_result['preds'], ERR_result['gts'])
            print("ERR: ", ERR_metrics)

            # ORD
            Acc = sum(ORD_result['acc']) / len(ORD_result['acc'])
            print("ORD: ", f"Acc {Acc}")
            # PQA
            acc = sum(PQA_result['acc']) / max(len(PQA_result['acc']), 1)
            print("PQA: ", acc)

            # GEN
            bleu = np.mean(GEN_result)
            print("GEN: ", bleu)

        return self.data
    
    def load_data(self):
        if self.data is None:
            raise ValueError("Data has not been loaded yet.")
        return self.data
    
    def ERR(self, idx, preds=None, setting='train'):
        gt = self.data[idx]['is_correct'] # ground truth
        if preds is None:
            preds = [
                AnswerExtract(p, self.dataset_name) for p in self.data[idx]['outputs']
            ] # extracted predictions
        else:
            preds = [AnswerExtract(preds, self.dataset_name)]

        preds = [True if 'true' in p.lower() else False for p in preds]
        acc = [p == gt for p in preds]

        if setting == 'train':
            self.data[idx]['eval_result'] = {"outputs": acc, "acc": sum(acc)/len(acc), 'pred': preds[0]}
            self.data[idx]['ground_truth'] = gt
            return sum(acc)/len(acc)
        else:
            assert len(preds) == 1
            self.data[idx]['eval_result'] = {"outputs": acc[0], "acc": sum(acc)/len(acc), 'pred': preds[0]}
            self.data[idx]['ground_truth'] = gt
            return sum(acc)/len(acc)

    def PQA(self, idx, preds=None, setting='train'):
        gt = self.data[idx]['answer'] # ground truth
        if preds is None:
            preds = [
                AnswerExtract(p, self.dataset_name) for p in self.data[idx]['outputs']
            ] # extracted predictions
        else:
            preds = [AnswerExtract(preds, self.dataset_name)]

        # Handle possible answer-confidence formats
        if setting == 'train':
            preds = [p.split("&")[0].strip() for p in preds]

            preds = [(gt in p) or (p == gt) for p in preds]
            self.data[idx]['eval_result'] = {"outputs": preds, "acc": sum(preds)/len(preds)}
            self.data[idx]['ground_truth'] = gt

            return sum(preds)/len(preds)
        else:
            answers, confidences = [], []
            for p in preds:
                if "&" in p:
                    answer, confidence = p.split("&")[0].strip(), p.split("&")[-1].strip()
                else:
                    answer, confidence = p.split(" ")[0].strip(), p.split(" ")[-1].strip()

                confidence_match = re.search(r"\d+", confidence)
                confidence = int(confidence_match.group()) if confidence_match else 100
                answers.append(answer.strip())
                confidences.append(confidence)
            
            acc = [int((gt in a) or (a == gt)) for a in answers]
            assert len(confidences) == 1
            # bs = brier_score_loss(acc, np.array(confidences)/100)
            self.data[idx]['eval_result'] = {"outputs": acc, "acc": sum(acc)/len(acc), "conf": confidences[0]}
            self.data[idx]['ground_truth'] = gt

        return sum(acc)/len(acc)
    
    def ORD(self, idx, preds=None, setting="train"):
        gt = self.data[idx]['correct_steps'] # ground truth
        wrong_steps = self.data[idx]['wrong_steps']
        if preds is None:
            preds = [
                AnswerExtract(p, self.dataset_name) for p in self.data[idx]['outputs']
            ] # extracted predictions
        else:
            preds = [AnswerExtract(preds, self.dataset_name)]

        indices = []
        for p in preds:
            try:
                indice = ast.literal_eval(p)
                if not isinstance(indice, list) and not isinstance(indice, tuple):
                    indice = [indice]
                elif isinstance(indice, tuple):
                    indice = [x[0] for x in indice]
            except:
                indices.append([])
                continue
            try:
                if set(indice) != set(range(len(gt))):
                    indices.append([])
                    continue
            except Exception as e:
                print(indice)
                raise e
            indices.append([wrong_steps[i] for i in indice])
        
        if setting == 'train':
            preds = [p == gt for p in indices]
            self.data[idx]['eval_result'] = {"outputs": preds, "acc": sum(preds)/len(preds)}
            self.data[idx]['ground_truth'] = gt
        else:
            fail = 0
            if indices[0] == []:
                fail = 1
            preds = [p == gt for p in indices]
            assert len(preds) == 1
            self.data[idx]['eval_result'] = {"outputs": preds[0], "acc": sum(preds)/len(preds), 'fail': fail, 'indices': indices[0]}
            self.data[idx]['ground_truth'] = gt

        return sum(preds)/len(preds)
    
    def GEN(self, idx, preds=None):
        gt = self.data[idx]['output'] # ground truth
        if preds is None:
            preds = [
                AnswerExtract(p, self.dataset_name) for p in self.data[idx]['outputs']
            ] # extracted predictions
        else:
            preds = [AnswerExtract(preds, self.dataset_name)]

        gt_tokens = nltk.word_tokenize(gt.lower())
        preds_tokens = [nltk.word_tokenize(p.lower()) for p in preds]

        bleus = [sentence_bleu([gt_tokens], p_tokens, weights=(0.5, 0.5), smoothing_function=SmoothingFunction().method1) for p_tokens in preds_tokens]

        self.data[idx]['eval_result'] = {"outputs": bleus, "acc": sum(bleus)/len(bleus)}
        self.data[idx]['ground_truth'] = gt

        return sum(bleus)/len(bleus)

            
class ChemCoTEval:
    def __init__(self, data=None):
        self.dataset_name = 'ChemCoTDataset'
        self.data = data
    
    def classify(self, data):
        self.data = data
        for i in tqdm(range(len(data)), total=len(data), desc="Classifying data in ChemCoTDataset"):
            if self.data[i]['task'] == 'mol_und':
                result = self.mol_und(i)
            elif self.data[i]['task'] == 'mol_edit':
                result = self.mol_edit(i)
            elif self.data[i]['task'] == 'reaction':
                result = self.reaction(i)
            elif self.data[i]['task'] == 'mol_opt':
                result = self.mol_opt(i)
            else:
                raise NotImplementedError(f"Unknown task: {self.data[i]['task']}")
            
            if result >= 0.6:
                self.data[i]['difficulty'] = 'Easy'
            elif result >= 0.2:
                self.data[i]['difficulty'] = 'Medium'
            else:
                self.data[i]['difficulty'] = 'Hard'

        return self.data
    
    def evaluate(self, data, setting="train"):
        self.data = data # new data

        mol_und_result = {
            'fg_count': [], 'Murcko_scaffold': [], 'ring_count': [], 'ring_system_scaffold': [], 'equivalence': [],
        }
        mol_edit_result = {
            'add': [], 'delete': [], 'sub': [],
        }
        reaction_result = {'nepp': [], 'rcr': [], 'mechsel': [], 'retro': [], 'fs-M': [], 'fs-B': []}
        mol_opt_result = {
                "improve": {k: [] for k in set([d['subtask'] for d in self.data if d['task'] == 'mol_opt'])},
                "scaffold_hard": {k: [] for k in set([d['subtask'] for d in self.data if d['task'] == 'mol_opt'])},
                "scaffold_soft": {k: [] for k in set([d['subtask'] for d in self.data if d['task'] == 'mol_opt'])},
            }
        for i, d in tqdm(enumerate(self.data), total=len(self.data)):
            if self.data[i]['task'] == 'mol_und':
                result = self.mol_und(i, self.data[i]['output'], setting=setting)
                mol_und_result[d['subtask']].append(result)
            elif self.data[i]['task'] == 'mol_edit':
                result = self.mol_edit(i, self.data[i]['output'])
                mol_edit_result[d['subtask']].append(result)
            elif self.data[i]['task'] == 'reaction':
                if setting == "train":
                    result = self.reaction(i, self.data[i]['output'])
                else:
                    result = self.reaction_eval(i, self.data[i]['output'])
                if d['subtask'] == 'fs':
                    reaction_result['fs-M'].append(self.data[i]['eval_result']['acc']['major'])
                    if 'byproduct' in self.data[i]['eval_result']['acc']:
                        reaction_result['fs-B'].append(self.data[i]['eval_result']['acc']['byproduct'])
                else:   reaction_result[d['subtask']].append(result)

            elif self.data[i]['task'] == 'mol_opt':
                result = self.mol_opt(i, self.data[i]['output'])
                mol_opt_result['improve'][d['subtask']].append(np.mean([x['improve'] for x in self.data[i]['eval_result']['outputs']]))
                mol_opt_result['scaffold_hard'][d['subtask']].append(np.mean([x['scaffold_hard'] for x in self.data[i]['eval_result']['outputs']]))
                mol_opt_result['scaffold_soft'][d['subtask']].append(np.mean([x['scaffold_soft'] for x in self.data[i]['eval_result']['outputs']]))
            else:
                raise NotImplementedError(f"Unknown task: {self.data[i]['task']}")
        
        for k, v in mol_und_result.items():
            print(f"mol_und {k}: {np.mean(v)}")
        for k, v in mol_edit_result.items():
            print(f"mol_edit {k}: {np.mean(v)}")
        reaction_result = {k: np.mean(v) for k, v in reaction_result.items()}
        print(f"reaction: {reaction_result}")
        mol_opt_result['improve'] = {k: np.mean(v) for k, v in mol_opt_result['improve'].items()}
        mol_opt_result['scaffold_hard'] = {k: np.mean(v) for k, v in mol_opt_result['scaffold_hard'].items()}
        mol_opt_result['scaffold_soft'] = {k: np.mean(v) for k, v in mol_opt_result['scaffold_soft'].items()}
        print(f"mol_opt improve: {mol_opt_result['improve']}")
        print(f"mol_opt scaffold_hard: {mol_opt_result['scaffold_hard']}")
        print(f"mol_opt scaffold_soft: {mol_opt_result['scaffold_soft']}")

        return self.data
    
    def mol_und(self, idx, preds=None, setting="train"):
        # pre-definition
        pred_key_dict = dict(
            fg_count="count", Murcko_scaffold='Output Scaffold', ring_count='count', ring_system_scaffold='output'
        )
        gt_key_dict = dict(
            fg_count="gt", Murcko_scaffold='gt', ring_count='gt', ring_system_scaffold=''
        )
        prop_evaluator = mol_opt_evaluater(prop='qed')
        
        if 'gt' in json.loads(self.data[idx]['meta']):
            gt = json.loads(self.data[idx]['meta'])['gt']
        elif 'gt' in self.data[idx]:
            gt = self.data[idx]['gt']
        elif self.data[idx]['subtask'] != "ring_system_scaffold":
            gt = json.loads(self.data[idx]['meta'])[gt_key_dict[self.data[idx]['subtask']]]
        else:
            # ring_system_scaffold only 
            if json.loads(self.data[idx]['meta'])['molecule'] not in self.data[idx]['struct_cot']:
                self.data[idx]['eval_result'] = {"outputs": [], "acc": 1.0}
                self.data[idx]['ground_truth'] = gt
                return 1.0
            
            self.data[idx]['struct_cot'] = self.data[idx]['struct_cot'].split("```json\\n")[-1].split("\\n```")[0].strip().replace("\\n", "\n").replace('\\"', '\"')
            gt = json.loads(self.data[idx]['struct_cot'])["output"]
            if 'yes' in gt.lower(): gt = "Yes"
            elif 'no' in gt.lower(): gt = "No"
            else: raise Exception(f"No ground truth: {gt}")
        
        task = self.data[idx]['subtask']
        if preds is None:
            preds = [
                AnswerExtract(p, self.dataset_name) for p in self.data[idx]['outputs']
            ] # extracted predictions
        else:
            preds = [AnswerExtract(preds, self.dataset_name)]

        # 提取preds中的json格式数据
        results = []
        for p in preds:
            if not isinstance(p, dict):
                results.append(str(p))
                continue
            if len(p) == 1:
                results.append(list(p.values())[0])
            else:
                try:
                    results.append(p[pred_key_dict[task]])
                except:
                    results.append(list(p.values())[0])
        
        if task in ['ring_system_scaffold', 'equivalence']:
            acc = [1 if str(r).lower() == gt.lower() or gt.lower() in str(r).lower() else 0 for r in results]
        elif task in ["fg_count", "ring_count"]:
            acc = [1 if str(r) == str(gt) or str(gt) in str(r) else 0 for r in results]
            if setting != "train":
                results = [float(r) if str(r).isdigit() else 0 for r in results]
                acc = [np.mean(np.abs(float(r) - float(gt))) for r in results]
        elif task == 'Murcko_scaffold':
            acc = [prop_evaluator.scaffold_consistency(src_mol_list=[gt], tgt_mol_list=[r])[1] for r in results]
        else:
            raise NotImplementedError(f"Unknown task: {task}")
        
        self.data[idx]['eval_result'] = {"outputs": acc, "acc": sum(acc)/len(acc)}
        self.data[idx]['ground_truth'] = gt

        return sum(acc)/len(acc)
    
    def mol_edit(self, idx, preds=None):
        task = self.data[idx]['subtask'] # add, delete, sub
        try:
            gt = json.loads(self.data[idx]['meta'])['molecule']  # source
            if "." in gt:
                assert gt[-1] == '.' and len(gt.split(".")) == 2, gt
                gt = gt.replace(".", '')
        except:
            gt = self.data[idx]['query'].split("Input Molecule: ")[-1].split(",")[0].strip()
        
        if preds is None:
            preds = [
                AnswerExtract(p, self.dataset_name) for p in self.data[idx]['outputs']
            ] # extracted predictions
        else:
            preds = [AnswerExtract(preds, self.dataset_name)]

        results = []
        for p in preds:
            if not isinstance(p, dict):
                results.append(str(p))
                continue
            if len(p) == 1:
                results.append(list(p.values())[0])
            elif 'result' in p:
                results.append(p['result'])
            elif 'output' in p:
                results.append(p['output'])
            else:
                results.append(list(p.values())[0])

        group_a, group_b = None, None
        if task == 'add':
            try:
                group_a = json.loads(self.data[idx]['meta'])['added_group'].replace('.', '')
            except:
                group_a = self.data[idx]['query'].split("Functional Group to add: ")[-1].replace('.', '').strip()
            acc = [1 if check_edit_add_valid(src=gt, tgt=r, group=group_a) else 0 for r in results]
        elif task == 'delete':
            try:
                group_a = json.loads(self.data[idx]['meta'])['removed_group'].replace('.', '')
            except:
                group_a = self.data[idx]['query'].split("Functional Group to delete: ")[-1].replace('.', '').strip()
            acc = [1 if check_edit_del_valid(src=gt, tgt=r, group=group_a) else 0 for r in results]
        elif task == 'sub':
            try:
                group_a = json.loads(self.data[idx]['meta'])['added_group'].replace('.', '')
                group_b = json.loads(self.data[idx]['meta'])['removed_group'].replace('.', '')
            except:
                group_a = self.data[idx]['query'].split("Functional Group to add: ")[-1].replace('.', '').strip()
                group_b = self.data[idx]['query'].split("Functional Group to delete: ")[-1].split(",")[0].replace('.', '').strip()
            acc = [1 if check_edit_sub_valid(src=gt, tgt=r, add_group=group_a, remove_group=group_b) else 0 for r in results]
        else:
            raise NotImplementedError(f"Unknown task: {task}")
        
        self.data[idx]['eval_result'] = {"outputs": acc, "acc": sum(acc)/len(acc), "preds": results}
        self.data[idx]['ground_truth'] = gt

        return sum(acc)/len(acc)
    
    def reaction_eval(self, idx, preds=None):
        if 'gt' in self.data[idx]['meta']:
            gt = json.loads(self.data[idx]['meta'])['gt']
        elif self.data[idx]['subtask'] == 'fs':
            gt = [json.loads(self.data[idx]['gt'])['Major Product'], json.loads(self.data[idx]['gt'])['Byproduct(s)']]
        elif self.data[idx]['subtask'] in ["mechsel", "rcr", "nepp", "retro"]:
            gt = self.data[idx]['gt']
        
        if preds is None:
            preds = [
                AnswerExtract(p, self.dataset_name) for p in self.data[idx]['outputs']
            ] # extracted predictions
        else:
            preds = [AnswerExtract(preds, self.dataset_name)]

        results = []
        for p in preds:
            if not isinstance(p, dict):
                if self.data[idx]['subtask'] == 'fs' and isinstance(gt, list):
                    major = p.split("Major Product")[-1].split(",")[0].strip().replace('\"', '').replace(":", '')
                    byproduct = p.split("Byproduct(s)")[-1].split("}")[0].strip().replace('\"', '').replace(":", '').replace(",", '')
                    results.append([major, byproduct])
                else:   results.append(str(p))
                continue
            if self.data[idx]['subtask'] == 'fs' and isinstance(gt, list):
                if 'Major Product' in p:    major = p['Major Product']
                else:   major = list(p.values())[0] if len(p) > 0 else ""
                if 'Byproduct(s)' in p:    byproduct = p['Byproduct(s)']
                else:   byproduct = list(p.values())[0] if len(p) > 0 else ""
                results.append([str(major), str(byproduct)])
                continue
            if len(p) < 1:
                results.append("")
            elif 'result' in p:
                results.append(p['result'])
            else:
                results.append(list(p.values())[0])
        
        if self.data[idx]['subtask'] == 'fs':
            if not isinstance(gt, list):
                # train数据集中仅有其中之1
                acc = [smiles_similarity(src_mol_list=[gt], tgt_mol_list=[r]) for r in results]
                self.data[idx]['eval_result'] = {"outputs": {"major": acc}, "acc": {"major": sum(acc)/len(acc)}, "preds": results}
                self.data[idx]['ground_truth'] = gt
                return sum(acc)/len(acc)

            major_acc = [smiles_similarity(src_mol_list=[gt[0]], tgt_mol_list=[r[0]]) for r in results]
            if not isinstance(gt[1], list) and len(gt[1]) > 0:
                byproduct_acc = [
                    np.mean(
                        [smiles_similarity(src_mol_list=[g]*len(r[1].split(".")), tgt_mol_list=r[1].split(".")) for g in gt[1].split(".")]
                    ) for r in results]
                self.data[idx]['eval_result'] = {"outputs": {"major": major_acc, "byproduct": byproduct_acc}, "acc": {"major": sum(major_acc)/len(major_acc), "byproduct": sum(byproduct_acc)/len(byproduct_acc)}, "preds": results}
            else:
                self.data[idx]['eval_result'] = {"outputs": {"major": major_acc}, "acc": {"major": sum(major_acc)/len(major_acc)}, "preds": results}
            self.data[idx]['ground_truth'] = gt
            return sum(major_acc)/len(major_acc)
        
        elif self.data[idx]['subtask'] in ["rcr", "nepp", "retro"]:
            acc = [np.mean([smiles_similarity(src_mol_list=[g]*len(r.split(".")), tgt_mol_list=r.split(".")) for g in gt.split(".")]) for r in results]
        elif self.data[idx]['subtask'] == 'mechsel':
            acc = [1 if str(r).lower() == gt.lower() or gt.lower() in str(r).lower() else 0 for r in results]

        self.data[idx]['eval_result'] = {"outputs": acc, "acc": sum(acc)/len(acc), "preds": results}
        self.data[idx]['ground_truth'] = gt
        return sum(acc)/len(acc)


    def reaction(self, idx, preds=None):
        gt = json.loads(self.data[idx]['meta'])['gt']

        if preds is None:
            preds = [
                AnswerExtract(p, self.dataset_name) for p in self.data[idx]['outputs']
            ] # extracted predictions
        else:
            preds = [AnswerExtract(preds, self.dataset_name)]

        results = []
        for p in preds:
            if not isinstance(p, dict):
                results.append(str(p))
                continue
            if len(p) == 1:
                results.append(list(p.values())[0])
            elif 'result' in p:
                results.append(p['result'])
            else:
                results.append(list(p.values())[0])
    
        acc = [np.mean([smiles_similarity(src_mol_list=[g]*len(r.split(".")), tgt_mol_list=r.split(".")) for g in gt.split(".")]) for r in results]
        self.data[idx]['eval_result'] = {"outputs": acc, "acc": sum(acc)/len(acc), "preds": results}
        self.data[idx]['ground_truth'] = gt

        return sum(acc)/len(acc)
    
    def mol_opt(self, idx, preds=None):
        # 评测的属性
        prop_dict = dict(logp='logp', solubility='solubility', qed="qed",  drd='drd2', jnk='jnk3', gsk='gsk3b')

        try:
            gt = json.loads(self.data[idx]['meta'])['molecule']
        except:
            gt = self.data[idx]['query'].split("Source Molecule: ")[-1].replace(".", '').strip()
        
        if preds is None:
            preds = [
                AnswerExtract(p, self.dataset_name) for p in self.data[idx]['outputs']
            ] # extracted predictions
        else:
            preds = [AnswerExtract(preds, self.dataset_name)]

        results = []
        for p in preds:
            if not isinstance(p, dict):
                results.append(str(p))
                continue
            if len(p) == 1:
                results.append(list(p.values())[0])
            elif 'result' in p:
                results.append(p['result'])
            elif 'Final Target Molecule' in p:
                results.append(p['Final Target Molecule'])
            else:
                if len(p) >= 1:
                    results.append(list(p.values())[0])
                else:
                    results.append("")
        
        accs = []
        for r in results:
            acc = {}
            for prop in prop_dict.keys():
                if prop != self.data[idx]['subtask']:
                    continue
                prop_evaluator = mol_opt_evaluater(prop=prop_dict[prop])
                improve_scores = prop_evaluator.property_improvement(src_mol_list=[gt], tgt_mol_list=[r], total_num=1)
                scaffold_hard, scaffold_soft = prop_evaluator.scaffold_consistency(src_mol_list=[gt], tgt_mol_list=[r])

                acc[prop] = {
                    "improve": improve_scores['success_rate'],
                    "scaffold_hard": scaffold_hard,
                    "scaffold_soft": scaffold_soft,
                }
            total = {
                "improve": sum([v['improve'] for v in acc.values()])/len(acc),
                "scaffold_hard": sum([v['scaffold_hard'] for v in acc.values()])/len(acc),
                "scaffold_soft": sum([v['scaffold_soft'] for v in acc.values()])/len(acc),
                "acc": acc
            }

            accs.append(total)

        self.data[idx]['eval_result'] = {"outputs": accs, "acc": sum([v['scaffold_soft'] for v in accs])/len(accs), "preds": results}
        self.data[idx]['ground_truth'] = gt

        return sum([v['scaffold_soft'] for v in accs])/len(accs)


if __name__ == "__main__":
    
    dataname = "BioProBench"

    if dataname == "BioProBench":
        # load dataset
        f = open("/data/khfeng/project/scithink/outputs/qwen3-8b/BioProBench.jsonl")
        data = [json.loads(line) for line in f.readlines()]

        filter = BioProBenchEval()
        data = filter.classify(data) # 获得难度分类

        difficulty = [d['difficulty'] for d in data]
        difficulty = {k: difficulty.count(k) for k in set(difficulty)}
        print(difficulty)

        # # save data with medium and hard difficulty
        # data = [d for d in data if d['difficulty'] != 'Easy']
        
        with open("/data/khfeng/project/scithink/outputs/qwen3-8b/BioProBench.jsonl", "w") as f:
            for d in data:
                f.write(json.dumps(d, ensure_ascii=False) + '\n')
    
    elif dataname == "ChemCoTDataset":
        # load dataset
        f = open("/data/khfeng/project/scithink/outputs/qwen3-8b/ChemCoTDataset-5n.jsonl")
        data = [json.loads(line) for line in f.readlines()]

        filter = ChemCoTEval()
        data = filter.classify(data) # 获得难度分类

        difficulty = [d['difficulty'] for d in data]
        difficulty = {k: difficulty.count(k) for k in set(difficulty)}
        print(difficulty)

        # # save data with medium and hard difficulty
        # data = [d for d in data if d['difficulty'] != 'Easy']
        
        with open("/data/khfeng/project/scithink/outputs/qwen3-8b/ChemCoTDataset-classify.json", "w") as f:
            json.dump(data, f, indent=4, ensure_ascii=False)

        
        