# Copyright 2024 Bytedance Ltd. and/or its affiliates
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Adapted from https://github.com/EleutherAI/lm-evaluation-harness/blob/main/lm_eval/tasks/hendrycks_math/utils.py

import random
import signal
from contextlib import contextmanager
from symeval import EvaluatorMath, EvaluatorMathBatch
from typing import List, Union

from uniform_eval.tasks.base import Task


class TimeoutException(Exception):
    pass


@contextmanager
def time_limit(seconds):

    def signal_handler(signum, frame):
        raise TimeoutException("Timed out!")

    # Set the signal handler and a timeout alarm
    signal.signal(signal.SIGALRM, signal_handler)
    signal.alarm(seconds)

    try:
        yield
    finally:
        # Disable the alarm
        signal.alarm(0)


def extract_ans(solution):
    try:
        with time_limit(1):  # Set 1 second timeout
            return EvaluatorMath().extract_ans(solution)
    except (TimeoutException, ValueError) as e:
        return ''

@Task.register_reward("custom_math")
def compute_score(solution_strs: List[str], ground_truths: List[str]) -> List[float]:
    extracted_answers = [extract_ans(solution_str) for solution_str in solution_strs]
    equals = EvaluatorMathBatch(use_tqdm=False, timeout=1).batch_eq(ref_answers=ground_truths,
                                                                    pred_answers=extracted_answers)

    for solution_str, ground_truth, equal in zip(solution_strs, ground_truths, equals):
        do_print = random.randint(1, 64) == 1

        if do_print:
            print(f'=' * 30)
            print(f'-' * 5 + ' Answer ' + '-' * 5)
            print(ground_truth)
            print(f'-' * 5 + ' Solution string ' + '-' * 5)
            print(solution_str)
            print(f'-' * 5 + ' Equal ' + '-' * 5)
            print(equal)
            print(f'=' * 30)

    scores = [float(e) for e in equals]
    return scores
