# Copyright 2024 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# from . import gsm8k, math, prime_math, prime_code

import traceback
from . import prime_math

def _default_compute_score(data_source, solution_str, ground_truth, extra_info=None, sandbox_fusion_url=None, concurrent_semaphore=None):
    do_check = extra_info is not None and extra_info.get('do_check', False)
    try:
        if 'mathvista' in data_source.lower():
            from . import mathvista
            res = mathvista.compute_score(solution_str, ground_truth, extra_info=extra_info)
        elif 'mmmu' in data_source.lower():
            from . import mmmu
            res = mmmu.compute_score(solution_str, ground_truth, extra_info=extra_info)
        elif 'dynamath' in data_source.lower():
            from . import dynamath
            res = dynamath.compute_score(solution_str, ground_truth, extra_info=extra_info)
        elif 'mathverse' in data_source.lower():
            from . import mathverse
            res = mathverse.compute_score(solution_str, ground_truth, extra_info=extra_info)
        elif 'slide' in data_source.lower() or 'mmlong' in data_source.lower() or 'dude' in data_source.lower() or "synth" in data_source.lower():
            from . import doc 
            # print('using doc compute score')
            res = doc._default_compute_score(solution_str, ground_truth, data_source)
            for k,v in res.items():
                if k not in {'score', 'acc'}:
                    res[k] = str(v)
            res['data_source'] = data_source
        elif 'multimodal' in data_source.lower() or 'virl' in data_source.lower():
            from . import multimodal_math
            is_test = 'test' in data_source.lower() 
            res = multimodal_math.compute_score(solution_str, ground_truth, extra_info=extra_info, check_format=False if 'multimodal' in data_source.lower() else True, is_test=is_test)
        else: # llm benchmarks
            print('using prime compute score')
            res = prime_math.compute_score(solution_str, str(ground_truth), do_check=do_check)
            
        if isinstance(res, dict):
            return res
        elif isinstance(res, (int, float, bool)):
            return float(res)
        else:
            return float(res[0])
    except Exception as e:
        print(f"[ERROR] Error in process_completion for task : {str(e)}")
        traceback.print_exc()  # 打印完整堆栈
        raise  # 重新抛出异常以便上层捕获
